In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patheffects as pe
from IPython.display import clear_output
import time 
from matplotlib import colors

import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import copy
from collections import deque

%matplotlib inline

In [None]:
# setting device on GPU if available, else CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print()

#Additional Info when using cuda
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_cached(0)/1024**3,1), 'GB')

In [None]:
def render_env(grid, snake, food, sleep_time : float = 0.1) :
    """
    Render taken from one of the labs

    """
    # Turn interactive mode on.
    plt.ion()
    fig = plt.figure(num = "env_render")
    ax = plt.gca()
    ax.clear()
    clear_output(wait = True)

    
    env_plot = np.copy(grid)
    for i in snake:
        env_plot[tuple(i)] = 2
    env_plot[tuple(food)] = 3

    # Plot the gridworld.
    cmap = colors.ListedColormap(["grey", "black", "white", "green"])
    bounds = list(range(5))
    norm = colors.BoundaryNorm(bounds, cmap.N)
    ax.imshow(env_plot, cmap = cmap, norm = norm, zorder = 0)

    
    # Set up axes.
    ax.grid(which = 'major', axis = 'both', linestyle = '-', color = 'k', linewidth = 2, zorder = 1)
    ax.set_xticks(np.arange(-0.5, grid.shape[1] , 1));
    ax.set_xticklabels([])
    ax.set_yticks(np.arange(-0.5, grid.shape[0], 1));
    ax.set_yticklabels([])

    plt.show()

    # Sleep if desired.
    if (sleep_time > 0) :
        time.sleep(sleep_time)

In [None]:
class SnakeEnv():
    
    def __init__(self):
        # game board, for rendering
        self.grid_height = 20 
        self.grid_width = 20 
        self.grid = np.zeros((self.grid_height, self.grid_width), dtype = int)
        
        self.grid[1:-1,1:-1] = 1
        
        # add random extra walls
        # self.grid[4,6] = 0
        
        self.walls = np.argwhere(self.grid == 0).tolist()
        # place snake randomly
        self.snake = [[random.randint(1, self.grid_height-2), random.randint(1, self.grid_width-2)]]
        
        self.snake_orientation = np.random.randint(0,4,1)[0]
        
        # place food somewhere the snake isnt
        food = None
        while food is None:
            nf = [random.randint(1, self.grid_height-2),random.randint(1, self.grid_width-2)]
            food = nf if nf not in self.snake else None
        self.food = food
        
        
    def reset(self):
        # place snake randomly
        self.snake = [[random.randint(1, self.grid_height-2), random.randint(1, self.grid_width-2)]]
        
        # place food somewhere the snake isnt
        food = None
        while food is None:
            nf = [random.randint(1, self.grid_height-2),random.randint(1, self.grid_width-2)]
            food = nf if nf not in self.snake else None
        self.food = food
        
        return self.get_state()
    
    def sample(self):
        return np.random.randint(0,3,1)[0]
        
        
    def step(self, action):
        actions = ['left','foward','right']
        action = actions[action]
        new_head = [self.snake[0][0], self.snake[0][1]]
        
        
        def step_ori(orientation, new_head):
            if orientation == 2:
                new_head[0] += 1
            if orientation == 0:
                new_head[0] -= 1
            if orientation == 3:
                new_head[1] -= 1
            if orientation == 1:
                new_head[1] += 1
                
            return new_head
        
        if action == 'left':
            self.snake_orientation = (self.snake_orientation - 1) % 4
        
        if action == 'right':
            self.snake_orientation = (self.snake_orientation + 1) % 4
        
        new_head = step_ori(self.snake_orientation, new_head)

        self.snake.insert(0, new_head)
        
        if self.snake[0] in self.walls or self.snake[0] in self.snake[1:]:
            # reward, terminal
            return (-1., True) 
        
        elif self.snake[0] == self.food:
            food = None
            while food is None:
                nf = [
                    random.randint(1, self.grid_height-2),
                    random.randint(1, self.grid_width-2)
                ]
                food = nf if nf not in self.snake else None
            self.food = food
            return (100., False) 
        else:
            tail = self.snake.pop()
            return (-0.01, False) 
    
        
    def get_state(self):
        
        
        head = np.array([self.snake[0][0], self.snake[0][1]])
        tail = np.array([self.snake[-1][0], self.snake[-1][1]])
        
        food = np.array(self.food[:])
        walls = np.array(self.walls[:])
        state = []
        
        
        
        temp = [np.array([-1,0]), np.array([-1,1]), np.array([0,1]), np.array([1,1]), \
                np.array([1,0]), np.array([1,-1]), np.array([0,-1]), np.array([-1,-1]) ]
        
        temp = temp+temp
        
        directions = temp[2*self.snake_orientation : 2*self.snake_orientation + 8]
        
        
        
        for direction in directions:
            pos = head + direction
            # walls
            shortest_wall = np.inf
            for wall in self.walls:
                if np.linalg.norm(pos-wall) < shortest_wall:
                    shortest_wall = np.linalg.norm(pos-wall)
            state.append(shortest_wall)

            
            # food
            
            state.append(np.linalg.norm(pos-food))
            
            # tail
            
            state.append(np.linalg.norm(pos-tail))
            
        
        return np.array(state)
                    
        
    
    def render(self):
        render_env(self.grid, self.snake, self.food)

