## Game Source Code

In [2]:
!pip install matplotlib
!pip install pyvirtualdisplay
!apt-get install -y xvfb # Install Xvfb


import matplotlib
matplotlib.use('Agg')


import turtle
import time
import random
import torch
import numpy as np

from pyvirtualdisplay import Display # Import Display
from google.colab import drive
import os

# Create virtual display
display = Display(visible=0, size=(1400, 900))
display.start() # Start the virtual display


GAME_SPEED = 0.5 # MS
ACTIONS = [1,2,3,4]
google_drive_weights_dir = '/content/drive/MyDrive/SnakeRLAgent/'
drive.mount('/content/drive')

class Head:
    def __init__(self):
        self.head = turtle.Turtle();
        self.head.speed(0)
        self.head.shape("square")
        self.head.color("black")
        self.head.penup()
        self.head.goto(0,0)
        self.head.direction = "stop"

    def go_up(self):
        if self.head.direction != "down":
            self.head.direction = "up"

    def go_down(self):
        if self.head.direction != "up":
            self.head.direction = "down"

    def go_left(self):
        if self.head.direction != "right":
            self.head.direction = "left"

    def go_right(self):
        if self.head.direction != "left":
            self.head.direction = "right"

    def move(self):
        if self.head.direction == "up":
            y = self.head.ycor()
            self.head.sety(y + 20)

        if self.head.direction == "down":
            y = self.head.ycor()
            self.head.sety(y - 20)

        if self.head.direction == "left":
            x = self.head.xcor()
            self.head.setx(x - 20)

        if self.head.direction == "right":
            x = self.head.xcor()
            self.head.setx(x + 20)

