# Tutorial 18 (JAX): Deep Reinforcement Learning

In [1]:
## Standard libraries
import os
import numpy as np
import math
import json
from functools import partial
from PIL import Image
from collections import defaultdict
from typing import Any, Callable
from types import SimpleNamespace
from copy import deepcopy
from statistics import mean
import pickle

## Imports for plotting
import matplotlib.pyplot as plt
plt.set_cmap('cividis')
%matplotlib inline
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('svg', 'pdf') # For export
from matplotlib.colors import to_rgb
import matplotlib
matplotlib.rcParams['lines.linewidth'] = 2.0
import seaborn as sns
sns.reset_orig()

## tqdm for progress bars
from tqdm.auto import tqdm

## To run JAX on TPU in Google Colab, uncomment the two lines below
# import jax.tools.colab_tpu
# jax.tools.colab_tpu.setup_tpu()

## JAX
import jax
import jax.numpy as jnp
from jax import random

## Flax (NN in JAX)
try:
    import flax
except ModuleNotFoundError: # Install flax if missing
    !pip install --quiet flax
    import flax
from flax import linen as nn
from flax.training import checkpoints
from flax.training.train_state import TrainState

## Optax (Optimizers in JAX)
try:
    import optax
except ModuleNotFoundError: # Install optax if missing
    !pip install --quiet optax
    import optax

## PyTorch
import torch
import torch.utils.data as data
from torch.utils.tensorboard import SummaryWriter

# Path to the folder where the datasets are/should be downloaded (e.g. CIFAR10)
DATASET_PATH = "../../data"
# Path to the folder where the pretrained models are saved
CHECKPOINT_PATH = "../../saved_models/tutorial18_jax"

print("Device:", jax.devices()[0])

  PyTreeDef = type(jax.tree_structure(None))


Device: gpu:0


In [2]:
try:
    import gym
except ModuleNotFoundError:
    !pip install --quiet gym[all]
    import gym

In [3]:
def jax_to_np_rng(rng):
    return random.randint(rng, shape=(1,), minval=0, maxval=int(1e6), dtype=np.int64).item()

  and should_run_async(code)


In [4]:
class MazeEnv(gym.Env):
    
    def __init__(self, grid_size = 5):
        super().__init__()
        self.grid_size = grid_size
        self.action_space = gym.spaces.Discrete(4)
        self.rng = np.random.default_rng()
        self.spec = SimpleNamespace(max_episode_steps=100)
        
    def reset(self, seed=None):
        if isinstance(seed, jnp.ndarray):
            seed = jax_to_np_rng(seed)
        super().reset(seed=seed)
        if seed is not None:
            self.rng = np.random.default_rng(seed)
        self.pos = self.rng.integers(0, high=self.grid_size, size=(2,))
        self.goal = None
        while self.goal is None or self._is_on_goal():
            self.goal = self.rng.integers(0, high=self.grid_size, size=(2,))
        self.grid = np.ones((self.grid_size, self.grid_size, 3), dtype=np.uint8) * 255
        return self._render_img()
    
    def _render_img(self):
        grid = np.copy(self.grid)
        grid[self.pos[0], self.pos[1], 1:] = 0
        grid[self.goal[0], self.goal[1], ::2] = 0
        return grid
    
    def _is_on_goal(self):
        return all([self.pos[i] == self.goal[i] for i in range(2)])
    
    def step(self, action):
        if action == 0: # Up
            self.pos[0] = max(0, self.pos[0] - 1)
        elif action == 1: # Right
            self.pos[1] = min(self.grid_size - 1, self.pos[1] + 1)
        elif action == 2: # Down
            self.pos[0] = min(self.grid_size - 1, self.pos[0] + 1)
        elif action == 3: # Left
            self.pos[1] = max(0, self.pos[1] - 1)
        
        done = self._is_on_goal()
        reward = 1 if done else -0.1
        state = self._render_img()
        
        return state, reward, done, False, {}
        
    def close(self):
        pass

In [None]:
jax_to_np_rng(random.PRNGKey(10))

