# Training an AI to play snake game 

Train a neural network to play snake game using reinforcement learning with Deep Q Network and PyTorch

## Steps to Implement a DQN
1. Import dependencies
2. Define hyperparameters
3. Load Environment
4. Define Q-Network and the Target Q-Network
5. Define experience replay buffer
6. Define loss function
7. Define agent learn function
8. Define training loop
9. Test

### Import dependencies

In [1]:
import time
import os
import random
import numpy as np
import torch
from torch import nn
from collections import deque, namedtuple
from environment import Environment

### Define hyperparameters

In [2]:
MEMORY_SIZE = 100000
MINI_BATCH_SIZE = 1000
NUM_STEPS_FOR_UPDATE = 4
GAMMA = 0.995
ALPHA = 1e-3
TAU = 0.01
EPS_DECAY = 0.9995
EPS_MIN = 0.01
NUM_EPISODES = 10000
MAX_TIMESTEPS = 1000

### Load Environment

In [3]:
env = Environment()
state_size = env.state_size
number_of_actions = env.number_of_actions

### Define Q-Network and the Target Q-Network

In [4]:
device = "cpu"

if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"

class QNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(state_size, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, number_of_actions),
        )
        
    def forward(self, x):
        y = self.layers(x)
        return y

class TargetQNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(state_size, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, number_of_actions),
        )
        
    def forward(self, x):
        y = self.layers(x)
        return y

In [5]:
q_network = QNetwork().to(device)
target_q_network = TargetQNetwork().to(device)

optimizer = torch.optim.Adam(q_network.parameters(), lr=ALPHA)

print(q_network)
print(target_q_network)
print(optimizer)

QNetwork(
  (layers): Sequential(
    (0): Linear(in_features=11, out_features=256, bias=True)
    (1): ReLU()
    (2): Linear(in_features=256, out_features=128, bias=True)
    (3): ReLU()
    (4): Linear(in_features=128, out_features=3, bias=True)
  )
)
TargetQNetwork(
  (layers): Sequential(
    (0): Linear(in_features=11, out_features=256, bias=True)
    (1): ReLU()
    (2): Linear(in_features=256, out_features=128, bias=True)
    (3): ReLU()
    (4): Linear(in_features=128, out_features=3, bias=True)
  )
)
Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 0.001
    maximize: False
    weight_decay: 0
)


### Define experience replay buffer

In [6]:
Experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"])

### Define Loss Function

In [7]:
# experiences is the mini batch of [states, actions, rewards, next_state, done]
def loss_fn(experiences):
    states, actions, rewards, next_states, dones = experiences

    actions = torch.unsqueeze(actions, 1)
    
    max_qsa, _ = torch.max(target_q_network(next_states), dim=1)
    
    y = rewards + (1 - dones) * (GAMMA) * max_qsa
    
    pred = torch.gather(q_network(states), 1, actions)
    pred = torch.squeeze(pred)
    
    MSE = nn.MSELoss()

    return MSE(y, pred)

### Define Learning Function

In [8]:
# experiences is the mini batch of [states, actions, rewards, next_state, done]
def agent_learn(experiences):
    # Gradient Descent

    loss = loss_fn(experiences)

    # Backprop
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    # Soft Update
    for target_param, local_param in zip(target_q_network.parameters(), q_network.parameters()):
        target_param.data.copy_(TAU*local_param.data + (1.0-TAU)*target_param.data)

### Define training loop

In [9]:
# eps-greedy policy
def get_action(state, epsilon):
    prob = random.random()
    
    if (prob <= epsilon):
        return random.randint(0, 2)
    else:
        mx = 0
        with torch.no_grad():
            mx = np.argmax(q_network(state).numpy()[0])
        return mx

# sample mini batch
def get_experiences(memory_buffer):
    experiences = random.sample(memory_buffer, k=MINI_BATCH_SIZE)
    states = torch.from_numpy(np.array([e.state for e in experiences])).float()
    actions = torch.from_numpy(np.array([e.action for e in experiences])).long()
    rewards = torch.from_numpy(np.array([e.reward for e in experiences])).float()
    next_states = torch.from_numpy(np.array([e.next_state for e in experiences])).float()
    done_vals = torch.from_numpy(np.array([e.done for e in experiences])).int()
    
    return (states, actions, rewards, next_states, done_vals)

def train():
    start = time.time()
    total_point_history = []
    num_p_av = 100
    epsilon = 1.0
    memory_buffer = deque(maxlen=MEMORY_SIZE)
    target_q_network.load_state_dict(q_network.state_dict())

    for i in range(NUM_EPISODES):
        state = env.reset()
        total_points = 0

        for t in range(MAX_TIMESTEPS):
            state_tensor = torch.tensor([state])
            state_tensor = state_tensor.float()
            action = get_action(state_tensor, epsilon)

            next_state, reward, done = env.step(action)
            memory_buffer.append(Experience(state, action, reward, next_state, done))

            state = next_state
            total_points += reward

            if (len(memory_buffer) >= MINI_BATCH_SIZE and t % NUM_STEPS_FOR_UPDATE == 0):
                experiences = get_experiences(memory_buffer)
                agent_learn(experiences)

            if done:
                break

        total_point_history.append(total_points)

        av_latest_points = np.mean(total_point_history[-num_p_av:])
    
        # Update the ε value
        epsilon = max(EPS_MIN, epsilon * EPS_DECAY)
        
        print(f"\rEpisode {i+1} | Total point average of the last {num_p_av} episodes: {av_latest_points:.2f}", end="")
    
        if (i+1) % num_p_av == 0:
            print(f"\rEpisode {i+1} | Total point average of the last {num_p_av} episodes: {av_latest_points:.2f}")
    
    tot_time = time.time() - start
    
    print(f"\nTotal Runtime: {tot_time:.2f} s ({(tot_time/60):.2f} min)")            

In [10]:
train()

Episode 100 | Total point average of the last 100 episodes: -8.90
Episode 200 | Total point average of the last 100 episodes: -8.95
Episode 300 | Total point average of the last 100 episodes: -8.20
Episode 400 | Total point average of the last 100 episodes: -7.25
Episode 500 | Total point average of the last 100 episodes: -7.35
Episode 600 | Total point average of the last 100 episodes: -6.50
Episode 700 | Total point average of the last 100 episodes: -5.95
Episode 800 | Total point average of the last 100 episodes: -4.35
Episode 900 | Total point average of the last 100 episodes: -5.25
Episode 1000 | Total point average of the last 100 episodes: -4.15
Episode 1100 | Total point average of the last 100 episodes: -4.30
Episode 1200 | Total point average of the last 100 episodes: -3.30
Episode 1300 | Total point average of the last 100 episodes: -2.30
Episode 1400 | Total point average of the last 100 episodes: -2.00
Episode 1500 | Total point average of the last 100 episodes: -2.00
Epis

### Dumb Snake In Action

In [18]:
def test():
    from IPython.display import clear_output
    state = env.reset()
    total_points = 0
    
    env.render()
    time.sleep(0.5)
    while (True):
        state_tensor = torch.tensor([state])
        state_tensor = state_tensor.float()
        action = 0
        with torch.no_grad():
            action = np.argmax(q_network(state_tensor).numpy()[0])
    
        next_state, reward, done = env.step(action)
        state = next_state
        total_points += reward

        clear_output(wait=True)
        env.render()
        time.sleep(0.5)
        
        if done:
            return

In [22]:
test()

....****
..***..*
..*....*
..*....*
..*....*
.@*...**
..*..***
..****o*
