In [1]:
# Imports
import gymnasium as gym
from gymnasium.wrappers import AtariPreprocessing, FrameStackObservation
import matplotlib.pyplot as plt
import ale_py
from stable_baselines3.common.atari_wrappers import AtariWrapper
import numpy as np
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from collections import deque
import time

In [2]:
# Register
gym.register_envs(ale_py)

In [3]:
# DQN Class
class DQN(nn.Module):
    def __init__(self):
        super(DQN, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(4, 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU()
        )
        self.fc = nn.Sequential(
            nn.Linear(64 * 7 * 7, 512),
            nn.ReLU(),
            nn.Linear(512, 6)
        )

    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)

class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        return random.sample(self.buffer, batch_size)

    def __len__(self):
        return len(self.buffer)


In [4]:
# DQN Agent
class DQNAgent:
    BATCH_SIZE = 32
    EPSILON_START = 1
    EPSILON_CUTOFF = 0.1
    EPSILON_DECAY = 0.99684
    GAMMA = 0.99
    LR = 1e-4
    MEMORY_SIZE = 10000
    total_steps = 0

    def __init__(self):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.policy_net = DQN().to(self.device)
        self.target_net = DQN().to(self.device)
        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.target_net.eval()
        self.optimizer = optim.Adam(self.policy_net.parameters(), lr=self.LR)
        self.memory = ReplayBuffer(self.MEMORY_SIZE)
        self.steps_done = 0
        self.epsilon = self.EPSILON_START
        self.current_episode = 0

    def select_e_greedy_action(self, env, state):
        if random.random() < self.epsilon:
            return random.randint(0, env.action_space.n - 1)
        else:
            state = torch.tensor(state, dtype=torch.float32, device=self.device).unsqueeze(0)
            with torch.no_grad():
                return self.policy_net(state).argmax(dim=1).item()

    def select_greedy_action(self, state):
        state = torch.tensor(state, dtype=torch.float32, device=self.device).unsqueeze(0)
        with torch.no_grad():
            return self.policy_net(state).argmax(dim=1).item()

    def decay_epsilon(self):
        self.total_steps += 1
        self.epsilon = max(self.EPSILON_CUTOFF, self.EPSILON_START * (self.EPSILON_DECAY ** self.current_episode))

    def step(self):
        if len(self.memory) < self.BATCH_SIZE:
            return
        batch = self.memory.sample(self.BATCH_SIZE)

        states, actions, rewards, next_states, dones = zip(*batch)
        states = torch.tensor(np.array(states), dtype=torch.float32, device=self.device)
        actions = torch.tensor(np.array(actions), dtype=torch.int64, device=self.device).unsqueeze(1)
        rewards = torch.tensor(np.array(rewards), dtype=torch.float32, device=self.device).unsqueeze(1)
        next_states = torch.tensor(np.array(next_states), dtype=torch.float32, device=self.device)
        dones = torch.tensor(np.array(dones), dtype=torch.float32, device=self.device).unsqueeze(1)
        current_q = self.policy_net(states).gather(1, actions)
        next_q = self.target_net(next_states).max(1, keepdim=True)[0].detach()
        target_q = rewards + (self.GAMMA * next_q * (1 - dones))

        loss = F.mse_loss(target_q, current_q)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def update_target_network(self):
        self.target_net.load_state_dict(self.policy_net.state_dict())


