Setup

In [1]:
# Install
!pip install gymnasium[atari]
!pip install autorom[accept-rom-license]
!pip install torch

# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Imports
import gymnasium as gym
from gymnasium.wrappers import AtariPreprocessing, FrameStackObservation
import matplotlib.pyplot as plt
import ale_py
import numpy as np
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from collections import deque
import time
import os

# Register
gym.register_envs(ale_py)

# DQN
class DQN(nn.Module):
    def __init__(self):
        super(DQN, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(4, 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU()
        )
        self.fc = nn.Sequential(
            nn.Linear(64 * 7 * 7, 512),
            nn.ReLU(),
            nn.Linear(512, 6)
        )

    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)

class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        return random.sample(self.buffer, batch_size)

    def __len__(self):
        return len(self.buffer)

# DQN Agent
class DQNAgent:
    BATCH_SIZE = 32
    EPSILON_START = 1
    EPSILON_CUTOFF = 0.05
    EPSILON_DECAY = 0.9977 # Will take about 1000 episodes to reach 0.1 and 1300 to reach 0.05
    GAMMA = 0.99
    LR = 1e-4
    MEMORY_SIZE = 30000
    total_steps = 0

    def __init__(self):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.policy_net = DQN().to(self.device)
        self.target_net = DQN().to(self.device)
        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.target_net.eval()
        self.optimizer = optim.Adam(self.policy_net.parameters(), lr=self.LR)
        self.memory = ReplayBuffer(self.MEMORY_SIZE)
        self.steps_done = 0
        self.epsilon = self.EPSILON_START
        self.current_episode = 0

    def select_e_greedy_action(self, env, state):
        if random.random() < self.epsilon:
            return random.randint(0, env.action_space.n - 1)
        else:
            state = torch.tensor(state, dtype=torch.float32, device=self.device).unsqueeze(0)
            with torch.no_grad():
                return self.policy_net(state).argmax(dim=1).item()

    def select_greedy_action(self, state):
        state = torch.tensor(state, dtype=torch.float32, device=self.device).unsqueeze(0)
        with torch.no_grad():
            return self.policy_net(state).argmax(dim=1).item()

    def decay_epsilon(self):
        self.total_steps += 1
        self.epsilon = max(self.EPSILON_CUTOFF, self.EPSILON_START * (self.EPSILON_DECAY ** self.current_episode))

    def step(self):
        if len(self.memory) < self.BATCH_SIZE:
            return
        batch = self.memory.sample(self.BATCH_SIZE)

        states, actions, rewards, next_states, dones = zip(*batch)
        states = torch.tensor(np.array(states), dtype=torch.float32, device=self.device)
        actions = torch.tensor(np.array(actions), dtype=torch.int64, device=self.device).unsqueeze(1)
        rewards = torch.tensor(np.array(rewards), dtype=torch.float32, device=self.device).unsqueeze(1)
        next_states = torch.tensor(np.array(next_states), dtype=torch.float32, device=self.device)
        dones = torch.tensor(np.array(dones), dtype=torch.float32, device=self.device).unsqueeze(1)
        current_q = self.policy_net(states).gather(1, actions)
        next_q = self.target_net(next_states).max(1, keepdim=True)[0].detach()
        target_q = rewards + (self.GAMMA * next_q * (1 - dones))

        loss = F.mse_loss(target_q, current_q)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def update_target_network(self):
        self.target_net.load_state_dict(self.policy_net.state_dict())
        