In [None]:
# # env with preprocessing
# env = gym.make('PongNoFrameskip-v4', 
#                # 'ALE/SpaceInvaders-v5', 
#                new_step_api=True, 
#                full_action_space=False,
#                frameskip=1
#               )
# env = gym.wrappers.AtariPreprocessing(env, 
#                                       new_step_api=True,
#                                       grayscale_obs=False,
#                                       )
# env = gym.wrappers.FrameStack(env, 2, new_step_api=True)
env = MazeEnv(grid_size = 10)

## Deep Q-Network

In [None]:
class Backbone(nn.Module):
    
    @nn.compact
    def __call__(self, x):
        x = x.astype(jnp.float32) / 255. * 2. - 1.
        if len(x.shape) >= 5:
            x = jnp.concatenate([x[:,1] - x[:,0], x[:,1]], axis=-1)
        
        if False:
            x = nn.Conv(16, kernel_size=(7, 7), strides=(4, 4))(x)  # 84 => 21
            x = nn.relu(x)
            x = nn.Conv(32, kernel_size=(5, 5), strides=(2, 2))(x)  # 21 => 10
            x = nn.relu(x)
        else:
            x = nn.Conv(16, kernel_size=(3, 3), strides=(1, 1))(x) 
            x = nn.relu(x)
        x = x.reshape(x.shape[0], -1)
        x = nn.Dense(256)(x)
        x = nn.relu(x)
        return x

In [None]:
class DQN(nn.Module):
    num_actions : int
    
    @nn.compact
    def __call__(self, x):
        x = Backbone()(x)
        q = nn.Dense(self.num_actions,
                     kernel_init=nn.initializers.zeros)(x)
        return q

### Experience Replay

In [None]:
class ExperienceReplayBuffer:
    
    def __init__(self, capacity=25000):
        self.capacity = capacity
        self.buffer = list()
        self.probs = list()
        
    def add(self, s, s_next, action, reward, done, prob=1.0):
        while len(self.buffer) >= self.capacity:
            self.buffer.pop(0)
            self.probs.pop(0)
        self.buffer.append((s, s_next, action, reward, done))
        self.probs.append(prob)
        
    def sample(self, batch_size):
        probs = np.array(self.probs)
        probs = probs / probs.sum()
        trans_idxs = np.random.choice(len(self.buffer), 
                                      batch_size, 
                                      replace=len(self.buffer) <= batch_size,
                                      p=probs)
        transitions = [self.buffer[idx] for idx in trans_idxs]
        transitions = tuple(np.stack(t, axis=0) for t in zip(*transitions))
        return transitions
    
    def avg_reward(self):
        rewards = [b[3] for b in self.buffer]
        return sum(rewards) / len(rewards)
    
    def clear(self):
        self.buffer.clear()
        self.probs.clear()
        
    def __len__(self):
        return len(self.buffer)

### Training setup

In [None]:
@jax.jit
def sample_action(state, inps, rng, eps=0.0):
    if len(inps.shape) == 2:
        q_vals = inps
    else:
        q_vals = state.apply_fn(state.params, inps)
    # Epsilon-greedy policy
    q_argmax = q_vals.argmax(axis=-1)
    probs = jax.nn.one_hot(q_argmax, num_classes=q_vals.shape[-1]) * (1 - eps) + eps / q_vals.shape[-1]
    actions = random.categorical(rng, jnp.log(jnp.maximum(probs, 1e-10)), axis=-1)
    # q-value of selected action
    act_q_vals = q_vals[jnp.arange(actions.shape[0]), actions]
    return actions, act_q_vals

@jax.jit
def apply_model(state, params, imgs):
    return state.apply_fn(params, imgs)

In [None]:
def td_error_func(reward, q_vals, target_q, done, gamma):
    target = reward + (1 - done) * gamma * target_q.max(axis=-1)
    error = target - q_vals
    return error

In [None]:
def clipped_mse(error):
    diff = jnp.abs(error)
    return jnp.where(diff > 1.0, diff, diff**2)

