In [1]:
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
import numpy as np
import optax

import sys
sys.path.append('..')

from experior.utils import moving_average
from experior.rl_agents import make_boot_dqn_train
from experior.sampling import langevin_sampling
from experior.envs import DeepSea
from experior.experts import generate_optimal_trajectories
from experior.prior_trainers import make_max_ent_log_pdf, make_max_ent_prior_train


import logging

jax.config.update('jax_enable_x64', True)

logging.basicConfig(
    format='%(asctime)s %(levelname)-8s %(message)s',
    level=logging.INFO,
    datefmt='%Y-%m-%d %H:%M:%S')


%load_ext autoreload
%autoreload 2


### Environment


In [2]:
EPISODES = 2000
NUM_ENVS = 30
HORIZON = 30
ENV_SEED = 7192
METHOD_SEED = 42
PRIOR_SEED = 4748


def get_goal_col_dist(i):
    if i == 0:
        return lambda key, size: size - 1
    else:
        if i == 1:
            min_value = 0.75
        elif i == 2:
            min_value = 0.5
        elif i == 3:
            min_value = 0.0
        else:
            raise ValueError("Invalid goal column distribution")
        return lambda key, size: jax.random.randint(
            key, shape=(), minval=int(min_value * size), maxval=size
        )


goal_col_dists = [jax.tree_util.Partial(get_goal_col_dist(i)) for i in range(4)]

### Expert Trajectories


In [3]:
envs = [
    DeepSea(size=HORIZON, goal_column_dist=goal_col_dist)
    for goal_col_dist in goal_col_dists
]
EXPERT_TRAJ_SEED = 2874
EXPERT_N_TRAJ = 1000
expert_trajectories_list = [
    generate_optimal_trajectories(
        jax.random.PRNGKey(EXPERT_TRAJ_SEED), env, EXPERT_N_TRAJ, HORIZON
    )
    for env in envs
]

2024-04-09 18:31:31 INFO     Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: Interpreter CUDA
2024-04-09 18:31:31 INFO     Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'


### Naive Boot-DQN


In [4]:
import flax.linen as nn
from typing import Sequence


class MLP(nn.Module):
    features: Sequence[int]
    activation: nn.activation = nn.relu

    @nn.compact
    def __call__(self, x):
        for feat in self.features[:-1]:
            x = nn.Dense(feat)(x)
            x = self.activation(x)
        return nn.Dense(self.features[-1])(x)


class QNetwork(nn.Module):
    n_actions: int
    n_hidden: int
    n_features: int

    @nn.compact
    def __call__(self, inputs: jnp.ndarray):
        features = [self.n_hidden, self.n_features, self.n_actions]
        return MLP(features)(inputs)

In [5]:
# default parameters

DQN_BUFFER_SIZE = 10000
Q_NETWORK_CONFIG = {"n_hidden": 50, "n_features": 50}
TARGET_FREQ = 4
BATCH_SIZE = 128
TRAIN_FREQ = 1
NUM_ENSEMBLES = 25
EPSILON_FN = lambda _: 0.0
LEARNING_STARTS = 128
OPTIMIZER = optax.adam(learning_rate=1e-3)
MASK_PROB = 1.0
NOISE_SCALE = 0.01

In [6]:
naive_boot_dqn_outputs = []

for env in envs:
    logging.info(f"Starting Naive Boot-DQN for {env}")
    main_q_network = QNetwork(env.num_actions, **Q_NETWORK_CONFIG)
    q_network = main_q_network

    def q_init_fn(key, inputs):
        return q_network.init(key, inputs)

    boot_dqn_train = make_boot_dqn_train(
        env,
        q_network,
        DQN_BUFFER_SIZE,
        BATCH_SIZE,
        EPISODES * HORIZON,
        LEARNING_STARTS,
        NUM_ENSEMBLES,
        OPTIMIZER,
        jax.tree_util.Partial(q_init_fn),
        EPSILON_FN,
    )

    boot_dqn_train = jax.jit(boot_dqn_train)
    env_params = jax.vmap(env.init_env, (0, None))(
        jax.random.split(jax.random.PRNGKey(ENV_SEED), NUM_ENVS), env.default_params
    )

    state, output = jax.vmap(boot_dqn_train, (0, 0, None, None, None, None))(
        jax.random.split(jax.random.PRNGKey(METHOD_SEED), NUM_ENVS),
        env_params,
        MASK_PROB,
        NOISE_SCALE,
        TRAIN_FREQ,
        TARGET_FREQ,
    )

    naive_boot_dqn_outputs.append(
        output["info"]["returned_episode_returns"][:, ::HORIZON]
    )

