In [69]:
%load_ext autoreload
%autoreload 2

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
from collections import deque
import gymnasium as gym
import torch.optim as optim
from collections import namedtuple, deque
import random
import random

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [70]:
LR = 1e-4
EPS_END = 0.005
EPS_START = 1
BATCH_SIZE = 128 # expirement with this more
EPS_DECAY = 1000 # does this cause faster or slower update?
TAU = 0.05

In [71]:
Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward'))

class ReplayMemory:
    def __init__(self, capacity):
        self.memory = deque(maxlen=capacity)
        
    def append(self, *args):
        self.memory.append(Transition(*args))
        
    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)
    
    def __getitem__(self, i):
        return self.memory[i]

class DQN(nn.Module):
    def __init__(self, n_observations, n_actions):
        super().__init__()

        self.layers = nn.Sequential(
            nn.Linear(n_observations, 128),
            nn.ReLU(),
            
            nn.Linear(128, 128),
            nn.ReLU(),
            
            nn.Linear(128, n_actions),
        )
        
    def forward(self, x):
        return self.layers(x)

In [72]:
env = gym.make("CartPole-v1")

n_actions = env.action_space.n
n_observations = env.observation_space.shape[0]

policy_net = DQN(n_observations, n_actions)
target_net = DQN(n_observations, n_actions)
target_net.load_state_dict(policy_net.state_dict())

opt = optim.AdamW(policy_net.parameters(), lr=LR, amsgrad=True)
criterion = nn.HuberLoss()
replay_memory = ReplayMemory(10000)

In [73]:
steps_done = 0


def select_action(state):
    global steps_done

    eps = EPS_END + (EPS_START - EPS_END) * np.exp(
        -1.0 * steps_done / EPS_DECAY
    )  # try to look at how eps is updated
    steps_done += 1
    print(f"eps: {eps}, steps_done: {steps_done}")

    if random.random() < eps:
        torch.tensor([[env.action_space.sample()]], dtype=torch.float32)

    with torch.no_grad():  # attempt without this
        return torch.argmax(policy_net(state)).view(
            1, 1
        )  # why this weird shape? why not get the item directly

In [None]:
def optimize_model():
    if len(replay_memory) < BATCH_SIZE:
        return

    transitions = replay_memory.sample(BATCH_SIZE)
    batch = Transition(*zip(*transitions))

    non_final_mask = torch.tensor(
        list(map(lambda state: state is not None, batch.next_state))
    )

    non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])

    state_batch = torch.cat(batch.state)
    action_batch = torch.cat(batch.action)
    reward_batch = torch.cat(batch.reward)

    q_values = policy_net(state_batch).gather(1, action_batch)

    next_q_values = torch.zeros(BATCH_SIZE)
    next_q_values[non_final_mask] = (
        target_net(non_final_next_states).max(1).values.detach()
    )

    expected_q_values = reward_batch + next_q_values

    loss = criterion(q_values, expected_q_values)

    opt.zero_grad()

    # test
    grads = [p.grad for p in policy_net.parameters() if p.grad is not None]
    all_grads = torch.cat([g.view(-1) for g in grads])
    print("Before clipping:")
    print("Max grad:", all_grads.max().item())
    print("Min grad:", all_grads.min().item())
    print("Mean grad:", all_grads.mean().item())
    print("Std grad:", all_grads.std().item())

    nn.utils.clip_grad_value_(policy_net.parameters(), 100) # test without it

    grads = [p.grad for p in policy_net.parameters() if p.grad is not None]
    all_grads = torch.cat([g.view(-1) for g in grads])
    print("After clipping:")
    print("Max grad:", all_grads.max().item())
    print("Min grad:", all_grads.min().item())
    print("Mean grad:", all_grads.mean().item())
    print("Std grad:", all_grads.std().item())
    # end test

    loss.backward()
    opt.step()

In [77]:
num_episodes = 1000

total_rewards = []
for episode in range(num_episodes):
    state, _ = env.reset()
    state = torch.tensor(state, dtype=torch.float32).unsqueeze(0)  # (1, 4)

    done = False
    R = 0
    while not done:
        action = select_action(state)  # (1, 1)

        next_state, reward, terminated, truncated, _ = env.step(action.item())
        done = terminated or truncated

        reward = torch.tensor(reward, dtype=torch.float32).unsqueeze(0)  # (1, 1)

        next_state = None if done else torch.tensor(next_state, dtype=torch.float32).unsqueeze(0)

        replay_memory.append(state, action, next_state, reward)

        state = next_state

        optimize_model()

        target_net_state_dict = target_net.state_dict()
        policy_net_state_dict = policy_net.state_dict()

        for key in target_net_state_dict:  # inspect how this key looks like
            target_net_state_dict[key] = (
                TAU * policy_net_state_dict[key]
                + (1 - TAU) * target_net_state_dict[key]
            )
        R += reward
    
    total_rewards += R
    print(R)

plt.plot(total_rewards)

eps: 0.8804541122489206, steps_done: 129
[]


  return F.huber_loss(input, target, reduction=self.reduction, delta=self.delta)


RuntimeError: torch.cat(): expected a non-empty list of Tensors

In [None]:
for key in target_net.state_dict():
    print('hi')

hi
hi
hi
hi
hi
hi
