<a href="https://colab.research.google.com/github/simux0072/Machine-Learning/blob/master/Snake_DQN(Single_no_improv).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [156]:
import numpy as np
import random
import math

import torch
from torch import optim
import torch.nn as nn
import torch.nn.functional as F

from itertools import count
from collections import namedtuple

import IPython

# Snake game

In [157]:
DIMENSIONS = (16, 12)
SIZE = 30

DIRECTIONS = {
    0: (0, -1), # UP
    1: (1, 0),  # RIGHT
    2: (0, 1),  # DOWN
    3: (-1, 0)  # LEFT 
}

class player():
    def __init__(self):
        self.size = SIZE
        self.color = {
            'Outer': (0, 0, 255),
            'Inner': (0, 100, 255)
        }
        self.points = 0
        self.head = {
            'Name': 'Head',
            'Coordinates': [round(DIMENSIONS[0]/2), round(DIMENSIONS[1]/2)],
            'Direction': 3
        }
        self.snake = [self.head]
        self.iter = 0
    
    def _move(self, move):
        for i in range(0, len(self.snake)):
            if i == len(self.snake) - 1:
                self.snake[i]['Direction'] += move - 1      # Change direction
                if self.snake[i]['Direction'] >= 4:     # Check if need to change the direction number
                    self.snake[i]['Direction'] = 0
                elif self.snake[i]['Direction'] < 0:
                    self.snake[i]['Direction'] = 3
            else:
                self.snake[i]['Direction'] = self.snake[i + 1]['Direction']
                
            self.iter += 1  # Add 1 to iter counter
            self.snake[i]['Coordinates'][0] += DIRECTIONS[self.snake[i]['Direction']][0] # Change the coordinates of the snake
            self.snake[i]['Coordinates'][1] += DIRECTIONS[self.snake[i]['Direction']][1]

    def _collision(self):
        if self.snake[-1]['Coordinates'][0] < 0 or self.snake[-1]['Coordinates'][0] > DIMENSIONS[0] - 1:
            return True, -1
        elif self.snake[-1]['Coordinates'][1] < 0 or self.snake[-1]['Coordinates'][1] > DIMENSIONS[1] - 1:
            return True, -1
        
        for i in self.snake[::-1]:
            if i != self.snake[-1] and i['Coordinates'] == self.snake[-1]['Coordinates']:
                return True, -1
        return False, 0

    def _ate_food(self, food_):
        if self.snake[-1]['Coordinates'] == food_.coordinates:
            temp = {
                'Name': 'Body',
                'Coordinates': [self.snake[0]['Coordinates'][0] - DIRECTIONS[self.snake[0]['Direction']][0], 
                            self.snake[0]['Coordinates'][1] - DIRECTIONS[self.snake[0]['Direction']][1]],
                'Direction': self.snake[0]['Direction']
            }
            self.snake.insert(0, temp) # Insert new body
            self.iter = 0
            self.points += 1
            food_._generate()
            return True, 1
        return False, 0 
            

    def _too_many_moves(self):
        if self.iter > len(self.snake) * 50: # Check if there were too many moves
            return True, -50
        return False, 0

    def _play(self, food_, move):
        self._move(move)
        # self._draw(food_)
        game_end, reward = self._too_many_moves()
        if game_end:
            return reward, self.points, game_end
        
        ate_food, reward = self._ate_food(food_)
        if ate_food:
            return reward, self.points, False
        else:
            game_end, reward = self._collision()
            if game_end:
                return reward, self.points, game_end
        
        return 0, self.points, False

class food():
    def __init__(self, player_):
        self.coordinates = []
        self.player = player_
        self._generate()
        self.color = {
            'Red': (200, 0, 0)
        }

    def _generate(self):
        while True:
            _is_same = False
            self.coordinates = [random.randrange(0, DIMENSIONS[0] - 1),
                            random.randrange(0, DIMENSIONS[1] - 1)]
            for i in self.player.snake:
                if i['Coordinates'] == self.coordinates:
                    _is_same = True
            
            if not _is_same:
                break

class environment():
    def __init__(self, player, food):
        self.player_ = player
        self.food_ = food
        self.state_head = np.zeros((3, DIMENSIONS[1], DIMENSIONS[0]), dtype=np.float32)
        self.state_body = np.zeros((3, DIMENSIONS[1], DIMENSIONS[0]), dtype=np.float32)
        self.state_food = np.zeros((2, DIMENSIONS[1], DIMENSIONS[0]), dtype=np.float32)

    def _get_current_state(self):
        temp_state_head = np.zeros((DIMENSIONS[1], DIMENSIONS[0]))
        temp_state_body = np.zeros((DIMENSIONS[1], DIMENSIONS[0]))
        temp_state_food = np.zeros((DIMENSIONS[1], DIMENSIONS[0]))
        for i in self.player_.snake:
            if i['Name'] == 'Body':
                temp_state_body[i['Coordinates'][1]][i['Coordinates'][0]] = 1
            elif i['Name'] == 'Head':
                temp_state_head[i['Coordinates'][1]][i['Coordinates'][0]] = 1
        temp_state_food[self.food_.coordinates[1]][self.food_.coordinates[0]] = 1

        self.state_body = np.insert(np.delete(self.state_body, -1, axis=0), 0, temp_state_body, axis=0)
        self.state_head = np.insert(np.delete(self.state_head, -1, axis=0), 0, temp_state_head, axis=0)
        self.state_food = np.insert(np.delete(self.state_food, -1, axis=0), 0, temp_state_food, axis=0)
        return np.concatenate((self.state_head, self.state_body, self.state_food))