2024-04-09 18:31:36 INFO     Starting Naive Boot-DQN for <experior.envs.deep_sea.DeepSea object at 0x7f21c6f37d90>


2024-04-09 18:45:21 INFO     Starting Naive Boot-DQN for <experior.envs.deep_sea.DeepSea object at 0x7f21c6f37d00>
2024-04-09 18:59:04 INFO     Starting Naive Boot-DQN for <experior.envs.deep_sea.DeepSea object at 0x7f21c6f34220>
2024-04-09 19:12:47 INFO     Starting Naive Boot-DQN for <experior.envs.deep_sea.DeepSea object at 0x7f21c6f379d0>


In [7]:
jnp.savez("../output/deep_sea/naive_boot_dqn.npz", jnp.array(naive_boot_dqn_outputs))

### ExPerior


In [8]:
# hyperparameters for prior training
PRIOR_EPOCHS = 5000
PRIOR_BATCH_SIZE = 1500
EXPERT_BETA = 5.0
PRIOR_LR = 1e-2
PRIOR_LAMBDA = 10.0

In [9]:
experior_outputs = []
j = 0
for env, expert_trajectories in zip(envs, expert_trajectories_list):
    logging.info(f"Starting ExPerior for {env}")

    # the only hyperparameter we tuned
    if j == 2:
        PRIOR_REG = 1
    elif j == 3:
        PRIOR_REG = 10
    else:
        PRIOR_REG = 0.1

    prior_q_network = QNetwork(env.num_actions, **Q_NETWORK_CONFIG)
    max_ent_prior_train = make_max_ent_prior_train(
        prior_q_network, PRIOR_EPOCHS, PRIOR_BATCH_SIZE
    )
    prior_rng = jax.random.PRNGKey(PRIOR_SEED)
    prior_rng, rng_ = jax.random.split(prior_rng)
    prior_state, prior_loss = jax.jit(max_ent_prior_train)(
        rng_,
        expert_trajectories,
        EXPERT_BETA,
        PRIOR_LR,
        PRIOR_LAMBDA,
        PRIOR_REG,
    )

    prior_log_pdf = make_max_ent_log_pdf(
        prior_state, prior_q_network, expert_trajectories, EXPERT_BETA, PRIOR_REG
    )

    prior_rng, rng_ = jax.random.split(prior_rng)
    init_params = prior_q_network.init(rng_, expert_trajectories.obs[0])
    grad_opt = lambda g: jax.tree_util.tree_map(lambda x: jnp.clip(x, -50, 50), g)
    prior_rng, rng_ = jax.random.split(prior_rng)
    _, samples = langevin_sampling(
        rng_, init_params, prior_log_pdf, 1e-2, 10000, grad_opt
    )

    num_samples = 1500
    samples = jax.tree_util.tree_map(lambda p: p[-(2 * num_samples) :: 2], samples)

    def q_init_fn(key, _):
        ind = jax.random.randint(key, shape=(), minval=0, maxval=num_samples)
        return jax.tree_util.tree_map(lambda x: x[ind], samples)

    q_network = QNetwork(env.num_actions, **Q_NETWORK_CONFIG)

    boot_max_ent_train = make_boot_dqn_train(
        env,
        q_network,
        DQN_BUFFER_SIZE,
        BATCH_SIZE,
        EPISODES * HORIZON,
        LEARNING_STARTS,
        NUM_ENSEMBLES,
        OPTIMIZER,
        jax.tree_util.Partial(q_init_fn),
        EPSILON_FN,
    )

    boot_max_ent_train = jax.jit(boot_max_ent_train)
    env_params = jax.vmap(env.init_env, (0, None))(
        jax.random.split(jax.random.PRNGKey(ENV_SEED), NUM_ENVS), env.default_params
    )
    state, max_ent_output = jax.vmap(
        boot_max_ent_train, (0, 0, None, None, None, None)
    )(
        jax.random.split(jax.random.PRNGKey(METHOD_SEED), NUM_ENVS),
        env_params,
        MASK_PROB,
        NOISE_SCALE,
        TRAIN_FREQ,
        TARGET_FREQ,
    )
    experior_outputs.append(
        max_ent_output["info"]["returned_episode_returns"][:, ::HORIZON]
    )
    j += 1