def train_q_values(env, env_name, target_update_interval=10, training_episodes=1000, agent_class=DQNAgent, checkpoint_path=None):
    agent = agent_class()
    rewards_per_episode = []
    epsilon_values = []
    episode_times = []
    start_episode = 0
    early_stop_counter = 0

    # Loading from last checkpoint in case training fails
    drive_path = '/content/drive/MyDrive/dqn_pong_checkpoints/'
    if checkpoint_path is not None:
        checkpoint = torch.load(os.path.join(drive_path, checkpoint_path))
        agent.policy_net.load_state_dict(checkpoint['policy_net_state_dict'])
        agent.target_net.load_state_dict(checkpoint['policy_net_state_dict'])
        agent.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        agent.total_steps = checkpoint['total_steps']
        agent.current_episode = checkpoint['current_episode']
        rewards_per_episode = checkpoint['rewards_per_episode']
        epsilon_values = checkpoint['epsilon_values']
        episode_times = checkpoint['episode_times']
        start_episode = checkpoint['current_episode'] + 1
        agent.epsilon = max(agent.EPSILON_CUTOFF, agent.EPSILON_START * (agent.EPSILON_DECAY ** agent.current_episode))
        print(f"Resumed training from episode {start_episode}")

    for episode in range(start_episode, training_episodes):
        agent.current_episode = episode
        start_time = time.time()
        obs, _ = env.reset()
        state = obs
        total_reward = 0
        done = False
        steps_taken = 0
        episode_transitions = []

        epsilon_values.append(agent.epsilon)

        while not done:
            action = agent.select_e_greedy_action(env, state)
            obs, reward, done, _, _ = env.step(action)
            next_state = np.array(obs, dtype=np.float32)
            episode_transitions.append((state, action, reward, next_state, done))
            state = next_state
            total_reward += reward
            steps_taken += 1

            agent.decay_epsilon()

        for transition in episode_transitions:
            agent.memory.push(*transition)

        for _ in range(len(episode_transitions)):
            agent.step()

        rewards_per_episode.append(total_reward)
        end_time = time.time()
        episode_time = end_time - start_time
        episode_times.append(episode_time)

        if total_reward >= 20:
            early_stop_counter += 1
        else:
            early_stop_counter = 0

        if early_stop_counter >= 15:
            print(f"Early stopping at episode {episode}")
            torch.save(agent.policy_net.state_dict(), os.path.join(drive_path, "dqn_pong_env_v5_early_stop_train.pth"))
            break

        if episode % 10 == 0:
            print(f"Episode {episode + 1}: Total Reward: {total_reward}, Epsilon: {agent.epsilon:.3f}, Steps: {steps_taken}, Time: {episode_time:.2f}s")
            # Saving model
            torch.save(agent.policy_net.state_dict(), os.path.join(drive_path, "dqn_pong_env_v5.pth"))
            # Creating checkpoint dictionary and saving
            checkpoint = {
                'policy_net_state_dict': agent.policy_net.state_dict(),
                'optimizer_state_dict': agent.optimizer.state_dict(),
                'total_steps': agent.total_steps,
                'current_episode': agent.current_episode,
                'rewards_per_episode': rewards_per_episode,
                'epsilon_values': epsilon_values,
                'episode_times': episode_times
            }
            torch.save(checkpoint, os.path.join(drive_path, "dqn_pong_checkpoints.pth"))

        if episode % target_update_interval == 0:
            agent.update_target_network()

    return agent, rewards_per_episode, epsilon_values, episode_times

env_name = "PongNoFrameskip-v4"
pong_env = gym.make(env_name, render_mode="rgb_array", frameskip=1)
pong_env = AtariPreprocessing(
    pong_env,
    frame_skip=4,
    grayscale_obs=True,
    scale_obs=False,
    terminal_on_life_loss=False
)
pong_env = FrameStackObservation(pong_env, stack_size=4)

# Creating folder
drive_path = '/content/drive/MyDrive/dqn_pong_checkpoints/'
os.makedirs(drive_path, exist_ok=True)

# Training
trained_agent, rewards_per_episode, epsilon_values, episode_times = train_q_values(pong_env, env_name, training_episodes=5000)
torch.save(trained_agent.policy_net.state_dict(), os.path.join(drive_path, "dqn_pongnoframeskip_env_v5.pth"))

