# Tutorial 18 (JAX): Deep Reinforcement Learning

In [1]:
## Standard libraries
import os
import numpy as np
import math
import json
import random
from functools import partial
from PIL import Image
from collections import defaultdict
from typing import Any

## Imports for plotting
import matplotlib.pyplot as plt
plt.set_cmap('cividis')
%matplotlib inline
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('svg', 'pdf') # For export
from matplotlib.colors import to_rgb
import matplotlib
matplotlib.rcParams['lines.linewidth'] = 2.0
import seaborn as sns
sns.reset_orig()

## tqdm for progress bars
from tqdm.auto import tqdm

## To run JAX on TPU in Google Colab, uncomment the two lines below
# import jax.tools.colab_tpu
# jax.tools.colab_tpu.setup_tpu()

## JAX
import jax
import jax.numpy as jnp
from jax import random as jrandom

## Flax (NN in JAX)
try:
    import flax
except ModuleNotFoundError: # Install flax if missing
    !pip install --quiet flax
    import flax
from flax import linen as nn
from flax.training import train_state, checkpoints

## Optax (Optimizers in JAX)
try:
    import optax
except ModuleNotFoundError: # Install optax if missing
    !pip install --quiet optax
    import optax

## PyTorch
import torch
import torch.utils.data as data
from torch.utils.tensorboard import SummaryWriter

# Path to the folder where the datasets are/should be downloaded (e.g. CIFAR10)
DATASET_PATH = "../../data"
# Path to the folder where the pretrained models are saved
CHECKPOINT_PATH = "../../saved_models/tutorial18_jax"

print("Device:", jax.devices()[0])

  PyTreeDef = type(jax.tree_structure(None))


Device: gpu:0


In [2]:
try:
    import gym
except ModuleNotFoundError:
    !pip install --quiet gym[all]
    import gym

In [3]:
# env with preprocessing
env = gym.make('PongNoFrameskip-v4', 
               # 'ALE/SpaceInvaders-v5', 
               new_step_api=True, 
               full_action_space=False,
               frameskip=1
              )
env = gym.wrappers.AtariPreprocessing(env, 
                                      new_step_api=True,
                                      grayscale_obs=False,
                                      )
env = gym.wrappers.FrameStack(env, 2, new_step_api=True)

  and should_run_async(code)


In [4]:
class ExperienceReplayBuffer:
    
    def __init__(self, capacity=15000):
        self.capacity = capacity
        self.buffer = list()
        
    def add(self, s, s_next, action, reward):
        while len(self.buffer) >= self.capacity:
            self.buffer.pop(0)
        self.buffer.append((s, s_next, action, reward))
        
    def sample(self, batch_size):
        transitions = random.sample(self.buffer, batch_size)
        transitions = tuple(np.stack(t, axis=0) for t in zip(*transitions))
        return transitions
    
    def avg_reward(self):
        rewards = [b[3] for b in self.buffer]
        return sum(rewards) / len(rewards)

In [5]:
buffer = ExperienceReplayBuffer()

In [6]:
def run_model(apply_fn, params, imgs):
    return apply_fn(params, imgs)

@jax.jit
def sample_action(state, imgs, rng, eps=0.0):
    q_vals = run_model(state.apply_fn, state.params, imgs)
    q_argmax = q_vals.argmax(axis=-1)
#     if eps == 0.0:
#         return q_argmax
#     else:
    probs = jax.nn.one_hot(q_argmax, num_classes=q_vals.shape[-1]) * (1 - eps) + eps / q_vals.shape[-1]
    actions = jrandom.categorical(rng, jnp.log(jnp.maximum(probs, 1e-10)), axis=-1)
    return actions

In [7]:
def run_episode(state, buffer, eps=0.0):
    s = env.reset()
    rng = state.rng
    for t in range(env.spec.max_episode_steps):
        rng, step_rng = jrandom.split(rng)
        a = sample_action(state, np.array(s)[None], rng, eps=eps)
        a = a.item()
        s_next, r, terminated, truncated, info = env.step(a)
        buffer.add(s, s_next, a, r)  #  / 5.
        done = terminated or truncated
        if done:
            break
        s = s_next
    state = state.replace(rng=rng)
    return state

In [8]:
class CNN(nn.Module):
    num_actions : int
    
    @nn.compact
    def __call__(self, x):
        x = x.astype(jnp.float32) / 255. * 2. - 1.
        x = jnp.concatenate([x[:,1] - x[:,0], x[:,1]], axis=-1)
        
        x = nn.Conv(16, kernel_size=(7, 7), strides=(4, 4))(x)  # 84 => 21
        x = nn.relu(x)
        x = nn.Conv(32, kernel_size=(5, 5), strides=(2, 2))(x)  # 21 => 10
        x = nn.relu(x)
        x = x.reshape(x.shape[0], -1)
        x = nn.Dense(256)(x)
        x = nn.relu(x)
        x = nn.Dense(self.num_actions,
                     kernel_init=nn.initializers.zeros)(x)
        return x

In [9]:
model = CNN(num_actions=env.action_space.n)
s = env.reset()
params = model.init(jrandom.PRNGKey(0), np.array(s)[None])

  "Core environment is written in old step API which returns one bool instead of two. "


