Skip to content

Commit

Permalink
[snake-dqn] Initial commit (#265)
Browse files Browse the repository at this point in the history
This PR checks in the following parts of the Snake game-based DQN example:
- snake_game.js: Game logic (without any graphics) 
- dqn.js: DQN network definition, along with some utility functions
- replay_memory.js: The replay buffer used for DQN training
- agent.js: The agent based on the epsilon-greedy algorithm
- train.js: Training logic

All modules are accompanied by unit tests.
  • Loading branch information
caisq committed Apr 23, 2019
1 parent 59b606d commit e6d0b76
Show file tree
Hide file tree
Showing 17 changed files with 8,309 additions and 0 deletions.
18 changes: 18 additions & 0 deletions snake-dqn/.babelrc
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
{
"presets": [
[
"env",
{
"esmodules": false,
"targets": {
"browsers": [
"> 3%"
]
}
}
]
],
"plugins": [
"transform-runtime"
]
}
61 changes: 61 additions & 0 deletions snake-dqn/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Using Deep Q-Learning to Solve the Snake Game

Deep Q-Learning is a reinforcement-learning (RL) algorithm. It is used
frequently to solve arcade-style games like the Snake game used in this
example.

## The Snake game

The Snake game is a grid-world action game in which the player controls
a virtual snake that keeps moving on the game board (9x9 by default).
At each step, there are four possible actions: left, right, up, and down.
To achieve higher scores (rewards), the player should guide the snake
to the fruits on the screen and "eat" them, while avoiding
- its head going off the board, and
- its head bumping into its own body.

This example consists of two parts:
1. Training the Deep Q-Network (DQN) in Node.js
2. Live demo in the browser

## Training the Deep Q-Network in Node.js

To train the DQN, use command:

```sh
yarn
yarn train
```

If you have a CUDA-enabled GPU installed on your system, along with all
the required drivers and libraries, append the `--gpu` flag to the command
above to let use the GPU for training, which will lead to a significant
increase in the training speed:

```sh
yarn train --gpu
```

To monitor the training progress using TensorBoard, use the `--logDir` flag
and point it to a log directory, e.g.,

```sh
yarn train --logDir /tmp/snake_logs
```

During the training, you can use TensorBoard to visualize the curves of
- Cumulative reward values from the games
- Training speed (game frames per second)
- Value of the epsilon from the epsilon-greedy algorithm
and so forth.

Specifically, open a separate terminal. In the terminal, install tensorboard and
launch the backend server of tensorboard:

```sh
pip install tensorboard
tensorboard --logdir /tmp/snake_logs
```

Once started, the tensorboard backend process will print an `http://` URL to the
console. Open your browser and navigate to the URL to see the logged curves.
156 changes: 156 additions & 0 deletions snake-dqn/agent.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import * as tf from '@tensorflow/tfjs';

import {createDeepQNetwork} from './dqn';
import {getRandomAction, SnakeGame, NUM_ACTIONS, ALL_ACTIONS, getStateTensor} from './snake_game';
import {ReplayMemory} from './replay_memory';
import { assertPositiveInteger } from './utils';

export class SnakeGameAgent {
/**
* Constructor of SnakeGameAgent.
*
* @param {SnakeGame} game A game object.
* @param {object} config The configuration object with the following keys:
* - `replayBufferSize` {number} Size of the replay memory. Must be a
* positive integer.
* - `epsilonInit` {number} Initial value of epsilon (for the epsilon-
* greedy algorithm). Must be >= 0 and <= 1.
* - `epsilonFinal` {number} The final value of epsilon. Must be >= 0 and
* <= 1.
* - `epsilonDecayFrames` {number} The # of frames over which the value of
* `epsilon` decreases from `episloInit` to `epsilonFinal`, via a linear
* schedule.
*/
constructor(game, config) {
assertPositiveInteger(config.epsilonDecayFrames);

this.game = game;

this.epsilonInit = config.epsilonInit;
this.epsilonFinal = config.epsilonFinal;
this.epsilonDecayFrames = config.epsilonDecayFrames;
this.epsilonIncrement_ = (this.epsilonFinal - this.epsilonInit) /
this.epsilonDecayFrames;

this.onlineNetwork =
createDeepQNetwork(game.height, game.width, NUM_ACTIONS);
this.targetNetwork =
createDeepQNetwork(game.height, game.width, NUM_ACTIONS);
// Freeze taget network: it's weights are updated only through copying from
// the online network.
this.targetNetwork.trainable = false;

this.optimizer = tf.train.adam(config.learningRate);

this.replayBufferSize = config.replayBufferSize;
this.replayMemory = new ReplayMemory(config.replayBufferSize);
this.frameCount = 0;
this.reset();
}

reset() {
this.cumulativeReward_ = 0;
this.game.reset();
}

/**
* Play one step of the game.
*
* @returns {number | null} If this step leads to the end of the game,
* the total reward from the game as a plain number. Else, `null`.
*/
playStep() {
this.epsilon = this.frameCount >= this.epsilonDecayFrames ?
this.epsilonFinal :
this.epsilonInit + this.epsilonIncrement_ * this.frameCount;
this.frameCount++;

// The epsilon-greedy algorithm.
let action;
const state = this.game.getState();
if (Math.random() < this.epsilon) {
// Pick an action at random.
action = getRandomAction();
} else {
// Greedily pick an action based on online DQN output.
tf.tidy(() => {
const stateTensor =
getStateTensor(state, this.game.height, this.game.width)
action = ALL_ACTIONS[
this.onlineNetwork.predict(stateTensor).argMax(-1).dataSync()[0]];
});
}

const {state: nextState, reward, done} = this.game.step(action);

this.replayMemory.append([state, action, reward, done, nextState]);

this.cumulativeReward_ += reward;
const output = {
action,
cumulativeReward: this.cumulativeReward_,
done
};
if (done) {
this.reset();
}
return output;
}

/**
* Perform training on a randomly sampled batch from the replay buffer.
*
* @param {number} batchSize Batch size.
* @param {numebr} gamma Reward discount rate. Must be >= 0 and <= 1.
* @param {tf.train.Optimizer} optimizer The optimizer object used to update
* the weights of the online network.
*/
trainOnReplayBatch(batchSize, gamma, optimizer) {
// Get a batch of examples from the replay buffer.
const batch = this.replayMemory.sample(batchSize);
const lossFunction = () => tf.tidy(() => {
const stateTensor = getStateTensor(
batch.map(example => example[0]), this.game.height, this.game.width);
const actionTensor = tf.tensor1d(
batch.map(example => example[1]), 'int32');
const qs = this.onlineNetwork.predict(
stateTensor).mul(tf.oneHot(actionTensor, NUM_ACTIONS)).sum(-1);

const rewardTensor = tf.tensor1d(batch.map(example => example[2]));
const nextStateTensor = getStateTensor(
batch.map(example => example[4]), this.game.height, this.game.width);
const nextMaxQTensor =
this.targetNetwork.predict(nextStateTensor).max(-1);
const doneMask = tf.scalar(1).sub(
tf.tensor1d(batch.map(example => example[3])).asType('float32'));
const targetQs =
rewardTensor.add(nextMaxQTensor.mul(doneMask).mul(gamma));
return tf.losses.meanSquaredError(targetQs, qs);
});

// TODO(cais): Remove the second argument when `variableGrads()` obeys the
// trainable flag.
const grads =
tf.variableGrads(lossFunction, this.onlineNetwork.getWeights());
optimizer.applyGradients(grads.grads);
tf.dispose(grads);
// TODO(cais): Return the loss value here?
}
}
122 changes: 122 additions & 0 deletions snake-dqn/agent_test.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import * as tf from '@tensorflow/tfjs-node';

import {SnakeGameAgent} from "./agent";
import {SnakeGame} from "./snake_game";

describe('SnakeGameAgent', () => {
it('playStep', () => {
const game = new SnakeGame({
height: 9,
width: 9,
numFruits: 1,
initLen: 2
});
const agent = new SnakeGameAgent(game, {
replayBufferSize: 100,
epsilonInit: 1,
epsilonFinal: 0.1,
epsilonDecayFrames: 10
});

const numGames = 40;
let bufferIndex = 0;
for (let n = 0; n < numGames; ++n) {
// At the beginnig of a game, the cumulative reward ought to be 0.
expect(agent.cumulativeReward_).toEqual(0);
let out = null;
let outPrev = null;
for (let m = 0; m < 10; ++m) {
const currentState = agent.game.getState();
out = agent.playStep();
// Check the content of the replay buffer.
expect(agent.replayMemory.buffer[bufferIndex % 100][0])
.toEqual(currentState);
expect(agent.replayMemory.buffer[bufferIndex % 100][1])
.toEqual(out.action);

expect(agent.replayMemory.buffer[bufferIndex % 100][2]).toEqual(
outPrev == null ? out.cumulativeReward :
out.cumulativeReward - outPrev.cumulativeReward);
expect(agent.replayMemory.buffer[bufferIndex % 100][3]).toEqual(out.done);
expect(agent.replayMemory.buffer[bufferIndex % 100][4])
.toEqual(out.done ? undefined : agent.game.getState());
bufferIndex++;
if (out.done) {
break;
}
outPrev = out;
}
agent.reset();
}
});

it('trainOnReplayBatch', () => {
const game = new SnakeGame({
height: 9,
width: 9,
numFruits: 1,
initLen: 2
});
const replayBufferSize = 1000;
const agent = new SnakeGameAgent(game, {
replayBufferSize,
epsilonInit: 1,
epsilonFinal: 0.1,
epsilonDecayFrames: 1000,
learningRate: 1e-2
});

const oldOnlineWeights =
agent.onlineNetwork.getWeights().map(x => x.dataSync());
const oldTargetWeights =
agent.targetNetwork.getWeights().map(x => x.dataSync());

for (let i = 0; i < replayBufferSize; ++i) {
agent.playStep();
}
// Burn-in run for memory leak check below.
const batchSize = 512;
const gamma = 0.99;
const optimizer = tf.train.adam();
agent.trainOnReplayBatch(batchSize, gamma, optimizer);

const numTensors0 = tf.memory().numTensors;
agent.trainOnReplayBatch(batchSize, gamma, optimizer);
expect(tf.memory().numTensors).toEqual(numTensors0);

const newOnlineWeights =
agent.onlineNetwork.getWeights().map(x => x.dataSync());
const newTargetWeights =
agent.targetNetwork.getWeights().map(x => x.dataSync());

// Verify that the online network's weights are updated.
for (let i = 0; i < oldOnlineWeights.length; ++i) {
expect(tf.tensor1d(newOnlineWeights[i])
.sub(tf.tensor1d(oldOnlineWeights[i]))
.abs().max().arraySync()).toBeGreaterThan(0);
}
// Verify that the target network's weights have not changed.
for (let i = 0; i < oldOnlineWeights.length; ++i) {
expect(tf.tensor1d(newTargetWeights[i])
.sub(tf.tensor1d(oldTargetWeights[i]))
.abs().max().arraySync()).toEqual(0);
}
});
});
Loading

0 comments on commit e6d0b76

Please sign in to comment.