Solve Trading Bot example using Rainbow DQN


In [None]:
%%bash

pip install \
  pytorch-lightning==1.6.0 \


#### Import the necessary code libraries

In [2]:
import copy
import torch
import random
import gym
import matplotlib

import numpy as np
import matplotlib.pyplot as plt

import torch.nn.functional as F

from collections import deque, namedtuple
from IPython.display import HTML
from base64 import b64encode

from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data.dataset import IterableDataset
from torch.optim import AdamW

from pytorch_lightning import LightningModule, Trainer

from gym.wrappers import TransformObservation, NormalizeObservation, \
  NormalizeReward, RecordVideo, RecordEpisodeStatistics, AtariPreprocessing


device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
num_gpus = torch.cuda.device_count()

print(f'device {device} num_gpus {num_gpus}')

  from .autonotebook import tqdm as notebook_tqdm


device cuda:0 num_gpus 2


  if not hasattr(tensorboard, "__version__") or LooseVersion(
  np.bool8: (False, True),


#### Create the Deep Q-Network

In [3]:
import math 
from torch.nn.init import kaiming_uniform_, zeros_

class NoisyLinear(nn.Module):

  def __init__(self, in_features, out_features, sigma):
    super(NoisyLinear, self).__init__()
    self.w_mu = nn.Parameter(torch.empty((out_features, in_features)))
    self.w_sigma = nn.Parameter(torch.empty((out_features, in_features)))
    self.b_mu = nn.Parameter(torch.empty((out_features)))
    self.b_sigma = nn.Parameter(torch.empty((out_features)))

    kaiming_uniform_(self.w_mu, a=math.sqrt(5))
    kaiming_uniform_(self.w_sigma, a=math.sqrt(5))
    zeros_(self.b_mu)
    zeros_(self.b_sigma)
    
  def forward(self, x, sigma=0.5):
    if self.training:
      w_noise = torch.normal(0, sigma, size=self.w_mu.size()).to(device)
      b_noise = torch.normal(0, sigma, size=self.b_mu.size()).to(device)
      return F.linear(x, self.w_mu + self.w_sigma * w_noise, self.b_mu + self.b_sigma * b_noise)
    else:
      return F.linear(x, self.W_mu, self.b_mu)

In [4]:
class DQN(nn.Module):

  def __init__(self, hidden_size, obs_shape, n_actions, atoms=51, sigma=0.5):
    super().__init__()
    self.atoms = atoms
    self.n_actions = n_actions
    
    self.conv = nn.Sequential(
      nn.Conv2d(obs_shape[0], 64, kernel_size=3),
      nn.MaxPool2d(kernel_size=4),
      nn.ReLU(),
      nn.Conv2d(64, 64, kernel_size=3),
      nn.MaxPool2d(kernel_size=4),
      nn.ReLU()
    )
    conv_out_size = self._get_conv_out(obs_shape)
    self.head = nn.Sequential(
      NoisyLinear(conv_out_size, hidden_size, sigma=sigma),
      nn.ReLU(),
    )

    self.fc_adv = NoisyLinear(hidden_size, self.n_actions * self.atoms, sigma=sigma) 
    self.fc_value = NoisyLinear(hidden_size, self.atoms, sigma=sigma)

  def _get_conv_out(self, shape):
    conv_out = self.conv(torch.zeros(1, *shape))
    return int(np.prod(conv_out.size()))
  
  def forward(self, x):
    x = self.conv(x.float()).view(x.size()[0], -1)
    x = self.head(x)
    adv = self.fc_adv(x).view(-1, self.n_actions, self.atoms)  # (B, A, N)
    value = self.fc_value(x).view(-1, 1, self.atoms)  # (B, 1, N)
    q_logits = value + adv - adv.mean(dim=1, keepdim=True)  # (B, A, N)
    q_probs = F.softmax(q_logits, dim=-1)  # (B, A, N)
    return q_probs

#### Create the policy

In [6]:
def greedy(state, net, support):
  state = torch.tensor(np.array([state])).to(device)
  q_value_probs = net(state)  # (1, A, N) 
  q_values = (support * q_value_probs).sum(dim=-1)  # (1, A)
  action = torch.argmax(q_values, dim=-1)  # (1, 1)
  action = int(action.item())  # ()
  return action

#### Create the replay buffer

In [8]:
class ReplayBuffer:

  def __init__(self, capacity):
    self.buffer = deque(maxlen=capacity)
    self.priorities = deque(maxlen=capacity)
    self.capacity = capacity
    self.alpha = 0.0  # anneal.
    self.beta = 1.0  # anneal.
    self.max_priority = 0.0

  def __len__(self):
    return len(self.buffer)
  
  def append(self, experience):
    self.buffer.append(experience)
    self.priorities.append(self.max_priority)
  
  def update(self, index, priority):
    if priority > self.max_priority:
      self.max_priority = priority
    self.priorities[index] = priority

  def sample(self, batch_size):
    prios = np.array(self.priorities, dtype=np.float64) + 1e-4 # Stability constant.
    prios = prios ** self.alpha
    probs = prios / prios.sum()

    weights = (self.__len__() * probs) ** -self.beta
    weights = weights / weights.max()

    idx = random.choices(range(self.__len__()), weights=probs, k=batch_size)
    sample = [(i, weights[i], *self.buffer[i]) for i in idx]
    return sample

In [9]:
class RLDataset(IterableDataset):

  def __init__(self, buffer, sample_size=400):
    self.buffer = buffer
    self.sample_size = sample_size
  
  def __iter__(self):
    for experience in self.buffer.sample(self.sample_size):
      yield experience

In [11]:
class TradingBotEnvironment:

  def __init__(self, file, state_size):
    self.file = file


#### Create the Deep Q-Learning algorithm

In [None]:
class DeepQLearning(LightningModule):

  # Initialize.
  def __init__(self, env_name, policy=greedy, capacity=100_000, 
               batch_size=256, lr=1e-3, hidden_size=128, gamma=0.99, 
               loss_fn=F.smooth_l1_loss, optim=AdamW, samples_per_epoch=10_000, 
               sync_rate=10, sigma=0.5, a_start=0.5, a_end=0.0, a_last_episode=100, 
               b_start=0.4, b_end=1.0, b_last_episode=100, n_steps=3, 
               v_min=-10.0, v_max=10.0, atoms=51):
    
    super().__init__()

    self.support = torch.linspace(v_min, v_max, atoms, device=device)  # (N)
    self.delta = (v_max - v_min) / (atoms - 1)

    self.env = TradingBotEnvironment(env_name)

    obs_size = self.env.observation_space.shape
    n_actions = self.env.action_space.n

    self.q_net = DQN(hidden_size, obs_size, n_actions, atoms=atoms, sigma=sigma)

    self.target_q_net = copy.deepcopy(self.q_net)

    self.policy = policy
    self.buffer = ReplayBuffer(capacity=capacity)

    self.save_hyperparameters()

    while len(self.buffer) < self.hparams.samples_per_epoch:
      print(f"{len(self.buffer)} samples in experience buffer. Filling...")
      self.play_episode()
    
  @torch.no_grad()
  def play_episode(self, policy=None):
    state = self.env.reset()
    done = False
    transitions = []

    while not done:
      if policy:
        action = policy(state, self.q_net, self.support)
      else:
        action = self.env.action_space.sample()
      
      next_state, reward, done, info = self.env.step(action)
      exp = (state, action, reward, done, next_state)
      transitions.append(exp)
      state = next_state

    for i, (s, a, r, d, ns) in enumerate(transitions):
      batch = transitions[i:i+self.hparams.n_steps]
      ret = sum([t[2] * self.hparams.gamma**j for j, t in enumerate(batch)])
      _, _, _, ld, ls = batch[-1]
      self.buffer.append((s, a, ret, ld, ls))

  # Forward.
  def forward(self, x):
    return self.q_net(x)

  # Configure optimizers.
  def configure_optimizers(self):
    q_net_optimizer = self.hparams.optim(self.q_net.parameters(), lr=self.hparams.lr)
    return [q_net_optimizer]

  # Create dataloader.
  def train_dataloader(self):
    dataset = RLDataset(self.buffer, self.hparams.samples_per_epoch)
    dataloader = DataLoader(
        dataset=dataset,
        batch_size=self.hparams.batch_size,
        num_workers=36
    )
    return dataloader

  # Training step.
  def training_step(self, batch, batch_idx):
    indices, weights, states, actions, returns, dones, next_states = batch
    returns = returns.unsqueeze(1)
    dones = dones.unsqueeze(1)
    batch_size = len(indices)

    q_value_probs = self.q_net(states)  # (B, A, N)

    action_value_probs = q_value_probs[range(batch_size), actions, :]  # (B, N)
    log_action_value_probs = torch.log(action_value_probs + 1e-6)  # (B, N)

    with torch.no_grad():
      next_q_value_probs = self.q_net(next_states)  # (B, A, N)
      next_q_values = (next_q_value_probs * self.support).sum(dim=-1)  # (B, A)
      next_actions = next_q_values.argmax(dim=-1)  # (B,)

      next_q_value_probs = self.target_q_net(next_states)  # (B, A, N)
      next_action_value_probs = next_q_value_probs[range(batch_size), next_actions, :]  # (B, N)

    m = torch.zeros(batch_size * self.hparams.atoms, device=device, dtype=torch.float64)  # (B * N)

    Tz = returns + ~dones * self.hparams.gamma**self.hparams.n_steps * self.support.unsqueeze(0)  # (B, N)

    Tz.clamp_(min=self.hparams.v_min, max=self.hparams.v_max)  # (B, N)
    b = (Tz - self.hparams.v_min) / self.delta  # (B, N)
    l, u = b.floor().long(), b.ceil().long()  # (B, N)

    offset = torch.arange(batch_size, device=device).view(-1, 1) * self.hparams.atoms  # (B, 1)

    l_idx = (l + offset).flatten()  # (B * N)
    u_idx = (u + offset).flatten()  # (B * N)
    
    upper_probs = (next_action_value_probs * (u - b)).flatten()  # (B * N)
    lower_probs = (next_action_value_probs * (b - l)).flatten()  # (B * N)

    m.index_add_(dim=0, index=l_idx, source=upper_probs)
    m.index_add_(dim=0, index=u_idx, source=lower_probs)

    m = m.reshape(batch_size, self.hparams.atoms)  # (B, N)

    cross_entropies = - (m * log_action_value_probs).sum(dim=-1)  # (B,)

    for idx, e in zip(indices, cross_entropies):
      self.buffer.update(idx, e.detach().item())

    loss = (weights * cross_entropies).mean()

    self.log('episode/Q-Error', loss)
    return loss

  # Training epoch end.
  def training_epoch_end(self, training_step_outputs):
    alpha = max(
        self.hparams.a_end,
        self.hparams.a_start - self.current_epoch / self.hparams.a_last_episode
    )
    beta = min(
        self.hparams.b_end,
        self.hparams.b_start + self.current_epoch / self.hparams.b_last_episode
    )
    self.buffer.alpha = alpha
    self.buffer.beta = beta

    self.play_episode(policy=self.policy)
    self.log('episode/Return', self.env.return_queue[-1])

    if self.current_epoch % self.hparams.sync_rate == 0:
      self.target_q_net.load_state_dict(self.q_net.state_dict())

#### Purge logs and run the visualization tool (Tensorboard)

In [None]:
!rm -r ./lightning_logs/
%load_ext tensorboard
%tensorboard --logdir ./lightning_logs/

#### Train the policy

In [None]:
algo = DeepQLearning(
  'QbertNoFrameskip-v4',
  lr=0.0001,
  sigma=0.5,
  hidden_size=512,
  a_last_episode=2_000,
  b_last_episode=2_000,
  n_steps=8,
)

trainer = Trainer(
  strategy='dp',
  accelerator='gpu',
  devices=1,
  max_epochs=2_400,
  log_every_n_steps=1
)

trainer.fit(algo)