# Relevant Plots
plt.plot(rewards_per_episode)
plt.xlabel("Iteration")
plt.ylabel("Reward")
plt.title(f"{env_name} Reward Characteristics - Training")
plt.show()

plt.plot(epsilon_values, label=f"Decay rate = {DQNAgent.EPSILON_DECAY}")
plt.xlabel("Episode")
plt.ylabel("Epsilon")
plt.title(f"{env_name} Epsilon Characteristics - Training")
plt.show()

plt.plot(episode_times)
plt.xlabel("Episode")
plt.ylabel("Time (seconds)")
plt.title(f"{env_name} Episode Processing Time")
plt.show()

^C
^C



In [None]:
# Install
!pip install gymnasium[atari]
!pip install autorom[accept-rom-license]
!pip install torch

Collecting autorom[accept-rom-license]
  Downloading AutoROM-0.6.1-py3-none-any.whl.metadata (2.4 kB)
Collecting AutoROM.accept-rom-license (from autorom[accept-rom-license])
  Downloading AutoROM.accept-rom-license-0.6.1.tar.gz (434 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m434.7/434.7 kB[0m [31m6.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Downloading AutoROM-0.6.1-py3-none-any.whl (9.4 kB)
Building wheels for collected packages: AutoROM.accept-rom-license
  Building wheel for AutoROM.accept-rom-license (pyproject.toml) ... [?25l[?25hdone
  Created wheel for AutoROM.accept-rom-license: filename=autorom_accept_rom_license-0.6.1-py3-none-any.whl size=446709 sha256=8ad27fa28155ba78bd88cb2f93a797deb1319fe75689fb38a3326e9704da3a7d
  Stored in directory: /root/.cache/pip/wheels/bc/fc/c6/8aa657c0d208

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# Imports
import gymnasium as gym
from gymnasium.wrappers import AtariPreprocessing, FrameStackObservation
import matplotlib.pyplot as plt
import ale_py
import numpy as np
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from collections import deque
import time
import os

In [None]:
# Register
gym.register_envs(ale_py)

Defining Classes

In [None]:
# DQN
class DQN(nn.Module):
    def __init__(self):
        super(DQN, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(4, 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU()
        )
        self.fc = nn.Sequential(
            nn.Linear(64 * 7 * 7, 512),
            nn.ReLU(),
            nn.Linear(512, 6)
        )

    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)

class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        return random.sample(self.buffer, batch_size)

    def __len__(self):
        return len(self.buffer)

# DQN Agent
class DQNAgent:
    BATCH_SIZE = 32
    EPSILON_START = 1
    EPSILON_CUTOFF = 0.05
    EPSILON_DECAY = 0.9977 # Will take about 1000 episodes to reach 0.1 and 1300 to reach 0.05
    GAMMA = 0.99
    LR = 1e-4
    MEMORY_SIZE = 30000
    total_steps = 0

    def __init__(self):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.policy_net = DQN().to(self.device)
        self.target_net = DQN().to(self.device)
        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.target_net.eval()
        self.optimizer = optim.Adam(self.policy_net.parameters(), lr=self.LR)
        self.memory = ReplayBuffer(self.MEMORY_SIZE)
        self.steps_done = 0
        self.epsilon = self.EPSILON_START
        self.current_episode = 0

    def select_e_greedy_action(self, env, state):
        if random.random() < self.epsilon:
            return random.randint(0, env.action_space.n - 1)
        else:
            state = torch.tensor(state, dtype=torch.float32, device=self.device).unsqueeze(0)
            with torch.no_grad():
                return self.policy_net(state).argmax(dim=1).item()

    def select_greedy_action(self, state):
        state = torch.tensor(state, dtype=torch.float32, device=self.device).unsqueeze(0)
        with torch.no_grad():
            return self.policy_net(state).argmax(dim=1).item()

    def decay_epsilon(self):
        self.total_steps += 1
        self.epsilon = max(self.EPSILON_CUTOFF, self.EPSILON_START * (self.EPSILON_DECAY ** self.current_episode))

    def step(self):
        if len(self.memory) < self.BATCH_SIZE:
            return
        batch = self.memory.sample(self.BATCH_SIZE)

        states, actions, rewards, next_states, dones = zip(*batch)
        states = torch.tensor(np.array(states), dtype=torch.float32, device=self.device)
        actions = torch.tensor(np.array(actions), dtype=torch.int64, device=self.device).unsqueeze(1)
        rewards = torch.tensor(np.array(rewards), dtype=torch.float32, device=self.device).unsqueeze(1)
        next_states = torch.tensor(np.array(next_states), dtype=torch.float32, device=self.device)
        dones = torch.tensor(np.array(dones), dtype=torch.float32, device=self.device).unsqueeze(1)
        current_q = self.policy_net(states).gather(1, actions)
        next_q = self.target_net(next_states).max(1, keepdim=True)[0].detach()
        target_q = rewards + (self.GAMMA * next_q * (1 - dones))

        loss = F.mse_loss(target_q, current_q)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def update_target_network(self):
        self.target_net.load_state_dict(self.policy_net.state_dict())

Training Function

In [None]:
def train_q_values(env, env_name, target_update_interval=10, training_episodes=1000, agent_class=DQNAgent, checkpoint_path=None):
    agent = agent_class()
    rewards_per_episode = []
    epsilon_values = []
    episode_times = []
    start_episode = 0
    early_stop_counter = 0

    # Loading from last checkpoint in case training fails
    drive_path = '/content/drive/MyDrive/dqn_pong_checkpoints/'
    if checkpoint_path is not None:
        checkpoint = torch.load(os.path.join(drive_path, checkpoint_path))
        agent.policy_net.load_state_dict(checkpoint['policy_net_state_dict'])
        agent.target_net.load_state_dict(checkpoint['policy_net_state_dict'])
        agent.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        agent.total_steps = checkpoint['total_steps']
        agent.current_episode = checkpoint['current_episode']
        rewards_per_episode = checkpoint['rewards_per_episode']
        epsilon_values = checkpoint['epsilon_values']
        episode_times = checkpoint['episode_times']
        start_episode = checkpoint['current_episode'] + 1
        agent.epsilon = max(agent.EPSILON_CUTOFF, agent.EPSILON_START * (agent.EPSILON_DECAY ** agent.current_episode))
        print(f"Resumed training from episode {start_episode}")

    for episode in range(start_episode, training_episodes):
        agent.current_episode = episode
        start_time = time.time()
        obs, _ = env.reset()
        state = obs
        total_reward = 0
        done = False
        steps_taken = 0
        episode_transitions = []

        epsilon_values.append(agent.epsilon)

        while not done:
            action = agent.select_e_greedy_action(env, state)
            obs, reward, done, _, _ = env.step(action)
            next_state = np.array(obs, dtype=np.float32)
            episode_transitions.append((state, action, reward, next_state, done))
            state = next_state
            total_reward += reward
            steps_taken += 1

            agent.decay_epsilon()

        for transition in episode_transitions:
            agent.memory.push(*transition)

        for _ in range(len(episode_transitions)):
            agent.step()

        rewards_per_episode.append(total_reward)
        end_time = time.time()
        episode_time = end_time - start_time
        episode_times.append(episode_time)

        if total_reward >= 20:
            early_stop_counter += 1
        else:
            early_stop_counter = 0

        if early_stop_counter >= 15:
            print(f"Early stopping at episode {episode}")
            torch.save(agent.policy_net.state_dict(), os.path.join(drive_path, "dqn_pong_env_v5_early_stop_train.pth"))
            break

        if episode % 10 == 0:
            print(f"Episode {episode + 1}: Total Reward: {total_reward}, Epsilon: {agent.epsilon:.3f}, Steps: {steps_taken}, Time: {episode_time:.2f}s")
            # Saving model
            torch.save(agent.policy_net.state_dict(), os.path.join(drive_path, "dqn_pong_env_v5.pth"))
            # Creating checkpoint dictionary and saving
            checkpoint = {
                'policy_net_state_dict': agent.policy_net.state_dict(),
                'optimizer_state_dict': agent.optimizer.state_dict(),
                'total_steps': agent.total_steps,
                'current_episode': agent.current_episode,
                'rewards_per_episode': rewards_per_episode,
                'epsilon_values': epsilon_values,
                'episode_times': episode_times
            }
            torch.save(checkpoint, os.path.join(drive_path, "dqn_pong_checkpoints.pth"))

        if episode % target_update_interval == 0:
            agent.update_target_network()

    return agent, rewards_per_episode, epsilon_values, episode_times

Training Pong

In [None]:
env_name = "PongNoFrameskip-v4"
pong_env = gym.make(env_name, render_mode="rgb_array", frameskip=1)
pong_env = AtariPreprocessing(
    pong_env,
    frame_skip=4,
    grayscale_obs=True,
    scale_obs=False,
    terminal_on_life_loss=False
)
pong_env = FrameStackObservation(pong_env, stack_size=4)

# Creating folder
drive_path = '/content/drive/MyDrive/dqn_pong_checkpoints/'
os.makedirs(drive_path, exist_ok=True)

# Training
trained_agent, rewards_per_episode, epsilon_values, episode_times = train_q_values(pong_env, env_name, training_episodes=5000)
torch.save(trained_agent.policy_net.state_dict(), os.path.join(drive_path, "dqn_pongnoframeskip_env_v5.pth"))

Episode 1: Total Reward: -20.0, Epsilon: 1.000, Steps: 837, Time: 6.55s
Episode 11: Total Reward: -21.0, Epsilon: 0.977, Steps: 880, Time: 6.10s
Episode 21: Total Reward: -21.0, Epsilon: 0.955, Steps: 869, Time: 6.04s
Episode 31: Total Reward: -21.0, Epsilon: 0.933, Steps: 954, Time: 6.57s
Episode 41: Total Reward: -19.0, Epsilon: 0.912, Steps: 1038, Time: 7.30s
Episode 51: Total Reward: -20.0, Epsilon: 0.891, Steps: 933, Time: 6.55s
Episode 61: Total Reward: -17.0, Epsilon: 0.871, Steps: 1252, Time: 8.67s
Episode 71: Total Reward: -20.0, Epsilon: 0.851, Steps: 863, Time: 6.03s
Episode 81: Total Reward: -21.0, Epsilon: 0.832, Steps: 876, Time: 6.09s
Episode 91: Total Reward: -21.0, Epsilon: 0.813, Steps: 975, Time: 6.90s
Episode 101: Total Reward: -18.0, Epsilon: 0.794, Steps: 1240, Time: 8.71s
Episode 111: Total Reward: -20.0, Epsilon: 0.776, Steps: 898, Time: 6.36s
Episode 121: Total Reward: -21.0, Epsilon: 0.759, Steps: 968, Time: 6.84s
Episode 131: Total Reward: -21.0, Epsilon: 0.7

In [None]:
    # Relevant Plots
    plt.plot(rewards_per_episode)
    plt.xlabel("Iteration")
    plt.ylabel("Reward")
    plt.title(f"{env_name} Reward Characteristics - Training")
    plt.show()

    plt.plot(epsilon_values, label=f"Decay rate = {DQNAgent.EPSILON_DECAY}")
    plt.xlabel("Episode")
    plt.ylabel("Epsilon")
    plt.title(f"{env_name} Epsilon Characteristics - Training")
    plt.show()

    plt.plot(episode_times)
    plt.xlabel("Episode")
    plt.ylabel("Time (seconds)")
    plt.title(f"{env_name} Episode Processing Time")
    plt.show()