# Training an AI to play snake with deep reinforcement learning

This notebook documents our experience learning how to develop an AI that plays Snake with deep reinforcement learning.  

We used Tensorflow as our deep learning toolbox.  For those new to Tensorflow, we recommend first completing their [tutorial](https://www.tensorflow.org/tutorials/mnist/beginners/).

# The snake game

### Description

For those unfamiliar with the snake game,
```
$ python /snake/play_ascii_snake.py 10 10
```
will load up an example of the game for a 10 x 10 board.  Use the W-A-S-D keys to move North-West-South-East (in that order).

![example](ascii-snake.gif)

In short, you play as a **snake** (represented by `O`) and move along the board to eat **pellets** (represented by `*`).  

**The snake always moves in the direction it faces (at a constant rate) until a command is given forcing it to turn.**  For example, the game begins with the snake moving North.  Then the `A` or `D` commands would cause the snake to move West or East, while the `W` command would be ignored since the snake is already moving North.  **The snake cannot make 180-degree turns,** meaning the `S` command to move South would also be ignored.

**The snake consumes the pellet when it collides with it.  When this happens, the snake grows in length by 1 and the Score is increased by 1.**  Specifically, when the snake grows, an additional `O` character replaces the pellet; as the snake moves away from that coordinate, the `O` is appended to the end of the snake body.

**The game ends when the snake collides with a wall** (represented by `|` and `-`) **or with itself.**

### Motivation

We chose the snake game for several reasons.  It's simple to code up a working version, and most people are familiar with how the game works.  

The game also allowed us to explore a couple interesting questions:

- Will the AI have difficulty learning as a result of the snake beginning as a single point and growing as the game progresses?

- How well will the AI learn if we don't tell it at the start of the game which point is the snake and which point is the pellet?

- The game can be described as having only two actions (Turn Left and Turn Right) among four buttons.  Can the AI learn that two of the available actions at any point in the game are ignored?

Of course, there were also standard questions we wanted to answer:

- How much is the AI impacted by the size of the game board?

- How does performance differ between different network structures?

# Q-learning and deep Q-learning

For those new to Q-learning (or deep Q-learning), we highly recommend the [blog post](https://www.nervanasys.com/demystifying-deep-reinforcement-learning/) by Tambet Matiisen.

In short, Q-learning involves learning the entries in the Q matrix, where:
- each row corresponds to a state
- each column corresponds to a possible action
- the $(s,a)^{th}$ entry contains the expected long-term reward for taking action a given state s, assuming we will continue to take the "optimal" (in the long-term expected reward sense) actions as the game progresses.  In short:

$$Q[s,a] = r + \gamma \max_{a'} Q[s',a'] $$

where $r$ is the immediate reward gained from making the transition from state $s$ to $s'$.  $\gamma$ is

If such a matrix were given to us, we'd be able to play a game by looking up each new state we enter and selecting the action with the maximum Q-value.  Such a strategy should (on average) maximize the total reward we accumulate as we play the game.



Unfortunately, even a simple 10 x 10 Snake board, where each cell has 3 possible values among [snake, pellet, empty], has $3^{100}$ possible states. (The true number is smaller than that, since there are certain states that are impossible to reach; for example, snake cells must be contiguous.  But it's still a lot!)

We'll never explore all the states in our lifetime, meaning we won't be able to update the Q matrix properly.





# Deep convolutional neural networks



# Defining network structure with Tensorflow

Let's look at the network structure described in the blog post:

- The network consists of 3 convolutional layers followed by 2 fully connected layers.
- The first 4 layers use ReLu activations while the final output layer uses a linear activation.

### Input

The input is 4 frames of an 84 x 84 pixel image, which can be represented by the tensor:
```
q_input = tf.placeholder(dtype=tf.float32, shape=[None, 84, 84, 4])
```

### Convolutional kernel

The first convolutional layer uses 32 (8 x 8) kernels.  We initialize each kernel weight with Gaussian(mu=0, sigma=0.1) of corresponding dimension:

```
w_conv1 = util.init_conv_filter(filter_height=8, filter_width=8,
                                in_channels=int(q_input.get_shape()[-1]), out_channels=32)
```

Note that by setting the number of `in_channels` equal to the number of `out_channels` from the previous layer (which in this case is 4 frames), we ensure that each kernel output is a combination of values in 8 x 8 x 4 blocks.

### Hidden conv layer 1

With strides of 4, each kernel performs 84 / 4 - 1 = 20 mappings, so result is a 20 x 20 x 32 tensor.  Finally, the ReLu activation applies an element-wise non-linear mapping which results in the hidden layer:

```
h_conv1 = tf.nn.relu(util.conv2d(input=q_input, filter=w_conv1, stride=4))
```

### Hidden conv layers 2 and 3

The next two convolutional layers use 64 (4 x 4) and 64 (3 x 3) kernels with strides of 2 and 1, respectively:

```
w_conv2 = util.init_conv_filter(filter_height=4, filter_width=4,
                                in_channels=int(h_conv1.get_shape()[-1]), out_channels=64)
h_conv2 = tf.nn.relu(util.conv2d(input=h_conv1, filter=w_conv2, stride=2))

w_conv3 = util.init_conv_filter(filter_height=3, filter_width=3,
                                in_channels=int(h_conv2.get_shape()[-1]), out_channels=64)
h_conv3 = tf.nn.relu(util.conv2d(input=h_conv2, filter=w_conv3, stride=1))
```

This produces 20 / 2 - 1 = 9 and 9 / 1 - 1 = 8

### Hidden fully connected layer 1

Once we get to the first fully connected layer, we simply treat the last convolutional layer output as the input to a standard regression problem.  To do this, we flatten out the final 

```
h_conv3_flat = util.flatten_4d_to_2d(h_conv3)
len_h_conv3 = int(h_conv3_flat.get_shape()[-1])

w_fc1 = util.init_fc_weights(height=len_h_conv3, width=len_h_conv3)
b_fc1 = util.init_fc_bias(length=len_h_conv3)
h_fc1 = tf.nn.relu(tf.matmul(h_conv3_flat, w_fc1) + b_fc1)
```