In [None]:
class ReplayMemory():
    def __init__(self, max_length):
        self.memory = deque(maxlen = max_length)

    def append_mem(self, transition):
        # transition (state,next_state,action,reward,terminal)
        self.memory.append(transition)
        
    def sample_minibatch(self, minibatch_length):
        states = []
        next_states = []
        actions = []
        rewards = []
        terminals = []
        for i in range(minibatch_length):
            random_int = np.random.randint(0, len(self.memory)-1) 
            transition = self.memory[random_int]
            states.append(transition[0])
            next_states.append(transition[1])
            actions.append(transition[2])
            rewards.append(transition[3])
            terminals.append(transition[4])
        return torch.Tensor(states).cuda(), torch.Tensor(next_states).cuda(), torch.Tensor(actions).cuda(), torch.Tensor(rewards), torch.Tensor(terminals).cuda()


In [None]:
class Net11(nn.Module):
    def __init__(self):
        super(Net11, self).__init__()
        
        self.fc = nn.Sequential(
                        nn.Linear(in_features = 24, out_features = 18, bias = True),
                        nn.ReLU(),
                        nn.Linear(in_features = 18, out_features = 3, bias = True),
                        nn.Softmax(dim=1)
                        )
    def forward(self, x):
        x = self.fc(x)
        
        return x

In [None]:
class Net13(nn.Module):
    def __init__(self):
        super(Net13, self).__init__()
        
        self.fc = nn.Sequential(
                        nn.Linear(in_features = 24, out_features = 18, bias = True),
                        nn.ReLU(),
                        nn.Linear(in_features = 18, out_features = 18, bias = True),
                        nn.ReLU(),
                        nn.Linear(in_features = 18, out_features =3, bias = True),
                        nn.Softmax(dim=1)
                        )
    def forward(self, x):
        x = self.fc(x)
        
        return x

In [None]:
def init_weights(m):
    if type(m) == nn.Linear:
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)

In [None]:
class DQN:

    def __init__(self, network = Net13().apply(init_weights).cuda()):

        self.name = 'net13_{}'.format(time.time())
        # hyperparameters
        
        # network hyperparameters
        self.n_learning_rate = 0.001

        # q-learning hyperparameters
        self.gamma = 0.99

        # memory
        self.memory_capacity = 50000
        self.replay_memory = ReplayMemory(self.memory_capacity)

        # networks
        # optional update target net every N episodes
        self.target_network_update = 20
        # initialise action-value net
        self.network = network
        
        # initialise target action-value net with weights of original network
        self.target_network = copy.deepcopy(self.network).cuda()

        self.network_optimiser = torch.optim.Adam( self.network.parameters(), lr=self.n_learning_rate)
        self.MSELoss_function = nn.MSELoss().cuda()
        
        self.tau = 0.99 #0.99
        self.epsilon = 0.1 #0.1
    
    def epsilon_greedy_action(self, state, epsilon):
        if np.random.uniform(0, 1) < epsilon:
            # choose random action
            return random.choice([0,1,2])
        else:
            
            # network_output = self.network(state).clone().detach().cpu().data.numpy()
            network_output = self.network(state).cpu().data.numpy()
            return np.argmax(network_output)
    
    def update_QN(self, state, next_state, action, reward, terminals):
        
        qsa = torch.gather(self.network(state).cuda(), dim=1, index=action.long().cuda())
        qsa_next_action = self.target_network(next_state)
        qsa_next_action,_ = torch.max(qsa_next_action, dim=1, keepdim=True)
        
        not_terminals = 1 - terminals
        
        qsa_next_target = reward.cpu() + not_terminals.cpu() * self.gamma * qsa_next_action.cpu()
        
        q_network_loss = self.MSELoss_function(qsa, qsa_next_target.detach().cuda())
        
        self.network_optimiser.zero_grad()
        q_network_loss.backward()
        self.network_optimiser.step()
    
    def soft_target_update(self, network, target_network, tau):
        for net_params, target_net_params in zip(network.parameters(), target_network.parameters()):
            target_net_params.data.copy_(net_params.data * tau + target_net_params.data * (1 - tau))
            
    def update(self, update_rate):
        for i in range(update_rate):
            states, next_states, actions, rewards, terminals = self.replay_memory.sample_minibatch(128)
            self.update_QN(states, next_states, actions, rewards, terminals)
            self.soft_target_update(self.network, self.target_network, self.tau)

