In [1]:
from ale_py import ALEInterface, roms

ale = ALEInterface()
ale.loadROM(roms.get_rom_path("breakout"))
ale.reset_game()

reward = ale.act(0)  # noop
screen_obs = ale.getScreenRGB()

A.L.E: Arcade Learning Environment (version 0.10.1+6a7e0ae)
[Powered by Stella]
Game console created:
  ROM file:  /opt/anaconda3/lib/python3.12/site-packages/ale_py/roms/breakout.bin
  Cart Name: Breakout - Breakaway IV (1978) (Atari)
  Cart MD5:  f34f08e5eb96e500e851a80be3277a56
  Display Format:  AUTO-DETECT ==> NTSC
  ROM Size:        2048
  Bankswitch Type: AUTO-DETECT ==> 2K

Running ROM file...
Random seed is 1734803125


In [7]:
import gymnasium as gym
import ale_py

gym.register_envs(ale_py)  # unnecessary but helpful for IDEs

env = gym.make('ALE/Breakout-v5', render_mode="human")  # remove render_mode in training
obs, info = env.reset()
episode_over = False
# while not episode_over:
#     action = policy(obs)  # to implement - use `env.action_space.sample()` for a random policy
#     obs, reward, terminated, truncated, info = env.step(action)

#     episode_over = terminated or truncated
# env.close()

# print(obs.shape)

In [5]:
import gymnasium as gym
import ale_py
import numpy as np
import random
import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque
import cv2
import matplotlib.pyplot as plt  # For plotting graphs
import os
from tqdm import tqdm

# Register Atari environments (optional, helps IDEs)
gym.register_envs(ale_py)

# Hyperparameters
ENV_NAME = 'ALE/Breakout-v5'
GAMMA = 0.99
LEARNING_RATE = 1e-4
BATCH_SIZE = 32
MEMORY_SIZE = 10000
MIN_MEMORY_SIZE = 1000
EPS_START = 0.5
EPS_END = 0.1
EPS_DECAY = 1000000
TARGET_UPDATE_FREQ = 100
NUM_EPISODES = 500
MAX_STEPS = 1000
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def preprocess_frame(frame):
    """Preprocesses a single frame: grayscale, resize, normalize."""
    frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)  # Convert to grayscale
    frame = cv2.resize(frame, (84, 84), interpolation=cv2.INTER_AREA)  # Resize to 84x84
    frame = frame / 255.0  # Normalize pixel values
    return frame

class DQN(nn.Module):
    """Simplified Deep Q-Network for faster training."""
    def __init__(self, input_channels, num_actions):
        super(DQN, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(input_channels, 16, kernel_size=8, stride=4),  # Reduced channels
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=4, stride=2),  # Reduced channels
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, stride=1),  # Reduced channels
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(32 * 7 * 7, 256),  # Reduced linear layer size
            nn.ReLU(),
            nn.Linear(256, num_actions)
        )
    
    def forward(self, x):
        return self.net(x)

class ReplayBuffer:
    """Experience Replay Buffer."""
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)
    
    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))
    
    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        return (
            torch.tensor(states, dtype=torch.float32).to(DEVICE),
            torch.tensor(actions, dtype=torch.long).to(DEVICE),
            torch.tensor(rewards, dtype=torch.float32).to(DEVICE),
            torch.tensor(next_states, dtype=torch.float32).to(DEVICE),
            torch.tensor(dones, dtype=torch.float32).to(DEVICE)
        )
    
    def __len__(self):
        return len(self.buffer)

def select_action(state, policy_net, epsilon, num_actions):
    """Selects an action using epsilon-greedy policy."""
    if random.random() < epsilon:
        return random.randrange(num_actions)
    else:
        with torch.no_grad():
            state = torch.tensor(state, dtype=torch.float32).unsqueeze(0).to(DEVICE)
            q_values = policy_net(state)
            return q_values.argmax().item()

def plot_metrics(episodes, rewards, epsilons, losses, save_dir, episode):
    """Plots and saves the training metrics after each episode."""
    os.makedirs(save_dir, exist_ok=True)
    
    # Plot Total Rewards per Episode
    plt.figure(figsize=(12, 6))
    plt.plot(episodes, rewards, label='Total Reward per Episode')
    plt.xlabel('Episode')
    plt.ylabel('Total Reward')
    plt.title('Total Rewards Over Episodes')
    plt.legend()
    plt.savefig(os.path.join(save_dir, f'total_rewards_episode.png'))
    plt.close()
    
    # Plot Epsilon Decay
    plt.figure(figsize=(12, 6))
    plt.plot(episodes, epsilons, label='Epsilon Value')
    plt.xlabel('Episode')
    plt.ylabel('Epsilon')
    plt.title('Epsilon Decay Over Episodes')
    plt.legend()
    plt.savefig(os.path.join(save_dir, f'epsilon_decay_episode.png'))
    plt.close()
    
    # Plot Loss Over Time
    plt.figure(figsize=(12, 6))
    plt.plot(episodes, losses, label='Loss')
    plt.xlabel('Episode')
    plt.ylabel('Loss')
    plt.title('Loss Over Episodes')
    plt.legend()
    plt.savefig(os.path.join(save_dir, f'loss_episode.png'))
    plt.close()