2024-04-09 19:26:30 INFO     Starting ExPerior for <experior.envs.deep_sea.DeepSea object at 0x7f21c6f37d90>


2024-04-09 19:41:00 INFO     Starting ExPerior for <experior.envs.deep_sea.DeepSea object at 0x7f21c6f37d00>
2024-04-09 19:55:18 INFO     Starting ExPerior for <experior.envs.deep_sea.DeepSea object at 0x7f21c6f34220>
2024-04-09 20:09:36 INFO     Starting ExPerior for <experior.envs.deep_sea.DeepSea object at 0x7f21c6f379d0>


In [10]:
jnp.savez("../output/deep_sea/experior.npz", jnp.array(experior_outputs))

### Plot the results


In [27]:
experior_outputs = jnp.load("../output/deep_sea/experior.npz")["arr_0"]
naive_boot_dqn_outputs = jnp.load("../output/deep_sea/naive_boot_dqn.npz")["arr_0"]
explore_outputs = jnp.load("../output/deep_sea/explore.npz")["arr_0"]

In [28]:
import matplotlib as mpl

from experior.utils import (
    latexify,
    FIG_WIDTH,
    GOLDEN_RATIO,
    FONT_SIZE,
    LEGEND_SIZE,
    LIGHT_COLORS,
)

mpl.use("pdf")

In [30]:
max_rewards = 0.99
min_rewards = 0.0

latexify(
    FIG_WIDTH,
    FIG_WIDTH * GOLDEN_RATIO * 0.8,
    font_size=FONT_SIZE,
    legend_size=LEGEND_SIZE,
    labelsize=LEGEND_SIZE,
)
fig, axes = plt.subplots(2, 2, figsize=(FIG_WIDTH, FIG_WIDTH * GOLDEN_RATIO * 0.8))
window = 5
start = 1
bases = [
    [
        moving_average(
            naive_boot_dqn_outputs[i].mean(0)[start:],
            window,
        ),
        moving_average(
            experior_outputs[i].mean(0)[start:],
            window,
        ),
        moving_average(explore_outputs[i].mean(0)[start:], window),
    ]
    for i in range(4)
]
stds = [
    [
        moving_average(
            naive_boot_dqn_outputs[i].std(0)[start:],
            window,
        ),
        moving_average(
            experior_outputs[i].std(0)[start:],
            window,
        ),
        moving_average(explore_outputs[i].std(0)[start:], window),
    ]
    for i in range(4)
]
names = [r"Naïve Boot-DQN", r"ExPerior ({Ours})", r"ExPLORe"]
color_list = ["red", "blue", "green"]
fill_colors = ["red", "blue", "green"]
linestyles = ["-"] * 3

for i in range(len(bases)):
    x, y = i // 2, i % 2
    axis = axes[x, y]
    for j, base in enumerate(bases[i]):
        ranges = [k + 1 for k in range(len(base))]
        axis.plot(
            ranges,
            base,
            label=names[j],
            linewidth=1.5,
            linestyle=linestyles[j],
            c=LIGHT_COLORS[color_list[j]],
        )
        axis.fill_between(
            ranges,
            np.clip(base - stds[i][j], min_rewards, max_rewards),
            np.clip(base + stds[i][j], min_rewards, max_rewards),
            color=LIGHT_COLORS[fill_colors[j]],
            alpha=0.3,
            linewidth=1,
        )
    if x == 1:
        axis.set_xlabel(r"Episodes, $T$")
    if y == 0:
        axis.set_ylabel(r"Average Reward")

hs = axes[0, 0].get_legend_handles_labels()[0]

fig.legend(hs, names, loc="upper center", ncol=3)
plt.subplots_adjust(left=0.1, right=0.97)

fig.savefig("../output/deep_sea/deep_sea_results.pdf")
plt.close()