# DQN

## Preliminary

In [1]:
import gymnasium as gym
import math
import random
from itertools import count
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from tqdm import tqdm 

from IPython import display
import wandb

import os
import sys
import time

# trick to import from relative path
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

from replay_memory import ReplayMemory, Transition
from nb_utils.widgets import ArrayRenderWidget


# if GPU is to be used
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
%matplotlib inline

In [3]:
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mmax-schik[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [4]:
config = {
    "num_episodes": 2000,
    "batch_size": 128,
    "gamma": 0.99,
    "eps_start": 0.9,
    "eps_end": 0.05,
    "eps_decay": 1000,
    "tau": 0.005,
    "lr": 1e-4,
    "env": "Acrobot-v1",
}

In [5]:
env = gym.make(config["env"])

In [6]:
class DQN(nn.Module):
    def __init__(self, n_observations, n_actions):
        super(DQN, self).__init__()
        self.layer1 = nn.Linear(n_observations, 128)
        self.layer2 = nn.Linear(128, 128)
        self.layer3 = nn.Linear(128, n_actions)

    def forward(self, x):
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        return self.layer3(x)

In [7]:
class DQNAgent:
    def __init__(self, env, target_net, policy_net, memory, update_interval=1):
        self.env = env
        self.target_net = target_net
        self.policy_net = policy_net
        self.memory = memory
        self.update_interval = update_interval
        self.optimizer = optim.AdamW(self.policy_net.parameters(), lr=config["lr"], amsgrad=True)
        
        self.t_step = 0

    def step(self, state, action, reward, next_state, done):
        memory.push(state, action, next_state, reward)
        self.t_step = (self.t_step + 1) % self.update_interval

        if self.t_step == 0 and len(self.memory) > config["batch_size"]:
            batch = self.memory.sample(config["batch_size"])
            return self.learn(batch)

    def act(self, state, eps=0.):
        sample = random.random()
        if sample > eps:
            with torch.no_grad():
                return self.policy_net(state).max(1).indices.view(1, 1)
        else:
            return torch.tensor([[self.env.action_space.sample()]], device=device, dtype=torch.long)                

    def learn(self, batch):
        # This converts batch-array of Transitions to Transition of batch-arrays.
        batch = Transition(*zip(*batch))

        # Compute a mask of non-final states and concatenate the batch elements
        # (a final state would've been the one after which simulation ended)
        non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
                                              batch.next_state)), device=device, dtype=torch.bool)
        non_final_next_states = torch.cat([s for s in batch.next_state
                                                    if s is not None])
        state_batch = torch.cat(batch.state)
        action_batch = torch.cat(batch.action)
        reward_batch = torch.cat(batch.reward)
    
        # Compute Q(s_t, a) - the model computes Q(s_t), then we select the
        # columns of actions taken. These are the actions which would've been taken
        # for each batch state according to policy_net
        state_action_values = self.policy_net(state_batch).gather(1, action_batch)
    
        # Compute V(s_{t+1}) for all next states.
        # Expected values of actions for non_final_next_states are computed based
        # on the "older" target_net; selecting their best reward with max(1).values
        # This is merged based on the mask, such that we'll have either the expected
        # state value or 0 in case the state was final.
        next_state_values = torch.zeros(config["batch_size"], device=device)
        with torch.no_grad():
            next_state_values[non_final_mask] = self.target_net(non_final_next_states).max(1).values
        # Compute the expected Q values
        expected_state_action_values = reward_batch + (next_state_values * config["gamma"])
    
        # Compute Huber loss
        criterion = nn.SmoothL1Loss()
        loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))
    
        # Optimize the model
        self.optimizer.zero_grad()
        loss.backward()
        # In-place gradient clipping
        torch.nn.utils.clip_grad_value_(self.policy_net.parameters(), 100)
        self.optimizer.step()

        self.soft_update()
    
        return loss

    def soft_update(self):
        # Soft update of the target network's weights
        # θ′ ← τ θ + (1 −τ )θ′
        target_net_state_dict = self.target_net.state_dict()
        policy_net_state_dict = self.policy_net.state_dict()
        for key in policy_net_state_dict:
            target_net_state_dict[key] = policy_net_state_dict[key]*config["tau"] + target_net_state_dict[key]*(1-config["tau"])
        self.target_net.load_state_dict(target_net_state_dict)

In [8]:
n_actions = env.action_space.n

state, info = env.reset()
n_observations = len(state)

policy_net = DQN(n_observations, n_actions).to(device)
target_net = DQN(n_observations, n_actions).to(device)
target_net.load_state_dict(policy_net.state_dict())

memory = ReplayMemory(10000)
agent = DQNAgent(env, target_net, policy_net, memory)

In [9]:
wandb.init(
    project=f"learn-rl.dqn.{config['env']}",
    config=config
)

step = 0

for i_episode in tqdm(range(config["num_episodes"])):
    # Initialize the environment and get its state
    state, info = env.reset()
    state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)

    rewards = 0
    losses = []
    
    for t in count():
        eps = config["eps_end"] + (config["eps_start"] - config["eps_end"]) * \
              math.exp(-1. * step / config["eps_decay"])
        action = agent.act(state, eps)
        step += 1
        observation, reward, terminated, truncated, _ = env.step(action.item())
        reward = torch.tensor([reward], device=device)
        done = terminated or truncated


        if terminated:
            next_state = None
        else:
            next_state = torch.tensor(observation, dtype=torch.float32, device=device).unsqueeze(0)

        loss = agent.step(state, action, reward, next_state, done)

        # Move to the next state
        state = next_state
        
        rewards += reward
        if loss is not None:
            losses.append(loss.detach().cpu().numpy())

        if done:
            wandb.log({"reward": rewards, "loss": sum(losses) / max(len(losses), 1), "length": t+1})
            break

wandb.finish()

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2000/2000 [06:39<00:00,  5.01it/s]


VBox(children=(Label(value='0.003 MB of 0.003 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
length,█▁▂▃▁▁█▁▁▁▁▁▂▂▂▂▁▁▂▂▁▁▁▁▁▁▁▁▃▁▁▁▁▁▁▂▁▁▁▁
loss,▁▄▇▇▇██▇█▇█▇▆▆▆▅▅▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▃▃▃▃▃▄
reward,▁█▇▆██▁█████▇▇▇▇██▇▇████████▆██████▇████

0,1
length,75.0
loss,0.26157
reward,-74.0


In [10]:
torch.save(policy_net.state_dict(), f"../weights/dqn_{config['env']}.pt")

In [11]:
image_widget = ArrayRenderWidget(
    format='png',
    width=600,
    height=400,
)

In [12]:
env = gym.make(config["env"], render_mode="rgb_array")

observation, info = env.reset()

display.display(image_widget)
image_widget.render(env.render())
for _ in range(1000):
    image_widget.render(env.render())


    # action = env.action_space.sample()
    rewards = policy_net(torch.tensor(observation).to(device))
    action = torch.argmax(rewards)
    action = action.detach().cpu().numpy()
    
    observation, reward, terminated, truncated, info = env.step(action)
    time.sleep(1/30)

    if terminated or truncated:
        observation, info = env.reset()
        break

env.close()

ArrayRenderWidget(value=b'', height='400', width='600')