In [11]:
import random
import numpy as np
import gymnasium as gym
import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque
import wandb


In [12]:
class DQN(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(DQN, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, action_dim)
        )

    def forward(self, x):
        return self.net(x)


In [13]:
class DQNAgent:
    def __init__(self, state_dim, action_dim, gamma=0.99, lr=1e-3, epsilon_start=1.0, epsilon_end=0.01, epsilon_decay=0.995, enable_wandb=True):
        self.action_dim = action_dim
        self.gamma = gamma # discount factor: how much future rewards are valued
        self.epsilon = epsilon_start # exploration rate: probability of choosing a random action
        self.epsilon_min = epsilon_end # minimum exploration rate
        self.epsilon_decay = epsilon_decay # rate of decay for exploration probability

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = DQN(state_dim, action_dim).to(self.device)
        self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
        self.loss_fn = nn.MSELoss()
        
        # Only watch with WandB if enabled (during training)
        if enable_wandb:
            wandb.watch(self.model, log="all", log_freq=10)

    def select_action(self, state):
        if random.random() < self.epsilon:  # exploration
            return random.randrange(self.action_dim)
        state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
        q_values = self.model(state)
        return torch.argmax(q_values).item()  # exploitation

    def train_step(self, memory, batch_size):
        if len(memory) < batch_size:
            return 0.0

        states, actions, rewards, next_states, dones = memory.sample(batch_size)

        states = torch.FloatTensor(states).to(self.device)
        next_states = torch.FloatTensor(next_states).to(self.device)
        actions = torch.LongTensor(actions).unsqueeze(1).to(self.device)
        rewards = torch.FloatTensor(rewards).unsqueeze(1).to(self.device)
        dones = torch.FloatTensor(dones).unsqueeze(1).to(self.device)

        # Q(s, a)
        q_values = self.model(states).gather(1, actions)

        # Target: r + γ * max_a' Q(next_state, a')
        next_q_values = self.model(next_states).max(1)[0].unsqueeze(1)
        target_q = rewards + (1 - dones) * self.gamma * next_q_values

        loss = self.loss_fn(q_values, target_q.detach())  # Keep .detach()!

        self.optimizer.zero_grad()
        loss.backward()
        
        # ADD THIS ONE LINE:
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
        
        self.optimizer.step()

        # decay epsilon
        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay
            
        return loss.item()

In [14]:

