<a href="https://colab.research.google.com/github/rrl7012005/Reinforcement-Learning/blob/main/AI_plays_Snake_DQN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Deep Q Learning

Recall reinforcement learning just focuses on maximizing the CUMULATIVE reward

There are 3 actions, turn left, turn right, go straight. So the action vector is 3 dimensional.

The reward function we choose will be +10 if you collect the pellets, -10 if you lose.

The state will be described as a 11 dimensional vector. 4 to describe which direction we are going (left, right, up, down), 3 to describe which direction the danger is (can be multiple dangers), and 4 to describe which direction the food is.

The Q value measures the quality of action. Q(s, a) is the expected cumulative future reward of taking an action a in state s and following the optimal policy.

Q-Learning:

-Initialize the Q value  
-Choose an action either via our Q function or randomly (exploration vs exploitation)  
-Perform the action  
-Measure the reward  
-Update the Q value  

Repeat iteratively

The Q value is updated according to the Bellman equation. The Q value is updated as follows

Q(s, a)_new = Q(s, a) + alpha * [R(s, a) + gamma * max(Q'(s', a')) - Q(s, a)]

-alpha is the learning rate, controlling how much new information overrides the old information    
-R(s, a) is the reward for taking that action in that state  
-gamma is the discount factor (how much future rewards should be accounted for when compared to immediate rewards)  
-Q(s', a')  is the expected future reward given the new state and all possible actions at that new state. This is maximized over all actions.

The next action is chosen as the one that maximizes the Q function i.e. the expected cumulative future reward.

For regular Q learning, a Q table is used and it would typically be initialized with zeros. The Q table has a value for every state-action pair. Neural networks approximate the Q value in DQN.

The Q value is unknown before training, it represents the expected cumulative future reward. For each episode, single run until the agent finishes or fails, we take actions, observe states etc. Everytime we reach a state and do an action, we update the Q value for that state-action. If it was a good action then (good reward) or if in the future the reward was good, then the Q values will have been propagated (because each Q value is updated taking into account future rewards as well), then the Q values for that state-action pair will be high. After many iterations, the Q-value at each time step will contain information long into the future about which actions and states were good, in other words, a policy to follow to maximize cumulative reward.

Deep Q learning approximates the Q function through the network's weights and biases. There are 2 components, the Q-Network and the target Q-network. The Q-network is the main network to predict Q values and the target Q-network is a separate network with the same architecture but its weights are periodically updated to match the Q-Network. This stabilizes training by reducing oscillations.

Past experiences are stored in a replay memory (buffer). Each experience (frame essentially) is a tuple of state, action taken, reward received, next state and done indicating if the episode has finished. Each experience is stored in a replay buffer. During training, a mini batch is sampled from the buffer. The target Q value for each experience is calculated using the Bellman equation, because this method is the way the Q should be updated. This is also termed Q_new. Now the loss is computed between the target Q-value and the predicted Q-value from the Q-network. This loss is used to update the weights of the Q-network, so the Q-network learns bellman's equation and approximates the way Q functions should evolve. DQN is used for high dimensional state spaces. For the snake

We use a decaying epsilon (exploration probability) over time to encourage exploitation.


In [None]:
import torch
import matplotlib.pyplot

#Building the Environment

Building the environment involves building the snake game. The environment should have a loop where taking in an acton as the input and returning a reward, the current score, the state, and whether or not the game is over.

pygame is used to build games like this in Python. Enum, use this to define constants.

In [None]:
import pygame
import random
from enum import Enum
from collections import namedtuple
import numpy as np

pygame 2.6.0 (SDL 2.28.4, Python 3.10.12)
Hello from the pygame community. https://www.pygame.org/contribute.html


Set the font size to 25

In [None]:
pygame.init()
font = pygame.font.SysFont('arial', 25)

In [None]:
class Direction(Enum):
  RIGHT, LEFT, UP, DOWN = 1, 2, 3, 4

Point = namedtuple('Point', 'x, y')

WHITE = (255, 255, 255)
RED = (200, 0, 0)
BLUE1 = (0, 0, 255)
BLUE2 = (0, 100, 255)
BLACK = (0, 0, 0)

BLOCK_SIZE = 20
SPEED = 40

Now create the actual snake game. We will define the display with
pygame.display.set_mode with caption 'Snake'. Start the time, start at the center, start going right first, initialize variables.



In [None]:
class SnakeGameAI():
  def __init__(self, w = 640, h = 480):
    self.w = w
    self.h = h

    self.display = pygame.display.set_mode((self.w, self.h))
    pygame.display.set_caption('Snake')

    #Initialize the time
    self.clock = pygame.time.Clock()
    self.reset()

  def reset(self):

    #Initialize the direction moving
    self.direction = Direction.RIGHT

    #Initialize a head at the center
    self.head = Point(self.w/2, self.h/2)

    #Initialize the snake with 3 blocks
    self.snake = [self.head, Point(self.head.x - BLOCK_SIZE, self.head.y),
                  Point(self.head.x - 2 * BLOCK_SIZE, self.head.y)]

    self.score = 0
    self.food = None
    self._place_food()

    self.frame_iteration = 0

  #randomly place a food at a point, if inside the snake, free point (grid)
  def _place_food(self):
    x = random.randint(0, (self.w - BLOCK_SIZE) // BLOCK_SIZE) * BLOCK_SIZE
    y = random.randint(0, (self.h - BLOCK_SIZE) // BLOCK_SIZE) * BLOCK_SIZE
    self.food = Point(x, y)

    if self.food in self.snake:
      self._place_food()


  def play_step(self, action):

    self.frame_iteration += 1
    for event in pygame.event.get():

      if event.type == pygame.QUIT:
        pygame.quit()
        quit()

    #Move the head
    self._move(action)

    #Grow the snake
    self.snake.insert(0, self.head)

    #Check if game is over
    reward = 0
    game_over = False
    if self._is_collision() or self.frame_iteration > 100 * len(self.snake):
      game_over = True
      reward = -10
      return game_over, self.score

    #Place new food or dont grow

    if self.head == self.food:
      self.score += 1
      reward = 10
      self._place_food()
    else:
      self.snake.pop()

    #update the ui and clock
    self._update_ui()

    #the timestep
    self.clock.tick(SPEED)

    return reward, game_over, self.score

  def _is_collision(self, pt=None):

    if pt is None:
      pt = self.head

    if pt.x > self.w - BLOCK_SIZE or pt.x < 0 or pt.y > self.h - BLOCK_SIZE or pt.y < 0:
      return True

    if self.head in self.snake[1:]:
      return True

    return False

  def _update_ui(self):

    self.display.fill(BLACK)

    #Draw the blue snake
    for pt in self.snake:
      pygame.draw.rect(self.display, BLUE1, pygame.Rect(pt.x, pt.y, BLOCK_SIZE, BLOCK_SIZE))
      pygame.draw.rect(self.display, BLUE2, pygame.Rect(pt.x + 4, pt.y + 4, 12, 12))

    #Draw the red food
    pygame.draw.rect(self.display, RED, pygame.Rect(self.food.x, self.food.y, BLOCK_SIZE, BLOCK_SIZE))

    #Update the font
    text = font.render("Score: " + str(self.score), True, WHITE)
    self.display.blit(text, [0, 0])
    pygame.display.flip()

  def _move(self, action):

    clock_wise = [Direction.RIGHT, Direction.DOWN, Direction.LEFT, Direction.UP]
    idx = clock_wise.index(self.direction)

    if np.array_equal(action, [1, 0, 0]):
      new_dir = clock_wise[idx]
    elif np.array_equal(action, [0, 1, 0]):
      next_idx = (idx + 1) % 4
      new_dir = clock_wise[next_idx] # right turn r -> d -> l -> u
    else: # [0, 0, 1]
      next_idx = (idx - 1) % 4
      new_dir = clock_wise[next_idx]

    self.direction = new_dir

    x = self.head.x
    y = self.head.y

    if self.direction == Direction.RIGHT:
        x += BLOCK_SIZE
    elif self.direction == Direction.LEFT:
        x -= BLOCK_SIZE
    elif self.direction == Direction.DOWN:
        y += BLOCK_SIZE
    elif self.direction == Direction.UP:
        y -= BLOCK_SIZE

    self.head = Point(x, y)

Final Score 0


#Creating the agent

In the agent class we want to dictate how the agent interacts with the environment, essentially the processing and bulk of the code. We want to write what we do with the actions, how do we get the next state etc.

We need a function to get the state of the game, a function to remember, functions to train the Q-network and the target Q-network

In [None]:
from collections import deque #queue data structure


MAX_MEMORY = 100000
BATCH_SIZE = 1000
LR = 0.001

class Agent():
  def __init__(self):

    self.ngames = 0
    self.epsilon = 0 #randomness
    self.gamma = 0.9
    self.memory = deque(maxlen=MAX_MEMORY)
    self.model = Linear_QNet(11, 256, 3) #We'll create the Q nets below
    self.target_model = QTrainer(self.model, lr=LR, gamma=self.gamma))

  def get_state(self, game):
    #game is an instance of the SnakeGameAI
    head = game.snake[0]

    point_l = Point(head.x - 20, head.y)
    point_r = Point(head.x + 20, head.y)
    point_u = Point(head.x, head.y - 20)
    point_d = Point(head.x, head.y + 20)

    #one hot encode the direction of movement
    dir_l = game.direction == Direction.LEFT
    dir_r = game.direction == Direction.RIGHT
    dir_u = game.direction == Direction.UP
    dir_d = game.direction == Direction.DOWN

    state = [
        #Danger straight
        # Danger straight
            (dir_r and game.is_collision(point_r)) or
            (dir_l and game.is_collision(point_l)) or
            (dir_u and game.is_collision(point_u)) or
            (dir_d and game.is_collision(point_d)),

            # Danger right
            (dir_u and game.is_collision(point_r)) or
            (dir_d and game.is_collision(point_l)) or
            (dir_l and game.is_collision(point_u)) or
            (dir_r and game.is_collision(point_d)),

            # Danger left
            (dir_d and game.is_collision(point_r)) or
            (dir_u and game.is_collision(point_l)) or
            (dir_r and game.is_collision(point_u)) or
            (dir_l and game.is_collision(point_d)),

            # Move direction
            dir_l,
            dir_r,
            dir_u,
            dir_d,

            # Food location
            game.food.x < game.head.x,  # food left
            game.food.x > game.head.x,  # food right
            game.food.y < game.head.y,  # food up
            game.food.y > game.head.y  # food down

    ]

    return np.array(state, dtype=int)

  def remember(self, state, action, reward, next_state, done):
    #Store state, action, reward, next_state, done in a queue for training later
    self.memory.append((state, action, reward, next_state, done))

  def train_long_memory(self):
    if len(self.memory) > BATCH_SIZE:
      #Train a batch at a time on the target Q network
      mini_sample = random.sample(self.memory, BATCH_SIZE)
    else:
      mini_sample = self.memory

    #zip * separates out the entries by index of each entry
    states, actions, rewards, next_states, dones = zip(*mini_sample)
    self.trainer.train_step(states, actions, rewards, next_states, dones)

  def train_short_memory(self, state, action, reward, next_state, done):
    self.trainer.train_step(state, action, reward, next_state, done)

  def get_action(self, state):
    self.epsilon = 80 - self.ngames #slowly decay epsilon
    final_move = [0, 0, 0]
    if random.randint(0, 200) < self.epsilon:
      #explore in random direction
      move = random.randint(0, 2)
      final_move[move] = 1
    else:
      state0 = torch.tensor(state, dtype=torch.float)
      prediction = self.model(state0)
      move = torch.argmax(prediction).item()
      final_move[move] = 1

    return final_move

def train():

  plot_scores = []
  plot_mean_scores = []
  total_score = 0
  record = 0
  agent = Agent()
  game = SnakeGameAI()

  while True:

    state_old = agent.get_state(game)
    final_move = agent.get_action(state_old)

    #interact with environment
    reward, done, score = game.play_step(final_move)

    state_new = agent.get_state(game)

    #Update the Q network
    agent.train_short_memory(state_old, final_move, reward, state_new, done)

    agent.remember(state_old, final_move, reward, state_new, done)

    if done:
      #if finish an episode
      game.reset()
      agent.n_games += 1
      agent.train_long_memory()

      if score > record:
        record = score
        agent.model.save()

      print('Game', agent.ngames, 'Score', score, 'Record:', record)

      plot_scores.append(score)
      total_score += score
      mean_score = total_score / agent.n_games
      plot_mean_scores.append(mean_score)

      plot(plot_scores, plot_mean_scores)

if __name__ == '__main__':
  train()

#Training

Now lets build the networks and the code to train the Q-networks. The Q-net contains only 1 hidden layer.

In [None]:
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import os

class Linear_QNet(nn.Module):
  #target Q network
  def __init__(self, input_size, hidden_size, output_size):
    super().__init__()
    self.linear1 = nn.Linear(input_size, hidden_size)
    self.linear2 = nn.Linear(hidden_size, output_size)

  def forward(self, x):
    x = self.linear2(F.relu(self.linear1(x)))
    return x

  def save(self, filename='model.pth'):
    model_path = './model'
    if not os.path.exists(model_path):
      os.makedirs(model_path)

    filename = os.path.join(model_path, filename)

    torch.save(self.state_dict(), filename)

class QTrainer():
  def __init__(self, model, lr, gamma):
    self.lr = lr
    self.gamma = gamma
    self.model = model
    self.optimizer = optim.Adam(model.parameters(), lr=self.lr)
    self.criterion = nn.MSELoss()

  def train_step(self, state, action, reward, next_state, done):

    state = torch.tensor(state, dtype=torch.float)
    next_state = torch.tensor(next_state, dtype=torch.float)
    action = torch.tensor(action, dtype=torch.long)
    reward = torch.tensor(reward, dtype=torch.float)

    if len(state.shape) == 1:
      state = torch.unsqueeze(state, 0)
      next_state = torch.unsqueeze(next_state, 0)
      action = torch.unsqueeze(action, 0)
      reward = torch.unsqueeze(reward, 0)
      done = (done, )

    pred = self.model(state)

    target = pred.clone()

    for idx in range(len(done)):

      Q_new = reward[idx]
      if not done[idx]:
        Q_new = reward[idx] + self.gamma * torch.max(self.model(next_state[idx]))

      target[idx][torch.argmax(action).item()] = Q_new

    self.optimizer.zero_grad()
    loss = self.criterion(target, pred)
    loss.backward()

    self.optimizer.step()

We need one more function to plot the learning curve

In [None]:
import matplotlib.pyplot as plt
from IPython import display

plt.ion()

def plot(scores, mean_scores):
    display.clear_output(wait=True)
    display.display(plt.gcf())
    plt.clf()
    plt.title('Training...')
    plt.xlabel('Number of Games')
    plt.ylabel('Score')
    plt.plot(scores)
    plt.plot(mean_scores)
    plt.ylim(ymin=0)
    plt.text(len(scores)-1, scores[-1], str(scores[-1]))
    plt.text(len(mean_scores)-1, mean_scores[-1], str(mean_scores[-1]))
    plt.show(block=False)
    plt.pause(.1)

Ok great everything is done, just put it all into scripts and download

#Scripts

In [None]:
%%writefile helper.py

import matplotlib.pyplot as plt
from IPython import display

plt.ion()

def plot(scores, mean_scores):
    display.clear_output(wait=True)
    display.display(plt.gcf())
    plt.clf()
    plt.title('Training...')
    plt.xlabel('Number of Games')
    plt.ylabel('Score')
    plt.plot(scores)
    plt.plot(mean_scores)
    plt.ylim(ymin=0)
    plt.text(len(scores)-1, scores[-1], str(scores[-1]))
    plt.text(len(mean_scores)-1, mean_scores[-1], str(mean_scores[-1]))
    plt.show(block=False)
    plt.pause(.1)

Writing helper.py


In [None]:
%%writefile model.py

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import os

class Linear_QNet(nn.Module):
  #target Q network
  def __init__(self, input_size, hidden_size, output_size):
    super().__init__()
    self.linear1 = nn.Linear(input_size, hidden_size)
    self.linear2 = nn.Linear(hidden_size, output_size)

  def forward(self, x):
    x = self.linear2(F.relu(self.linear1(x)))
    return x

  def save(self, filename='model.pth'):
    model_path = './model'
    if not os.path.exists(model_path):
      os.makedirs(model_path)

    filename = os.path.join(model_path, filename)

    torch.save(self.state_dict(), filename)

class QTrainer():
  def __init__(self, model, lr, gamma):
    self.lr = lr
    self.gamma = gamma
    self.model = model
    self.optimizer = optim.Adam(model.parameters(), lr=self.lr)
    self.criterion = nn.MSELoss()

  def train_step(self, state, action, reward, next_state, done):

    state = torch.tensor(state, dtype=torch.float)
    next_state = torch.tensor(next_state, dtype=torch.float)
    action = torch.tensor(action, dtype=torch.long)
    reward = torch.tensor(reward, dtype=torch.float)

    if len(state.shape) == 1:
      state = torch.unsqueeze(state, 0)
      next_state = torch.unsqueeze(next_state, 0)
      action = torch.unsqueeze(action, 0)
      reward = torch.unsqueeze(reward, 0)
      done = (done, )

    pred = self.model(state)

    target = pred.clone()

    for idx in range(len(done)):

      Q_new = reward[idx]
      if not done[idx]:
        Q_new = reward[idx] + self.gamma * torch.max(self.model(next_state[idx]))

      target[idx][torch.argmax(action).item()] = Q_new

    self.optimizer.zero_grad()
    loss = self.criterion(target, pred)
    loss.backward()

    self.optimizer.step()

Writing model.py


In [None]:
%%writefile agent.py

#our main function

import torch
import random
import numpy as np
from collections import deque #queue data structure
from game import SnakeGameAI, Direction, Point
from model import Linear_QNet, QTrainer
from helper import plot


MAX_MEMORY = 100000
BATCH_SIZE = 1000
LR = 0.001

class Agent():
  def __init__(self):

    self.ngames = 0
    self.epsilon = 0 #randomness
    self.gamma = 0.9
    self.memory = deque(maxlen=MAX_MEMORY)
    self.model = Linear_QNet(11, 256, 3) #We'll create the Q nets below
    self.target_model = QTrainer(self.model, lr=LR, gamma=self.gamma))

  def get_state(self, game):
    #game is an instance of the SnakeGameAI
    head = game.snake[0]

    point_l = Point(head.x - 20, head.y)
    point_r = Point(head.x + 20, head.y)
    point_u = Point(head.x, head.y - 20)
    point_d = Point(head.x, head.y + 20)

    #one hot encode the direction of movement
    dir_l = game.direction == Direction.LEFT
    dir_r = game.direction == Direction.RIGHT
    dir_u = game.direction == Direction.UP
    dir_d = game.direction == Direction.DOWN

    state = [
        #Danger straight
        # Danger straight
            (dir_r and game.is_collision(point_r)) or
            (dir_l and game.is_collision(point_l)) or
            (dir_u and game.is_collision(point_u)) or
            (dir_d and game.is_collision(point_d)),

            # Danger right
            (dir_u and game.is_collision(point_r)) or
            (dir_d and game.is_collision(point_l)) or
            (dir_l and game.is_collision(point_u)) or
            (dir_r and game.is_collision(point_d)),

            # Danger left
            (dir_d and game.is_collision(point_r)) or
            (dir_u and game.is_collision(point_l)) or
            (dir_r and game.is_collision(point_u)) or
            (dir_l and game.is_collision(point_d)),

            # Move direction
            dir_l,
            dir_r,
            dir_u,
            dir_d,

            # Food location
            game.food.x < game.head.x,  # food left
            game.food.x > game.head.x,  # food right
            game.food.y < game.head.y,  # food up
            game.food.y > game.head.y  # food down

    ]

    return np.array(state, dtype=int)

  def remember(self, state, action, reward, next_state, done):
    #Store state, action, reward, next_state, done in a queue for training later
    self.memory.append((state, action, reward, next_state, done))

  def train_long_memory(self):
    if len(self.memory) > BATCH_SIZE:
      #Train a batch at a time on the target Q network
      mini_sample = random.sample(self.memory, BATCH_SIZE)
    else:
      mini_sample = self.memory

    #zip * separates out the entries by index of each entry
    states, actions, rewards, next_states, dones = zip(*mini_sample)
    self.trainer.train_step(states, actions, rewards, next_states, dones)

  def train_short_memory(self, state, action, reward, next_state, done):
    self.trainer.train_step(state, action, reward, next_state, done)

  def get_action(self, state):
    self.epsilon = 80 - self.ngames #slowly decay epsilon
    final_move = [0, 0, 0]
    if random.randint(0, 200) < self.epsilon:
      #explore in random direction
      move = random.randint(0, 2)
      final_move[move] = 1
    else:
      state0 = torch.tensor(state, dtype=torch.float)
      prediction = self.model(state0)
      move = torch.argmax(prediction).item()
      final_move[move] = 1

    return final_move

def train():

  plot_scores = []
  plot_mean_scores = []
  total_score = 0
  record = 0
  agent = Agent()
  game = SnakeGameAI()

  while True:

    state_old = agent.get_state(game)
    final_move = agent.get_action(state_old)

    #interact with environment
    reward, done, score = game.play_step(final_move)

    state_new = agent.get_state(game)

    #Update the Q network
    agent.train_short_memory(state_old, final_move, reward, state_new, done)

    agent.remember(state_old, final_move, reward, state_new, done)

    if done:
      #if finish an episode
      game.reset()
      agent.n_games += 1
      agent.train_long_memory()

      if score > record:
        record = score
        agent.model.save()

      print('Game', agent.ngames, 'Score', score, 'Record:', record)

      plot_scores.append(score)
      total_score += score
      mean_score = total_score / agent.n_games
      plot_mean_scores.append(mean_score)

      plot(plot_scores, plot_mean_scores)

if __name__ == '__main__':
  train()

Writing agent.py


In [None]:
%%writefile game.py

import pygame
import random
from enum import Enum
from collections import namedtuple
import numpy as np

pygame.init()
font = pygame.font.SysFont('arial', 25)

class Direction(Enum):
  RIGHT, LEFT, UP, DOWN = 1, 2, 3, 4

Point = namedtuple('Point', 'x, y')

WHITE = (255, 255, 255)
RED = (200, 0, 0)
BLUE1 = (0, 0, 255)
BLUE2 = (0, 100, 255)
BLACK = (0, 0, 0)

BLOCK_SIZE = 20
SPEED = 40

class SnakeGameAI():
  def __init__(self, w = 640, h = 480):
    self.w = w
    self.h = h

    self.display = pygame.display.set_mode((self.w, self.h))
    pygame.display.set_caption('Snake')

    #Initialize the time
    self.clock = pygame.time.Clock()
    self.reset()

  def reset(self):

    #Initialize the direction moving
    self.direction = Direction.RIGHT

    #Initialize a head at the center
    self.head = Point(self.w/2, self.h/2)

    #Initialize the snake with 3 blocks
    self.snake = [self.head, Point(self.head.x - BLOCK_SIZE, self.head.y),
                  Point(self.head.x - 2 * BLOCK_SIZE, self.head.y)]

    self.score = 0
    self.food = None
    self._place_food()

    self.frame_iteration = 0

  #randomly place a food at a point, if inside the snake, free point (grid)
  def _place_food(self):
    x = random.randint(0, (self.w - BLOCK_SIZE) // BLOCK_SIZE) * BLOCK_SIZE
    y = random.randint(0, (self.h - BLOCK_SIZE) // BLOCK_SIZE) * BLOCK_SIZE
    self.food = Point(x, y)

    if self.food in self.snake:
      self._place_food()


  def play_step(self, action):

    self.frame_iteration += 1
    for event in pygame.event.get():

      if event.type == pygame.QUIT:
        pygame.quit()
        quit()

    #Move the head
    self._move(action)

    #Grow the snake
    self.snake.insert(0, self.head)

    #Check if game is over
    reward = 0
    game_over = False
    if self._is_collision() or self.frame_iteration > 100 * len(self.snake):
      game_over = True
      reward = -10
      return game_over, self.score

    #Place new food or dont grow

    if self.head == self.food:
      self.score += 1
      reward = 10
      self._place_food()
    else:
      self.snake.pop()

    #update the ui and clock
    self._update_ui()

    #the timestep
    self.clock.tick(SPEED)

    return reward, game_over, self.score

  def _is_collision(self, pt=None):

    if pt is None:
      pt = self.head

    if pt.x > self.w - BLOCK_SIZE or pt.x < 0 or pt.y > self.h - BLOCK_SIZE or pt.y < 0:
      return True

    if self.head in self.snake[1:]:
      return True

    return False

  def _update_ui(self):

    self.display.fill(BLACK)

    #Draw the blue snake
    for pt in self.snake:
      pygame.draw.rect(self.display, BLUE1, pygame.Rect(pt.x, pt.y, BLOCK_SIZE, BLOCK_SIZE))
      pygame.draw.rect(self.display, BLUE2, pygame.Rect(pt.x + 4, pt.y + 4, 12, 12))

    #Draw the red food
    pygame.draw.rect(self.display, RED, pygame.Rect(self.food.x, self.food.y, BLOCK_SIZE, BLOCK_SIZE))

    #Update the font
    text = font.render("Score: " + str(self.score), True, WHITE)
    self.display.blit(text, [0, 0])
    pygame.display.flip()

  def _move(self, action):

    clock_wise = [Direction.RIGHT, Direction.DOWN, Direction.LEFT, Direction.UP]
    idx = clock_wise.index(self.direction)

    if np.array_equal(action, [1, 0, 0]):
      new_dir = clock_wise[idx]
    elif np.array_equal(action, [0, 1, 0]):
      next_idx = (idx + 1) % 4
      new_dir = clock_wise[next_idx] # right turn r -> d -> l -> u
    else: # [0, 0, 1]
      next_idx = (idx - 1) % 4
      new_dir = clock_wise[next_idx]

    self.direction = new_dir

    x = self.head.x
    y = self.head.y

    if self.direction == Direction.RIGHT:
        x += BLOCK_SIZE
    elif self.direction == Direction.LEFT:
        x -= BLOCK_SIZE
    elif self.direction == Direction.DOWN:
        y += BLOCK_SIZE
    elif self.direction == Direction.UP:
        y -= BLOCK_SIZE

    self.head = Point(x, y)

Writing game.py
