In [4]:
pip install pygame matplotlib numpy

Collecting pygame
  Using cached pygame-2.6.1-cp313-cp313-win_amd64.whl.metadata (13 kB)
Collecting matplotlib
  Using cached matplotlib-3.10.3-cp313-cp313-win_amd64.whl.metadata (11 kB)
Collecting numpy
  Using cached numpy-2.3.1-cp313-cp313-win_amd64.whl.metadata (60 kB)
Collecting contourpy>=1.0.1 (from matplotlib)
  Downloading contourpy-1.3.2-cp313-cp313-win_amd64.whl.metadata (5.5 kB)
Collecting cycler>=0.10 (from matplotlib)
  Using cached cycler-0.12.1-py3-none-any.whl.metadata (3.8 kB)
Collecting fonttools>=4.22.0 (from matplotlib)
  Downloading fonttools-4.59.0-cp313-cp313-win_amd64.whl.metadata (110 kB)
Collecting kiwisolver>=1.3.1 (from matplotlib)
  Using cached kiwisolver-1.4.8-cp313-cp313-win_amd64.whl.metadata (6.3 kB)
Collecting pillow>=8 (from matplotlib)
  Downloading pillow-11.3.0-cp313-cp313-win_amd64.whl.metadata (9.2 kB)
Collecting pyparsing>=2.3.1 (from matplotlib)
  Using cached pyparsing-3.2.3-py3-none-any.whl.metadata (5.0 kB)
Using cached pygame-2.6.1-cp313-

In [7]:
import pygame
import random
import numpy as np
from collections import defaultdict
import matplotlib.pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg
import time

class DynaQAgent:
    """Algorithme Dyna-Q selon Sutton & Barto"""
    
    def __init__(self, n_actions, alpha=0.1, gamma=0.95, epsilon=0.1, n_planning=50):
        self.n_actions = n_actions
        self.alpha = alpha  # Taux d'apprentissage
        self.gamma = gamma  # Facteur de discount
        self.epsilon = epsilon  # Exploration ε-greedy
        self.n_planning = n_planning  # Étapes de planification
        
        # Table Q : Q[état][action] = valeur
        self.Q = defaultdict(lambda: defaultdict(float))
        
        # Modèle : Model[état][action] = (récompense, nouvel_état)
        self.Model = defaultdict(lambda: defaultdict(lambda: None))
        
        # Paires (état, action) visitées pour planification
        self.visited_state_actions = set()
    
    def select_action(self, state, valid_actions=None):
        """Politique ε-greedy"""
        if valid_actions is None:
            valid_actions = list(range(self.n_actions))
        
        if random.random() < self.epsilon:
            return random.choice(valid_actions)
        else:
            q_values = [self.Q[state][action] for action in valid_actions]
            max_q = max(q_values)
            best_actions = [a for a, q in zip(valid_actions, q_values) if q == max_q]
            return random.choice(best_actions)
    
    def update_q(self, state, action, reward, next_state, valid_next_actions=None):
        """Mise à jour Q-learning"""
        if valid_next_actions is None:
            valid_next_actions = list(range(self.n_actions))
        
        if valid_next_actions:
            max_next_q = max([self.Q[next_state][a] for a in valid_next_actions])
        else:
            max_next_q = 0.0
        
        # Q(s,a) ← Q(s,a) + α[r + γ max Q(s',a') - Q(s,a)]
        current_q = self.Q[state][action]
        td_target = reward + self.gamma * max_next_q
        self.Q[state][action] = current_q + self.alpha * (td_target - current_q)
    
    def update_model(self, state, action, reward, next_state):
        """Mise à jour du modèle"""
        self.Model[state][action] = (reward, next_state)
        self.visited_state_actions.add((state, action))
    
    def planning_step(self, get_valid_actions_fn=None):
        """Une étape de planification"""
        if not self.visited_state_actions:
            return
        
        # Échantillonner (état, action) aléatoirement
        state, action = random.choice(list(self.visited_state_actions))
        
        if self.Model[state][action] is not None:
            reward, next_state = self.Model[state][action]
            
            # Actions valides dans next_state
            if get_valid_actions_fn:
                valid_next_actions = get_valid_actions_fn(next_state)
            else:
                valid_next_actions = list(range(self.n_actions))
            
            # Mise à jour Q avec expérience simulée
            self.update_q(state, action, reward, next_state, valid_next_actions)
    
    def learn_step(self, state, action, reward, next_state, valid_next_actions=None, get_valid_actions_fn=None):
        """Étape complète Dyna-Q"""
        
        # 1. APPRENTISSAGE DIRECT
        self.update_q(state, action, reward, next_state, valid_next_actions)
        
        # 2. MISE À JOUR MODÈLE
        self.update_model(state, action, reward, next_state)
        
        # 3. PLANIFICATION (n étapes)
        for _ in range(self.n_planning):
            self.planning_step(get_valid_actions_fn)

class LineWorld:
    """Environnement Line World - monde linéaire 1D"""
    
    def __init__(self, size=10, goal_state=None, obstacles=None):
        self.size = size
        self.goal_state = goal_state if goal_state is not None else size - 1
        self.obstacles = obstacles if obstacles is not None else []
        self.current_state = 0
        self.actions = ['LEFT', 'RIGHT']  # 0: gauche, 1: droite
        self.n_actions = len(self.actions)
        
    def reset(self):
        """Réinitialiser l'environnement"""
        self.current_state = 0
        return self.current_state
    
    def get_valid_actions(self, state):
        """Retourne les actions valides pour un état donné"""
        valid_actions = []
        
        # Gauche (0) : possible si pas au début et pas d'obstacle
        if state > 0 and (state - 1) not in self.obstacles:
            valid_actions.append(0)
            
        # Droite (1) : possible si pas à la fin et pas d'obstacle
        if state < self.size - 1 and (state + 1) not in self.obstacles:
            valid_actions.append(1)
            
        return valid_actions
    
    def step(self, action):
        """Exécuter une action"""
        valid_actions = self.get_valid_actions(self.current_state)
        
        if action not in valid_actions:
            # Action invalide, reste sur place
            reward = -0.1  # Petite pénalité
            done = False
            return self.current_state, reward, done
        
        # Exécuter l'action
        if action == 0:  # Gauche
            self.current_state = max(0, self.current_state - 1)
        elif action == 1:  # Droite
            self.current_state = min(self.size - 1, self.current_state + 1)
        
        # Calculer la récompense
        if self.current_state == self.goal_state:
            reward = 10.0  # Récompense pour atteindre le but
            done = True
        else:
            reward = -0.01  # Petite pénalité pour chaque pas
            done = False
            
        return self.current_state, reward, done