def main():
    # Initialize environment
    env = gym.make(ENV_NAME)
    num_actions = env.action_space.n

    # Initialize networks
    input_channels = 4  # Stacked frames
    policy_net = DQN(input_channels, num_actions).to(DEVICE)
    target_net = DQN(input_channels, num_actions).to(DEVICE)
    target_net.load_state_dict(policy_net.state_dict())
    target_net.eval()

    optimizer = optim.Adam(policy_net.parameters(), lr=LEARNING_RATE)
    criterion = nn.MSELoss()

    # Initialize replay buffer
    replay_buffer = ReplayBuffer(MEMORY_SIZE)

    epsilon = EPS_START
    epsilon_decay = (EPS_START - EPS_END) / EPS_DECAY

    # Lists to store metrics
    episode_rewards = []
    episode_epsilons = []
    episode_losses = []
    episodes = []

    # Directory to save plots and model
    save_directory = 'dqn_training_results'
    os.makedirs(save_directory, exist_ok=True)

    episode_bar = tqdm(range(1, NUM_EPISODES + 1), desc="Training Episodes", unit="episode")
    for episode in episode_bar:
        obs, info = env.reset()
        state = preprocess_frame(obs)
        state_stack = deque([state] * 4, maxlen=4)  # Initialize with 4 frames
        total_reward = 0
        done = False
        loss_per_episode = 0
        steps = 0

        for step in range(MAX_STEPS):
            stacked_state = np.stack(state_stack, axis=0)
            action = select_action(stacked_state, policy_net, epsilon, num_actions)
            next_obs, reward, terminated, truncated, info = env.step(action)
            done = terminated or truncated
            total_reward += reward

            next_state = preprocess_frame(next_obs)
            state_stack.append(next_state)
            stacked_next_state = np.stack(state_stack, axis=0)

            replay_buffer.push(stacked_state, action, reward, stacked_next_state, done)

            if len(replay_buffer) > MIN_MEMORY_SIZE:
                # Sample a batch
                states, actions, rewards, next_states, dones = replay_buffer.sample(BATCH_SIZE)

                # Compute current Q values
                q_values = policy_net(states).gather(1, actions.unsqueeze(1)).squeeze(1)

                # Compute target Q values
                with torch.no_grad():
                    next_q_values = target_net(next_states).max(1)[0]
                    target_q_values = rewards + GAMMA * next_q_values * (1 - dones)

                # Compute loss
                loss = criterion(q_values, target_q_values)
                loss_per_episode += loss.item()

                # Optimize the model
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                # Update epsilon
                if epsilon > EPS_END:
                    epsilon -= epsilon_decay

                # Update target network
                if step % TARGET_UPDATE_FREQ == 0:
                    target_net.load_state_dict(policy_net.state_dict())

            if done:
                break
            steps += 1

        # Record metrics
        episode_rewards.append(total_reward)
        episode_epsilons.append(epsilon)
        average_loss = loss_per_episode / steps if steps > 0 else 0
        episode_losses.append(average_loss)
        episodes.append(episode)

        # Update progress bar with latest metrics
        episode_bar.set_postfix({
            'Total Reward': total_reward,
            'Epsilon': f"{epsilon:.4f}",
            'Avg Loss': f"{average_loss:.4f}"
        })

        # Plot and save metrics after each episode
        plot_metrics(episodes, episode_rewards, episode_epsilons, episode_losses, save_directory, episode)

        # Print progress every 100 episodes
        if episode % 100 == 0:
            print(f"Episode {episode}, Total Reward: {total_reward}, Epsilon: {epsilon:.4f}, Average Loss: {average_loss:.4f}")

    # Plot and save final metrics
    plot_metrics(episodes, episode_rewards, episode_epsilons, episode_losses, save_directory, 'final')

    # Save the trained model
    model_path = os.path.join(save_directory, 'dqn_breakout_model.pth')
    torch.save(policy_net.state_dict(), model_path)
    print(f"Training completed! Model saved at {model_path}")
    env.close()

if __name__ == "__main__":
    main()

Training Episodes:  20%|██        | 100/500 [1:22:59<5:46:10, 51.93s/episode, Total Reward=2, Epsilon=0.4915, Avg Loss=0.0144]

Episode 100, Total Reward: 2.0, Epsilon: 0.4915, Average Loss: 0.0144


Training Episodes:  40%|████      | 200/500 [2:32:59<3:14:59, 39.00s/episode, Total Reward=1, Epsilon=0.4830, Avg Loss=0.0081]

Episode 200, Total Reward: 1.0, Epsilon: 0.4830, Average Loss: 0.0081


Training Episodes:  60%|██████    | 300/500 [3:57:02<3:04:00, 55.20s/episode, Total Reward=4, Epsilon=0.4727, Avg Loss=0.0089]

Episode 300, Total Reward: 4.0, Epsilon: 0.4727, Average Loss: 0.0089


Training Episodes:  80%|███████▉  | 398/500 [5:44:08<1:27:20, 51.38s/episode, Total Reward=3, Epsilon=0.4606, Avg Loss=0.0200] 