In [None]:
def eval_episode(state, env, seed, eps=0.0):
    s = env.reset(seed)
    rng = random.PRNGKey(seed)
    rew_return = 0.0
    for t in range(env.spec.max_episode_steps):
        rng, step_rng = random.split(rng)
        a, _ = sample_action(state, np.array(s)[None], step_rng, eps=eps)
        a = a.item()
        s, r, terminated, truncated, _ = env.step(a)
        rew_return += r
        done = terminated or truncated
        if done:
            return (True, t, rew_return)
    return (False, t, rew_return)

def eval_policy(state, env):
    wins, steps, returns = zip(*[eval_episode(state, env, seed=i, eps=0.0) for i in range(100)])
    win_prop = mean(wins)
    avg_steps = mean(steps)
    avg_return = mean(returns)
    unsolved = [i for i in range(len(wins)) if not wins[i]]
    return {'wins': win_prop, 'steps': avg_steps, 'returns': avg_return, 'unsolved': unsolved}

In [None]:
class QTrainer:
    
    def __init__(self, 
                 env : gym.Env, 
                 seed : int = 42,
                 lr : float = 3e-4,
                 gamma : float = 0.99,
                 train_freq : int = 4,
                 eval_freq : int = 10000,
                 target_freq : int = 2000,
                 batch_size : int = 64,
                 model_name : str = "DQN"):
        super().__init__()
        self.env = deepcopy(env)
        self.train_freq = train_freq
        self.eval_freq = eval_freq
        self.target_freq = target_freq
        self.batch_size = batch_size
        self.lr = lr
        self.gamma = gamma
        self.create_model()
        self.init(seed)
        self.create_train_step()
        self.buffer = ExperienceReplayBuffer()
        self.log_dir = os.path.join(CHECKPOINT_PATH,
                                    f'{model_name}/seed_{seed}/')
        self.logger = SummaryWriter(self.log_dir)
        self.all_evals = {}
        
    def create_model(self):
        self.model = DQN(num_actions=self.env.action_space.n)
        
    def init(self, seed):
        rng = random.PRNGKey(seed)
        rng, init_rng, env_rng = random.split(rng, 3)
        s = self.env.reset(env_rng)
        params = self.model.init(init_rng, s[None])
        optimizer = optax.adam(self.lr)
        self.state = TrainState.create(apply_fn=self.model.apply,
                                       params=params,
                                       tx=optimizer)
        self.rng = rng
        
    def create_train_step(self):
        self.create_loss_fn()
        def train_step(state, target_params, batch):
            loss, grads = jax.value_and_grad(self.loss_fn)(state.params,
                                                           state,
                                                           target_params,
                                                           batch)
            state = state.apply_gradients(grads=grads)
            return state, loss
        self.train_step = jax.jit(train_step)
        
    def create_loss_fn(self):
        def loss_fn(params, state, target_params, batch):
            s, s_next, action, reward, done = batch
            q_current = state.apply_fn(params, s)
            q_next = state.apply_fn(target_params, s_next)
            q_current = q_current[np.arange(action.shape[0]), 
                                  action.astype(np.int32)]
            error = td_error_func(reward, q_current, q_next, done, self.gamma)
            loss = clipped_mse(error)
            return loss.mean()
        self.loss_fn = loss_fn
        
    def init_environment_state(self):
        self.rng, env_rng = random.split(self.rng, 2)
        s = self.env.reset(env_rng)
        self.env_state = {
            'last_state': s
        }
        
    def take_environment_step(self, eps):
        self.rng, step_rng = random.split(self.rng)
        s = self.env_state['last_state']
        a, _ = sample_action(self.state, 
                             s[None], 
                             step_rng, 
                             eps=eps)
        a = a.item()
        s_next, r, done, _, _ = self.env.step(a)
        self.buffer.add(s, s_next, a, r, done)
        if done:
            self.rng, env_rng = random.split(self.rng)
            s_next = self.env.reset(env_rng)
        self.env_state['last_state'] = s_next
    
    def train_model(self, num_steps=500000):
        self.target_params = self.state.params
        best_return = -9e15
        eval_env = deepcopy(self.env)
        self.init_environment_state()
        losses = []
        for step_idx in tqdm(range(1, num_steps + 1)):
            self.take_environment_step(eps=1.0 - step_idx / num_steps)
            if len(self.buffer) < 1000:
                continue
            
            if step_idx % self.train_freq == 0:
                batch = self.buffer.sample(self.batch_size)
                self.state, loss = self.train_step(self.state, self.target_params, batch)
                losses.append(loss.item())
                if len(losses) >= 50:
                    self.logger.add_scalar('train/loss', mean(losses), global_step=step_idx)
                    losses.clear()
            
            if step_idx % self.eval_freq == 0:
                eval_dict = eval_policy(self.state, eval_env)
                self.logger.add_scalar('val/wins', 
                                       eval_dict['wins'], 
                                       global_step=step_idx)
                self.logger.add_scalar('val/steps', 
                                       eval_dict['steps'], 
                                       global_step=step_idx)
                self.logger.add_scalar('val/returns', 
                                       eval_dict['returns'], 
                                       global_step=step_idx)
                self.save_eval(eval_dict, step_idx)
                if eval_dict['returns'] > best_return:
                    best_return = eval_dict['returns']
                    self.save_model(step_idx)
                
            if step_idx % self.target_freq == 0:
                self.target_params = self.state.params
                
    def save_eval(self, eval_dict, step):
        self.all_evals[step] = eval_dict
        with open(os.path.join(self.log_dir, 'evals.pik'), 'wb') as f:
            pickle.dump(self.all_evals, f)
            
    def load_eval(self):
        with open(os.path.join(self.log_dir, 'evals.pik'), 'rb') as f:
            self.all_evals = pickle.load(f)
        
    def save_model(self, step=0):
        # Save current model at certain training iteration
        checkpoints.save_checkpoint(ckpt_dir=self.log_dir,
                                    target=self.state.params,
                                    step=step,
                                    overwrite=True)
        
    def load_model(self):
        params = checkpoints.restore_checkpoint(ckpt_dir=self.log_dir, target=None)
        self.state = self.state.replace(params=params)