class GameLauncher:
    def __init__(self):
        self.turtle = turtle
        self.boundary = 140  # Game boundary is ±140
        self.cell_size = 20  # Each block in the snake segment is 20 x 20

        # Calculate grid size from boundaries and cell size
        # (-140 to +140 = 280 units total, divided by 20 = 14 cells)
        self.grid_size = (self.boundary * 2) // self.cell_size  # Equals 14``

        self.setup_window()
        self.segments = []
        self.high_score = 0
        self.head = Head()
        self.setup_pen()
        self.setup_food()
        self.score = 0

    def quit(self):
        self.turtle.reset()


    def check_wall_collision(self):
        head = self.head
        segments = self.segments
        if head.head.xcor()>140 or head.head.xcor()<-140 or head.head.ycor()>140 or head.head.ycor()<-140:
            time.sleep(1)
            head.head.goto(0,0)
            head.head.direction = "stop"

            # Hide the segments
            for segment in segments:
                segment.goto(1000, 1000)

            # Clear the segments list
            segments.clear()

            # Reset the score
            self.score = 0

            # Reset the delay
            self.delay = GAME_SPEED

            self.pen.clear()
            self.pen.write("Score: {}  High Score: {}".format(self.score, self.high_score), align="center", font=("Courier", 24, "normal"))
            return 1
        return 0

    def check_food_collision(self):
        if self.head.head.distance(self.food) < 20:
            # Move the self.food to a random spot
            x = random.randint(-140, 140)
            y = random.randint(-140, 140)
            self.food.goto(x,y)

            # Add a segment
            new_segment = turtle.Turtle()
            new_segment.speed(0)
            new_segment.shape("square")
            new_segment.color("grey")
            new_segment.penup()
            self. segments.append(new_segment)

            # Shorten the delay
            self.delay -= 0.001

            # Increase the score
            self.score += 10

            if self.score > self.high_score:
                self.high_score = self.score

            self.pen.clear()
            self.pen.write("Score: {}  High Score: {}".format(self.score, self.high_score), align="center", font=("Courier", 24, "normal"))

    def state_to_array(self):
        def turtle_to_grid(x, y):
            # Convert from (-140, 140) range to (0, 13) range
            grid_x = int((x + self.boundary) // self.cell_size)
            grid_y = int((y + self.boundary) // self.cell_size)
            # Ensure coordinates are within grid bounds
            grid_x = max(0, min(grid_x, self.grid_size - 1))
            grid_y = max(0, min(grid_y, self.grid_size - 1))
            return grid_x, grid_y

        # Create empty grid
        grid = np.zeros((self.grid_size, self.grid_size, 3))

        # Set head
        head_x, head_y = turtle_to_grid(self.head.head.xcor(), self.head.head.ycor())
        grid[head_y, head_x, 0] = 1

        # Set body segments
        for segment in self.segments:
            seg_x, seg_y = turtle_to_grid(segment.xcor(), segment.ycor())
            grid[seg_y, seg_x, 1] = 1

        # Set food
        food_x, food_y = turtle_to_grid(self.food.xcor(), self.food.ycor())
        grid[food_y, food_x, 2] = 1  # Make sure this line is actually setting the value


        return torch.from_numpy(grid.flatten()).float()




    def update_snake_body(self):
        # Move the end segments first in reverse order
        for index in range(len(self.segments)-1, 0, -1):
            x = self.segments[index-1].xcor()
            y = self.segments[index-1].ycor()
            self.segments[index].goto(x, y)

        # Move segment 0 to where the head is
        if len(self.segments) > 0:
            x = self.head.head.xcor()
            y = self.head.head.ycor()
            self.segments[0].goto(x,y)


    def get_random_action(self):
        # This is a placeholder. In a real RL setup, you'd get the action from your model
        return random.choice(ACTIONS)

    def get_reward(self, old_score, new_score, collision):
        if collision:
            return -1
        elif new_score > old_score:
            return 1
        else:
            return 0

    def step(self,action):
        old_score = self.score
        head = self.head
        self.delay = GAME_SPEED

        # Score
        segments = self.segments

        match action:
            case 0:
                self.head.go_up()
            case 1:
                self.head.go_down()
            case 2:
                self.head.go_left()
            case _:
                self.head.go_right()

        self.wn.update()
        collision_check = self.check_wall_collision() or self.check_self_collision()
        self.check_food_collision()
        self.update_snake_body()
        head.move()
        new_state = self.state_to_array()
        reward = self.get_reward(old_score, self.score, collision_check)
        done = collision_check

        return new_state,reward,done

    def check_self_collision(self):
        segments = self.segments
        head = self.head
        pen = self.pen

    # Check for head collision with the body segments
        for segment in segments:
            if segment.distance(head.head) < 20:
                time.sleep(1)
                head.head.goto(0,0)
                head.direction = "stop"

                # Hide the segments
                for segment in segments:
                    segment.goto(1000, 1000)

                # Clear the segments list
                segments.clear()

                # Reset the score
                self.score = 0

                # Reset the delay
                self.delay = 0.1

                # Update the score display
                pen.clear()
                pen.write("Score: {}  High Score: {}".format(self.score, self.high_score), align="center", font=("Courier", 24, "normal"))

                return 1
        return 0

    def setup_window(self):
        self.wn = self.turtle.Screen()
        self.wn = turtle.Screen()
        self.wn.title("SnaKE")
        self.wn.bgcolor("green")
        self.wn.setup(width=300, height=300)
        self.wn.tracer(0) # Turns off the screen updates

    def setup_pen(self):
        self.pen = self.turtle.Turtle()
        self.pen.speed(0)
        self.pen.shape("square")
        self.pen.color("white")
        self.pen.penup()
        self.pen.hideturtle()
        self.pen.goto(0, 140)
        self.pen.write("Score: 0  High Score: 0", align="center", font=("Courier", 24, "normal"))

    def setup_food(self):
        self.food = turtle.Turtle()
        self.food.speed(0)
        self.food.shape("circle")
        self.food.color("red")
        self.food.penup()
        self.food.goto(0,100)


# game = GameLauncher()
# for i in range(20):
#     action = game.get_random_action()
#     new_state,reward,done =game.step(action)
#     print(f"Run {new_state.shape} |Reward is {reward} | Done is {done}")
#     time.sleep(2)


Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
xvfb is already the newest version (2:21.1.4-2ubuntu1.7~22.04.12).
0 upgraded, 0 newly installed, 0 to remove and 49 not upgraded.
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


# Model and Replay Buffer definition

In [4]:
from collections import deque, namedtuple
import torch.nn as nn
import random

Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward'))
class ReplayBuffer:
    def __init__(self,max_capacity):
        self.memory = deque([],max_capacity)
        self.device = "cuda"

    # Takes a named tuple of Transition
    def push(self,transition_):
        self.memory.append(transition_)

    def can_sample(self,batch_size):
        return len(self.memory) >= batch_size

    def sample(self, batch_size):
            """Sample a batch of transitions and convert to torch tensors"""
            if not self.can_sample(batch_size):
                raise ValueError(f"Not enough samples in buffer. Has {len(self)} samples, but {batch_size} requested")

            # Sample random transitions
            transitions = random.sample(list(self.memory), batch_size)

            # Transpose the batch
            batch = Transition(*zip(*transitions))

            try:
                # Convert states to tensor
                state_batch = torch.stack([
                    torch.as_tensor(s, dtype=torch.float32) if not isinstance(s, torch.Tensor) else s
                    for s in batch.state
                ]).to(self.device)

                # Convert actions to tensor - handle scalar integers
                action_batch = torch.tensor(
                    [a if isinstance(a, (list, tuple)) else [a] for a in batch.action],
                    dtype=torch.long
                ).to(self.device)

                # Convert next_states to tensor
                next_state_batch = torch.stack([
                    torch.as_tensor(s, dtype=torch.float32) if not isinstance(s, torch.Tensor) else s
                    for s in batch.next_state
                ]).to(self.device)

                # Convert rewards to tensor - handle scalar values
                reward_batch = torch.tensor(
                    [r if isinstance(r, (list, tuple)) else [r] for r in batch.reward],
                    dtype=torch.float32
                ).to(self.device)

                return state_batch, action_batch, next_state_batch, reward_batch

            except Exception as e:
                # Detailed error reporting
                shapes = {
                    'state': [np.shape(s) for s in batch.state[:3]],  # Show first 3 shapes
                    'action': [np.shape(a) if hasattr(a, 'shape') else type(a) for a in batch.action[:3]],
                    'next_state': [np.shape(s) for s in batch.next_state[:3]],
                    'reward': [np.shape(r) if hasattr(r, 'shape') else type(r) for r in batch.reward[:3]]
                }
                raise RuntimeError(f"Error creating batch: {str(e)}\n"
                                f"Sample shapes/types:\n"
                                f"States: {shapes['state']}\n"
                                f"Actions: {shapes['action']}\n"
                                f"Next States: {shapes['next_state']}\n"
                                f"Rewards: {shapes['reward']}")

    def __len__(self):
        return len(self.memory)

class DQN(nn.Module):
    def __init__(self,n_observations,actions):
        super(DQN,self).__init__()
        self.relu = nn.ReLU()
        self.input = nn.Linear(n_observations,128)
        self.middle = nn.Linear(128,256)
        self.fc = nn.Linear(256,len(actions))

    def forward(self,x):
        x = self.relu(self.input(x))
        x = self.relu(self.middle(x))
        return self.fc(x)



# Training

In [3]:
import torch
from torch.functional import F
import copy
epsilon = 1.0       # Start epsilon at 1.0 for exploration
epsilon_min = 0.01  # Minimum epsilon for a reasonable amount of exploitation
epsilon_decay = 0.995
device = torch.device("cuda")

def policy(state,action_list,inference_model):
    global epsilon  # Ensure epsilon is tracked across calls
    if torch.rand(1) < epsilon:
        epsilon = max(epsilon_min,epsilon * epsilon_decay)
        return random.randrange(len(action_list))
    else:
        return inference_model(state).detach().argmax().item()



def training_model(policy_net:DQN,game_instance:GameLauncher,lr,batch_size,episodes = 20,gamma=0.99,actions=[1,2,3,4]):
    optimizer = torch.optim.AdamW(policy_net.parameters(),lr=lr,)
    replay_buffer = ReplayBuffer(max_capacity=batch_size + 1)
    target_q_model = copy.deepcopy(policy_net).to(device).eval()

    for episode in range(1,episodes + 1):
        done = False
        total_loss = 0
        num_updates = 0  # Track number of updates for averaging
        while not done:
            current_state = game_instance.state_to_array().to(device)
            action = policy(state=current_state,action_list=actions,inference_model=target_q_model)
            new_state,reward,done =game_instance.step(action)

            replay_buffer.push(Transition(state=current_state,action=action,next_state=new_state,reward=reward))

            if replay_buffer.can_sample(batch_size):
                state_b,action_b,next_state_b,reward_b = replay_buffer.sample(batch_size=batch_size)
                q_values = policy_net(state_b)

                current_q_value = q_values.gather(1, action_b)

                with torch.no_grad():
                    maximum_next_q_value = torch.max(target_q_model(next_state_b),dim=1,keepdim=True)[0]
                    target_q_value =  reward_b + (gamma * maximum_next_q_value * reward_b)

                criterion = nn.SmoothL1Loss()

                loss = criterion(current_q_value,target_q_value)
                loss.backward() # compute gradients for all parameters
                total_loss += loss.item()
                num_updates += 1  # Track number of updates


                torch.nn.utils.clip_grad_value_(policy_net.parameters(), 100)
                optimizer.zero_grad()
                optimizer.step()


            if episode % 2 == 0 and num_updates != 0 and replay_buffer.can_sample(batch_size):
                target_q_model.load_state_dict(policy_net.state_dict())
                print(f"Episode {episode} done with average error | {total_loss / num_updates}")
                total_loss = 0
                torch.save(policy_net.state_dict(),f"{google_drive_weights_dir}/policy_episode({episode}).pth")


game = GameLauncher()
policy_net = DQN(len(torch.rand(588,)),actions=ACTIONS).to(device)
training_model(policy_net=policy_net,game_instance=game,lr=0.001,batch_size=10_000,episodes=5_000)
# game.quit()




Episode 1060 done with average error | 0.05951027572154999
Episode 1060 done with average error | 0.029755156487226486
Episode 1060 done with average error | 0.019836706419785816
Episode 1060 done with average error | 0.014877482317388058
Episode 1060 done with average error | 0.011901963502168655
Episode 1060 done with average error | 0.009918297330538431
Episode 1060 done with average error | 0.00850138972912516
Episode 1060 done with average error | 0.0074386862106621265
Episode 1060 done with average error | 0.006618380960490968
Episode 1062 done with average error | 0.059564799070358276
Episode 1062 done with average error | 0.029782427474856377
Episode 1062 done with average error | 0.019854942957560223
Episode 1062 done with average error | 0.014891154132783413
Episode 1062 done with average error | 0.011912871152162552
Episode 1062 done with average error | 0.009927392626802126
Episode 1062 done with average error | 0.008509200598512377
Episode 1062 done with average error | 0.

KeyboardInterrupt: 

## Inference

In [None]:
game = GameLauncher()
model = DQN(len(torch.rand(588,)),actions=ACTIONS)
model.load_state_dict(torch.load(f"{google_drive_weights_dir}/policy_episode(1152).pth"))
for i in range(1000):
    with torch.no_grad():
        q_values = model(game.state_to_array())
        action =  torch.argmax(q_values).item()
    new_state,reward,done =game.step(action)
    print(f"Run {new_state.shape} |Reward is {reward} | Done is {done}")
    time.sleep(2)