class DynaQVisualization:
    """Visualisation Pygame pour Dyna-Q dans Line World"""
    
    def __init__(self, world, agent, width=1200, height=600):
        pygame.init()
        self.world = world
        self.agent = agent
        self.width = width
        self.height = height
        self.screen = pygame.display.set_mode((width, height))
        pygame.display.set_caption("Dyna-Q Line World")
        
        # Couleurs
        self.WHITE = (255, 255, 255)
        self.BLACK = (0, 0, 0)
        self.BLUE = (0, 0, 255)
        self.RED = (255, 0, 0)
        self.GREEN = (0, 255, 0)
        self.GRAY = (128, 128, 128)
        self.YELLOW = (255, 255, 0)
        self.PURPLE = (128, 0, 128)
        
        # Paramètres de visualisation
        self.cell_width = 60
        self.cell_height = 60
        self.world_y = height // 2 - self.cell_height // 2
        self.world_x = (width - world.size * self.cell_width) // 2
        
        self.font = pygame.font.Font(None, 24)
        self.small_font = pygame.font.Font(None, 16)
        
        # Statistiques
        self.episode_rewards = []
        self.episode_steps = []
        self.current_episode = 0
        
    def draw_world(self):
        """Dessiner le monde linéaire"""
        for i in range(self.world.size):
            x = self.world_x + i * self.cell_width
            y = self.world_y
            
            # Couleur de la cellule
            if i == self.world.goal_state:
                color = self.GREEN
            elif i in self.world.obstacles:
                color = self.GRAY
            elif i == self.world.current_state:
                color = self.BLUE
            else:
                color = self.WHITE
            
            # Dessiner la cellule
            pygame.draw.rect(self.screen, color, (x, y, self.cell_width, self.cell_height))
            pygame.draw.rect(self.screen, self.BLACK, (x, y, self.cell_width, self.cell_height), 2)
            
            # Numéro de l'état
            text = self.font.render(str(i), True, self.BLACK)
            text_rect = text.get_rect(center=(x + self.cell_width // 2, y + self.cell_height // 2))
            self.screen.blit(text, text_rect)
    
    def draw_q_values(self):
        """Dessiner les valeurs Q sous chaque état"""
        for state in range(self.world.size):
            x = self.world_x + state * self.cell_width
            y = self.world_y + self.cell_height + 10
            
            # Valeurs Q pour cet état
            q_left = self.agent.Q[state][0]
            q_right = self.agent.Q[state][1]
            
            # Afficher Q-values
            q_text = f"L:{q_left:.2f} R:{q_right:.2f}"
            text = self.small_font.render(q_text, True, self.BLACK)
            text_rect = text.get_rect(center=(x + self.cell_width // 2, y))
            self.screen.blit(text, text_rect)
    
    def draw_policy(self):
        """Dessiner la politique (flèches)"""
        for state in range(self.world.size):
            if state == self.world.goal_state:
                continue
                
            valid_actions = self.world.get_valid_actions(state)
            if not valid_actions:
                continue
                
            # Choisir la meilleure action
            q_values = [self.agent.Q[state][action] for action in valid_actions]
            if max(q_values) == min(q_values):
                continue  # Pas de préférence claire
                
            best_action = valid_actions[np.argmax(q_values)]
            
            # Dessiner la flèche
            x = self.world_x + state * self.cell_width + self.cell_width // 2
            y = self.world_y - 20
            
            if best_action == 0:  # Gauche
                pygame.draw.polygon(self.screen, self.RED, 
                                  [(x-10, y), (x+5, y-5), (x+5, y+5)])
            else:  # Droite
                pygame.draw.polygon(self.screen, self.RED, 
                                  [(x+10, y), (x-5, y-5), (x-5, y+5)])
    
    def draw_stats(self):
        """Dessiner les statistiques"""
        y_offset = 20
        
        # Informations générales
        info_lines = [
            f"Épisode: {self.current_episode}",
            f"Position: {self.world.current_state}",
            f"Epsilon: {self.agent.epsilon:.3f}",
            f"Alpha: {self.agent.alpha:.3f}",
            f"Planning steps: {self.agent.n_planning}",
        ]
        
        for i, line in enumerate(info_lines):
            text = self.font.render(line, True, self.BLACK)
            self.screen.blit(text, (20, y_offset + i * 25))
        
        # Statistiques récentes
        if self.episode_rewards:
            recent_rewards = self.episode_rewards[-10:]  # 10 derniers épisodes
            avg_reward = sum(recent_rewards) / len(recent_rewards)
            
            stats_lines = [
                f"Récompense moyenne (10 derniers): {avg_reward:.2f}",
                f"Dernier épisode - Récompense: {self.episode_rewards[-1]:.2f}",
                f"Dernier épisode - Steps: {self.episode_steps[-1]}",
            ]
            
            for i, line in enumerate(stats_lines):
                text = self.font.render(line, True, self.BLACK)
                self.screen.blit(text, (20, y_offset + 150 + i * 25))
    
    def draw_legend(self):
        """Dessiner la légende"""
        legend_x = self.width - 200
        legend_y = 20
        
        legend_items = [
            ("Agent", self.BLUE),
            ("But", self.GREEN),
            ("Obstacle", self.GRAY),
            ("Vide", self.WHITE),
        ]
        
        title = self.font.render("Légende:", True, self.BLACK)
        self.screen.blit(title, (legend_x, legend_y))
        
        for i, (label, color) in enumerate(legend_items):
            y = legend_y + 30 + i * 25
            pygame.draw.rect(self.screen, color, (legend_x, y, 20, 20))
            pygame.draw.rect(self.screen, self.BLACK, (legend_x, y, 20, 20), 1)
            text = self.font.render(label, True, self.BLACK)
            self.screen.blit(text, (legend_x + 30, y))
    
    def update_display(self):
        """Mettre à jour l'affichage"""
        self.screen.fill(self.WHITE)
        
        self.draw_world()
        self.draw_q_values()
        self.draw_policy()
        self.draw_stats()
        self.draw_legend()
        
        pygame.display.flip()
    
    def run_episode(self, max_steps=100, delay=0.1):
        """Exécuter un épisode avec visualisation"""
        state = self.world.reset()
        total_reward = 0
        steps = 0
        
        for step in range(max_steps):
            # Gérer les événements Pygame
            for event in pygame.event.get():
                if event.type == pygame.QUIT:
                    return False, total_reward, steps
            
            # Sélectionner et exécuter une action
            valid_actions = self.world.get_valid_actions(state)
            if not valid_actions:
                break
                
            action = self.agent.select_action(state, valid_actions)
            next_state, reward, done = self.world.step(action)
            
            # Apprentissage Dyna-Q
            valid_next_actions = self.world.get_valid_actions(next_state)
            self.agent.learn_step(state, action, reward, next_state, 
                                valid_next_actions, self.world.get_valid_actions)
            
            total_reward += reward
            steps += 1
            state = next_state
            
            # Mettre à jour l'affichage
            self.update_display()
            time.sleep(delay)
            
            if done:
                break
        
        return True, total_reward, steps
    
    def run_training(self, n_episodes=100, delay=0.1):
        """Exécuter l'entraînement complet"""
        print("Démarrage de l'entraînement Dyna-Q...")
        print("Appuyez sur ESC pour arrêter, SPACE pour pause")
        
        running = True
        paused = False
        
        for episode in range(n_episodes):
            if not running:
                break
                
            self.current_episode = episode + 1
            
            # Gérer les événements
            for event in pygame.event.get():
                if event.type == pygame.QUIT:
                    running = False
                    break
                elif event.type == pygame.KEYDOWN:
                    if event.key == pygame.K_ESCAPE:
                        running = False
                        break
                    elif event.key == pygame.K_SPACE:
                        paused = not paused
            
            if paused:
                continue
            
            # Exécuter un épisode
            continue_training, reward, steps = self.run_episode(delay=delay)
            if not continue_training:
                break
            
            # Enregistrer les statistiques
            self.episode_rewards.append(reward)
            self.episode_steps.append(steps)
            
            # Réduire progressivement l'exploration
            if episode % 10 == 0:
                self.agent.epsilon = max(0.01, self.agent.epsilon * 0.95)
            
            # Afficher les progrès
            if episode % 20 == 0:
                avg_reward = sum(self.episode_rewards[-20:]) / min(20, len(self.episode_rewards))
                print(f"Épisode {episode}: Récompense moyenne = {avg_reward:.2f}, "
                      f"Epsilon = {self.agent.epsilon:.3f}")
        
        pygame.quit()
        print("Entraînement terminé!")

def main():
    """Fonction principale"""
    # Créer l'environnement Line World
    world = LineWorld(size=8, goal_state=7, obstacles=[3, 5])
    
    # Créer l'agent Dyna-Q
    agent = DynaQAgent(
        n_actions=2,
        alpha=0.1,
        gamma=0.95,
        epsilon=0.3,
        n_planning=10
    )
    
    # Créer la visualisation
    viz = DynaQVisualization(world, agent)
    
    # Lancer l'entraînement
    viz.run_training(n_episodes=200, delay=0.05)

if __name__ == "__main__":
    main()

Démarrage de l'entraînement Dyna-Q...
Appuyez sur ESC pour arrêter, SPACE pour pause
Épisode 0: Récompense moyenne = -1.00, Epsilon = 0.285
Épisode 20: Récompense moyenne = -1.00, Epsilon = 0.257
Épisode 40: Récompense moyenne = -1.00, Epsilon = 0.232
Épisode 60: Récompense moyenne = -1.00, Epsilon = 0.210
Épisode 80: Récompense moyenne = -1.00, Epsilon = 0.189
Épisode 100: Récompense moyenne = -1.00, Epsilon = 0.171
Épisode 120: Récompense moyenne = -1.00, Epsilon = 0.154
Épisode 140: Récompense moyenne = -1.00, Epsilon = 0.139
Épisode 160: Récompense moyenne = -1.00, Epsilon = 0.125
Épisode 180: Récompense moyenne = -1.00, Epsilon = 0.113
Entraînement terminé!


In [None]:
import pygame
import random
import numpy as np
from collections import defaultdict
import time
import matplotlib.pyplot as plt

class DynaQAgent:
    """Algorithme Dyna-Q selon Sutton & Barto"""
    
    def __init__(self, n_actions, alpha=0.1, gamma=0.95, epsilon=0.1, n_planning=50):
        self.n_actions = n_actions
        self.alpha = alpha  # Taux d'apprentissage
        self.gamma = gamma  # Facteur de discount
        self.epsilon = epsilon  # Exploration ε-greedy
        self.n_planning = n_planning  # Étapes de planification
        
        # Table Q : Q[état][action] = valeur
        self.Q = defaultdict(lambda: defaultdict(float))
        
        # Modèle : Model[état][action] = (récompense, nouvel_état)
        self.Model = defaultdict(lambda: defaultdict(lambda: None))
        
        # Paires (état, action) visitées pour planification
        self.visited_state_actions = set()
    
    def select_action(self, state, valid_actions=None):
        """Politique ε-greedy"""
        if valid_actions is None:
            valid_actions = list(range(self.n_actions))
        
        if random.random() < self.epsilon:
            return random.choice(valid_actions)
        else:
            q_values = [self.Q[state][action] for action in valid_actions]
            if not q_values:
                return random.choice(valid_actions)
            max_q = max(q_values)
            best_actions = [a for a, q in zip(valid_actions, q_values) if q == max_q]
            return random.choice(best_actions)
    
    def update_q(self, state, action, reward, next_state, valid_next_actions=None):
        """Mise à jour Q-learning"""
        if valid_next_actions is None:
            valid_next_actions = list(range(self.n_actions))
        
        if valid_next_actions:
            max_next_q = max([self.Q[next_state][a] for a in valid_next_actions])
        else:
            max_next_q = 0.0
        
        # Q(s,a) ← Q(s,a) + α[r + γ max Q(s',a') - Q(s,a)]
        current_q = self.Q[state][action]
        td_target = reward + self.gamma * max_next_q
        self.Q[state][action] = current_q + self.alpha * (td_target - current_q)
    
    def update_model(self, state, action, reward, next_state):
        """Mise à jour du modèle"""
        self.Model[state][action] = (reward, next_state)
        self.visited_state_actions.add((state, action))
    
    def planning_step(self, get_valid_actions_fn=None):
        """Une étape de planification"""
        if not self.visited_state_actions:
            return
        
        # Échantillonner (état, action) aléatoirement
        state, action = random.choice(list(self.visited_state_actions))
        
        if self.Model[state][action] is not None:
            reward, next_state = self.Model[state][action]
            
            # Actions valides dans next_state
            if get_valid_actions_fn:
                valid_next_actions = get_valid_actions_fn(next_state)
            else:
                valid_next_actions = list(range(self.n_actions))
            
            # Mise à jour Q avec expérience simulée
            self.update_q(state, action, reward, next_state, valid_next_actions)
    
    def learn_step(self, state, action, reward, next_state, valid_next_actions=None, get_valid_actions_fn=None):
        """Étape complète Dyna-Q"""
        
        # 1. APPRENTISSAGE DIRECT
        self.update_q(state, action, reward, next_state, valid_next_actions)
        
        # 2. MISE À JOUR MODÈLE
        self.update_model(state, action, reward, next_state)
        
        # 3. PLANIFICATION (n étapes)
        for _ in range(self.n_planning):
            self.planning_step(get_valid_actions_fn)


class LineWorld:
    def __init__(self, size=5, start_state=2):
        """
        Initialize the LineWorld environment.

        Args:
            size (int): Number of states in the environment (default 5).
            start_state (int): The agent's initial position (default 2).
        """
        self.size = size
        self.start_state = start_state
        self.state = start_state
        self.terminal_states = [0, size - 1]
        self.action_space = [0, 1]  # 0: left, 1: right

    def state_to_index(self, state):
        """For compatibility with RL algos. State is already int, so just return."""
        return state

    def index_to_state(self, index):
        """For compatibility with RL algos. State is already int, so just return."""
        return index

    @property
    def n_states(self):
        """Total number of states (for tabular RL)."""
        return self.size

    def reset(self):
        """
        Reset the environment to the initial state.
        Returns:
            int: The initial state.
        """
        self.state = self.start_state
        return self.state

    def is_terminal(self, state):
        """
        Check if a given state is terminal.
        Args:
            state (int): State to check.
        Returns:
            bool: True if the state is terminal, False otherwise.
        """
        return state in self.terminal_states

    def get_reward(self, next_state):
        """
        Get the reward for transitioning into the next state.
        Args:
            next_state (int): The state resulting from the agent's action.
        Returns:
            float: The reward received for the transition.
        """
        if next_state == 0:
            return -1.0
        elif next_state == self.size - 1:
            return 1.0
        else:
            return 0.0

    def get_valid_actions(self, state):
        """Get valid actions for a given state"""
        if self.is_terminal(state):
            return []
        return self.action_space.copy()

    def step(self, action):
        """
        Take an action in the environment.
        Args:
            action (int): Action to take (0 = left, 1 = right).
        Returns:
            tuple: (next_state (int), reward (float), done (bool))
        """
        if self.is_terminal(self.state):
            return self.state, 0.0, True

        if action == 0:
            next_state = max(self.state - 1, 0)
        elif action == 1:
            next_state = min(self.state + 1, self.size - 1)
        else:
            raise ValueError("Invalid action (0=left, 1=right)")

        reward = self.get_reward(next_state)
        done = self.is_terminal(next_state)
        self.state = next_state
        return next_state, reward, done

    def simulate_step(self, state, action):
        """
        Simulate a step from a given state and action without modifying self.state.
        Returns:
            next_state, reward, done
        """
        if self.is_terminal(state):
            return state, 0.0, True

        if action == 0:
            next_state = max(state - 1, 0)
        elif action == 1:
            next_state = min(state + 1, self.size - 1)
        else:
            raise ValueError("Invalid action (0=left, 1=right)")

        reward = self.get_reward(next_state)
        done = self.is_terminal(next_state)
        return next_state, reward, done

    def render(self):
        """
        Print a visual representation of the environment.
        'A': Agent position
        'T': Terminal state(s)
        '_': Normal state
        """
        line = []
        for i in range(self.size):
            if i == self.state:
                line.append('A')
            elif i in self.terminal_states:
                line.append('T')
            else:
                line.append('_')
        print(' '.join(line))


class DynaQLineWorldVisualization:
    """Visualisation Pygame pour Dyna-Q dans LineWorld"""
    
    def __init__(self, world, agent, width=1200, height=700):
        pygame.init()
        self.world = world
        self.agent = agent
        self.width = width
        self.height = height
        self.screen = pygame.display.set_mode((width, height))
        pygame.display.set_caption("Dyna-Q LineWorld - Apprentissage par Renforcement")
        
        # Couleurs
        self.WHITE = (255, 255, 255)
        self.BLACK = (0, 0, 0)
        self.BLUE = (0, 100, 255)
        self.RED = (255, 50, 50)
        self.GREEN = (50, 255, 50)
        self.GRAY = (180, 180, 180)
        self.YELLOW = (255, 255, 0)
        self.PURPLE = (150, 50, 200)
        self.DARK_GREEN = (0, 150, 0)
        self.DARK_RED = (150, 0, 0)
        self.LIGHT_BLUE = (173, 216, 230)
        
        # Paramètres de visualisation
        self.cell_width = 80
        self.cell_height = 80
        self.world_y = height // 2 - self.cell_height // 2
        self.world_x = (width - world.size * self.cell_width) // 2
        
        # Polices
        self.font_large = pygame.font.Font(None, 28)
        self.font_medium = pygame.font.Font(None, 24)
        self.font_small = pygame.font.Font(None, 18)
        self.font_tiny = pygame.font.Font(None, 14)
        
        # Statistiques
        self.episode_rewards = []
        self.episode_steps = []
        self.current_episode = 0
        self.total_steps = 0
        
        # Animation
        self.animation_step = 0
        
    def draw_world(self):
        """Dessiner le monde linéaire"""
        for i in range(self.world.size):
            x = self.world_x + i * self.cell_width
            y = self.world_y
            
            # Couleur de base selon le type de cellule
            if i in self.world.terminal_states:
                if i == 0:  # Terminal négatif
                    base_color = self.DARK_RED
                    label_color = self.WHITE
                    reward_text = "-1"
                else:  # Terminal positif
                    base_color = self.DARK_GREEN
                    label_color = self.WHITE
                    reward_text = "+1"
            else:
                base_color = self.LIGHT_BLUE
                label_color = self.BLACK
                reward_text = "0"
            
            # Effet de surbrillance pour la position actuelle
            if i == self.world.state:
                # Animation de pulsation
                pulse = abs(np.sin(self.animation_step * 0.1)) * 30
                base_color = tuple(min(255, c + pulse) for c in base_color)
            
            # Dessiner la cellule
            pygame.draw.rect(self.screen, base_color, (x, y, self.cell_width, self.cell_height))
            pygame.draw.rect(self.screen, self.BLACK, (x, y, self.cell_width, self.cell_height), 3)
            
            # Numéro de l'état
            state_text = self.font_medium.render(f"S{i}", True, label_color)
            state_rect = state_text.get_rect(center=(x + self.cell_width // 2, y + 15))
            self.screen.blit(state_text, state_rect)
            
            # Récompense
            reward_surface = self.font_small.render(f"R: {reward_text}", True, label_color)
            reward_rect = reward_surface.get_rect(center=(x + self.cell_width // 2, y + 35))
            self.screen.blit(reward_surface, reward_rect)
            
            # Agent (marqueur spécial)
            if i == self.world.state:
                agent_radius = 15
                pygame.draw.circle(self.screen, self.YELLOW, 
                                 (x + self.cell_width // 2, y + self.cell_height - 25), 
                                 agent_radius)
                pygame.draw.circle(self.screen, self.BLACK, 
                                 (x + self.cell_width // 2, y + self.cell_height - 25), 
                                 agent_radius, 2)
                agent_text = self.font_small.render("A", True, self.BLACK)
                agent_rect = agent_text.get_rect(center=(x + self.cell_width // 2, y + self.cell_height - 25))
                self.screen.blit(agent_text, agent_rect)
    
    def draw_q_values(self):
        """Dessiner les valeurs Q sous chaque état"""
        for state in range(self.world.size):
            if self.world.is_terminal(state):
                continue
                
            x = self.world_x + state * self.cell_width
            y = self.world_y + self.cell_height + 10
            
            # Valeurs Q pour cet état
            q_left = self.agent.Q[state][0]
            q_right = self.agent.Q[state][1]
            
            # Couleurs basées sur les valeurs
            left_color = self.GREEN if q_left > q_right else self.GRAY
            right_color = self.GREEN if q_right > q_left else self.GRAY
            
            # Afficher Q-values avec couleurs
            left_text = f"←{q_left:.2f}"
            right_text = f"→{q_right:.2f}"
            
            left_surface = self.font_tiny.render(left_text, True, left_color)
            right_surface = self.font_tiny.render(right_text, True, right_color)
            
            # Positionner les textes
            left_rect = left_surface.get_rect(center=(x + self.cell_width // 4, y))
            right_rect = right_surface.get_rect(center=(x + 3 * self.cell_width // 4, y))
            
            self.screen.blit(left_surface, left_rect)
            self.screen.blit(right_surface, right_rect)
    
    def draw_policy(self):
        """Dessiner la politique (flèches pour la meilleure action)"""
        for state in range(self.world.size):
            if self.world.is_terminal(state):
                continue
                
            valid_actions = self.world.get_valid_actions(state)
            if not valid_actions:
                continue
            
            # Trouver la meilleure action
            q_values = [self.agent.Q[state][action] for action in valid_actions]
            if max(q_values) == min(q_values):
                continue  # Pas de préférence claire
                
            best_action = valid_actions[np.argmax(q_values)]
            
            # Dessiner la flèche
            x = self.world_x + state * self.cell_width + self.cell_width // 2
            y = self.world_y - 30
            
            arrow_size = 12
            if best_action == 0:  # Gauche
                points = [(x-arrow_size, y), (x+arrow_size//2, y-arrow_size//2), (x+arrow_size//2, y+arrow_size//2)]
                pygame.draw.polygon(self.screen, self.PURPLE, points)
            else:  # Droite
                points = [(x+arrow_size, y), (x-arrow_size//2, y-arrow_size//2), (x-arrow_size//2, y+arrow_size//2)]
                pygame.draw.polygon(self.screen, self.PURPLE, points)
    
    def draw_model_info(self):
        """Afficher les informations sur le modèle appris"""
        y_start = self.world_y + self.cell_height + 60
        
        title = self.font_medium.render("Modèle Dyna-Q (échantillon):", True, self.BLACK)
        self.screen.blit(title, (20, y_start))
        
        # Afficher quelques transitions apprises
        displayed = 0
        max_display = 6
        
        # Parcourir correctement le modèle à deux niveaux
        for state in list(self.agent.Model.keys())[:max_display]:
            if displayed >= max_display:
                break
                
            for action in self.agent.Model[state]:
                if displayed >= max_display:
                    break
                    
                transition = self.agent.Model[state][action]
                if transition is not None:  # Vérifier qu'il y a bien une transition
                    reward, next_state = transition
                    action_name = "←" if action == 0 else "→"
                    text = f"S{state} {action_name} → S{next_state} (R={reward})"
                    surface = self.font_small.render(text, True, self.BLACK)
                    self.screen.blit(surface, (20, y_start + 25 + displayed * 20))
                    displayed += 1
    
    def draw_statistics(self):
        """Dessiner les statistiques détaillées"""
        stats_x = 20
        stats_y = 20
        
        # Titre
        title = self.font_large.render("Statistiques Dyna-Q", True, self.BLACK)
        self.screen.blit(title, (stats_x, stats_y))
        
        # Informations de l'épisode
        info_lines = [
            f"Épisode: {self.current_episode}",
            f"Position: S{self.world.state}",
            f"Steps totaux: {self.total_steps}",
            "",
            f"Paramètres:",
            f"  ε (exploration): {self.agent.epsilon:.3f}",
            f"  α (apprentissage): {self.agent.alpha:.3f}",
            f"  γ (discount): {self.agent.gamma:.3f}",
            f"  Planning steps: {self.agent.n_planning}",
        ]
        
        for i, line in enumerate(info_lines):
            if line:  # Ignorer les lignes vides pour le rendu
                text = self.font_small.render(line, True, self.BLACK)
                self.screen.blit(text, (stats_x, stats_y + 35 + i * 18))
        
        # Statistiques de performance
        if self.episode_rewards:
            perf_y = stats_y + 35 + len(info_lines) * 18 + 20
            
            recent_rewards = self.episode_rewards[-20:]  # 20 derniers épisodes
            avg_reward = sum(recent_rewards) / len(recent_rewards)
            
            perf_lines = [
                "Performance:",
                f"  Récompense moyenne (20 derniers): {avg_reward:.3f}",
                f"  Dernier épisode: {self.episode_rewards[-1]:.3f}",
                f"  Steps dernier épisode: {self.episode_steps[-1]}",
                f"  États-actions explorés: {len(self.agent.visited_state_actions)}",
            ]
            
            for i, line in enumerate(perf_lines):
                text = self.font_small.render(line, True, self.BLACK)
                self.screen.blit(text, (stats_x, perf_y + i * 18))
    
    def draw_legend(self):
        """Dessiner la légende"""
        legend_x = self.width - 250
        legend_y = 20
        
        title = self.font_medium.render("Légende:", True, self.BLACK)
        self.screen.blit(title, (legend_x, legend_y))
        
        legend_items = [
            ("Agent (A)", self.YELLOW),
            ("État terminal +1", self.DARK_GREEN),
            ("État terminal -1", self.DARK_RED),
            ("État normal", self.LIGHT_BLUE),
            ("Meilleure action", self.PURPLE),
            ("Q-value élevée", self.GREEN),
        ]
        
        for i, (label, color) in enumerate(legend_items):
            y = legend_y + 30 + i * 25
            pygame.draw.rect(self.screen, color, (legend_x, y, 20, 20))
            pygame.draw.rect(self.screen, self.BLACK, (legend_x, y, 20, 20), 1)
            text = self.font_small.render(label, True, self.BLACK)
            self.screen.blit(text, (legend_x + 30, y + 2))
    
    def draw_progress_bar(self):
        """Dessiner une barre de progression pour l'épisode"""
        if not hasattr(self, 'episode_length'):
            self.episode_length = 100  # Longueur maximale estimée
            
        bar_x = self.world_x
        bar_y = self.world_y + self.cell_height + 120
        bar_width = self.world.size * self.cell_width
        bar_height = 20
        
        # Fond de la barre
        pygame.draw.rect(self.screen, self.GRAY, (bar_x, bar_y, bar_width, bar_height))
        
        # Progression
        if hasattr(self, 'current_step'):
            progress = min(1.0, self.current_step / self.episode_length)
            pygame.draw.rect(self.screen, self.GREEN, (bar_x, bar_y, bar_width * progress, bar_height))
        
        # Bordure
        pygame.draw.rect(self.screen, self.BLACK, (bar_x, bar_y, bar_width, bar_height), 2)
        
        # Texte
        text = self.font_small.render("Progression de l'épisode", True, self.BLACK)
        text_rect = text.get_rect(center=(bar_x + bar_width // 2, bar_y - 15))
        self.screen.blit(text, text_rect)
    
    def update_display(self):
        """Mettre à jour l'affichage complet"""
        self.screen.fill(self.WHITE)
        
        self.draw_world()
        self.draw_q_values()
        self.draw_policy()
        self.draw_statistics()
        self.draw_legend()
        self.draw_model_info()
        self.draw_progress_bar()
        
        # Incrémenter l'animation
        self.animation_step += 1
        
        pygame.display.flip()
    
    def run_episode(self, max_steps=100, delay=0.1):
        """Exécuter un épisode avec visualisation"""
        state = self.world.reset()
        total_reward = 0
        steps = 0
        self.current_step = 0
        self.episode_length = max_steps
        
        for step in range(max_steps):
            self.current_step = step
            
            # Gérer les événements Pygame
            for event in pygame.event.get():
                if event.type == pygame.QUIT:
                    return False, total_reward, steps
                elif event.type == pygame.KEYDOWN:
                    if event.key == pygame.K_SPACE:
                        # Pause - attendre une autre pression sur espace
                        paused = True
                        while paused:
                            for pause_event in pygame.event.get():
                                if pause_event.type == pygame.QUIT:
                                    return False, total_reward, steps
                                elif pause_event.type == pygame.KEYDOWN:
                                    if pause_event.key == pygame.K_SPACE:
                                        paused = False
                            time.sleep(0.1)
            
            # Vérifier si l'état est terminal
            if self.world.is_terminal(state):
                break
            
            # Sélectionner et exécuter une action
            valid_actions = self.world.get_valid_actions(state)
            if not valid_actions:
                break
                
            action = self.agent.select_action(state, valid_actions)
            next_state, reward, done = self.world.step(action)
            
            # Apprentissage Dyna-Q
            valid_next_actions = self.world.get_valid_actions(next_state)
            self.agent.learn_step(state, action, reward, next_state, 
                                valid_next_actions, self.world.get_valid_actions)
            
            total_reward += reward
            steps += 1
            self.total_steps += 1
            state = next_state
            
            # Mettre à jour l'affichage
            self.update_display()
            time.sleep(delay)
            
            if done:
                break
        
        return True, total_reward, steps
    
    def run_training(self, n_episodes=200, delay=0.1):
        """Exécuter l'entraînement complet"""
        print("=== Démarrage de l'entraînement Dyna-Q ===")
        print("Contrôles:")
        print("- ESC: Arrêter l'entraînement")
        print("- SPACE: Pause/Reprendre")
        print("- Fermer la fenêtre: Quitter")
        print()
        
        running = True
        
        for episode in range(n_episodes):
            if not running:
                break
                
            self.current_episode = episode + 1
            
            # Exécuter un épisode
            continue_training, reward, steps = self.run_episode(delay=delay)
            if not continue_training:
                break
            
            # Enregistrer les statistiques
            self.episode_rewards.append(reward)
            self.episode_steps.append(steps)
            
            # Réduire progressivement l'exploration
            if episode % 20 == 0 and episode > 0:
                self.agent.epsilon = max(0.01, self.agent.epsilon * 0.95)
            
            # Afficher les progrès dans la console
            if episode % 50 == 0 or episode < 10:
                recent_rewards = self.episode_rewards[-20:] if len(self.episode_rewards) >= 20 else self.episode_rewards
                avg_reward = sum(recent_rewards) / len(recent_rewards) if recent_rewards else 0
                print(f"Épisode {episode}: Récompense moyenne = {avg_reward:.3f}, "
                      f"Epsilon = {self.agent.epsilon:.3f}, Steps = {steps}")
        
        print("\n=== Entraînement terminé ===")
        print(f"Episodes complétés: {len(self.episode_rewards)}")
        if self.episode_rewards:
            final_avg = sum(self.episode_rewards[-50:]) / min(50, len(self.episode_rewards))
            print(f"Performance finale (50 derniers épisodes): {final_avg:.3f}")
        
        # Afficher un résumé final
        input("Appuyez sur Entrée pour fermer...")
        pygame.quit()


def main():
    """Fonction principale"""
    print("=== Configuration Dyna-Q LineWorld ===")
    
    # Créer l'environnement LineWorld
    world_size = 7
    start_pos = world_size // 2  # Position centrale
    world = LineWorld(size=world_size, start_state=start_pos)
    
    print(f"Monde linéaire de taille {world_size}")
    print(f"Position de départ: S{start_pos}")
    print(f"États terminaux: S0 (récompense -1) et S{world_size-1} (récompense +1)")
    
    # Créer l'agent Dyna-Q
    agent = DynaQAgent(
        n_actions=2,           # Gauche, Droite
        alpha=0.1,             # Taux d'apprentissage
        gamma=0.95,            # Facteur de discount
        epsilon=0.3,           # Exploration initiale
        n_planning=10          # Étapes de planification
    )
    
    print(f"Agent Dyna-Q configuré:")
    print(f"- Taux d'apprentissage (α): {agent.alpha}")
    print(f"- Facteur de discount (γ): {agent.gamma}")
    print(f"- Exploration initiale (ε): {agent.epsilon}")
    print(f"- Étapes de planification: {agent.n_planning}")
    
    # Créer la visualisation
    viz = DynaQLineWorldVisualization(world, agent)
    
    # Lancer l'entraînement
    print("\nLancement de la visualisation...")
    viz.run_training(n_episodes=300, delay=0.08)


if __name__ == "__main__":
    main()

pygame 2.6.1 (SDL 2.28.4, Python 3.13.5)
Hello from the pygame community. https://www.pygame.org/contribute.html
=== Configuration Dyna-Q LineWorld ===
Monde linéaire de taille 7
Position de départ: S3
États terminaux: S0 (récompense -1) et S6 (récompense +1)
Agent Dyna-Q configuré:
- Taux d'apprentissage (α): 0.1
- Facteur de discount (γ): 0.95
- Exploration initiale (ε): 0.3
- Étapes de planification: 10

Lancement de la visualisation...
=== Démarrage de l'entraînement Dyna-Q ===
Contrôles:
- ESC: Arrêter l'entraînement
- SPACE: Pause/Reprendre
- Fermer la fenêtre: Quitter

Épisode 0: Récompense moyenne = -1.000, Epsilon = 0.300, Steps = 13
Épisode 1: Récompense moyenne = 0.000, Epsilon = 0.300, Steps = 3
Épisode 2: Récompense moyenne = 0.333, Epsilon = 0.300, Steps = 7
Épisode 3: Récompense moyenne = 0.500, Epsilon = 0.300, Steps = 5
Épisode 4: Récompense moyenne = 0.600, Epsilon = 0.300, Steps = 3
Épisode 5: Récompense moyenne = 0.667, Epsilon = 0.300, Steps = 9
Épisode 6: Récompen

In [None]:
import pygame
import random
import numpy as np
from collections import defaultdict
import time
import matplotlib.pyplot as plt

class DynaQAgent:
    """Algorithme Dyna-Q selon Sutton & Barto"""
    
    def __init__(self, n_actions, alpha=0.1, gamma=0.95, epsilon=0.1, n_planning=50):
        self.n_actions = n_actions
        self.alpha = alpha  # Taux d'apprentissage
        self.gamma = gamma  # Facteur de discount
        self.epsilon = epsilon  # Exploration ε-greedy
        self.n_planning = n_planning  # Étapes de planification
        
        # Table Q : Q[état][action] = valeur
        self.Q = defaultdict(lambda: defaultdict(float))
        
        # Modèle : Model[état][action] = (récompense, nouvel_état)
        self.Model = defaultdict(lambda: defaultdict(lambda: None))
        
        # Paires (état, action) visitées pour planification
        self.visited_state_actions = set()
    
    def select_action(self, state, valid_actions=None):
        """Politique ε-greedy"""
        if valid_actions is None:
            valid_actions = list(range(self.n_actions))
        
        if random.random() < self.epsilon:
            return random.choice(valid_actions)
        else:
            q_values = [self.Q[state][action] for action in valid_actions]
            if not q_values:
                return random.choice(valid_actions)
            max_q = max(q_values)
            best_actions = [a for a, q in zip(valid_actions, q_values) if q == max_q]
            return random.choice(best_actions)
    
    def update_q(self, state, action, reward, next_state, valid_next_actions=None):
        """Mise à jour Q-learning"""
        if valid_next_actions is None:
            valid_next_actions = list(range(self.n_actions))
        
        if valid_next_actions:
            max_next_q = max([self.Q[next_state][a] for a in valid_next_actions])
        else:
            max_next_q = 0.0
        
        # Q(s,a) ← Q(s,a) + α[r + γ max Q(s',a') - Q(s,a)]
        current_q = self.Q[state][action]
        td_target = reward + self.gamma * max_next_q
        self.Q[state][action] = current_q + self.alpha * (td_target - current_q)
    
    def update_model(self, state, action, reward, next_state):
        """Mise à jour du modèle"""
        self.Model[state][action] = (reward, next_state)
        self.visited_state_actions.add((state, action))
    
    def planning_step(self, get_valid_actions_fn=None):
        """Une étape de planification"""
        if not self.visited_state_actions:
            return
        
        # Échantillonner (état, action) aléatoirement
        state, action = random.choice(list(self.visited_state_actions))
        
        if self.Model[state][action] is not None:
            reward, next_state = self.Model[state][action]
            
            # Actions valides dans next_state
            if get_valid_actions_fn:
                valid_next_actions = get_valid_actions_fn(next_state)
            else:
                valid_next_actions = list(range(self.n_actions))
            
            # Mise à jour Q avec expérience simulée
            self.update_q(state, action, reward, next_state, valid_next_actions)
    
    def learn_step(self, state, action, reward, next_state, valid_next_actions=None, get_valid_actions_fn=None):
        """Étape complète Dyna-Q"""
        
        # 1. APPRENTISSAGE DIRECT
        self.update_q(state, action, reward, next_state, valid_next_actions)
        
        # 2. MISE À JOUR MODÈLE
        self.update_model(state, action, reward, next_state)
        
        # 3. PLANIFICATION (n étapes)
        for _ in range(self.n_planning):
            self.planning_step(get_valid_actions_fn)


class LineWorld:
    def __init__(self, size=5, start_state=2):
        """
        Initialize the LineWorld environment.

        Args:
            size (int): Number of states in the environment (default 5).
            start_state (int): The agent's initial position (default 2).
        """
        self.size = size
        self.start_state = start_state
        self.state = start_state
        self.terminal_states = [0, size - 1]
        self.action_space = [0, 1]  # 0: left, 1: right

    def state_to_index(self, state):
        """For compatibility with RL algos. State is already int, so just return."""
        return state

    def index_to_state(self, index):
        """For compatibility with RL algos. State is already int, so just return."""
        return index

    @property
    def n_states(self):
        """Total number of states (for tabular RL)."""
        return self.size

    def reset(self):
        """
        Reset the environment to the initial state.
        Returns:
            int: The initial state.
        """
        self.state = self.start_state
        return self.state

    def is_terminal(self, state):
        """
        Check if a given state is terminal.
        Args:
            state (int): State to check.
        Returns:
            bool: True if the state is terminal, False otherwise.
        """
        return state in self.terminal_states

    def get_reward(self, next_state):
        """
        Get the reward for transitioning into the next state.
        Args:
            next_state (int): The state resulting from the agent's action.
        Returns:
            float: The reward received for the transition.
        """
        if next_state == 0:
            return -1.0
        elif next_state == self.size - 1:
            return 1.0
        else:
            return 0.0

    def get_valid_actions(self, state):
        """Get valid actions for a given state"""
        if self.is_terminal(state):
            return []
        return self.action_space.copy()

    def step(self, action):
        """
        Take an action in the environment.
        Args:
            action (int): Action to take (0 = left, 1 = right).
        Returns:
            tuple: (next_state (int), reward (float), done (bool))
        """
        if self.is_terminal(self.state):
            return self.state, 0.0, True

        if action == 0:
            next_state = max(self.state - 1, 0)
        elif action == 1:
            next_state = min(self.state + 1, self.size - 1)
        else:
            raise ValueError("Invalid action (0=left, 1=right)")

        reward = self.get_reward(next_state)
        done = self.is_terminal(next_state)
        self.state = next_state
        return next_state, reward, done

    def simulate_step(self, state, action):
        """
        Simulate a step from a given state and action without modifying self.state.
        Returns:
            next_state, reward, done
        """
        if self.is_terminal(state):
            return state, 0.0, True

        if action == 0:
            next_state = max(state - 1, 0)
        elif action == 1:
            next_state = min(state + 1, self.size - 1)
        else:
            raise ValueError("Invalid action (0=left, 1=right)")

        reward = self.get_reward(next_state)
        done = self.is_terminal(next_state)
        return next_state, reward, done

    def render(self):
        """
        Print a visual representation of the environment.
        'A': Agent position
        'T': Terminal state(s)
        '_': Normal state
        """
        line = []
        for i in range(self.size):
            if i == self.state:
                line.append('A')
            elif i in self.terminal_states:
                line.append('T')
            else:
                line.append('_')
        print(' '.join(line))


class DynaQLineWorldVisualization:
    """Visualisation Pygame pour Dyna-Q dans LineWorld"""
    
    def __init__(self, world, agent, width=1200, height=700):
        pygame.init()
        self.world = world
        self.agent = agent
        self.width = width
        self.height = height
        self.screen = pygame.display.set_mode((width, height))
        pygame.display.set_caption("Dyna-Q LineWorld - Apprentissage par Renforcement")
        
        # Couleurs
        self.WHITE = (255, 255, 255)
        self.BLACK = (0, 0, 0)
        self.BLUE = (0, 100, 255)
        self.RED = (255, 50, 50)
        self.GREEN = (50, 255, 50)
        self.GRAY = (180, 180, 180)
        self.YELLOW = (255, 255, 0)
        self.PURPLE = (150, 50, 200)
        self.DARK_GREEN = (0, 150, 0)
        self.DARK_RED = (150, 0, 0)
        self.LIGHT_BLUE = (173, 216, 230)
        
        # Paramètres de visualisation
        self.cell_width = 80
        self.cell_height = 80
        self.world_y = height // 2 - self.cell_height // 2
        self.world_x = (width - world.size * self.cell_width) // 2
        
        # Polices
        self.font_large = pygame.font.Font(None, 28)
        self.font_medium = pygame.font.Font(None, 24)
        self.font_small = pygame.font.Font(None, 18)
        self.font_tiny = pygame.font.Font(None, 14)
        
        # Statistiques
        self.episode_rewards = []
        self.episode_steps = []
        self.current_episode = 0
        self.total_steps = 0
        
        # Animation
        self.animation_step = 0
        
    def draw_world(self):
        """Dessiner le monde linéaire"""
        for i in range(self.world.size):
            x = self.world_x + i * self.cell_width
            y = self.world_y
            
            # Couleur de base selon le type de cellule
            if i in self.world.terminal_states:
                if i == 0:  # Terminal négatif
                    base_color = self.DARK_RED
                    label_color = self.WHITE
                    reward_text = "-1"
                else:  # Terminal positif
                    base_color = self.DARK_GREEN
                    label_color = self.WHITE
                    reward_text = "+1"
            else:
                base_color = self.LIGHT_BLUE
                label_color = self.BLACK
                reward_text = "0"
            
            # Effet de surbrillance pour la position actuelle
            if i == self.world.state:
                # Animation de pulsation
                pulse = abs(np.sin(self.animation_step * 0.1)) * 30
                base_color = tuple(min(255, c + pulse) for c in base_color)
            
            # Dessiner la cellule
            pygame.draw.rect(self.screen, base_color, (x, y, self.cell_width, self.cell_height))
            pygame.draw.rect(self.screen, self.BLACK, (x, y, self.cell_width, self.cell_height), 3)
            
            # Numéro de l'état
            state_text = self.font_medium.render(f"S{i}", True, label_color)
            state_rect = state_text.get_rect(center=(x + self.cell_width // 2, y + 15))
            self.screen.blit(state_text, state_rect)
            
            # Récompense
            reward_surface = self.font_small.render(f"R: {reward_text}", True, label_color)
            reward_rect = reward_surface.get_rect(center=(x + self.cell_width // 2, y + 35))
            self.screen.blit(reward_surface, reward_rect)
            
            # Agent (marqueur spécial)
            if i == self.world.state:
                agent_radius = 15
                pygame.draw.circle(self.screen, self.YELLOW, 
                                 (x + self.cell_width // 2, y + self.cell_height - 25), 
                                 agent_radius)
                pygame.draw.circle(self.screen, self.BLACK, 
                                 (x + self.cell_width // 2, y + self.cell_height - 25), 
                                 agent_radius, 2)
                agent_text = self.font_small.render("A", True, self.BLACK)
                agent_rect = agent_text.get_rect(center=(x + self.cell_width // 2, y + self.cell_height - 25))
                self.screen.blit(agent_text, agent_rect)
    
    def draw_q_values(self):
        """Dessiner les valeurs Q sous chaque état"""
        for state in range(self.world.size):
            if self.world.is_terminal(state):
                continue
                
            x = self.world_x + state * self.cell_width
            y = self.world_y + self.cell_height + 10
            
            # Valeurs Q pour cet état
            q_left = self.agent.Q[state][0]
            q_right = self.agent.Q[state][1]
            
            # Couleurs basées sur les valeurs
            left_color = self.GREEN if q_left > q_right else self.GRAY
            right_color = self.GREEN if q_right > q_left else self.GRAY
            
            # Afficher Q-values avec couleurs
            left_text = f"←{q_left:.2f}"
            right_text = f"→{q_right:.2f}"
            
            left_surface = self.font_tiny.render(left_text, True, left_color)
            right_surface = self.font_tiny.render(right_text, True, right_color)
            
            # Positionner les textes
            left_rect = left_surface.get_rect(center=(x + self.cell_width // 4, y))
            right_rect = right_surface.get_rect(center=(x + 3 * self.cell_width // 4, y))
            
            self.screen.blit(left_surface, left_rect)
            self.screen.blit(right_surface, right_rect)
    
    def draw_policy(self):
        """Dessiner la politique (flèches pour la meilleure action)"""
        for state in range(self.world.size):
            if self.world.is_terminal(state):
                continue
                
            valid_actions = self.world.get_valid_actions(state)
            if not valid_actions:
                continue
            
            # Trouver la meilleure action
            q_values = [self.agent.Q[state][action] for action in valid_actions]
            if max(q_values) == min(q_values):
                continue  # Pas de préférence claire
                
            best_action = valid_actions[np.argmax(q_values)]
            
            # Dessiner la flèche
            x = self.world_x + state * self.cell_width + self.cell_width // 2
            y = self.world_y - 30
            
            arrow_size = 12
            if best_action == 0:  # Gauche
                points = [(x-arrow_size, y), (x+arrow_size//2, y-arrow_size//2), (x+arrow_size//2, y+arrow_size//2)]
                pygame.draw.polygon(self.screen, self.PURPLE, points)
            else:  # Droite
                points = [(x+arrow_size, y), (x-arrow_size//2, y-arrow_size//2), (x-arrow_size//2, y+arrow_size//2)]
                pygame.draw.polygon(self.screen, self.PURPLE, points)
    
    def draw_model_info(self):
        """Afficher les informations sur le modèle appris"""
        y_start = self.world_y + self.cell_height + 60
        
        title = self.font_medium.render("Modèle Dyna-Q (échantillon):", True, self.BLACK)
        self.screen.blit(title, (20, y_start))
        
        # Afficher quelques transitions apprises
        displayed = 0
        max_display = 6
        
        # Parcourir correctement le modèle à deux niveaux
        for state in list(self.agent.Model.keys())[:max_display]:
            if displayed >= max_display:
                break
                
            for action in self.agent.Model[state]:
                if displayed >= max_display:
                    break
                    
                transition = self.agent.Model[state][action]
                if transition is not None:  # Vérifier qu'il y a bien une transition
                    reward, next_state = transition
                    action_name = "←" if action == 0 else "→"
                    text = f"S{state} {action_name} → S{next_state} (R={reward})"
                    surface = self.font_small.render(text, True, self.BLACK)
                    self.screen.blit(surface, (20, y_start + 25 + displayed * 20))
                    displayed += 1
    
    def draw_statistics(self):
        """Dessiner les statistiques détaillées"""
        stats_x = 20
        stats_y = 20
        
        # Titre
        title = self.font_large.render("Statistiques Dyna-Q", True, self.BLACK)
        self.screen.blit(title, (stats_x, stats_y))
        
        # Informations de l'épisode
        info_lines = [
            f"Épisode: {self.current_episode}",
            f"Position: S{self.world.state}",
            f"Steps totaux: {self.total_steps}",
            "",
            f"Paramètres:",
            f"  ε (exploration): {self.agent.epsilon:.3f}",
            f"  α (apprentissage): {self.agent.alpha:.3f}",
            f"  γ (discount): {self.agent.gamma:.3f}",
            f"  Planning steps: {self.agent.n_planning}",
        ]
        
        for i, line in enumerate(info_lines):
            if line:  # Ignorer les lignes vides pour le rendu
                text = self.font_small.render(line, True, self.BLACK)
                self.screen.blit(text, (stats_x, stats_y + 35 + i * 18))
        
        # Statistiques de performance
        if self.episode_rewards:
            perf_y = stats_y + 35 + len(info_lines) * 18 + 20
            
            recent_rewards = self.episode_rewards[-20:]  # 20 derniers épisodes
            avg_reward = sum(recent_rewards) / len(recent_rewards)
            
            perf_lines = [
                "Performance:",
                f"  Récompense moyenne (20 derniers): {avg_reward:.3f}",
                f"  Dernier épisode: {self.episode_rewards[-1]:.3f}",
                f"  Steps dernier épisode: {self.episode_steps[-1]}",
                f"  États-actions explorés: {len(self.agent.visited_state_actions)}",
            ]
            
            for i, line in enumerate(perf_lines):
                text = self.font_small.render(line, True, self.BLACK)
                self.screen.blit(text, (stats_x, perf_y + i * 18))
    
    def draw_legend(self):
        """Dessiner la légende"""
        legend_x = self.width - 250
        legend_y = 20
        
        title = self.font_medium.render("Légende:", True, self.BLACK)
        self.screen.blit(title, (legend_x, legend_y))
        
        legend_items = [
            ("Agent (A)", self.YELLOW),
            ("État terminal +1", self.DARK_GREEN),
            ("État terminal -1", self.DARK_RED),
            ("État normal", self.LIGHT_BLUE),
            ("Meilleure action", self.PURPLE),
            ("Q-value élevée", self.GREEN),
        ]
        
        for i, (label, color) in enumerate(legend_items):
            y = legend_y + 30 + i * 25
            pygame.draw.rect(self.screen, color, (legend_x, y, 20, 20))
            pygame.draw.rect(self.screen, self.BLACK, (legend_x, y, 20, 20), 1)
            text = self.font_small.render(label, True, self.BLACK)
            self.screen.blit(text, (legend_x + 30, y + 2))
    
    def draw_progress_bar(self):
        """Dessiner une barre de progression pour l'épisode"""
        if not hasattr(self, 'episode_length'):
            self.episode_length = 100  # Longueur maximale estimée
            
        bar_x = self.world_x
        bar_y = self.world_y + self.cell_height + 120
        bar_width = self.world.size * self.cell_width
        bar_height = 20
        
        # Fond de la barre
        pygame.draw.rect(self.screen, self.GRAY, (bar_x, bar_y, bar_width, bar_height))
        
        # Progression
        if hasattr(self, 'current_step'):
            progress = min(1.0, self.current_step / self.episode_length)
            pygame.draw.rect(self.screen, self.GREEN, (bar_x, bar_y, bar_width * progress, bar_height))
        
        # Bordure
        pygame.draw.rect(self.screen, self.BLACK, (bar_x, bar_y, bar_width, bar_height), 2)
        
        # Texte
        text = self.font_small.render("Progression de l'épisode", True, self.BLACK)
        text_rect = text.get_rect(center=(bar_x + bar_width // 2, bar_y - 15))
        self.screen.blit(text, text_rect)
    
    def update_display(self):
        """Mettre à jour l'affichage complet"""
        self.screen.fill(self.WHITE)
        
        self.draw_world()
        self.draw_q_values()
        self.draw_policy()
        self.draw_statistics()
        self.draw_legend()
        self.draw_model_info()
        self.draw_progress_bar()
        
        # Incrémenter l'animation
        self.animation_step += 1
        
        pygame.display.flip()
    
    def run_episode(self, max_steps=100, delay=0.1):
        """Exécuter un épisode avec visualisation"""
        state = self.world.reset()
        total_reward = 0
        steps = 0
        self.current_step = 0
        self.episode_length = max_steps
        
        for step in range(max_steps):
            self.current_step = step
            
            # Gérer les événements Pygame
            for event in pygame.event.get():
                if event.type == pygame.QUIT:
                    return False, total_reward, steps
                elif event.type == pygame.KEYDOWN:
                    if event.key == pygame.K_SPACE:
                        # Pause - attendre une autre pression sur espace
                        paused = True
                        while paused:
                            for pause_event in pygame.event.get():
                                if pause_event.type == pygame.QUIT:
                                    return False, total_reward, steps
                                elif pause_event.type == pygame.KEYDOWN:
                                    if pause_event.key == pygame.K_SPACE:
                                        paused = False
                            time.sleep(0.1)
            
            # Vérifier si l'état est terminal
            if self.world.is_terminal(state):
                break
            
            # Sélectionner et exécuter une action
            valid_actions = self.world.get_valid_actions(state)
            if not valid_actions:
                break
                
            action = self.agent.select_action(state, valid_actions)
            next_state, reward, done = self.world.step(action)
            
            # Apprentissage Dyna-Q
            valid_next_actions = self.world.get_valid_actions(next_state)
            self.agent.learn_step(state, action, reward, next_state, 
                                valid_next_actions, self.world.get_valid_actions)
            
            total_reward += reward
            steps += 1
            self.total_steps += 1
            state = next_state
            
            # Mettre à jour l'affichage
            self.update_display()
            time.sleep(delay)
            
            if done:
                break
        
        return True, total_reward, steps
    
    def run_training(self, n_episodes=200, delay=0.1):
        """Exécuter l'entraînement complet"""
        print("=== Démarrage de l'entraînement Dyna-Q ===")
        print("Contrôles:")
        print("- ESC: Arrêter l'entraînement")
        print("- SPACE: Pause/Reprendre")
        print("- Fermer la fenêtre: Quitter")
        print()
        
        running = True
        
        for episode in range(n_episodes):
            if not running:
                break
                
            self.current_episode = episode + 1
            
            # Exécuter un épisode
            continue_training, reward, steps = self.run_episode(delay=delay)
            if not continue_training:
                break
            
            # Enregistrer les statistiques
            self.episode_rewards.append(reward)
            self.episode_steps.append(steps)
            
            # Réduire progressivement l'exploration
            if episode % 20 == 0 and episode > 0:
                self.agent.epsilon = max(0.01, self.agent.epsilon * 0.95)
            
            # Afficher les progrès dans la console
            if episode % 50 == 0 or episode < 10:
                recent_rewards = self.episode_rewards[-20:] if len(self.episode_rewards) >= 20 else self.episode_rewards
                avg_reward = sum(recent_rewards) / len(recent_rewards) if recent_rewards else 0
                print(f"Épisode {episode}: Récompense moyenne = {avg_reward:.3f}, "
                      f"Epsilon = {self.agent.epsilon:.3f}, Steps = {steps}")
        
        print("\n=== Entraînement terminé ===")
        print(f"Episodes complétés: {len(self.episode_rewards)}")
        if self.episode_rewards:
            final_avg = sum(self.episode_rewards[-50:]) / min(50, len(self.episode_rewards))
            print(f"Performance finale (50 derniers épisodes): {final_avg:.3f}")
        
        # Afficher un résumé final
        input("Appuyez sur Entrée pour fermer...")
        pygame.quit()

    def run_manual_test(self):
        """Mode manuel : l'utilisateur contrôle l'agent avec le clavier"""
        state = self.world.reset()
        self.current_episode = "MANUEL"
        self.current_step = 0
        self.episode_length = 100

        running = True
        clock = pygame.time.Clock()

        print("=== MODE TEST MANUEL ===")
        print("Touches :")
        print("← (flèche gauche) = aller à gauche")
        print("→ (flèche droite) = aller à droite")
        print("R = Réinitialiser à la position de départ")
        print("ESC = Quitter le test\n")

        while running:
            for event in pygame.event.get():
                if event.type == pygame.QUIT:
                    running = False
                    break
                elif event.type == pygame.KEYDOWN:
                    if event.key == pygame.K_ESCAPE:
                        running = False
                        break
                    elif event.key == pygame.K_LEFT:
                        action = 0
                    elif event.key == pygame.K_RIGHT:
                        action = 1
                    elif event.key == pygame.K_r:
                        state = self.world.reset()
                        continue
                    else:
                        action = None

                    if action is not None and not self.world.is_terminal(state):
                        next_state, reward, done = self.world.step(action)
                        print(f"Action {'←' if action == 0 else '→'} | S{state} → S{next_state} | R = {reward}")
                        state = next_state

                        if done:
                            print(f"🎯 État terminal atteint: S{state} | Récompense: {reward}")
                            print("Appuie sur R pour recommencer ou ESC pour quitter.")

            self.update_display()
            clock.tick(10)



def main():
    """Fonction principale"""
    print("=== Configuration Dyna-Q LineWorld ===")
    
    # Créer l'environnement LineWorld
    world_size = 7
    start_pos = world_size // 2
    world = LineWorld(size=world_size, start_state=start_pos)
    
    print(f"Monde linéaire de taille {world_size}")
    print(f"Position de départ: S{start_pos}")
    print(f"États terminaux: S0 (récompense -1) et S{world_size-1} (récompense +1)")
    
    # Créer l'agent Dyna-Q
    agent = DynaQAgent(
        n_actions=2,
        alpha=0.1,
        gamma=0.95,
        epsilon=0.3,
        n_planning=10
    )
    
    print(f"Agent Dyna-Q configuré:")
    print(f"- Taux d'apprentissage (α): {agent.alpha}")
    print(f"- Facteur de discount (γ): {agent.gamma}")
    print(f"- Exploration initiale (ε): {agent.epsilon}")
    print(f"- Étapes de planification: {agent.n_planning}")
    
    # Créer la visualisation
    viz = DynaQLineWorldVisualization(world, agent)
    
    # Lancer l'entraînement
    print("\nLancement de la visualisation...")
    viz.run_training(n_episodes=300, delay=0.08)

    # Lancer le test manuel après l'entraînement
    print("\nSouhaitez-vous tester manuellement le comportement de l'agent ? (y/n)")
    choice = input().strip().lower()
    if choice == "y":
        viz.run_manual_test()



if __name__ == "__main__":
    main()