In [5]:
# Training Function
def train_q_values(env, env_name, target_update_interval=10, training_episodes=1000, agent_class=DQNAgent, checkpoint_path=None):
    agent = agent_class()
    rewards_per_episode = []
    epsilon_values = []
    episode_times = []
    start_episode = 0

    # Load checkpoint if provided
    if checkpoint_path is not None:
        checkpoint = torch.load(checkpoint_path)
        agent.policy_net.load_state_dict(checkpoint['policy_net_state_dict'])
        agent.target_net.load_state_dict(checkpoint['policy_net_state_dict'])
        agent.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        agent.total_steps = checkpoint['total_steps']
        agent.current_episode = checkpoint['current_episode']
        rewards_per_episode = checkpoint['rewards_per_episode']
        epsilon_values = checkpoint['epsilon_values']
        episode_times = checkpoint['episode_times']
        start_episode = checkpoint['current_episode'] + 1
        agent.epsilon = max(agent.EPSILON_CUTOFF, agent.EPSILON_START * (agent.EPSILON_DECAY ** agent.current_episode))
        print(f"Resumed training from episode {start_episode}")

    for episode in range(start_episode, training_episodes):
        agent.current_episode = episode
        start_time = time.time()
        obs, _ = env.reset()
        state = obs
        total_reward = 0
        done = False
        steps_taken = 0
        episode_transitions = []

        epsilon_values.append(agent.epsilon)

        while not done:
            action = agent.select_e_greedy_action(env, state)
            obs, reward, done, _, _ = env.step(action)
            next_state = np.array(obs, dtype=np.float32)
            episode_transitions.append((state, action, reward, next_state, done))
            state = next_state
            total_reward += reward
            steps_taken += 1

            agent.decay_epsilon()

        for transition in episode_transitions:
            agent.memory.push(*transition)

        for _ in range(len(episode_transitions)):
            agent.step()

        rewards_per_episode.append(total_reward)
        end_time = time.time()
        episode_time = end_time - start_time
        episode_times.append(episode_time)

        if episode % 10 == 0:
            print(f"Episode {episode + 1}: Total Reward: {total_reward}, Epsilon: {agent.epsilon:.3f}, Steps: {steps_taken}, Time: {episode_time:.2f}s")
            # Save model
            torch.save(agent.policy_net.state_dict(), "dqn_pong_env_v5.pth")
            # Save checkpoint
            checkpoint = {
                'policy_net_state_dict': agent.policy_net.state_dict(),
                'optimizer_state_dict': agent.optimizer.state_dict(),
                'total_steps': agent.total_steps,
                'current_episode': agent.current_episode,
                'rewards_per_episode': rewards_per_episode,
                'epsilon_values': epsilon_values,
                'episode_times': episode_times
            }
            torch.save(checkpoint, "dqn_pong_checkpoint.pth")

        if episode % target_update_interval == 0:
            agent.update_target_network()

    plt.plot(rewards_per_episode)
    plt.xlabel("Iteration")
    plt.ylabel("Reward")
    plt.title(f"{env_name} Reward Characteristics - Training")
    plt.show()

    plt.plot(epsilon_values, label=f"Decay rate = {DQNAgent.EPSILON_DECAY}")
    plt.xlabel("Episode")
    plt.ylabel("Epsilon")
    plt.title(f"{env_name} Epsilon Characteristics - Training")
    plt.show()

    plt.plot(episode_times)
    plt.xlabel("Episode")
    plt.ylabel("Time (seconds)")
    plt.title(f"{env_name} Episode Processing Time")
    plt.show()

    return agent

In [6]:
# Pong DQN Agent Training
env_name = "PongNoFrameskip-v4"
pong_env = gym.make(env_name, render_mode="rgb_array", frameskip=1)
pong_env = AtariPreprocessing(
    pong_env,
    frame_skip=4,
    grayscale_obs=True,
    scale_obs=False,
    terminal_on_life_loss=False
)
pong_env = FrameStackObservation(pong_env, stack_size=4)

trained_agent = train_q_values(pong_env, env_name, training_episodes=600)
torch.save(trained_agent.policy_net.state_dict(), "dqn_pongnoframeskip_env_v5.pth")

Episode 1: Total Reward: -21.0, Epsilon: 1.000, Steps: 1049, Time: 34.79s
Episode 11: Total Reward: -20.0, Epsilon: 0.969, Steps: 912, Time: 30.48s
Episode 21: Total Reward: -21.0, Epsilon: 0.939, Steps: 849, Time: 28.66s
Episode 31: Total Reward: -19.0, Epsilon: 0.909, Steps: 1251, Time: 42.55s
Episode 41: Total Reward: -21.0, Epsilon: 0.881, Steps: 846, Time: 29.20s
Episode 51: Total Reward: -20.0, Epsilon: 0.854, Steps: 832, Time: 29.20s
Episode 61: Total Reward: -20.0, Epsilon: 0.827, Steps: 855, Time: 29.98s
Episode 71: Total Reward: -20.0, Epsilon: 0.801, Steps: 1020, Time: 37.02s
Episode 81: Total Reward: -18.0, Epsilon: 0.776, Steps: 1218, Time: 47.32s
Episode 91: Total Reward: -21.0, Epsilon: 0.752, Steps: 926, Time: 36.74s
Episode 101: Total Reward: -17.0, Epsilon: 0.729, Steps: 1228, Time: 47.54s
Episode 111: Total Reward: -20.0, Epsilon: 0.706, Steps: 1099, Time: 42.60s
Episode 121: Total Reward: -18.0, Epsilon: 0.684, Steps: 1081, Time: 43.77s
Episode 131: Total Reward: -1

KeyboardInterrupt: 