# Agent class

In [158]:
class Agent():
    def __init__(self, strategy, num_actions, device, target_net, policy_net, lr, gamma, checkpoint):
        self.current_step = 0
        self.strategy = strategy
        self.num_actions = num_actions
        self.device = device
        self.target_net = target_net
        self.policy_net = policy_net
        self.lr = lr
        self.gamma = gamma
        self.optimizer = optim.Adam(self.policy_net.parameters(), self.lr)
        if checkpoint is not None:
          self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
          self.current_step = checkpoint['current_step']

    def select_action(self, state, policy_net):
        rate = self.strategy.get_exploration_rate(self.current_step)
        self.current_step += 1

        if rate > random.random():
            action = random.randrange(self.num_actions)
            return torch.tensor([action]).to(self.device) #Explore
        else:
            with torch.no_grad():
                return policy_net(state).argmax(dim = 1).to(self.device) #Exploit 
    
    def train_memory(self, states, actions, rewards, next_states, mask):
        current_q_values = self.policy_net(states).gather(1, actions.unsqueeze(1)).squeeze()

        next_q_values = self.target_net(next_states).max(1)[0]
        discounted_q_values = rewards.squeeze(1) + next_q_values * self.gamma * mask.type(torch.float32)

        loss = torch.square(discounted_q_values - current_q_values).mean()
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        return loss.item()

    def net_update(self):
        self.target_net.load_state_dict(self.policy_net.state_dict())

class EpsilonGreedyStrat():
    def __init__(self, start, end, decay):
        self.start = start
        self.end = end
        self.decay = decay

    def get_exploration_rate(self, current_step):
        return self.end + (self.start - self.end) * math.exp(-1. * current_step * self.decay)

# DQN Class

In [159]:
class DQN(nn.Module):
    def __init__(self, num_layer):
        super().__init__()
        self.conv1_1 = nn.Conv2d(in_channels=num_layer, out_channels=num_layer, kernel_size=3, groups=4, padding=1)
        self.conv1_2 = nn.Conv2d(in_channels=num_layer, out_channels=num_layer, kernel_size=3, groups=4, padding=1)
        self.conv2_1 = nn.Conv2d(in_channels=num_layer, out_channels=num_layer, kernel_size=3, groups=4, padding=1)
        self.conv2_2 = nn.Conv2d(in_channels=num_layer, out_channels=num_layer, kernel_size=3, groups=4, padding=1)
        
        self.norm1_1 = nn.BatchNorm2d(num_features=num_layer)
        self.norm1_2 = nn.BatchNorm2d(num_features=num_layer)
        self.norm2_1 = nn.BatchNorm2d(num_features=num_layer)
        self.norm2_2 = nn.BatchNorm2d(num_features=num_layer)

        self.conv_res1 = nn.Conv2d(in_channels=num_layer, out_channels=num_layer, kernel_size=1, groups=4)
        self.conv_res2 = nn.Conv2d(in_channels=num_layer, out_channels=num_layer, kernel_size=1, groups=4)

        self.globpool = nn.AvgPool2d(kernel_size=(6, 10), stride=2, padding=1)

        self.fc1 = nn.Linear(in_features=8*5*5, out_features=112)
        self.fc2 = nn.Linear(in_features=112, out_features=56)
        self.fc3 = nn.Linear(in_features=56, out_features=28)
        self.fc4 = nn.Linear(in_features=28, out_features=14)
        self.out = nn.Linear(in_features=14, out_features=3)

    def forward(self, t):
        Y = t
        t = F.relu(self.norm1_1(self.conv1_1(t)))
        t = F.relu(self.norm1_2(self.conv1_2(t)) + self.conv_res1(Y))

        Y = t
        t = F.relu(self.norm2_1(self.conv2_1(t)))
        t = F.relu(self.norm2_2(self.conv2_2(t)) + self.conv_res2(Y))
        
        t = self.globpool(t)

        t = t.flatten(start_dim=1)
        t = F.relu(self.fc1(t))
        t = F.relu(self.fc2(t))
        t = F.relu(self.fc3(t))
        t = F.relu(self.fc4(t))
        t = F.relu(self.out(t))
        return t


# Environmnet manager