In [None]:
def train_models(trainer_class, env, num_seeds):
    trainers = []
    for seed in range(num_seeds):
        trainer = trainer_class(env=env, seed=42 + seed)
        trainer.train_model(num_steps=250000)
        trainers.append(trainer)
    return trainers

In [None]:
q_trainers = train_models(QTrainer, env, num_seeds=3)

### Visualizing learned Q function

In [None]:
def visualize_value_func(state, seed):
    env.reset(seed)
    print(env.pos)
    env_states = []
    for i in range(env.grid_size):
        for j in range(env.grid_size):
            env.pos = np.array([i, j])
            env_states.append(env._render_img())
    env_states = np.stack(env_states, axis=0)        
    q_vals = state.apply_fn(state.params, env_states)
    v_vals = q_vals.max(axis=-1)
    v_vals = v_vals.reshape(env.grid_size, env.grid_size)
    v_vals = jax.device_get(v_vals)
    fig, ax = plt.subplots(1, 2, figsize=(8, 8))
    ax[0].imshow(v_vals)
    ax[1].imshow(env_states[0])
    plt.show()
    print(v_vals)
    
# visualize_value_func(state, seed=eval_dict['unsolved'][0])

## Improving DQN

### Prioritized Experience Replay

In [None]:
class QPrioTrainer(QTrainer):
    
    def __init__(self, *args, 
                 prio_alpha : float = 0.6,
                 prio_eps : float = 1e-5,
                 model_name : str = 'DQN_Prio',
                 **kwargs):
        super().__init__(*args, 
                         model_name=model_name, 
                         **kwargs)
        self.prio_alpha = prio_alpha
        self.prio_eps = prio_eps
        
    def get_prio_td_error(self, r, q_vals, s_next, done):
        target_q = apply_model(self.state, self.target_params, s_next[None])
        td_error = td_error_func(r, q_vals, target_q, done, self.gamma)
        return td_error
        
    def take_environment_step(self, eps):
        self.rng, step_rng = random.split(self.rng)
        s = self.env_state['last_state']
        a, q_vals = sample_action(self.state, 
                                  s[None], 
                                  step_rng, 
                                  eps=eps)
        a = a.item()
        s_next, r, done, _, _ = self.env.step(a)
        td_error = self.get_prio_td_error(r, q_vals, s_next, done)
        sample_prob = (abs(td_error.item()) + self.prio_eps) ** self.prio_alpha
        self.buffer.add(s, s_next, a, r, done, prob=sample_prob)
        if done:
            self.rng, env_rng = random.split(self.rng)
            s_next = self.env.reset(env_rng)
        self.env_state['last_state'] = s_next