class ReplayBuffer:
    def __init__(self, capacity=10000):
        self.memory = deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        batch = random.sample(self.memory, batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        return (
            np.array(states),
            np.array(actions),
            np.array(rewards),
            np.array(next_states),
            np.array(dones)
        )

    def __len__(self):
        return len(self.memory)


In [15]:
def train_dqn():
    wandb.init(project="dqn-cartpole", name="dqn_experiment_2")
    
    env = gym.make("CartPole-v1")
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.n

    print(f"state dimension: {state_dim}, action dimension: {action_dim}")

    agent = DQNAgent(state_dim, action_dim, lr=1e-4, epsilon_decay=0.995, epsilon_end=0.05)
    memory = ReplayBuffer(10000)
    episodes = 500
    batch_size = 64

    # MOVE CONFIG UPDATE HERE - before training starts
    wandb.config.update({
        "state_dim": state_dim,
        "action_dim": action_dim,
        "gamma": agent.gamma,
        "learning_rate": 1e-4,  # Fixed to match your actual LR
        "epsilon_start": 1.0,
        "epsilon_end": 0.05,    # Fixed to match your actual value
        "epsilon_decay": 0.995, # Fixed to match your actual value
        "batch_size": batch_size,
        "buffer_size": 10000,
        "episodes": episodes
    })

    print("Starting training...")

    for ep in range(episodes):
        state, _ = env.reset()
        total_reward = 0
        episode_loss = 0
        train_steps = 0

        for t in range(500):
            action = agent.select_action(state)
            next_state, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated

            memory.push(state, action, reward, next_state, done)
            loss = agent.train_step(memory, batch_size)
            
            if loss > 0:
                episode_loss += loss
                train_steps += 1

            state = next_state
            total_reward += reward

            if done:
                break

        # Calculate average loss
        avg_loss = episode_loss / train_steps if train_steps > 0 else 0
        
        # FIXED INDENTATION: Log ONCE PER EPISODE
        wandb.log({
            "episode": ep + 1,
            "total_reward": total_reward,
            "epsilon": agent.epsilon,
            "avg_loss": avg_loss,
            "buffer_size": len(memory),
            "episode_length": t + 1
        })

        print(f"Episode {ep+1}, Reward: {total_reward}, Epsilon: {agent.epsilon:.3f}, Avg Loss: {avg_loss:.4f}")

    # Save model AFTER training completes
    torch.save(agent.model.state_dict(), "dqn_model.pth")
    wandb.save("dqn_model.pth")
    wandb.finish()

if __name__ == "__main__":
    train_dqn()

state dimension: 4, action dimension: 2
Starting training...
Episode 1, Reward: 32.0, Epsilon: 1.000, Avg Loss: 0.0000
Episode 2, Reward: 53.0, Epsilon: 0.896, Avg Loss: 1.0925
Episode 3, Reward: 35.0, Epsilon: 0.751, Avg Loss: 1.1046
Episode 4, Reward: 29.0, Epsilon: 0.650, Avg Loss: 1.0659
Episode 5, Reward: 26.0, Epsilon: 0.570, Avg Loss: 1.0489
Episode 6, Reward: 14.0, Epsilon: 0.532, Avg Loss: 1.0501
Episode 7, Reward: 13.0, Epsilon: 0.498, Avg Loss: 1.0589
Episode 8, Reward: 16.0, Epsilon: 0.460, Avg Loss: 1.0748
Episode 9, Reward: 13.0, Epsilon: 0.431, Avg Loss: 1.0980
Episode 10, Reward: 12.0, Epsilon: 0.406, Avg Loss: 1.1360
Episode 11, Reward: 10.0, Epsilon: 0.386, Avg Loss: 1.2318
Episode 12, Reward: 14.0, Epsilon: 0.360, Avg Loss: 1.3489
Episode 13, Reward: 13.0, Epsilon: 0.337, Avg Loss: 1.4463
Episode 14, Reward: 11.0, Epsilon: 0.319, Avg Loss: 1.6021
Episode 15, Reward: 11.0, Epsilon: 0.302, Avg Loss: 1.8723
Episode 16, Reward: 8.0, Epsilon: 0.290, Avg Loss: 2.1414
Episo

[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


Episode 500, Reward: 142.0, Epsilon: 0.050, Avg Loss: 23411.2636


0,1
avg_loss,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▂▂▂▂▁▁█▅▆▆▁▁▃
buffer_size,▁▁▁▂▂▂▂▂▂▂▂▃▃▄▇█████████████████████████
episode,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▄▅▅▅▅▆▆▇▇▇▇▇▇▇█████
episode_length,▂▁▁▁▁▁▁▁▁▁▁▂▃▂▄▂▂▅▃▆▁▁▁▁▁▁▆▅▅▅▅▅▅██▃▂▁▂▄
epsilon,█▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
total_reward,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂▂▂▂▃▁▁▁▁▁▃▄▃▃▄█▄▄▄▇

0,1
avg_loss,23411.26361
buffer_size,10000.0
episode,500.0
episode_length,142.0
epsilon,0.04991
total_reward,142.0


In [16]:
import os
from datetime import datetime
import warnings
warnings.filterwarnings("ignore", category=UserWarning)

# Create a unique folder for videos
video_folder = f"videos/run_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
os.makedirs(video_folder, exist_ok=True)

env = gym.make("CartPole-v1", render_mode="rgb_array")
env = gym.wrappers.TimeLimit(env, max_episode_steps=1000)
env = gym.wrappers.RecordVideo(env, video_folder, episode_trigger=lambda e: True)

# Load the trained model WITH WandB DISABLED
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
agent = DQNAgent(state_dim, action_dim, enable_wandb=False)  # ← ADD THIS PARAMETER
agent.model.load_state_dict(torch.load("dqn_model.pth"))
agent.epsilon = 0  # Disable exploration

# Record multiple episodes
for episode in range(3):
    state, _ = env.reset()
    done = False
    while not done:
        action = agent.select_action(state)
        state, _, terminated, truncated, _ = env.step(action)
        print("terminated", terminated, "truncated", truncated)
        done = terminated or truncated

env.close()

terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated