# Snake game with AI

### This project is about making a snake game and using Reinforcement Learning to train a model to play

**For the random people who find this in github, the best practice to run it is to create a venv and install all packages from the requirements.txt.**
To do that, run:

```
python -m venv gameenv
gameenv\Scripts\activate
pip install -r requirements.txt
```

So... In order to make an AI model to play the game, first we need to code up the actual game.

For the game, the libraries we will need are:
```python
import pygame
import random
from enum import Enum
from collections import namedtuple
import numpy as np
```

The render font that I chose for the UI is arial.ttf, so we need to download that into the game folder.

Setting up the game we have the font, an enum class for Directions, a namedtuple we will be using to define each point in the game window, some color constants, block size we will be using for the objects in the game (snake body and food) and speed - which is the framerate.

**Helpers we have for the game:**
1. `arial.ttf` – font for game text
2. `Direction` (Enum) – movement directions
3. `Point` – named tuple for clean point management
4. Color constants – `WHITE`, `RED`, etc. for `pygame`
5. `BLOCK_SIZE` – the size of a single block (snake body / food)
6. `SPEED` – frames per second (FPS) of the game

**Game methods:**
1. `__init__` - the constructor, in which we specialize the display width and height
2. `_place_food()` - we place the food on a random place.
3. `play_step()` - we change the snake's direction based on user input
4. `_is_collision()` - we check if the snake bumped into something
5. `_move()` - we apply the user input direction
6. `_update_ui()` - we display and update the UI

### How the game works

First, we start by placing the snake horizontally and placing the food on a random place. The snake's head is at the center of the screen with it's body going to the left by the _x_ axis, and the default starting movement being right _relative to the screen_.

Then we start moving. The way we simulate movement is simple. We don't move each part of the snake each frame, instead, every step, we place a new head in the direction that the player specified and remove the tail if there was not food at that place. If there is food there, we don't remov the head and continue the loop. When we repeat that process really fast it looks like the whole snake is moving, but instead... we simply add a head and cut the tail really fast.

**Note - movement:** Game libraries don't use a normal coordinate system like in math. In games, in order to simulate monitor pixels, _x_ grows to the right, but _y_ grows downwards, not upwards.

Each step/frame we check if there is a collision and stop the game if there is.

And of course we update the UI each frame too.

In [None]:
import pygame
import random
from enum import Enum
from collections import namedtuple

pygame.init()
font = pygame.font.Font('arial.ttf', 25)

class Direction(Enum):
    RIGHT = 1
    LEFT = 2
    UP = 3
    DOWN = 4
    
Point = namedtuple('Point', 'x, y')

# rgb colors
WHITE = (255, 255, 255)
RED = (200,0,0)
BLUE1 = (0, 0, 255)
BLUE2 = (0, 100, 255)
BLACK = (0,0,0)

BLOCK_SIZE = 20
SPEED = 10

