<a href="https://colab.research.google.com/github/xqz-u/dopamine/blob/master/dopamine/thesis/tests/test_jit_dqv_performance.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
%%capture
!rm -rf /content/dopamine
!git clone https://github.com/xqz-u/dopamine.git

In [4]:
%%capture --no-stderr 
!pip install -r dopamine/requirements.txt

In [5]:
!nvidia-smi

Fri Dec 10 15:00:28 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 495.44       Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla K80           Off  | 00000000:00:04.0 Off |                    0 |
| N/A   40C    P8    26W / 149W |      0MiB / 11441MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [6]:
# get offline dqn data from gdrive
from google.colab import drive

drive.mount('/content/drive')

Mounted at /content/drive


In [7]:
%cd /content/dopamine

/content/dopamine


In [8]:
# general imports
import jax
from jax import numpy as jnp
from jax import random as jrand

from thesis import utils as u
from thesis.jax import networks
from thesis import experiment_data as exp_data
from thesis.jax.agents import dqv_agent as dqv


# constants
cartpole_state_shape = (4, 1)
cartpole_stack_size = 1
cartpole_num_actions = 2


# create dummy dqv agent with offline experience,
# easier to profile the desired functions, no interaction
# with env needed
def offline_exp_data():
  checkpoints_dir = "/content/drive/MyDrive/BSc Thesis/data/dqn_cartpole_sample/checkpoints"
  checkpoints_iters = [356, 357]
  return exp_data.ExperimentData(seed=0, stack_size=cartpole_stack_size, batch_size=128, checkpoint_dir=checkpoints_dir, checkpoint_iterations=checkpoints_iters)

def make_offline_dqv():
  return dqv.JaxDQVAgent(state_shape=cartpole_state_shape, num_actions=cartpole_num_actions, exp_data=offline_exp_data())

In [9]:
# profile action selection
def create_net_and_params():
  rng, k = u.force_devicearray_split(jrand.PRNGKey(42)) 
  state = jrand.uniform(k, (4, 1))
  net = networks.ClassicControlDNNetwork(output_dim=2)
  rng, k = u.force_devicearray_split(rng)
  params = net.init(k, state)
  return rng, net, params, state

def profile_action_selection(selection_fn, iters, rng, net, params, state):
  chosen_actions = []
  for _ in range(iters):
    rng, action = selection_fn(rng, 0.01, 2, net, params, state)
    chosen_actions.append(action)
    rng, k = u.force_devicearray_split(rng)
    state = jrand.uniform(k, (4, 1))
  return chosen_actions

def test_action_selection():
  %timeit profile_action_selection(dqv.egreedy_action_selection, 100, *create_net_and_params())
  %timeit profile_action_selection(dqv.egreedy_action_selection_jit, 100, *create_net_and_params())

In [10]:
# profile td error computation
def profile_td_error(td_error_fn, iters, agent):
  td_errors = [ ]
  for _ in range(iters):
    replay = agent.sample_memory()
    err = td_error_fn(agent.V_network, agent.V_online, replay["next_state"], replay["reward"], replay["terminal"], agent.exp_data.gamma)
    td_errors.append(err)
  return td_errors 

def test_td_error():
  %timeit profile_td_error(dqv.dqv_td_error, 100, make_offline_dqv())
  %timeit profile_td_error(dqv.dqv_td_error_jit, 100, make_offline_dqv())

In [16]:
# profile training routine
def profile_train_module(train_fn, iters, agent):
    v_losses, q_losses = [], []
    for _ in range(iters):
        replay_elements = agent.sample_memory()
        td_error = dqv.dqv_td_error_jit(
            agent.V_network,
            agent.V_target,
            replay_elements["next_state"],
            replay_elements["reward"],
            replay_elements["terminal"],
            agent.exp_data.gamma,
        )
        agent.V_optim_state, agent.V_online, v_loss = train_fn(
            agent.V_network,
            agent.V_online,
            td_error,
            agent.optimizer,
            agent.V_optim_state,
            agent.exp_data.loss_fn,
            replay_elements["state"],
            dqv.mask_v_estimates,
        )
        agent.Q_optim_state, agent.Q_online, q_loss = train_fn(
            agent.Q_network,
            agent.Q_online,
            td_error,
            agent.optimizer,
            agent.Q_optim_state,
            agent.exp_data.loss_fn,
            replay_elements["state"],
            dqv.mask_q_estimates,
            replay_elements["action"],
        )
        v_losses.append(v_loss)
        q_losses.append(q_loss)
    return v_losses, q_losses


def profile_train_module_b4(train_fn, iters, agent):
    v_losses, q_losses = [], []
    for _ in range(iters):
        replay_elements = agent.sample_memory()
        td_error = dqv.dqv_td_error_jit(
            agent.V_network,
            agent.V_target,
            replay_elements["next_state"],
            replay_elements["reward"],
            replay_elements["terminal"],
            agent.exp_data.gamma,
        )
        agent.V_optim_state, agent.V_online, v_loss = train_fn(
            agent.V_network,
            agent.V_online,
            td_error,
            agent.optimizer,
            agent.V_optim_state,
            agent.exp_data.loss_fn,
            replay_elements["state"],
            lambda e, *_, **__: e
        )
        agent.Q_optim_state, agent.Q_online, q_loss = train_fn(
            agent.Q_network,
            agent.Q_online,
            td_error,
            agent.optimizer,
            agent.Q_optim_state,
            agent.exp_data.loss_fn,
            replay_elements["state"],
            lambda e, *args, **__: jax.vmap(lambda x, y: x[y])(e, args[0]),
            replay_elements["action"],
        )
        v_losses.append(v_loss)
        q_losses.append(q_loss)
    return v_losses, q_losses


def test_train_module():
    %timeit profile_train_module_b4(dqv.train_module_jit, 100, make_offline_dqv()) 
    %timeit profile_train_module(dqv.train_module_jit, 100, make_offline_dqv())
    %timeit profile_train_module(dqv.train_module, 100, make_offline_dqv())

In [17]:
#test_action_selection()
# ~32x speedup, egreedy seems jitted correctly

#test_td_error()
# ~4.5x speedup, seems too small...

#test_train_module()
# ~10x speedup between correct optimization and no optimization

1 loop, best of 5: 1min 14s per loop
1 loop, best of 5: 2.33 s per loop


KeyboardInterrupt: ignored

In [None]:
# drive.flush_and_unmount()