In [None]:
import pygame
import time
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from collections import deque
import os
import matplotlib.pyplot as plt
import heapq
# Pygame setup
pygame.init()
# Colors and window dimensions
blanc, jaune, noir, rouge, vert, bleu = (255, 255, 255), (255, 255, 102), (0, 0, 0), (213, 50, 80), (0, 255, 0), (50, 153, 213)
largeur_ecran, hauteur_ecran, taille_bloc, vitesse_serpent = 600, 400, 20, 40
fenetre = pygame.display.set_mode((largeur_ecran, hauteur_ecran))
pygame.display.set_caption("Snake Game - DQN")
horloge = pygame.time.Clock()
police = pygame.font.SysFont("bahnschrift", 25)

# Display score
def afficher_score(score):
    valeur = police.render("Score: " + str(score), True, jaune)
    fenetre.blit(valeur, [0, 0])

# Neural Network for DQN with two selectable architectures

class Linear_QNet(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, complex_model=False):
        super().__init__()
        
        # Define network architecture based on model complexity
        if complex_model:
            # Complex model with two hidden layers
            self.fc1 = nn.Linear(input_size, hidden_size)
            self.fc2 = nn.Linear(hidden_size, hidden_size * 2)
            self.fc3 = nn.Linear(hidden_size * 2, output_size)
            self.complex = True
        else:
            # Simple model with one hidden layer
            self.fc1 = nn.Linear(input_size, hidden_size)
            self.fc2 = nn.Linear(hidden_size, output_size)
            self.complex = False

    def forward(self, x):
        # First layer with ReLU activation
        x = F.relu(self.fc1(x))
        
        # Pass through additional layer if complex model
        if self.complex:
            x = F.relu(self.fc2(x))
            x = self.fc3(x)  # Output layer
        else:
            x = self.fc2(x)  # Output layer for simple model
        return x

    def save(self, file_name='model.pth'):
        # Save model weights to specified directory
        model_folder_path = './models'
        if not os.path.exists(model_folder_path):
            os.makedirs(model_folder_path)
        
        file_path = os.path.join(model_folder_path, file_name)
        torch.save(self.state_dict(), file_path)

# Q-learning Trainer Class
class QTrainer:
    def __init__(self, model, lr, gamma):
        self.model = model
        self.gamma = gamma
        self.optimizer = optim.Adam(model.parameters(), lr=lr)
        self.criterion = nn.MSELoss()

    def train_step(self, state, action, reward, next_state, done):
        # Convert inputs to tensors, adding batch dimension if needed
        state = torch.FloatTensor(state).unsqueeze(0) if len(state.shape) == 1 else torch.FloatTensor(state)
        next_state = torch.FloatTensor(next_state).unsqueeze(0) if len(next_state.shape) == 1 else torch.FloatTensor(next_state)
        action = torch.LongTensor([action]) if isinstance(action, int) else torch.LongTensor(action)
        reward = torch.FloatTensor([reward]) if isinstance(reward, (int, float)) else torch.FloatTensor(reward)
        done = torch.FloatTensor([done]) if isinstance(done, bool) else torch.FloatTensor(done)

        # Predicted Q values for current state-action pairs
        pred = self.model(state).gather(1, action.view(-1, 1))

        # Calculate target Q values
        with torch.no_grad():
            next_q_values = self.model(next_state).max(1)[0]
            target_q_values = reward + (1 - done) * self.gamma * next_q_values
        
        # Compute loss
        loss = self.criterion(pred, target_q_values.unsqueeze(1))

        # Backpropagation
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()