In [160]:
class EnvManager():
    def __init__(self, device):
        self.device = device
        self.done = False
        self.player = player()
        self.food = food(self.player)
        self.env = environment(self.player, self.food)

    def reset(self):
        self.player = player()
        self.food = food(self.player)
        self.env = environment(self.player, self.food)
        self.done = False
    
    def take_action(self, action):
        reward, points, self.done = self.player._play(self.food, action.item())
        return torch.tensor([reward], device=self.device), points, torch.tensor([self.done], device=self.device)

    def get_state(self):
        if self.done:
            return torch.zeros((1, 8, DIMENSIONS[1], DIMENSIONS[0]), dtype=torch.float32).to(self.device)
        else:
            state = torch.from_numpy(self.env._get_current_state()).unsqueeze(dim=0).to(self.device)
        return state


# Replay Memory

In [161]:
  class ReplayMemory():
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
    
    def push(self, experience):
        if len(self.memory) < self.capacity:
            self.memory.append(experience)
        else:
            self.memory.append(experience)
            del self.memory[0]
    
    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)
    
    def can_provide_sample(self, batch_size):
        return len(self.memory) >= batch_size * 3

# Data saving and information output to screen


In [162]:
class bar_update():
  def __init__(self):
    self.bar_lenght = 50
    self.bar_prog = 0
    self.bar = list('[' + ' ' * (self.bar_lenght) + ']')
    self.out = display(IPython.display.Pretty(''), display_id=True)

  def new_bar(self):
    self.bar = list('[' + ' ' * (self.bar_lenght) + ']')
    self.bar_prog = 0
    self.out = display(IPython.display.Pretty(''), display_id=True)

  def print_info(self, num_ep, loss, score, high_schore, update, num_update):
    base = int(num_ep / (update / self.bar_lenght)) 
    if base != self.bar_prog:
      self.bar_prog = base
      if base == 1:
        self.bar[1] = '>'
      else:
        self.bar[base - 1] = '='
        self.bar[base] = '>'
    self.out.update(IPython.display.Pretty(''.join(self.bar) + ' Episode: ' + str(num_update) + ' Game: ' + str(num_ep) + ' loss: ' + str(loss) + ' score: ' + str(score) + ' High Score: ' + str(High_score)))

In [163]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [164]:
from os.path import exists

class model_drive():
  def __init__(self, path, model_name):
    self.path = path
    self.model_name = model_name

  def upload(self, model, optimizer, update, current_step):
    torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'update': update,
            'current_step': current_step
            }, self.path + self.model_name)

  def does_exist(self):
    return exists(self.path + self.model_name)

  def download(self):
    return torch.load(self.path + self.model_name)

# Main Class

In [None]:
mini_batch = 256
gamma = 0.99
eps_start = 1
eps_end = 0.01
eps_decay = 0.00001
target_update = 10
memory_size = 1000000
lr = 0.00001

Experience = namedtuple(
  'Experience',
  ('state', 'action', 'reward', 'next_state', 'end_state')
)

def extract_tensors(experience):
  batch = Experience(*zip(*experience))

  t1 = torch.cat(batch.state)
  t2 = torch.cat(batch.action)
  t3 = torch.cat(batch.reward).unsqueeze(dim=-1)
  t4 = torch.cat(batch.next_state)
  t5 = torch.cat(batch.end_state)
  return (t1, t2, t3, t4, t5)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

path = '/content/drive/MyDrive/models/'
model_name = 'DQN_Snake.mdl'

bar = bar_update()
drive = model_drive(path, model_name)

model_exists = drive.does_exist()

policy_net = DQN(8).to(device)

if model_exists:
  checkpoint = drive.download()
  update = checkpoint['update']
  policy_net.load_state_dict(checkpoint['model_state_dict'])
else:
  checkpoint = None
  update = 1

target_net = DQN(8).to(device)
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()

em = EnvManager(device)
strategy = EpsilonGreedyStrat(eps_start, eps_end, eps_decay)
agent = Agent(strategy, 3, device, target_net, policy_net, lr, gamma, checkpoint)
memory = ReplayMemory(memory_size)


points_all = 0
loss_all = 0
score_all = 0
High_score = 0
episode = 0


while True:
  episode += 1
  iter = 0
  loss = 0
  points = 0
  state = em.get_state()

  for timestep in count():

    action = agent.select_action(state, policy_net)
    reward, points, end_state = em.take_action(action)
    next_state = em.get_state()
    memory.push(Experience(state, action, reward, next_state, end_state))
    state = next_state

    if memory.can_provide_sample(mini_batch):
        experience = memory.sample(mini_batch)
        states, actions, rewards, next_states, mask = extract_tensors(experience)

        loss += agent.train_memory(states, actions, rewards, next_states, mask)
        iter += 1

    if em.done:
        em.reset()
        if loss != 0:
          loss_all += loss / iter
        score_all += points
        if points > High_score:
            High_score = points
        bar.print_info(episode, loss_all / episode, score_all / episode, High_score, target_update, update)
        break

  if episode % target_update == 0:
      agent.net_update()
      update += 1

      drive.upload(agent.policy_net, agent.optimizer, update, agent.current_step)

      bar.new_bar()

      episode = 0
      loss_all = 0
      score_all = 0
      High_score = 0