In [None]:
qprio_trainers = train_models(QPrioTrainer, env, num_seeds=3)

### Double Q-Learning

Just description on Double Q-Learning, we can combine it with the next one

In [None]:
def doubleQ_error_func(reward, q_vals, q_next, target_q, done, gamma):
    target_q = target_q[jnp.arange(target_q.shape[0]),
                        q_next.argmax(axis=-1)]
    target = reward + (1 - done) * gamma * target_q
    error = target - q_vals
    return error

In [None]:
class DoubleQTrainer(QPrioTrainer):
    
    def __init__(self, *args,
                 model_name : str = 'DoubleDQN',
                 **kwargs):
        super().__init__(*args, model_name=model_name, **kwargs)
    
    def create_loss_fn(self):
        def loss_fn(params, state, target_params, batch):
            s, s_next, action, reward, done = batch
            q = state.apply_fn(params, jnp.concatenate([s, s_next], axis=0))
            q_current, q_next = q.split(2, axis=0)
            q_target = state.apply_fn(target_params, s_next)
            q_current = q_current[np.arange(action.shape[0]), 
                                  action.astype(np.int32)]
            error = doubleQ_error_func(reward, q_current, q_next, q_target, done, self.gamma)
            loss = clipped_mse(error)
            return loss.mean()
        self.loss_fn = loss_fn
        
    def get_prio_td_error(self, r, q_vals, s_next, done):
        target_q = apply_model(self.state, self.target_params, s_next[None])
        orig_next_q = apply_model(self.state, self.state.params, s_next[None])
        self.env_state['last_qvals'] = orig_next_q
        td_error = doubleQ_error_func(r, q_vals, orig_next_q, target_q, done, self.gamma)
        return td_error
        
    def init_environment_state(self):
        super().init_environment_state()
        
    def take_environment_step(self, eps):
        self.rng, step_rng = random.split(self.rng)
        s = self.env_state['last_state']
        a, q_vals = sample_action(self.state, 
                                  self.env_state.get('last_qvals', s[None]), 
                                  step_rng, 
                                  eps=eps)
        a = a.item()
        s_next, r, done, _, _ = self.env.step(a)
        td_error = self.get_prio_td_error(r, q_vals, s_next, done)
        sample_prob = (abs(td_error.item()) + self.prio_eps) ** self.prio_alpha
        self.buffer.add(s, s_next, a, r, done, prob=sample_prob)
        if done:
            self.env_state.pop('last_qvals')
            self.rng, env_rng = random.split(self.rng)
            s_next = self.env.reset(env_rng)
        self.env_state['last_state'] = s_next

In [None]:
doubleq_trainers = train_models(DoubleQTrainer, env, num_seeds=3)

### Dueling DQN

In [None]:
class DuelingDQN(nn.Module):
    num_actions : int
    
    @nn.compact
    def __call__(self, x, return_separate=False):
        x = Backbone()(x)
        v = nn.Dense(1,
                     kernel_init=nn.initializers.zeros)(x)
        a = nn.Dense(self.num_actions,
                     kernel_init=nn.initializers.zeros)(x)
        a = a - a.mean(axis=-1, keepdims=True)
        q = v + a
        
        if not return_separate:
            return q
        else:
            return q, {'v': v, 'a': a}

In [None]:
class DuelingQTrainer(DoubleQTrainer):
    
    def __init__(self, *args,
                 model_name : str = 'DuelingDQN',
                 **kwargs):
        super().__init__(*args, model_name=model_name, **kwargs)
    
    def create_model(self):
        self.model = DuelingDQN(num_actions=self.env.action_space.n)

In [None]:
duelingq_trainers = train_models(DuelingQTrainer, env, num_seeds=3)