In [5]:
import gym
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
from collections import deque
import matplotlib.pyplot as plt
import os

# -------------------------------
# Q-Network definition
# -------------------------------
class QNetwork(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim=256, num_layers=2):
        super(QNetwork, self).__init__()
        layers = []
        # Input layer
        layers.append(nn.Linear(input_dim, hidden_dim))
        layers.append(nn.ReLU())
        # Hidden layers
        for _ in range(num_layers - 1):
            layers.append(nn.Linear(hidden_dim, hidden_dim))
            layers.append(nn.ReLU())
        # Output layer
        layers.append(nn.Linear(hidden_dim, output_dim))
        self.model = nn.Sequential(*layers)
        # Initialize weights uniformly between -0.001 and 0.001
        self.apply(self.init_weights)

    def init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.uniform_(m.weight, -0.001, 0.001)
            if m.bias is not None:
                nn.init.uniform_(m.bias, -0.001, 0.001)

    def forward(self, x):
        return self.model(x)

# -------------------------------
# Replay Buffer for mini-batch updates
# -------------------------------
class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)
    
    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))
    
    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        states, actions, rewards, next_states, dones = map(np.array, zip(*batch))
        return states, actions, rewards, next_states, dones
    
    def __len__(self):
        return len(self.buffer)

# -------------------------------
# Preprocessing function
# -------------------------------
def preprocess_state(state, env_name):
    # For Assault-ram-v5, the observation is a uint8 array in [0,255]
    if "Assault-ram" in env_name:
        return np.array(state, dtype=np.float32) / 255.0
    else:
        return np.array(state, dtype=np.float32)

# -------------------------------
# Single-step update without replay buffer
# -------------------------------
def train_episode(env, model, optimizer, algorithm, epsilon, gamma, device, env_name):
    state = preprocess_state(env.reset(), env_name)
    total_reward = 0.0
    done = False

    while not done:
        state_tensor = torch.FloatTensor(state).to(device)
        q_values = model(state_tensor)
        # ε–greedy action selection
        if random.random() < epsilon:
            action = env.action_space.sample()
        else:
            action = torch.argmax(q_values).item()

        next_state, reward, done, _ = env.step(action)
        next_state = preprocess_state(next_state, env_name)
        total_reward += reward

        # Compute target value
        state_tensor_next = torch.FloatTensor(next_state).to(device)
        with torch.no_grad():
            q_next = model(state_tensor_next)
        if done:
            target = reward
        else:
            if algorithm == "q_learning":
                target = reward + gamma * torch.max(q_next).item()
            elif algorithm == "expected_sarsa":
                num_actions = env.action_space.n
                # ε–greedy probabilities for next state:
                greedy_action = torch.argmax(q_next).item()
                probs = np.ones(num_actions) * (epsilon / num_actions)
                probs[greedy_action] += (1 - epsilon)
                expected_q = (q_next.cpu().numpy() * probs).sum()
                target = reward + gamma * expected_q

        # Loss: squared error
        loss = (q_values[action] - target) ** 2

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        state = next_state
    return total_reward

# -------------------------------
# Mini-batch update using replay buffer
# -------------------------------
def update_model_from_replay(model, optimizer, replay_buffer, algorithm, epsilon, gamma, batch_size, device):
    if len(replay_buffer) < batch_size:
        return
    states, actions, rewards, next_states, dones = replay_buffer.sample(batch_size)
    states = torch.FloatTensor(states).to(device)
    actions = torch.LongTensor(actions).to(device)
    rewards = torch.FloatTensor(rewards).to(device)
    next_states = torch.FloatTensor(next_states).to(device)
    dones = torch.FloatTensor(dones.astype(np.float32)).to(device)

    q_values = model(states)  # shape: [batch, num_actions]
    q_selected = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)

    with torch.no_grad():
        q_next = model(next_states)  # shape: [batch, num_actions]
        if algorithm == "q_learning":
            target_next = q_next.max(1)[0]
        elif algorithm == "expected_sarsa":
            num_actions = q_next.size(1)
            greedy_actions = q_next.argmax(1, keepdim=True)
            # Build ε–greedy probability distribution
            probs = torch.ones_like(q_next) * (epsilon / num_actions)
            probs.scatter_(1, greedy_actions, 1 - epsilon + (epsilon / num_actions))
            target_next = (q_next * probs).sum(dim=1)
    targets = rewards + gamma * target_next * (1 - dones)
    loss = torch.mean((q_selected - targets) ** 2)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