In [None]:
def train_agent(number_of_episodes, agent = DQN(), max_time_steps = 500, env= SnakeEnv()):
    food_sum_list = []
    for episode in range(number_of_episodes):
        reward_sum = 0
        food_sum = 0
        time_step = 0
        state = env.reset()

        while time_step < max_time_steps:
            
            nn_input = torch.tensor(state).type('torch.FloatTensor').view(1,-1).cuda()
            
            # for pretrained change epsilon to 0.05
            action = agent.epsilon_greedy_action( nn_input, 0.05) 
            
            reward, terminal = env.step(action)
            next_state = env.get_state()

            reward_sum += reward
            agent.replay_memory.append_mem( (state,next_state,[action],[reward],[terminal]) )

            state = next_state

            time_step += 1
            if reward > 5:
                food_sum += 1
                time_step = 0

            if terminal:
                clear_output(wait=True)
                print('episode:', episode, 'sum_of_rewards_for_episode:', reward_sum, 'food collected in episode:', food_sum)
                break



        print('Updating Target Network')
        agent.update(40)
        food_sum_list.append(food_sum)
        
    return agent, food_sum_list

In [None]:
def show_agent(agent, show = True):
    env = SnakeEnv()
    if show:
        env.render()
    
    
    food_sum_list = []
    for i in range(100):
        c_reward = 0
        food_sum = 0
        time_step = 0
        state = env.reset()

        while time_step < 500:
            state = env.get_state()
            nn_input = torch.tensor(state).type('torch.FloatTensor').view(1,-1).cuda()
            
            action = agent.epsilon_greedy_action( nn_input, 0)
            reward, terminal = env.step(action)
            
            if reward > 5:
                food_sum += 1
                time_step = 0
            
            if show:
                env.render()

            c_reward += reward
            if terminal:
                break
        
        food_sum_list.append(food_sum)
        
    return food_sum_list

In [None]:
def save_agent(agent, network_name):
    torch.save(agent.network.state_dict(),  network_name+'.pth')

In [None]:
new_snake0_05, food_list = train_agent(10000)
save_agent(new_snake0_05, 'best_10000')
with open("best.txt", "w") as output:
    output.write(str(food_list))

In [None]:
plt.plot((np.cumsum(food_list)/np.arange(1,len(food_list)+1)))
plt.xlabel('Episodes')
plt.ylabel('Cumulative Mean Food Collected')
plt.savefig('goodone_epsilon05')

In [None]:
total_food_sum = show_agent(new_snake0_05, False)

print('Mean', np.mean(total_food_sum))
print('Min', min(total_food_sum))
print('Max', max(total_food_sum))

In [None]:
# random agent
env = SnakeEnv()
#env.render()
food_list = []
for i in range(100):
    c_reward = 0
    food_sum = 0
    time_step = 0
    env.reset()
    
    while time_step < 500:
        
        reward, terminal = env.step(random.choice([0,1,2]))
        #env.render()
        
        c_reward += reward
        if terminal:
            break
            
        if reward > 5:
            food_sum += 1
            time_step = 0
    
    food_list.append(food_sum)
    
print('Mean', np.mean(food_list))
print('Min', min(food_list))
print('Max', max(food_list))