Notebook created by [Víctor Campos](https://imatge.upc.edu/web/people/victor-campos) for UPC ETSETB AAL 2019

Updates:

[Xavier Giró](https://imatge.upc.edu/web/people/xavier-giro) - UPC ETSETB AAL 2019

[Juan José Nieto](https://www.linkedin.com/in/juan-jose-nieto-salas/) - UPC School - AIDL Spring 2021

# DQN example in PyTorch

This notebook is adapted from the [official DQN tutorial](https://pytorch.org/tutorials/intermediate/reinforcement_q_learning.html). Unlike the tutorial, we will use the standard observation instead of the RGB images.

## Installing dependencies


In [None]:
!pip install gym wandb pygame --quiet

# install utilities for rendering OpenAI Gym videos in Colab
!apt-get -qq install -y xvfb x11-utils
!pip install pyvirtualdisplay==0.2.* \
             PyOpenGL==3.1.* \
             PyOpenGL-accelerate==3.1.* \
             --quiet


## Setting up the environment

In [None]:
import base64
import glob
import io
import os
import math
import timeit
import warnings

from IPython.display import HTML
from IPython.display import display

In [None]:
import gym
import wandb
import random

import numpy as np
from random import randint
from datetime import datetime
from collections import namedtuple

import matplotlib.pyplot as plt

In [None]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# starting a fake screen in the background
#  in order to render videos
os.system("Xvfb :1 -screen 0 1024x768x24 &")
os.environ["DISPLAY"] = ":1"

# utility to get video file from directory
def get_video_filename(dir="video"):
  glob_mp4 = os.path.join(dir, "*.mp4") 
  mp4list = glob.glob(glob_mp4)
  assert len(mp4list) > 0, "couldnt find video files"
  return mp4list[-1]

## Visualize a random policy in the environment

Our goal is to train an agent that is capable of solving the CartPole problem, where a pole is attached to a cart moving along a horizontal track. The agent can interact with the environment by applying a force (+1/-1) to the cart. The episode is terminated whenever the pole is more than 15 degrees from vertical or the cart goes out of bounds in the horizontal axis. The agent receives +1 reward for each timestep under the desired conditions.

We can visualize what a random policy would do in this environment:

In [None]:
env = gym.make("CartPole-v1", render_mode="rgb_array")

env = gym.wrappers.RecordVideo(env, "./video")

ob, done, total_rew = env.reset(return_info=True), False, 0

while not done:
  env.render()
  
  ac = env.action_space.sample()
  
  ob, rew, done, info = env.step(ac)
  
  total_rew += rew
  
print('Cumulative reward:', total_rew)
  
env.close()

# Log in to your Wandb account

In [None]:
wandb.login()

# Visualize random policy in Wandb

In [None]:
PROJECT = "AIDL-Spring-DRL"

In [None]:
wandb.init(project=PROJECT)
wandb.run.name = 'cartpole_random_agent'
mp4 = get_video_filename()
wandb.log({"Video eval": wandb.Video(mp4, fps=4, format="mp4")})
wandb.finish()

## Replay memory

The buffer will be a FIFO queue: when full, oldest experiences are removed to make room for new transitions.

**Exercise #1.** Implement the pointer to the next position to be filled in the replay memory, which corresponds to a FIFO queue.
(TIP: remember the [modulus % operator](https://python-reference.readthedocs.io/en/latest/docs/operators/modulus.html)).

In [None]:
Transition = namedtuple(
    'Transition', ('state', 'action', 'next_state', 'reward'))


class ReplayMemory(object):
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0

    def push(self, *args):
        """Saves a transition."""
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = Transition(*args)
        
        # TODO: Update the pointer to the next position in the replay memory
        self.position = ...

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

## Model definition

Now we will define our policy, parameterized by a feedforward neural network.

In [None]:
class DQN(nn.Module):
    def __init__(self, inputs, outputs, hidden_size=128):
        super(DQN, self).__init__()
        self.affine1 = nn.Linear(inputs, hidden_size)
        self.affine2 = nn.Linear(hidden_size, outputs)

    def forward(self, x):
        x = self.affine1(x)
        x = F.relu(x)
        x = self.affine2(x)
        return x

## Functions for collecting experience and updating the policy

**Exercise #2.** Complete eps_greedy policy to facilitate the exploration.

**Exercise #3.** Complete with `policy_net` or `target_net` the `TODO_net` in the code.

**Exercise #4.** Compute the TD target.

In [None]:
def compute_eps_threshold(step, eps_start, eps_end, eps_decay):
  return eps_end + (eps_start - eps_end) * math.exp(-1. * step / eps_decay)


def select_action(policy, state, eps_greedy_threshold, n_actions):
    # TODO: Select action using an epsilon-greedy strategy
    if random.random() ...

      with torch.no_grad():
            # t.max(1) will return largest column value of each row.
            # second column on max result is index of where max element was
            # found, so we pick action with the larger expected reward.
            
            action = policy(state).max(1)[1].view(1,1)
            
    else:
      action = torch.tensor(
          [[random.randrange(n_actions)]], device=device, dtype=torch.long)
    return action

    
def train(policy_net, target_net, optimizer, memory, batch_size, gamma):
    if len(memory) < batch_size:
        return 0
    transitions = memory.sample(batch_size)
    # Transpose the batch (see https://stackoverflow.com/a/19343/3343043 for
    # detailed explanation). This converts batch-array of Transitions
    # to Transition of batch-arrays.
    batch = Transition(*zip(*transitions))

    # Compute a mask of non-final states and concatenate the batch elements
    # (a final state would've been the one after which simulation ended)
    non_final_mask = torch.tensor(
        tuple(map(lambda s: s is not None, batch.next_state)), 
        device=device, 
        dtype=torch.bool)
    
    non_final_next_states = torch.cat(
        [s for s in batch.next_state if s is not None])
    
    state_batch = torch.cat(batch.state)
    action_batch = torch.cat(batch.action)
    reward_batch = torch.cat(batch.reward)

    # TODO: Compute Q(s_t, a) - the model computes Q(s_t) for all a, then we select 
    # the columns of actions taken. These are the actions which would've been 
    # taken for each batch state according to policy_net
    state_action_values = TODO_net(state_batch).gather(1, action_batch)

    next_state_values = torch.zeros(batch_size, device=device)

    # TODO : Compute Q(s_{t+1}) for all next states.
    # Expected values of actions for non_final_next_states are computed based
    # on the "older" target_net; selecting their best reward with max(1)[0].
    # This is merged based on the mask, such that we'll have either the expected
    # state value or 0 in case the state was final.
    # Note the call to detach() on Q(s_{t+1}), which prevents gradient flow
    next_state_values[non_final_mask] = TODO_net(non_final_next_states).max(1)[0].detach()


    # TODO: Compute targets for Q values: y_t = r_t + gamma*max(Q_{t+1})
    expected_state_action_values = TODO

    # Compute Huber loss between predicted Q values and targets y
    loss = F.smooth_l1_loss(
        state_action_values, expected_state_action_values.unsqueeze(1))

    # Take an SGD step
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    return loss.item()
    
    
def test(env, policy, video_path='./video', render=False):
    state, ep_reward, done = env.reset(), 0, False
    while not done:
        if render:
          env.render()
        state = torch.from_numpy(state).float().unsqueeze(0).to(device)
        action = select_action(policy_net, state, 0., 1)
        state, reward, done, info = env.step(action.item())
        ep_reward += reward

    env.close()
    mp4 = get_video_filename(video_path)
    wandb.log({"Video eval": wandb.Video(mp4, fps=4, format="mp4")})
    return ep_reward

## Training the agent

In [None]:
hparams = {
    'gamma' : 0.99,             # discount factor
    'log_interval' : 25,        # controls how often we log progress, in episodes
    'num_steps': 60000,         # number of steps to train on
    'batch_size': 256,          # batch size for optimization
    'lr' : 1e-4,                # learning rate
    'eps_start': 1.0,           # initial value for epsilon (in epsilon-greedy)
    'eps_end': 0.1,             # final value for epsilon (in epsilon-greedy)
    'eps_decay': 20000,         # length of epsilon decay, in env steps
    'target_update': 1000,      # how often to update target net, in env steps
    'replay_size': 10000,       # replay memory size
}


In [None]:
# Create environment
env_name = 'CartPole-v1'
env = gym.make(env_name, render_mode="rgb_array", new_step_api=True)

In [None]:
# Get number of actions from gym action space
n_inputs = env.observation_space.shape[0]
n_actions = env.action_space.n

**Exercise #5.** Complete the call to `memory_push`.

In [None]:
# Fix random seed (for reproducibility)
seed = 543
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

# Initialize wandb run
wandb.finish() # execute to avoid overlapping runnings (advice: later remove duplicates in wandb)
wandb.init(project=PROJECT, config=hparams)
wandb.run.name = 'dqn_cartpole_train_0'


# Initialize policy and target networks
policy_net = DQN(n_inputs, n_actions).to(device)
target_net = DQN(n_inputs, n_actions).to(device)
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()

optimizer = torch.optim.Adam(policy_net.parameters(), lr=hparams['lr'])
memory = ReplayMemory(hparams['replay_size'])

print(f"Target reward: {env.spec.reward_threshold}")
step_count = 0
running_reward = 0

ep_rew_history = []
i_episode, ep_reward = 0, -float('inf')
while step_count < hparams['num_steps']:
    # Initialize the environment and state
    state, done = env.reset(), False
    state = torch.from_numpy(state).float().unsqueeze(0).to(device)
    reward_episode = 0
    losses = []
    while not done:
        # Select an action
        eps_greedy_threshold = compute_eps_threshold(
            step_count, hparams['eps_start'], hparams['eps_end'], hparams['eps_decay'])
        action = select_action(
            policy_net, state, eps_greedy_threshold, n_actions)

        # Perform action in env
        next_state, reward, terminated, truncated, info = env.step(action.item())
        done = terminated or truncated

        # Bookkeeping
        if done:
            # train() treats states as terminal when next_state is None
            next_state = None
        else:
            next_state = torch.from_numpy(next_state).float().unsqueeze(0).to(device)
        
        reward = torch.tensor([reward], device=device)
        step_count += 1

        # TODO: Store the transition in memory
        memory.TODO

        # Move to the next state
        state = next_state

        # Reward episode
        reward_episode += reward.item()

        # Perform one step of the optimization (on the policy network)
        loss = train(policy_net, target_net, optimizer, memory, hparams['batch_size'], hparams['gamma'])
        losses.append(loss)
        # Update the target network, copying all weights and biases in DQN
        if step_count % hparams['target_update'] == 0:
            target_net.load_state_dict(policy_net.state_dict())

    i_episode += 1

    running_reward = 0.05 * reward_episode + (1 - 0.05) * running_reward
    wandb.log(
        {
        'loss': np.mean(losses),
        'running_reward': running_reward,
        'ep_reward': reward_episode,
        'epsilon': eps_greedy_threshold # log last epsilon of the episode
        })
    
    # Evaluate greedy policy
    if i_episode % hparams['log_interval'] == 0:
        video_path = datetime.now().isoformat(timespec="seconds")
        test_env = gym.wrappers.RecordVideo(env, f"./{video_path}")
        ep_reward = test(test_env, policy_net, video_path=video_path)
        ep_rew_history.append(ep_reward)
        print(f'Episode {i_episode}\tSteps: {step_count/1000:.2f}k\tEval reward: {ep_reward}\tRunning reward: {running_reward:.2f}')

wandb.finish()
print(f"Finished training! Eval reward: {ep_reward}")
if not os.path.exists('checkpoints'):
    os.makedirs('checkpoints')
torch.save(policy_net.state_dict(), f'checkpoints/dqn-{env_name}.pt')

In [None]:
plt.plot(np.arange(len(ep_rew_history)), ep_rew_history)
plt.xlabel('Episode')
plt.ylabel('Reward')