In [None]:
import pygame
import random
from enum import Enum
from collections import deque
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

# ==== GAME SETUP ====
BLOCK_SIZE = 20
SPEED = 100  # reduce if too fast

class Direction(Enum):
    RIGHT = 1
    LEFT = 2
    UP = 3
    DOWN = 4

class Point:
    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __eq__(self, other):
        if isinstance(other, Point):
            return self.x == other.x and self.y == other.y
        return False

class SnakeGameAI:
    def __init__(self, w=400, h=400):
        self.w = w
        self.h = h
        self.reset()
        self.display = pygame.display.set_mode((self.w, self.h))
        pygame.display.set_caption('Snake RL')
        self.clock = pygame.time.Clock()

    def reset(self):
        self.direction = Direction.RIGHT
        self.head = Point(self.w//2, self.h//2)
        self.snake = [self.head,
                      Point(self.head.x - BLOCK_SIZE, self.head.y),
                      Point(self.head.x - (2*BLOCK_SIZE), self.head.y)]
        self.score = 0
        self.food = None
        self._place_food()
        self.frame_iteration = 0

    def _place_food(self):
        x = random.randint(0, (self.w - BLOCK_SIZE) // BLOCK_SIZE) * BLOCK_SIZE
        y = random.randint(0, (self.h - BLOCK_SIZE) // BLOCK_SIZE) * BLOCK_SIZE
        self.food = Point(x, y)
        if self.food in self.snake:
            self._place_food()

    def play_step(self, action):
        self.frame_iteration += 1
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                pygame.quit()
                quit()

        self._move(action)  # Update the head
        self.snake.insert(0, self.head)

        reward = 0
        game_over = False
        if self._is_collision():
            game_over = True
            reward = -10
            return reward, game_over, self.score

        if self.head.x == self.food.x and self.head.y == self.food.y:
            self.score += 1
            reward = 10
            self._place_food()
        else:
            self.snake.pop()
        self._update_ui()
        self.clock.tick(SPEED)
        return reward, game_over, self.score

    def _is_collision(self, pt=None):
        if pt is None:
            pt = self.head
        return (
            pt.x >= self.w or pt.x < 0 or
            pt.y >= self.h or pt.y < 0 or
            pt in self.snake[1:]
        )

    def _update_ui(self):
        self.display.fill(pygame.Color('black'))
        for pt in self.snake:
            pygame.draw.rect(self.display, pygame.Color('green'), pygame.Rect(pt.x, pt.y, BLOCK_SIZE, BLOCK_SIZE))
        pygame.draw.rect(self.display, pygame.Color('red'), pygame.Rect(self.food.x, self.food.y, BLOCK_SIZE, BLOCK_SIZE))
        pygame.display.flip()

    def _move(self, action):
        clock_wise = [Direction.RIGHT, Direction.DOWN, Direction.LEFT, Direction.UP]
        idx = clock_wise.index(self.direction)

        if np.array_equal(action, [1, 0, 0]):
            new_dir = clock_wise[idx]  # straight
        elif np.array_equal(action, [0, 1, 0]):
            new_dir = clock_wise[(idx + 1) % 4]  # right turn
        else:
            new_dir = clock_wise[(idx - 1) % 4]  # left turn

        self.direction = new_dir

        x, y = self.snake[0].x, self.snake[0].y
        if self.direction == Direction.RIGHT:
            x += BLOCK_SIZE
        elif self.direction == Direction.LEFT:
            x -= BLOCK_SIZE
        elif self.direction == Direction.DOWN:
            y += BLOCK_SIZE
        elif self.direction == Direction.UP:
            y -= BLOCK_SIZE

        self.head = Point(x, y)  # set new head

    def get_state(self):
        head = self.snake[0]
        point_l = Point(head.x - BLOCK_SIZE, head.y)
        point_r = Point(head.x + BLOCK_SIZE, head.y)
        point_u = Point(head.x, head.y - BLOCK_SIZE)
        point_d = Point(head.x, head.y + BLOCK_SIZE)

        dir_l = self.direction == Direction.LEFT
        dir_r = self.direction == Direction.RIGHT
        dir_u = self.direction == Direction.UP
        dir_d = self.direction == Direction.DOWN

        # Danger detection (lookahead)
        danger_straight = (dir_r and self._is_collision(point_r)) or \
                          (dir_l and self._is_collision(point_l)) or \
                          (dir_u and self._is_collision(point_u)) or \
                          (dir_d and self._is_collision(point_d))

        danger_right = (dir_u and self._is_collision(point_r)) or \
                       (dir_d and self._is_collision(point_l)) or \
                       (dir_l and self._is_collision(point_u)) or \
                       (dir_r and self._is_collision(point_d))

        danger_left = (dir_u and self._is_collision(point_l)) or \
                      (dir_d and self._is_collision(point_r)) or \
                      (dir_l and self._is_collision(point_d)) or \
                      (dir_r and self._is_collision(point_u))

        # Food location
        food_left = self.food.x < head.x
        food_right = self.food.x > head.x
        food_up = self.food.y < head.y
        food_down = self.food.y > head.y

        state = [
            danger_straight,
            danger_right,
            danger_left,

            dir_l,
            dir_r,
            dir_u,
            dir_d,

            food_left,
            food_right,
            food_up,
            food_down
        ]

        return np.array(state, dtype=int)

# ==== AGENT + DQN ====
class Linear_QNet(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.l1 = nn.Linear(input_size, hidden_size)
        self.l2 = nn.Linear(hidden_size, hidden_size)
        self.l3 = nn.Linear(hidden_size, output_size)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.l1(x))
        x = self.relu(self.l2(x))
        x = self.l3(x)
        return x

class QTrainer:
    def __init__(self, model, target_model, lr, gamma):
        self.model = model
        self.target_model = target_model
        self.lr = lr
        self.gamma = gamma
        self.optimizer = optim.Adam(model.parameters(), lr=self.lr)
        self.criterion = nn.MSELoss()

    def train_step(self, state, action, reward, next_state, done):
        state = torch.tensor(state, dtype=torch.float)
        next_state = torch.tensor(next_state, dtype=torch.float)
        action = torch.tensor(action, dtype=torch.long)
        reward = torch.tensor(reward, dtype=torch.float)

        if len(state.shape) == 1:
            state = state.unsqueeze(0)
            next_state = next_state.unsqueeze(0)
            action = action.unsqueeze(0)
            reward = reward.unsqueeze(0)
            done = (done, )

        pred = self.model(state)
        target = pred.clone()
        for idx in range(len(done)):
            Q_new = reward[idx]
            if not done[idx]:
                with torch.no_grad():
                    # 1. Select best action using current model
                    next_q_values = self.model(next_state[idx])
                    next_action = torch.argmax(next_q_values).item()

                    # 2. Evaluate that action using target model
                    target_q_values = self.target_model(next_state[idx])
                    max_next_q = target_q_values[next_action]

                    # 3. Compute target
                    Q_new = reward[idx] + self.gamma * max_next_q
            target[idx][torch.argmax(action[idx]).item()] = Q_new

        self.optimizer.zero_grad()
        loss = self.criterion(target, pred)
        loss.backward()
        self.optimizer.step()

class Agent:
    def __init__(self):
        self.n_games = 0
        self.epsilon = 0  # randomness
        self.gamma = 0.9  # discount rate
        self.memory = deque(maxlen=100000)  # Replay buffer size
        self.model = Linear_QNet(11, 256, 3)  # Main Q-network model
        self.target_model = Linear_QNet(11, 256, 3) 
        self.target_model.load_state_dict(self.model.state_dict())
        self.target_model.eval()
        self.trainer = QTrainer(self.model, self.target_model, lr=0.001, gamma=self.gamma)
        self.update_target_network_counter = 0
        self.target_update_interval = 5  # Update target model every 1000 steps
        self.steps_counter = 0
        self.min_epsilon = 20
        self.max_epsilon = 80
        self.epsilon_decay = 0.5
        # self.update_target_model()
        
    def update_target_model(self):
        self.target_model.load_state_dict(self.model.state_dict())

    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))
        # Count steps for target network update
        self.steps_counter += 1
        if self.steps_counter >= self.target_update_interval:
            self.update_target_model()
            self.steps_counter = 0

    def train_long_memory(self):
        if len(self.memory) > 1000:
            mini_sample = random.sample(self.memory, 1000)  # Sample from memory
        else:
            mini_sample = self.memory

        # if self.n_games % 5 == 0:  # update target model every 5 games
        #     self.update_target_model()

        states, actions, rewards, next_states, dones = zip(*mini_sample)
        self.trainer.train_step(states, actions, rewards, next_states, dones)

    def train_short_memory(self, state, action, reward, next_state, done):
        self.trainer.train_step(state, action, reward, next_state, done)

    # def update_target_network(self):
    #     self.update_target_network_counter += 1
    #     if self.update_target_network_counter >= self.target_update_interval:
    #         self.update_target_network()

    def get_action(self, state):
        self.epsilon = self.max_epsilon - self.n_games * self.epsilon_decay  # Reduce exploration over time
        self.epsilon = max(self.min_epsilon, self.epsilon)
            
        final_move = [0, 0, 0]
        if random.randint(0, 100) < self.epsilon:
            move = random.randint(0, 2)
            final_move[move] = 1
        else:
            state0 = torch.tensor(state, dtype=torch.float)
            prediction = self.model(state0)
            move = torch.argmax(prediction).item()
            final_move[move] = 1
        return final_move

# ==== MAIN LOOP ====
def train():
    pygame.init()
    game = SnakeGameAI()
    agent = Agent()
    record = 0

    while True:
        state_old = game.get_state()
        final_move = agent.get_action(state_old)
        reward, done, score = game.play_step(final_move)
        state_new = game.get_state()

        agent.train_short_memory(state_old, final_move, reward, state_new, done)
        agent.remember(state_old, final_move, reward, state_new, done)

        if done:
            game.reset()
            agent.n_games += 1
            agent.train_long_memory()

            # # Update target network every N games
            # if agent.n_games % agent.target_update_interval == 0:
            #     agent.update_target_network()

            if score > record:
                record = score
                torch.save(agent.model.state_dict(), 'model.pth')

            print(f'Game: {agent.n_games}, Score: {score}, Record: {record}')


if __name__ == '__main__':
    train()

pygame 2.6.1 (SDL 2.28.4, Python 3.11.9)
Hello from the pygame community. https://www.pygame.org/contribute.html


  state = torch.tensor(state, dtype=torch.float)


Game: 1, Score: 0, Record: 0
Game: 2, Score: 0, Record: 0
Game: 3, Score: 0, Record: 0
Game: 4, Score: 0, Record: 0
Game: 5, Score: 0, Record: 0
Game: 6, Score: 1, Record: 1
Game: 7, Score: 0, Record: 1
Game: 8, Score: 0, Record: 1
Game: 9, Score: 0, Record: 1
Game: 10, Score: 0, Record: 1
Game: 11, Score: 0, Record: 1
Game: 12, Score: 1, Record: 1
Game: 13, Score: 0, Record: 1
Game: 14, Score: 0, Record: 1
Game: 15, Score: 0, Record: 1
Game: 16, Score: 0, Record: 1
Game: 17, Score: 0, Record: 1
Game: 18, Score: 0, Record: 1
Game: 19, Score: 0, Record: 1
Game: 20, Score: 1, Record: 1
Game: 21, Score: 0, Record: 1
Game: 22, Score: 0, Record: 1
Game: 23, Score: 0, Record: 1
Game: 24, Score: 0, Record: 1
Game: 25, Score: 0, Record: 1
Game: 26, Score: 0, Record: 1
Game: 27, Score: 1, Record: 1
Game: 28, Score: 0, Record: 1
Game: 29, Score: 0, Record: 1
Game: 30, Score: 1, Record: 1
Game: 31, Score: 0, Record: 1
Game: 32, Score: 1, Record: 1
Game: 33, Score: 0, Record: 1
Game: 34, Score: 0,

error: display Surface quit

: 