# -------------------------------
# Single episode using replay buffer updates
# -------------------------------
def train_episode_replay(env, model, optimizer, algorithm, epsilon, gamma, device, replay_buffer, batch_size, env_name):
    state = preprocess_state(env.reset(), env_name)
    total_reward = 0.0
    done = False

    while not done:
        state_tensor = torch.FloatTensor(state).to(device)
        q_values = model(state_tensor)
        if random.random() < epsilon:
            action = env.action_space.sample()
        else:
            action = torch.argmax(q_values).item()

        next_state, reward, done, _ = env.step(action)
        next_state = preprocess_state(next_state, env_name)
        total_reward += reward

        replay_buffer.push(state, action, reward, next_state, done)
        state = next_state

        # Perform a mini-batch update every step
        update_model_from_replay(model, optimizer, replay_buffer, algorithm, epsilon, gamma, batch_size, device)
    return total_reward

# -------------------------------
# Run training for one full run (one seed) over many episodes
# -------------------------------
import tqdm
def run_training(env_name, algorithm, epsilon, lr, num_episodes, gamma, device, replay, batch_size, seed):
    env = gym.make(env_name)
    env.seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)

    input_dim = env.observation_space.shape[0]
    output_dim = env.action_space.n

    model = QNetwork(input_dim, output_dim, hidden_dim=256, num_layers=2).to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)

    if replay:
        replay_buffer = ReplayBuffer(capacity=int(1e6))
    else:
        replay_buffer = None

    episode_rewards = []
    # for ep in range(num_episodes):
    for ep in tqdm.tqdm(range(num_episodes)):
        if replay:
            total_reward = train_episode_replay(env, model, optimizer, algorithm, epsilon, gamma, device, replay_buffer, batch_size, env_name)
        else:
            total_reward = train_episode(env, model, optimizer, algorithm, epsilon, gamma, device, env_name)
        episode_rewards.append(total_reward)
    env.close()
    return episode_rewards

In [6]:
# Create output directory for plots
%matplotlib inline
os.makedirs("plots", exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
env_names = ["Acrobot-v1", "ALE/Assault-ram-v5"]
algorithms = ["q_learning", "expected_sarsa"]
# Example ε values; adjust as desired.
epsilons = [0.1, 0.2, 0.3]
# Step–sizes (learning rates): 1/4, 1/8, 1/16
lrs = [0.25, 0.125, 0.0625]
num_runs = 50       # number of seeds per configuration
num_episodes = 1000 # episodes per run (set smaller for debugging)
gamma = 0.99
batch_size = 64

# We will aggregate results in a dictionary
results = {}

# Loop over environments and replay settings; for each we create one combined plot.
for env_name in env_names:
    for replay in [False, True]:
        plt.figure(figsize=(12, 8))
        legend_entries = []
        # Loop over algorithms, epsilons, and learning rates.
        for algorithm in algorithms:
            for epsilon in epsilons:
                for lr in lrs:
                    config_key = (env_name, replay, algorithm, epsilon, lr)
                    all_rewards = []
                    for run in range(num_runs):
                        seed = run  # or any seed scheduling
                        rewards = run_training(env_name, algorithm, epsilon, lr, num_episodes, gamma, device, replay, batch_size, seed)
                        all_rewards.append(rewards)
                    all_rewards = np.array(all_rewards)  # shape: (num_runs, num_episodes)
                    results[config_key] = all_rewards
                    mean_rewards = np.mean(all_rewards, axis=0)
                    std_rewards = np.std(all_rewards, axis=0)
                    episodes = np.arange(num_episodes)
                    
                    # Determine color and linestyle:
                    color = "green" if algorithm == "q_learning" else "red"
                    # Use learning rate to determine line style:
                    if np.isclose(lr, lrs[0]):
                        linestyle = "-"
                    elif np.isclose(lr, lrs[1]):
                        linestyle = "--"
                    else:
                        linestyle = ":"
                    # Here, we also encode epsilon in the label.
                    label = f"{algorithm} (ε={epsilon}, lr={lr})"
                    plt.plot(episodes, mean_rewards, color=color, linestyle=linestyle, label=label)
                    plt.fill_between(episodes, mean_rewards - std_rewards, mean_rewards + std_rewards, color=color, alpha=0.1)
                    legend_entries.append(label)
        plt.xlabel("Episode")
        plt.ylabel("Return")
        replay_str = "Replay" if replay else "No Replay"
        plt.title(f"{env_name} - {replay_str} Buffer")
        plt.legend(fontsize=8, loc="lower right", ncol=2)
        plot_filename = f"plots/{env_name}_{replay_str.replace(' ', '').lower()}_combined.png"
        plt.savefig(plot_filename)
        plt.show()
        plt.close()
        print(f"Saved plot: {plot_filename}")

100%|██████████| 1000/1000 [12:03<00:00,  1.38it/s]
100%|██████████| 1000/1000 [10:54<00:00,  1.53it/s]
100%|██████████| 1000/1000 [10:12<00:00,  1.63it/s]
100%|██████████| 1000/1000 [10:13<00:00,  1.63it/s]
100%|██████████| 1000/1000 [10:11<00:00,  1.64it/s]
100%|██████████| 1000/1000 [10:06<00:00,  1.65it/s]


KeyboardInterrupt: 

<Figure size 1200x800 with 0 Axes>