class SnakeGame:
    
    def __init__(self, w=640, h=480):
        self.w = w
        self.h = h
        # init display
        self.display = pygame.display.set_mode((self.w, self.h))
        pygame.display.set_caption('Snake')
        self.clock = pygame.time.Clock()
        
        # init game state
        self.direction = Direction.RIGHT
        
        self.head = Point(self.w/2, self.h/2)
        self.snake = [self.head, 
                      Point(self.head.x-BLOCK_SIZE, self.head.y),
                      Point(self.head.x-(2*BLOCK_SIZE), self.head.y)]
        
        self.score = 0
        self.food = None
        self._place_food()
        
    def _place_food(self):
        x = random.randint(0, (self.w-BLOCK_SIZE )//BLOCK_SIZE )*BLOCK_SIZE 
        y = random.randint(0, (self.h-BLOCK_SIZE )//BLOCK_SIZE )*BLOCK_SIZE
        self.food = Point(x, y)
        if self.food in self.snake:
            self._place_food()
        
    def play_step(self):
        # 1. collect user input
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                pygame.quit()
                quit()
            if event.type == pygame.KEYDOWN:
                if event.key == pygame.K_LEFT:
                    self.direction = Direction.LEFT
                elif event.key == pygame.K_RIGHT:
                    self.direction = Direction.RIGHT
                elif event.key == pygame.K_UP:
                    self.direction = Direction.UP
                elif event.key == pygame.K_DOWN:
                    self.direction = Direction.DOWN
        
        # 2. move
        self._move(self.direction) # update the head
        self.snake.insert(0, self.head)
        
        # 3. check if game over
        game_over = False
        if self._is_collision():
            game_over = True
            return game_over, self.score
            
        # 4. place new food or just move
        if self.head == self.food:
            self.score += 1
            self._place_food()
        else:
            self.snake.pop()
        
        # 5. update ui and clock
        self._update_ui()
        self.clock.tick(SPEED)
        # 6. return game over and score
        return game_over, self.score
    
    def _is_collision(self):
        # hits boundary
        if self.head.x > self.w - BLOCK_SIZE or self.head.x < 0 or self.head.y > self.h - BLOCK_SIZE or self.head.y < 0:
            return True
        # hits itself
        if self.head in self.snake[1:]:
            return True
        
        return False
        
    def _update_ui(self):
        self.display.fill(BLACK)
        
        for pt in self.snake:
            pygame.draw.rect(self.display, BLUE1, pygame.Rect(pt.x, pt.y, BLOCK_SIZE, BLOCK_SIZE))
            pygame.draw.rect(self.display, BLUE2, pygame.Rect(pt.x+4, pt.y+4, 12, 12))
            
        pygame.draw.rect(self.display, RED, pygame.Rect(self.food.x, self.food.y, BLOCK_SIZE, BLOCK_SIZE))
        
        text = font.render("Score: " + str(self.score), True, WHITE)
        self.display.blit(text, [0, 0])
        pygame.display.flip()
        
    def _move(self, direction):
        x = self.head.x
        y = self.head.y
        if direction == Direction.RIGHT:
            x += BLOCK_SIZE
        elif direction == Direction.LEFT:
            x -= BLOCK_SIZE
        elif direction == Direction.DOWN:
            y += BLOCK_SIZE
        elif direction == Direction.UP:
            y -= BLOCK_SIZE
            
        self.head = Point(x, y)

game = SnakeGame()
    
# game loop
while True:
    game_over, score = game.play_step()
        
    if game_over == True:
        break
        
print('Final Score', score)

pygame.quit()


### Next Steps: Hooking up an AI

Now that the game mechanics are coded, the next step is to build an agent that plays the game using **Reubfircement Learning (RL)**

In order for that to happen, first we need to extend the `play_step()` and `move()` methods, and then we need to remove the `__name__ == __main__` check at the end as we won't be starting up the game directly anymore.

#### Modifying the `play_step()`:

This is where the **Reinforcement Learning** magic happens.

Reinforcement learning is simple at its core. During training, we define rewards and penalties. The model is rewarded for good actions and penalized for bad ones. It learns through feedback:

``I ran into a wall and got a -10 reward --- I shouldn't do that.``

``I ate the food and got +10 --- eating food is good!``

Since the agent will now decide the snake’s direction, we add a second parameter to ``play_step()``: an ``action``, represented as a vector of type ``[float, float, float]`` (more on that later). This replaces keyboard input — we no longer listen to ``KEYDOWN`` events from the user.

Now... back to the sauce. We calculate the reward based on a couple of things. The main ones are obviously food and death. We reward the model 10 points for eating foor and -10 for dying.
But as you can guess, since the display width is 640p, height is 480p and block size is 20x20, that means the probability of the snake getting the food is: $$\frac{640}{20} \times \frac{480}{20} = 32 \times 24 = 768 \text{ total cells}$$ Meaning each move our change of getting food is: $$ \frac{1}{768} \approx 0.0013 \text{ or } 0.13\% $$

#### So what do we do?

**Well... we simply need to calculate the reward better and more precisely. We introduce _immediate rewards_:**
1. Since the agent will now decide the snake’s direction, we add a second parameter to play_step(): an action, represented as a vector of type [float, float, float] (more on that later). This replaces keyboard input — we no longer listen to KEYDOWN events from the user.
    1. If the new distance is shorter, we give **+1 reward**.
    2. If it’s **greater**, we apply a **−0.5** penalty.
2. We check for body parts in 8 surrounding directions (excluding the neck).
    1. For each nearby segment, we apply **−0.1**.
    2. This encourages the snake to avoid trapping itself.
3. If the snake loops around for too long, we penalize it using:
    $$ Penalty = 100 * len(snake) $$

#### Modifying the `_move()`:

Movement logic also needs to change. Instead of using fixed directions like UP or LEFT, we now handle **relative directions** based on the snake’s current heading.

Imagine the directions as a clock:
`[RIGHT, DOWN, LEFT, UP]`

If a snake is heading right:
  1. A **right** turn = DOWN
  2. A **left** turn = UP
  3. **Straight** = continue RIGHT

Since the model doesn’t choose absolute directions, it outputs a 3-element array:
 1. `[1, 0, 0]` = go straight
 2. `[0, 1, 0]` = go right
 3. `[0, 0, 1]` = go left

 We translate it into a new direction like so:

```python
# [[1, 0, 0], [0, 1, 0], [0, 0, 1]] = [straight, right, left]

        clock_wise = [Direction.RIGHT, Direction.DOWN, Direction.LEFT, Direction.UP]
        index = clock_wise.index(self.direction)

        if np.array_equal(action, [1, 0, 0]):
            new_direction = clock_wise[index] # straight - no change
        elif np.array_equal(action, [0, 1, 0]):
            next_index = (index + 1) % 4
            new_direction = clock_wise[next_index] # turn right
        else: # [0, 0, 1] = left
            next_index = (index - 1) % 4
            new_direction = clock_wise[next_index] # turn left

        self.direction = new_direction
```

This logic ensures that the snake moves relative to its current direction — just like how a real animal (or robot) might navigate.


In [None]:
import pygame
import random
from enum import Enum
from collections import namedtuple
import numpy as np

pygame.init()
font = pygame.font.Font('arial.ttf', 25)

class Direction(Enum):
    RIGHT = 1
    LEFT = 2
    UP = 3
    DOWN = 4
    
Point = namedtuple('Point', 'x, y')

# rgb colors
WHITE = (255, 255, 255)
RED = (200,0,0)
BLUE1 = (0, 0, 255)
BLUE2 = (0, 100, 255)
BLACK = (0,0,0)

BLOCK_SIZE = 20
SPEED = 40

class SnakeGameAI:
    
    def __init__(self, w=640, h=480):
        self.w = w
        self.h = h
        # init display
        self.display = pygame.display.set_mode((self.w, self.h))
        pygame.display.set_caption('Snake')
        self.clock = pygame.time.Clock()
        self.reset()

    def reset(self):
        # init game state
        self.direction = Direction.RIGHT
        
        self.head = Point(self.w/2, self.h/2)
        self.snake = [self.head, 
                      Point(self.head.x-BLOCK_SIZE, self.head.y),
                      Point(self.head.x-(2*BLOCK_SIZE), self.head.y)]
        
        self.score = 0
        self.food = None
        self._place_food()
        self.frame_iteration = 0
        
    def _place_food(self):
        x = random.randint(0, (self.w-BLOCK_SIZE )//BLOCK_SIZE )*BLOCK_SIZE 
        y = random.randint(0, (self.h-BLOCK_SIZE )//BLOCK_SIZE )*BLOCK_SIZE
        self.food = Point(x, y)
        if self.food in self.snake:
            self._place_food()
        
    def play_step(self, action):
        self.frame_iteration += 1
        # 1. collect user input
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                pygame.quit()
                quit()
        
        # Store previous distance to food (for directional reward)
        prev_dist = abs(self.head.x - self.food.x) + abs(self.head.y - self.food.y)
        
        # 2. move
        self._move(action) # update the head
        self.snake.insert(0, self.head)
        
        # 3. check if game over
        reward = 0
        game_over = False
        # Prevents endless loops (bigger snake more moves allowed)
        if self.is_collision() or self.frame_iteration > 100*len(self.snake):
            game_over = True
            reward = -10
            return reward, game_over, self.score
        
        nearby_body = 0
        directions = [
            (-20, 0), (20, 0), (0, -20), (0, 20),  # Adjacent
            (-20, -20), (20, -20), (-20, 20), (20, 20)  # Diagonals
        ]

        for dx, dy in directions:
            if Point(self.head.x + dx, self.head.y + dy) in self.snake[2:]:
                nearby_body += 1
            
        # 4. place new food or just move
        if self.head == self.food:
            self.score += 1
            reward = 10
            self._place_food()
        else:
            self.snake.pop()

        # Directional rewards
        new_dist = abs(self.head.x - self.food.x) + abs(self.head.y - self.food.y)
        reward += 1.0 if new_dist < prev_dist else -0.5
        reward += -0.1 * nearby_body
        
        print(f"Move: {'Straight' if action[0] else 'Right' if action[1] else 'Left'} | " 
            f"Dist: {prev_dist:.1f}→{new_dist:.1f} | "
            f"Body blocks: {nearby_body} | Reward: {reward:.1f}")
        
        # 5. update ui and clock
        self._update_ui()
        self.clock.tick(SPEED)
        # 6. return game over and score
        return reward, game_over, self.score
    
    def is_collision(self, point = None):
        if point is None:
            point = self.head
        # hits boundary
        if point.x > self.w - BLOCK_SIZE or point.x < 0 or point.y > self.h - BLOCK_SIZE or point.y < 0:
            return True
        # hits itself
        if point in self.snake[1:]:
            return True
        
        return False
        
    def _update_ui(self):
        self.display.fill(BLACK)
        
        for pt in self.snake:
            pygame.draw.rect(self.display, BLUE1, pygame.Rect(pt.x, pt.y, BLOCK_SIZE, BLOCK_SIZE))
            pygame.draw.rect(self.display, BLUE2, pygame.Rect(pt.x+4, pt.y+4, 12, 12))
            
        pygame.draw.rect(self.display, RED, pygame.Rect(self.food.x, self.food.y, BLOCK_SIZE, BLOCK_SIZE))
        
        text = font.render("Score: " + str(self.score), True, WHITE)
        self.display.blit(text, [0, 0])
        pygame.display.flip()
        
    def _move(self, action):
        # [[1, 0, 0], [0, 1, 0], [0, 0, 1]] = [straight, right, left]

        clock_wise = [Direction.RIGHT, Direction.DOWN, Direction.LEFT, Direction.UP]
        index = clock_wise.index(self.direction)

        if np.array_equal(action, [1, 0, 0]):
            new_direction = clock_wise[index] # straight - no change
        elif np.array_equal(action, [0, 1, 0]):
            next_index = (index + 1) % 4
            new_direction = clock_wise[next_index] # turn right
        else: # [0, 0, 1] = left
            next_index = (index - 1) % 4
            new_direction = clock_wise[next_index] # turn left

        self.direction = new_direction

        x = self.head.x
        y = self.head.y
        if self.direction == Direction.RIGHT:
            x += BLOCK_SIZE
        elif self.direction == Direction.LEFT:
            x -= BLOCK_SIZE
        elif self.direction == Direction.DOWN:
            y += BLOCK_SIZE
        elif self.direction == Direction.UP:
            y -= BLOCK_SIZE
            
        self.head = Point(x, y)

### Next: we need to build the model

For the model we will be **extending Reinforcement Learning** by using a **Deep Neural Network** to predict the actions.

In the `model.py` we have **2** classes. The `Linear_QNet` and the `QTrainer`.

**Linear_QNet**:
1. `__init__()` - We initialize the **NN** structure
2. `forward()` - Calculate the output result
3. `save()` - Save the model
4. `load()` - Load the model

**QTrainer**:
1. `__init__()` - Injects the model, parameters and other dependencies
2. `train_step()` - Trains the **NN**

So... let's dive a little more.

What does the `Linear_QNet` actually do?

The model has 3 layers. The Input layer, which is the state of the game (19 values), a hidden layer of 255 neurons (to help the model learn more complex patterns), and the output layer (the 3 possible actions - straight, right, left).

The `forward()` function takes in the state and returns 3 **Q-values** [0.3, -0.2, 1.1] - How good each action is in this situation. In this case it picks turning left as the best possibility.

Now, what does the `QTrainer` do?

The Trainer teaches the model by feeding in experiences and letting it adjust it's weights.

An **experience** is: STATE → took ACTION → got REWARD → ended up in NEXT_STATE

First, we turn everything into _tensors_, which is the data type torch works with.
```python
state = torch.tensor(state, dtype=torch.float)
```

We add a batch dimension if we have only one sample, becayse PyTorch expects batches. Batching also allows training on multiple experiences at once.
```python
if len(state.shape) == 1:
    state = torch.unsqueeze(state, 0)
    ...
```

Then we predict the Q-values for the current state
```python
pred = self.model(state)
```

Which then gives Q-values for all actions `[0.2, 0.1, -0.3]`

We clone it to make a target, where we will adjust the right action's values to show the model what it should have predicted
```python
target = pred.clone()
```

Calculate the target Q-value using a formula derived from the _Bellman Equation_ $ Q_new = reward + \gamma * max(Q(next_state)) $

Gamma (y) is a discount factor that balanced immediate vs future rewards.
```python
for index in range(len(done)):
    Q_new = reward[index]
    if not done[index]:
        Q_new = reward[index] + gamma * max(self.model(next_state[index]))
```

The formula means: The value of this action = what I got now + how good things might be if I keep going.

Then we replace only the action taken:
```python
target[index][torch.argmax(action).item()] = Q_new
```

Let's say the action taken was **2**. Then we update:
```python
target[index][2] = Q_new
```

So now:
```python
target = [0.2, 0.2, Q_new]
```

This is what we want the model to learn to predict next time.

And now we train it!
```python
self.optimizer.zero_grad()
loss = self.criterion(target, pred)
loss.backward()

self.optimizer.step()
```

And lastly we feed all that into the model, calculate the loss, compute gradients and update it.

`.zero_grad()`: clears out previous gradients

`loss`: difference between what the model predicted and what it _should_ predict

`.backward()`: computes gradients using gradient descent

`.step()`: updates the model using those gradients

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import os

class Linear_QNet(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        # Two hidden layers for better pattern recognition
        self.linear1 = nn.Linear(input_size, hidden_size)
        self.linear2 = nn.Linear(hidden_size, hidden_size)
        self.linear3 = nn.Linear(hidden_size, output_size)

    def forward(self, state):
        # ReLU(x) = max(0, x) -> activation function
        state = F.relu(self.linear1(state))
        state = F.relu(self.linear2(state))
        state = self.linear3(state)
        return state
    
    def save(self, file_name='model.pth'):
        model_folder_path = '.\model'
        if not os.path.exists(model_folder_path):
            os.makedirs(model_folder_path)

        file_name = os.path.join(model_folder_path, file_name)
        torch.save(self.state_dict(), file_name)

    def load(self, file_name='model.pth'):
        model_folder_path = '.\model'
        file_name = os.path.join(model_folder_path, file_name)

        if os.path.exists(file_name):
            self.load_state_dict(torch.load(file_name))
            self.eval()  # Set the model to evaluation mode
            print(f'Model loaded from {file_name}')
            return True
        else:
            print(f'No saved model found at {file_name}')
            return False


class QTrainer:
    def __init__(self, model, learning_rate, gamma):
        self.learning_rate = learning_rate
        self.model = model
        self.gamma = gamma
        self.optimizer = optim.Adam(model.parameters(), lr=self.learning_rate)
        self.criterion = nn.MSELoss()

    def train_step(self, state, action, reward, next_state, done):
        state = torch.tensor(state, dtype=torch.float)
        action = torch.tensor(action, dtype=torch.long)
        reward = torch.tensor(reward, dtype=torch.float)
        next_state = torch.tensor(next_state, dtype=torch.float)

        # Checks shape of the state, It's 1 if it's a single state and 2 if it's a batch of states
        # If it's a single state, we need to add a batch dimension at the beginning
        if len(state.shape) == 1:
            state = torch.unsqueeze(state, 0)
            action = torch.unsqueeze(action, 0)
            reward = torch.unsqueeze(reward, 0)
            next_state = torch.unsqueeze(next_state, 0)
            done = (done, )

        # 1: predicted Q values with current state
        pred = self.model(state)

        target = pred.clone()
        for index in range(len(done)): # Iterating over all games (all inputs have the same size/length)
            Q_new = reward[index]
            if not done[index]:
                # Q_new = r + y * max(next_predicted Q value) -> only do this if not done
                Q_new = reward[index] + self.gamma * torch.max(self.model(next_state[index]))

            # target[batch_index][max Q value index] = Q_new | argmax returns the index of the max Q value
            target[index][torch.argmax(action).item()] = Q_new 

        self.optimizer.zero_grad() # clear previous gradients
        loss = self.criterion(target, pred)
        loss.backward() # backpropagation

        self.optimizer.step() # update weights

### The Agent

Now we need to connect everything, we do that by creating an Agent `agent.py`

**Agent**:
1. `__init__(load_model=False, continue_training=False)` - Initialize variables, the memory, the model, the trainer etc.
2. `get_state(game: SnakeGameAI)`- Builds the state of the game (snake position, direction, danger etc.)
3. `remember(state, action, reward, next_state, done)` - Saves a single experience into memory for later training (long-term memory)
4. `train_long_memory()` - Trains the model using a random batch of past experiences
5. `train_short_memory(state, action, reward, next_state, done)` - Trains the model on the most recent experience
6. `get_action(state)` - Returns the action taken for that state

Diving a little deeper in those functions...

The `get_state()` builds the state vector in the current frame of the game.

It has 19 values:

- **Danger**: straight, right, left
- **Current direction**: left, right, up, down
- **Food direction**: left, right, up, down
- **Body proximity** (8 directions): left, right, up, down, top-left, top-right, bottom-left, bottom-right


```python
state = [
            # Danger straight
            (dir_r and game.is_collision(point_r)) or
            (dir_l and game.is_collision(point_l)) or
            (dir_u and game.is_collision(point_u)) or
            (dir_d and game.is_collision(point_d)),

            # Danger right
            (dir_u and game.is_collision(point_r)) or
            (dir_d and game.is_collision(point_l)) or
            (dir_l and game.is_collision(point_u)) or
            (dir_r and game.is_collision(point_d)),

            # Danger left
            (dir_d and game.is_collision(point_r)) or
            (dir_u and game.is_collision(point_l)) or
            (dir_r and game.is_collision(point_u)) or
            (dir_l and game.is_collision(point_d)),

            # Move direction
            dir_l,
            dir_r,
            dir_u,
            dir_d,

            # Food location
            game.food.x < game.head.x, # food left
            game.food.x > game.head.x, # food right
            game.food.y < game.head.y, # food up
            game.food.y > game.head.y  # food down
        ]

        # Add body proximity detection
        head = game.snake[0]
        directions = [
            (-20, 0),   # left
            (20, 0),    # right
            (0, -20),   # up
            (0, 20),    # down
            (-20, -20), # top-left
            (20, -20),  # top-right
            (-20, 20),  # bottom-left
            (20, 20)    # bottom-right
        ]

        for dx, dy in directions:
            point = Point(head.x + dx, head.y + dy)
            state.append(True if point in game.snake[2:] else False)
```

The `remember()` simply stores a single experience in the memory.

```python
self.memory.append((state, action, reward, next_state, done))
```

The `train_short_memory()` calls the trainer's `train_step()` function with **One** experience.

```python
self.trainer.train_step(state, action, reward, next_state, done)
```

The `train_long_memory()` calls the trainer's `train_step()` function with a batch of experiences.
If the memory's experiences outgrow our `BATCH_SIZE` we get a random batch of `BATCH_SIZE` experiences and train on them.

**zip(*mini_sample)** unpacks the list into individual tuples.

```python
if len(self.memory) > BATCH_SIZE:
            mini_sample = random.sample(self.memory, BATCH_SIZE) # list of tuples
        else:
            mini_sample = self.memory

        states, actions, rewards, next_states, dones = zip(*mini_sample)
        self.trainer.train_step(states, actions, rewards, next_states, dones)
```

The `get_action()` is interesting. We have an **epsilon** value, which is the randomness of each move. As the model learns through more games, we reduce randomness.

```python
self.epsilon -= self.number_of_games
action = [0, 0, 0]
if random.randint(0, 200) < self.epsilon:
    move = random.randint(0, 2)
    action[move] = 1
```

When we are not making random moves, this means that the moves are generated by the model's prediction.
We get the index of the biggest Q-value and get the move corresponding to the index.

```python
state0 = torch.tensor(state, dtype=torch.float)
prediction = self.model(state0)
move = torch.argmax(prediction).item()
action[move] = 1
```

#### **Tying it all together**

We define a `train()` function that creates an instance of the **Agent** and the **SnakeGameAI**, and starts the learning loop. In each iteration:

1. The agent observes the current state of the game.
2. It decides an action (either random or model-predicted).
3. The game processes the action and returns the next state and reward.
4. The agent trains on this immediate experience (short memory).
5. It stores the experience in long-term memory.
6. If the game ends (the snake dies), we:
    - Reset the game,
    - Train on a batch of past experiences (long memory),
    - Update and track scores,
    - Save the model if it reaches a new high score.

Over many iterations, the agent learns from its mistakes and successes, gradually improving its ability to survive and collect food — all without hardcoded game strategies.
But training an agent over hundreds or thousands of games generates a lot of data. To monitor progress, spot trends, and verify improvements, we need a way to visualize performance.

That's where `helper.py` comes in — a lightweight utility to plot scores and moving averages after each game.

In [None]:
import torch
import random
import numpy as np
from collections import deque
from game import SnakeGameAI, Direction, Point
from model import Linear_QNet, QTrainer
from helper import plot

MAX_MEMORY = 100_000
BATCH_SIZE = 1000
LEARNING_RATE = 0.001

class Agent:
    def __init__(self, load_model=False, continue_training=False):
        self.number_of_games = 0
        self.epsilon = 80 # control randomness
        self.gamma = 0.9 # discount rate
        self.memory = deque(maxlen=MAX_MEMORY) # popleft()
        self.model = Linear_QNet(19, 256, 3)
        self.trainer = QTrainer(self.model, learning_rate=LEARNING_RATE, gamma=self.gamma)

        if load_model:
            success = self.model.load('model.pth')
            if success and continue_training:
                self.model.train()  # Set the model to training mode if continuing training
                self.epsilon = 50 # Higher than fully trained, lower than new
            else:
                self.epsilon = 20 # Start with some exploration

    def get_state(self, game):
        head = game.snake[0]
        point_l = Point(head.x - 20, head.y) # we use the const 20 because this is the size of each snake rectangle
        point_r = Point(head.x + 20, head.y)
        point_u = Point(head.x, head.y - 20)
        point_d = Point(head.x, head.y + 20)

        dir_l = game.direction == Direction.LEFT
        dir_r = game.direction == Direction.RIGHT
        dir_u = game.direction == Direction.UP
        dir_d = game.direction == Direction.DOWN

        # This is a little confusing before it clicks :D
        state = [
            # Danger straight
            (dir_r and game.is_collision(point_r)) or
            (dir_l and game.is_collision(point_l)) or
            (dir_u and game.is_collision(point_u)) or
            (dir_d and game.is_collision(point_d)),

            # Danger right
            (dir_u and game.is_collision(point_r)) or
            (dir_d and game.is_collision(point_l)) or
            (dir_l and game.is_collision(point_u)) or
            (dir_r and game.is_collision(point_d)),

            # Danger left
            (dir_d and game.is_collision(point_r)) or
            (dir_u and game.is_collision(point_l)) or
            (dir_r and game.is_collision(point_u)) or
            (dir_l and game.is_collision(point_d)),

            # Move direction
            dir_l,
            dir_r,
            dir_u,
            dir_d,

            # Food location
            game.food.x < game.head.x, # food left
            game.food.x > game.head.x, # food right
            game.food.y < game.head.y, # food up
            game.food.y > game.head.y  # food down
        ]

        # Add body proximity detection
        head = game.snake[0]
        directions = [
            (-20, 0),   # left
            (20, 0),    # right
            (0, -20),   # up
            (0, 20),    # down
            (-20, -20), # top-left
            (20, -20),  # top-right
            (-20, 20),  # bottom-left
            (20, 20)    # bottom-right
        ]

        for dx, dy in directions:
            point = Point(head.x + dx, head.y + dy)
            state.append(True if point in game.snake[2:] else False)

        return np.array(state, dtype=int)

    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done)) # popleft() if MAX_MEMORY is reached

    def train_long_memory(self):
        if len(self.memory) > BATCH_SIZE:
            mini_sample = random.sample(self.memory, BATCH_SIZE) # list of tuples
        else:
            mini_sample = self.memory

        states, actions, rewards, next_states, dones = zip(*mini_sample)
        self.trainer.train_step(states, actions, rewards, next_states, dones)

    def train_short_memory(self, state, action, reward, next_state, done):
        self.trainer.train_step(state, action, reward, next_state, done)

    def get_action(self, state):
        # random moves: tradeoff exploration vs exploitation
        # The more games we play, the less random moves we make
        self.epsilon -= self.number_of_games
        action = [0, 0, 0]
        if random.randint(0, 200) < self.epsilon:
            move = random.randint(0, 2)
            action[move] = 1
        else:
            state0 = torch.tensor(state, dtype=torch.float)
            prediction = self.model(state0) # this will call the forward function of the model
            move = torch.argmax(prediction).item()
            action[move] = 1
            
        return action

def train(load_model=False, continue_training=False):
    plot_scores = []
    plot_mean_scores = []
    total_score = 0 
    max_score = 0
    agent = Agent(load_model=load_model, continue_training=continue_training)
    game = SnakeGameAI()

    while True:
        # get old state
        state_old = agent.get_state(game)

        # get action based on state
        action = agent.get_action(state_old)

        # perform move and get new state
        reward, done, score = game.play_step(action)
        state_new = agent.get_state(game)

        # train short memory
        agent.train_short_memory(state_old, action, reward, state_new, done)

        # remember
        agent.remember(state_old, action, reward, state_new, done)

        if done:
            # train the long memory, plot result
            game.reset()
            agent.number_of_games += 1
            agent.train_long_memory()

            if score > max_score:
                max_score = score
                agent.model.save()

            print('Game', agent.number_of_games, 'Score', score, 'Record', max_score)

            plot_scores.append(score)
            total_score += score
            mean_score = total_score / agent.number_of_games
            plot_mean_scores.append(mean_score)
            plot(plot_scores, plot_mean_scores) # plot the scores


if __name__ == '__main__':
    train()

#### The plotting

As the agent trains, it's important to visualize its progress to understand how well it's learning over time.

We use a single helper function: `plot(scores, mean_scores)`.
This function updates a live plot showing:
    - The raw score after each game
    - The average score across all games (mean = total score / number of games)

To enable real-time updates without interrupting the training loop, we use:

```python
# Enable interactive mode
plt.ion()
```

This allows the plot to refresh after each game without blocking the script. It's a simple but powerful way to monitor training and see whether the agent is actually improving.

In [None]:
import matplotlib.pyplot as plt
from IPython import display

plt.ion() # To plot interactively

def plot(scores, mean_scores):
    display.clear_output(wait=True)
    display.display(plt.gcf()) # Get current figure
    plt.clf() # Clear the current figure
    plt.title('Training...')
    plt.xlabel('Number of games')
    plt.ylabel('Score')
    plt.plot(scores)
    plt.plot(mean_scores)
    plt.ylim(ymin=0)
    plt.text(len(scores) - 1, scores[-1], str(scores[-1]))
    plt.text(len(mean_scores) - 1, mean_scores[-1], str(mean_scores[-1]))
    plt.show(block=False)
    plt.pause(0.1) # Pause to update the plot

To import an already trained model, we go to `agent.py` -> `train()`, and we change the **load_model** and **continue_training** to **True**

### Resources:

https://www.youtube.com/watch?v=PJl4iabBEz0&list=PLfR10wejCzo_OL-6OsBV-4jAPnSncvZZH

https://en.wikipedia.org/wiki/Reinforcement_learning

https://www.geeksforgeeks.org/what-is-reinforcement-learning/

https://www.geeksforgeeks.org/snake-game-in-python-using-pygame-module/

https://pytorch.org/

https://www.pygame.org/docs/

https://docs.pytorch.org/docs/stable/tensors.html



And of course... **ChatGPT**