-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
This PR checks in the following parts of the Snake game-based DQN example: - snake_game.js: Game logic (without any graphics) - dqn.js: DQN network definition, along with some utility functions - replay_memory.js: The replay buffer used for DQN training - agent.js: The agent based on the epsilon-greedy algorithm - train.js: Training logic All modules are accompanied by unit tests.
- Loading branch information
Showing
17 changed files
with
8,309 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
{ | ||
"presets": [ | ||
[ | ||
"env", | ||
{ | ||
"esmodules": false, | ||
"targets": { | ||
"browsers": [ | ||
"> 3%" | ||
] | ||
} | ||
} | ||
] | ||
], | ||
"plugins": [ | ||
"transform-runtime" | ||
] | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
# Using Deep Q-Learning to Solve the Snake Game | ||
|
||
Deep Q-Learning is a reinforcement-learning (RL) algorithm. It is used | ||
frequently to solve arcade-style games like the Snake game used in this | ||
example. | ||
|
||
## The Snake game | ||
|
||
The Snake game is a grid-world action game in which the player controls | ||
a virtual snake that keeps moving on the game board (9x9 by default). | ||
At each step, there are four possible actions: left, right, up, and down. | ||
To achieve higher scores (rewards), the player should guide the snake | ||
to the fruits on the screen and "eat" them, while avoiding | ||
- its head going off the board, and | ||
- its head bumping into its own body. | ||
|
||
This example consists of two parts: | ||
1. Training the Deep Q-Network (DQN) in Node.js | ||
2. Live demo in the browser | ||
|
||
## Training the Deep Q-Network in Node.js | ||
|
||
To train the DQN, use command: | ||
|
||
```sh | ||
yarn | ||
yarn train | ||
``` | ||
|
||
If you have a CUDA-enabled GPU installed on your system, along with all | ||
the required drivers and libraries, append the `--gpu` flag to the command | ||
above to let use the GPU for training, which will lead to a significant | ||
increase in the training speed: | ||
|
||
```sh | ||
yarn train --gpu | ||
``` | ||
|
||
To monitor the training progress using TensorBoard, use the `--logDir` flag | ||
and point it to a log directory, e.g., | ||
|
||
```sh | ||
yarn train --logDir /tmp/snake_logs | ||
``` | ||
|
||
During the training, you can use TensorBoard to visualize the curves of | ||
- Cumulative reward values from the games | ||
- Training speed (game frames per second) | ||
- Value of the epsilon from the epsilon-greedy algorithm | ||
and so forth. | ||
|
||
Specifically, open a separate terminal. In the terminal, install tensorboard and | ||
launch the backend server of tensorboard: | ||
|
||
```sh | ||
pip install tensorboard | ||
tensorboard --logdir /tmp/snake_logs | ||
``` | ||
|
||
Once started, the tensorboard backend process will print an `http://` URL to the | ||
console. Open your browser and navigate to the URL to see the logged curves. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
/** | ||
* @license | ||
* Copyright 2019 Google LLC. All Rights Reserved. | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
* ============================================================================= | ||
*/ | ||
|
||
import * as tf from '@tensorflow/tfjs'; | ||
|
||
import {createDeepQNetwork} from './dqn'; | ||
import {getRandomAction, SnakeGame, NUM_ACTIONS, ALL_ACTIONS, getStateTensor} from './snake_game'; | ||
import {ReplayMemory} from './replay_memory'; | ||
import { assertPositiveInteger } from './utils'; | ||
|
||
export class SnakeGameAgent { | ||
/** | ||
* Constructor of SnakeGameAgent. | ||
* | ||
* @param {SnakeGame} game A game object. | ||
* @param {object} config The configuration object with the following keys: | ||
* - `replayBufferSize` {number} Size of the replay memory. Must be a | ||
* positive integer. | ||
* - `epsilonInit` {number} Initial value of epsilon (for the epsilon- | ||
* greedy algorithm). Must be >= 0 and <= 1. | ||
* - `epsilonFinal` {number} The final value of epsilon. Must be >= 0 and | ||
* <= 1. | ||
* - `epsilonDecayFrames` {number} The # of frames over which the value of | ||
* `epsilon` decreases from `episloInit` to `epsilonFinal`, via a linear | ||
* schedule. | ||
*/ | ||
constructor(game, config) { | ||
assertPositiveInteger(config.epsilonDecayFrames); | ||
|
||
this.game = game; | ||
|
||
this.epsilonInit = config.epsilonInit; | ||
this.epsilonFinal = config.epsilonFinal; | ||
this.epsilonDecayFrames = config.epsilonDecayFrames; | ||
this.epsilonIncrement_ = (this.epsilonFinal - this.epsilonInit) / | ||
this.epsilonDecayFrames; | ||
|
||
this.onlineNetwork = | ||
createDeepQNetwork(game.height, game.width, NUM_ACTIONS); | ||
this.targetNetwork = | ||
createDeepQNetwork(game.height, game.width, NUM_ACTIONS); | ||
// Freeze taget network: it's weights are updated only through copying from | ||
// the online network. | ||
this.targetNetwork.trainable = false; | ||
|
||
this.optimizer = tf.train.adam(config.learningRate); | ||
|
||
this.replayBufferSize = config.replayBufferSize; | ||
this.replayMemory = new ReplayMemory(config.replayBufferSize); | ||
this.frameCount = 0; | ||
this.reset(); | ||
} | ||
|
||
reset() { | ||
this.cumulativeReward_ = 0; | ||
this.game.reset(); | ||
} | ||
|
||
/** | ||
* Play one step of the game. | ||
* | ||
* @returns {number | null} If this step leads to the end of the game, | ||
* the total reward from the game as a plain number. Else, `null`. | ||
*/ | ||
playStep() { | ||
this.epsilon = this.frameCount >= this.epsilonDecayFrames ? | ||
this.epsilonFinal : | ||
this.epsilonInit + this.epsilonIncrement_ * this.frameCount; | ||
this.frameCount++; | ||
|
||
// The epsilon-greedy algorithm. | ||
let action; | ||
const state = this.game.getState(); | ||
if (Math.random() < this.epsilon) { | ||
// Pick an action at random. | ||
action = getRandomAction(); | ||
} else { | ||
// Greedily pick an action based on online DQN output. | ||
tf.tidy(() => { | ||
const stateTensor = | ||
getStateTensor(state, this.game.height, this.game.width) | ||
action = ALL_ACTIONS[ | ||
this.onlineNetwork.predict(stateTensor).argMax(-1).dataSync()[0]]; | ||
}); | ||
} | ||
|
||
const {state: nextState, reward, done} = this.game.step(action); | ||
|
||
this.replayMemory.append([state, action, reward, done, nextState]); | ||
|
||
this.cumulativeReward_ += reward; | ||
const output = { | ||
action, | ||
cumulativeReward: this.cumulativeReward_, | ||
done | ||
}; | ||
if (done) { | ||
this.reset(); | ||
} | ||
return output; | ||
} | ||
|
||
/** | ||
* Perform training on a randomly sampled batch from the replay buffer. | ||
* | ||
* @param {number} batchSize Batch size. | ||
* @param {numebr} gamma Reward discount rate. Must be >= 0 and <= 1. | ||
* @param {tf.train.Optimizer} optimizer The optimizer object used to update | ||
* the weights of the online network. | ||
*/ | ||
trainOnReplayBatch(batchSize, gamma, optimizer) { | ||
// Get a batch of examples from the replay buffer. | ||
const batch = this.replayMemory.sample(batchSize); | ||
const lossFunction = () => tf.tidy(() => { | ||
const stateTensor = getStateTensor( | ||
batch.map(example => example[0]), this.game.height, this.game.width); | ||
const actionTensor = tf.tensor1d( | ||
batch.map(example => example[1]), 'int32'); | ||
const qs = this.onlineNetwork.predict( | ||
stateTensor).mul(tf.oneHot(actionTensor, NUM_ACTIONS)).sum(-1); | ||
|
||
const rewardTensor = tf.tensor1d(batch.map(example => example[2])); | ||
const nextStateTensor = getStateTensor( | ||
batch.map(example => example[4]), this.game.height, this.game.width); | ||
const nextMaxQTensor = | ||
this.targetNetwork.predict(nextStateTensor).max(-1); | ||
const doneMask = tf.scalar(1).sub( | ||
tf.tensor1d(batch.map(example => example[3])).asType('float32')); | ||
const targetQs = | ||
rewardTensor.add(nextMaxQTensor.mul(doneMask).mul(gamma)); | ||
return tf.losses.meanSquaredError(targetQs, qs); | ||
}); | ||
|
||
// TODO(cais): Remove the second argument when `variableGrads()` obeys the | ||
// trainable flag. | ||
const grads = | ||
tf.variableGrads(lossFunction, this.onlineNetwork.getWeights()); | ||
optimizer.applyGradients(grads.grads); | ||
tf.dispose(grads); | ||
// TODO(cais): Return the loss value here? | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
/** | ||
* @license | ||
* Copyright 2019 Google LLC. All Rights Reserved. | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
* ============================================================================= | ||
*/ | ||
|
||
import * as tf from '@tensorflow/tfjs-node'; | ||
|
||
import {SnakeGameAgent} from "./agent"; | ||
import {SnakeGame} from "./snake_game"; | ||
|
||
describe('SnakeGameAgent', () => { | ||
it('playStep', () => { | ||
const game = new SnakeGame({ | ||
height: 9, | ||
width: 9, | ||
numFruits: 1, | ||
initLen: 2 | ||
}); | ||
const agent = new SnakeGameAgent(game, { | ||
replayBufferSize: 100, | ||
epsilonInit: 1, | ||
epsilonFinal: 0.1, | ||
epsilonDecayFrames: 10 | ||
}); | ||
|
||
const numGames = 40; | ||
let bufferIndex = 0; | ||
for (let n = 0; n < numGames; ++n) { | ||
// At the beginnig of a game, the cumulative reward ought to be 0. | ||
expect(agent.cumulativeReward_).toEqual(0); | ||
let out = null; | ||
let outPrev = null; | ||
for (let m = 0; m < 10; ++m) { | ||
const currentState = agent.game.getState(); | ||
out = agent.playStep(); | ||
// Check the content of the replay buffer. | ||
expect(agent.replayMemory.buffer[bufferIndex % 100][0]) | ||
.toEqual(currentState); | ||
expect(agent.replayMemory.buffer[bufferIndex % 100][1]) | ||
.toEqual(out.action); | ||
|
||
expect(agent.replayMemory.buffer[bufferIndex % 100][2]).toEqual( | ||
outPrev == null ? out.cumulativeReward : | ||
out.cumulativeReward - outPrev.cumulativeReward); | ||
expect(agent.replayMemory.buffer[bufferIndex % 100][3]).toEqual(out.done); | ||
expect(agent.replayMemory.buffer[bufferIndex % 100][4]) | ||
.toEqual(out.done ? undefined : agent.game.getState()); | ||
bufferIndex++; | ||
if (out.done) { | ||
break; | ||
} | ||
outPrev = out; | ||
} | ||
agent.reset(); | ||
} | ||
}); | ||
|
||
it('trainOnReplayBatch', () => { | ||
const game = new SnakeGame({ | ||
height: 9, | ||
width: 9, | ||
numFruits: 1, | ||
initLen: 2 | ||
}); | ||
const replayBufferSize = 1000; | ||
const agent = new SnakeGameAgent(game, { | ||
replayBufferSize, | ||
epsilonInit: 1, | ||
epsilonFinal: 0.1, | ||
epsilonDecayFrames: 1000, | ||
learningRate: 1e-2 | ||
}); | ||
|
||
const oldOnlineWeights = | ||
agent.onlineNetwork.getWeights().map(x => x.dataSync()); | ||
const oldTargetWeights = | ||
agent.targetNetwork.getWeights().map(x => x.dataSync()); | ||
|
||
for (let i = 0; i < replayBufferSize; ++i) { | ||
agent.playStep(); | ||
} | ||
// Burn-in run for memory leak check below. | ||
const batchSize = 512; | ||
const gamma = 0.99; | ||
const optimizer = tf.train.adam(); | ||
agent.trainOnReplayBatch(batchSize, gamma, optimizer); | ||
|
||
const numTensors0 = tf.memory().numTensors; | ||
agent.trainOnReplayBatch(batchSize, gamma, optimizer); | ||
expect(tf.memory().numTensors).toEqual(numTensors0); | ||
|
||
const newOnlineWeights = | ||
agent.onlineNetwork.getWeights().map(x => x.dataSync()); | ||
const newTargetWeights = | ||
agent.targetNetwork.getWeights().map(x => x.dataSync()); | ||
|
||
// Verify that the online network's weights are updated. | ||
for (let i = 0; i < oldOnlineWeights.length; ++i) { | ||
expect(tf.tensor1d(newOnlineWeights[i]) | ||
.sub(tf.tensor1d(oldOnlineWeights[i])) | ||
.abs().max().arraySync()).toBeGreaterThan(0); | ||
} | ||
// Verify that the target network's weights have not changed. | ||
for (let i = 0; i < oldOnlineWeights.length; ++i) { | ||
expect(tf.tensor1d(newTargetWeights[i]) | ||
.sub(tf.tensor1d(oldTargetWeights[i])) | ||
.abs().max().arraySync()).toEqual(0); | ||
} | ||
}); | ||
}); |
Oops, something went wrong.