In [11]:
import random
import numpy as np
import gymnasium as gym
import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque
import wandb


In [12]:
class DQN(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(DQN, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, action_dim)
        )

    def forward(self, x):
        return self.net(x)


In [13]:

class DQNAgent:
    def __init__(self, state_dim, action_dim, gamma=0.99, lr=1e-3, epsilon_start=1.0, epsilon_end=0.01, epsilon_decay=0.995):
        self.action_dim = action_dim
        self.gamma = gamma
        self.epsilon = epsilon_start
        self.epsilon_min = epsilon_end
        self.epsilon_decay = epsilon_decay

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = DQN(state_dim, action_dim).to(self.device)
        self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
        self.loss_fn = nn.MSELoss()

        wandb.watch(self.model, log="all", log_freq=10)

    def select_action(self, state):
        if random.random() < self.epsilon:  # exploration
            return random.randrange(self.action_dim)
        state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
        q_values = self.model(state)
        return torch.argmax(q_values).item()  # exploitation

    def train_step(self, memory, batch_size):
        if len(memory) < batch_size:
            return 0.0# wait until buffer has enough samples

        states, actions, rewards, next_states, dones = memory.sample(batch_size)

        states = torch.FloatTensor(states).to(self.device)
        next_states = torch.FloatTensor(next_states).to(self.device)
        actions = torch.LongTensor(actions).unsqueeze(1).to(self.device)
        rewards = torch.FloatTensor(rewards).unsqueeze(1).to(self.device)
        dones = torch.FloatTensor(dones).unsqueeze(1).to(self.device)

        # Q(s, a)
        q_values = self.model(states).gather(1, actions)

        # Target: r + Î³ * max_a' Q(next_state, a')
        next_q_values = self.model(next_states).max(1)[0].unsqueeze(1)
        target_q = rewards + (1 - dones) * self.gamma * next_q_values

        loss = self.loss_fn(q_values, target_q.detach())

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # decay epsilon
        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay

        return loss.item() 


In [14]:

class ReplayBuffer:
    def __init__(self, capacity=10000):
        self.memory = deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        batch = random.sample(self.memory, batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        return (
            np.array(states),
            np.array(actions),
            np.array(rewards),
            np.array(next_states),
            np.array(dones)
        )

    def __len__(self):
        return len(self.memory)


In [None]:


def train_dqn():

    wandb.init(project="dqn-cartpole", name="dqn_experiment_1")
    
    env = gym.make("CartPole-v1")
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.n

    # Corrected print statement
    print(f"state dimension: {state_dim}, action dimension: {action_dim}")

    agent = DQNAgent(state_dim, action_dim, lr=5e-4, epsilon_decay=0.99)
    memory = ReplayBuffer(10000)
    episodes = 500
    batch_size = 64
# In your train_dqn() function, REPLACE the training loop:

    print("Starting training...")

    for ep in range(episodes):
        state, _ = env.reset()
        total_reward = 0
        episode_loss = 0
        train_steps = 0

        for t in range(500):
            action = agent.select_action(state)
            next_state, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated

            memory.push(state, action, reward, next_state, done)
            loss = agent.train_step(memory, batch_size)  # Now returns loss
            
            if loss > 0:  # Only count when training happened
                episode_loss += loss
                train_steps += 1

            state = next_state
            total_reward += reward

            if done:
                break

        # Calculate average loss
        avg_loss = episode_loss / train_steps if train_steps > 0 else 0
        
        # ADD THESE LINES for WandB logging:
        wandb.log({
            "episode": ep + 1,
            "total_reward": total_reward,
            "epsilon": agent.epsilon,
            "avg_loss": avg_loss,
            "buffer_size": len(memory)
        })

        print(f"Episode {ep+1}, Reward: {total_reward}, Epsilon: {agent.epsilon:.3f}, Avg Loss: {avg_loss:.4f}")

        # ADD THESE LINES before torch.save:
    wandb.config.update({
        "state_dim": state_dim,
        "action_dim": action_dim,
        "gamma": agent.gamma,
        "learning_rate": 5e-4,
        "epsilon_start": 1.0,
        "epsilon_end": 0.01,
        "epsilon_decay": 0.99,
        "batch_size": batch_size,
        "buffer_size": 10000,
        "episodes": episodes
    })

    torch.save(agent.model.state_dict(), "dqn_model.pth")
    wandb.save("dqn_model.pth")  # <-- ADD THIS to save model to WandB
    wandb.finish()  # <-- ADD THIS at the very end


if __name__ == "__main__":
    train_dqn()


Starting training...


NameError: name 'episodes' is not defined

In [None]:

import os
from datetime import datetime


# Suppress warnings
import warnings
warnings.filterwarnings("ignore", category=UserWarning)

# Create a unique folder for videos
video_folder = f"videos/run_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
os.makedirs(video_folder, exist_ok=True)

env = gym.make("CartPole-v1", render_mode="rgb_array")
env = gym.wrappers.TimeLimit(env, max_episode_steps=1000)  # Increase max steps
env = gym.wrappers.RecordVideo(env, video_folder, episode_trigger=lambda e: True)

# Load the trained model
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
agent = DQNAgent(state_dim, action_dim)
agent.model.load_state_dict(torch.load("dqn_model.pth"))
agent.epsilon = 0  # Disable exploration

# Record multiple episodes
for episode in range(3):  # Record 3 episodes
    state, _ = env.reset()
    done = False
    while not done:
        action = agent.select_action(state)
        state, _, terminated, truncated, _ = env.step(action)
        print("terminated", terminated, "truncated", truncated)
        done = terminated or truncated

# Close the environment to avoid cleanup errors
env.close()




terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated

  IMAGEMAGICK_BINARY = r"C:\Program Files\ImageMagick-6.8.8-Q16\magick.exe"


terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated False truncated False
terminated