# Paramètres pour le DQN (modifiés)
HIDDEN_SIZE = 128  # Increased hidden size
OUTPUT_SIZE = 4
BATCH_SIZE = 64  # Increased batch size
MEMORY_SIZE = 200000  # Increased memory size
UPDATE_TARGET_EVERY = 4  # Update target network less frequently
START_TRAINING_THRESHOLD = BATCH_SIZE * 10
MAX_EPISODES = 1000
# Agent with model, training, and actions
class AIAgent:
    def __init__(self, input_size, hidden_size, output_size, gamma=0.99, lr=0.001, epsilon_start=1.0, epsilon_end=0.01, epsilon_decay=0.9, complex_model=False, max_episodes=1000, batch_size=128):
        self.n_games = 0
        self.epsilon = epsilon_start
        self.epsilon_start = epsilon_start
        self.epsilon_end = epsilon_end
        self.epsilon_decay = epsilon_decay
        self.gamma = gamma
        self.memory = deque(maxlen=200000)
        
        # Model initialization
        self.model = Linear_QNet(input_size, hidden_size, output_size)
        self.target_model = Linear_QNet(input_size, hidden_size, output_size)
        self.target_model.load_state_dict(self.model.state_dict())
        self.trainer = QTrainer(self.model, lr=lr, gamma=gamma)
        
        # Tracking episodes and batch size for training
        self.games_played = 0
        self.episode_memory = []
        self.max_episodes = max_episodes
        self.batch_size = batch_size

    def get_state(self, snake, food):
        head = snake[-1]
        point_l = [head[0] - taille_bloc, head[1]]
        point_r = [head[0] + taille_bloc, head[1]]
        point_u = [head[0], head[1] - taille_bloc]
        point_d = [head[0], head[1] + taille_bloc]
        
        dir_l = head[0] > snake[-2][0] if len(snake) > 1 else False
        dir_r = head[0] < snake[-2][0] if len(snake) > 1 else False
        dir_u = head[1] > snake[-2][1] if len(snake) > 1 else False
        dir_d = head[1] < snake[-2][1] if len(snake) > 1 else False

        state = [
            # Danger straight
            (dir_r and self.is_collision(point_r, snake)) or 
            (dir_l and self.is_collision(point_l, snake)) or 
            (dir_u and self.is_collision(point_u, snake)) or 
            (dir_d and self.is_collision(point_d, snake)),

            # Danger right
            (dir_u and self.is_collision(point_r, snake)) or 
            (dir_d and self.is_collision(point_l, snake)) or 
            (dir_l and self.is_collision(point_u, snake)) or 
            (dir_r and self.is_collision(point_d, snake)),

            # Danger left
            (dir_d and self.is_collision(point_r, snake)) or 
            (dir_u and self.is_collision(point_l, snake)) or 
            (dir_r and self.is_collision(point_u, snake)) or 
            (dir_l and self.is_collision(point_d, snake)),
            
            # Move direction
            dir_l,
            dir_r,
            dir_u,
            dir_d,
            
            # Food location 
            food[0] < head[0],  # food left
            food[0] > head[0],  # food right
            food[1] < head[1],  # food up
            food[1] > head[1]   # food down
        ]

        return np.array(state, dtype=int)

    def is_collision(self, point, snake):
        if point[0] < 0 or point[0] >= largeur_ecran or point[1] < 0 or point[1] >= hauteur_ecran:
            return True
        if point in snake[:-1]:
            return True
        return False
    
    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))
    
    def train_short_memory(self, state, action, reward, next_state, done):
        self.trainer.train_step(state, action, reward, next_state, done)

    def train_long_memory(self):
        if len(self.memory) < self.batch_size:
            return
        
        batch = random.sample(self.memory, self.batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        
        # Train on batch
        for state, action, reward, next_state, done in zip(states, actions, rewards, next_states, dones):
            self.trainer.train_step(state, action, reward, next_state, done)

    def get_action(self, state):
        self.epsilon = 80 - self.games_played
        if random.randint(0, 200) < self.epsilon:
            return random.randint(0, 2)
        else:
            state_tensor = torch.tensor(state, dtype=torch.float)
            prediction = self.model(state_tensor)
            return torch.argmax(prediction).item()

   
    def store_episode(self, state, action, reward, next_state, done):
        self.episode_memory.append((state, action, reward, next_state, done))

    def train_on_episode(self):
        for experience in self.episode_memory:
            self.remember(*experience)
        
        # Train only if enough memory is stored
        if len(self.memory) >= self.batch_size:
            self.train_long_memory()
        
        # Clear episode memory after training
        self.episode_memory.clear()

def jeu():
    agent = AIAgent(11, HIDDEN_SIZE, OUTPUT_SIZE, complex_model=True)  
    nb_episodes, total_score, high_score = 0, 0, 0
    scores, avg_scores = [], []
    
    while nb_episodes < MAX_EPISODES:
        game_over = False
        score = 0
        serpent = [[largeur_ecran / 2, hauteur_ecran / 2]]
        x_nourriture, y_nourriture = round(random.randrange(0, largeur_ecran - taille_bloc) / taille_bloc) * taille_bloc, round(random.randrange(0, hauteur_ecran - taille_bloc) / taille_bloc) * taille_bloc
        steps_without_food = 0

        while not game_over:
            # Get the state and decide on an action
            state = agent.get_state(serpent, (x_nourriture, y_nourriture))
            action = agent.get_action(state)

            # Move snake based on action
            x_change, y_change = [(0, -taille_bloc), (0, taille_bloc), (-taille_bloc, 0), (taille_bloc, 0)][action]
            nouvelle_tete = [serpent[-1][0] + x_change, serpent[-1][1] + y_change]
            serpent.append(nouvelle_tete)

            # Check for game over
            if (nouvelle_tete[0] < 0 or nouvelle_tete[0] >= largeur_ecran or
                nouvelle_tete[1] < 0 or nouvelle_tete[1] >= hauteur_ecran or
                nouvelle_tete in serpent[:-1]):
                game_over = True
                reward = -10
            else:
                reward = -0.01  # Small survival reward

                # Check for food consumption
                if nouvelle_tete[0] == x_nourriture and nouvelle_tete[1] == y_nourriture:
                    reward = 50  # Eating food reward
                    score += 1
                    x_nourriture = round(random.randrange(0, largeur_ecran - taille_bloc) / taille_bloc) * taille_bloc
                    y_nourriture = round(random.randrange(0, hauteur_ecran - taille_bloc) / taille_bloc) * taille_bloc
                    steps_without_food = 0
                else:
                    serpent.pop(0)  # Snake moves without growing

                    # Reward adjustment for distance to food
                    if len(serpent) > 1:
                        distance_before = np.linalg.norm(np.array(serpent[-2]) - np.array([x_nourriture, y_nourriture]))
                    else:
                        distance_before = np.linalg.norm(np.array(serpent[-1]) - np.array([x_nourriture, y_nourriture]))
                    distance_after = np.linalg.norm(np.array(nouvelle_tete) - np.array([x_nourriture, y_nourriture]))
                    reward += 10 if distance_after < distance_before else -0.5
                    # distance_before = np.linalg.norm(np.array(serpent[-2]) - np.array([x_nourriture, y_nourriture]))
                    # distance_after = np.linalg.norm(np.array(nouvelle_tete) - np.array([x_nourriture, y_nourriture]))
                    # reward += 10 if distance_after < distance_before else -0.5
                    reward += 0.1  # Additional survival reward

                    steps_without_food += 1
                    if steps_without_food > 100:  # Penalize long survival without eating
                        reward -= 1
                        steps_without_food = 0

            # Prepare for the next step
            next_state = agent.get_state(serpent, (x_nourriture, y_nourriture))
            agent.store_episode(state, action, reward, next_state, game_over)

            # Update Pygame window
            fenetre.fill(noir)
            pygame.draw.rect(fenetre, rouge, [x_nourriture, y_nourriture, taille_bloc, taille_bloc])
            for bloc in serpent:
                pygame.draw.rect(fenetre, blanc, [bloc[0], bloc[1], taille_bloc, taille_bloc])
            afficher_score(score)
            pygame.display.update()
            horloge.tick(vitesse_serpent)

        # Training and score tracking
        agent.train_on_episode()
        nb_episodes += 1
        agent.games_played+=1
        total_score += score
        high_score = max(high_score, score)
        avg_score = total_score / nb_episodes
        scores.append(score)
        avg_scores.append(avg_score)

        # Update epsilon for exploration-exploitation balance
        # agent.update_epsilon(nb_episodes)

        print(f"Episode: {nb_episodes}/{MAX_EPISODES}, Score: {score}, Avg Score: {avg_score:.2f}, High Score: {high_score}, Epsilon: {agent.epsilon:.4f}")

        # Sync target network and save model periodically
        if nb_episodes % UPDATE_TARGET_EVERY == 0:
            agent.target_model.load_state_dict(agent.model.state_dict())
        if nb_episodes % 100 == 0:
            agent.model.save(f'snake_dqn_model_episode_{nb_episodes}.pth')
    
    pygame.quit()
    agent.plot_history(scores, avg_scores)
    print(f"Training completed. Total episodes: {nb_episodes}, Final Avg Score: {avg_score:.2f}, High Score: {high_score}")

# Run the game
jeu()