In [10]:
class TrainState(train_state.TrainState):
    # You can further extend the TrainState by any additional part here
    # For example, rng to keep for init, dropout, etc.
    rng : Any = None

In [11]:
state = TrainState.create(apply_fn=model.apply,
                          params=params,
                          rng=jrandom.PRNGKey(42),
                          tx=optax.adam(3e-4))

In [12]:
state = run_episode(state, buffer, eps=1.0)

In [13]:
def clipped_mse(y_true, y_pred):
    diff = jnp.abs(y_true - y_pred)
    return jnp.where(diff > 1.0, diff, diff**2)

In [14]:
# @jax.jit
def train_step(state, state_target, batch):
    s, s_next, action, reward = batch
    def loss_fn(params):
        q_current = run_model(state.apply_fn, params, s)[np.arange(action.shape[0]), 
                                                         action.astype(np.int32)]
        q_next = run_model(state_target.apply_fn, state_target.params, s_next)
        q_target = reward + q_next.max(axis=-1)
        loss = clipped_mse(q_target, q_current)
        return loss.mean()
    
    loss, grads = jax.value_and_grad(loss_fn)(state.params)
    state = state.apply_gradients(grads=grads)
    return state, loss

In [15]:
_, loss = train_step(state, state, buffer.sample(128))
print(loss)

0.0234375


In [22]:
buffer.sample(128)[3]

array([ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
        0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
        0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
        0.,  0.,  0.,  0.,  0.,  0.,  0., -1.,  0.,  0.,  0.,  0.,  0.,
        0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
        0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
        0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
        0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
        0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
        0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.])

In [17]:
num_episodes = 2500
for episode_idx in tqdm(range(num_episodes)):
    if episode_idx % 50 == 0:
        state_target = state
    state = run_episode(state, buffer, eps=0.5 * (1.0 - episode_idx / num_episodes))
    for _ in range(5):
        state, loss = train_step(state, state_target, buffer.sample(128))
    if episode_idx % 10 == 0:
        print(f'[Episode {episode_idx}] Loss: {loss.item():5.3f}, Avg reward: {buffer.avg_reward():3.2f}')

  0%|          | 0/2500 [00:00<?, ?it/s]

[Episode 0] Loss: 0.031, Avg reward: -0.02
[Episode 10] Loss: 0.015, Avg reward: -0.02
[Episode 20] Loss: 0.020, Avg reward: -0.03
[Episode 30] Loss: 0.004, Avg reward: -0.03
[Episode 40] Loss: 0.003, Avg reward: -0.02
[Episode 50] Loss: 0.025, Avg reward: -0.02
[Episode 60] Loss: 0.006, Avg reward: -0.03
[Episode 70] Loss: 0.003, Avg reward: -0.02
[Episode 80] Loss: 0.006, Avg reward: -0.02
[Episode 90] Loss: 0.000, Avg reward: -0.02
[Episode 100] Loss: 0.042, Avg reward: -0.02
[Episode 110] Loss: 0.002, Avg reward: -0.02
[Episode 120] Loss: 0.003, Avg reward: -0.02
[Episode 130] Loss: 0.002, Avg reward: -0.02
[Episode 140] Loss: 0.000, Avg reward: -0.03
[Episode 150] Loss: 0.015, Avg reward: -0.03
[Episode 160] Loss: 0.003, Avg reward: -0.02
[Episode 170] Loss: 0.006, Avg reward: -0.02
[Episode 180] Loss: 0.002, Avg reward: -0.02
[Episode 190] Loss: 0.002, Avg reward: -0.02
[Episode 200] Loss: 0.031, Avg reward: -0.02
[Episode 210] Loss: 0.010, Avg reward: -0.02
[Episode 220] Loss: 0

[Episode 1810] Loss: 0.016, Avg reward: -0.01
[Episode 1820] Loss: 0.019, Avg reward: -0.01
[Episode 1830] Loss: 0.010, Avg reward: -0.01
[Episode 1840] Loss: 0.015, Avg reward: -0.01
[Episode 1850] Loss: 0.021, Avg reward: -0.01
[Episode 1860] Loss: 0.016, Avg reward: -0.01
[Episode 1870] Loss: 0.015, Avg reward: -0.01
[Episode 1880] Loss: 0.013, Avg reward: -0.01
[Episode 1890] Loss: 0.012, Avg reward: -0.01
[Episode 1900] Loss: 0.010, Avg reward: -0.01
[Episode 1910] Loss: 0.032, Avg reward: -0.01
[Episode 1920] Loss: 0.024, Avg reward: -0.01
[Episode 1930] Loss: 0.014, Avg reward: -0.01
[Episode 1940] Loss: 0.012, Avg reward: -0.01
[Episode 1950] Loss: 0.011, Avg reward: -0.01
[Episode 1960] Loss: 0.023, Avg reward: -0.01
[Episode 1970] Loss: 0.016, Avg reward: -0.00
[Episode 1980] Loss: 0.013, Avg reward: -0.01
[Episode 1990] Loss: 0.022, Avg reward: -0.01
[Episode 2000] Loss: 0.029, Avg reward: -0.01
[Episode 2010] Loss: 0.016, Avg reward: -0.01
[Episode 2020] Loss: 0.019, Avg re