In [None]:
import jax
import jax.numpy as jnp
import optax
import matplotlib.pyplot as plt

In [None]:
import gymnax

In [None]:
import jax_ppo
from masked_pendulum import MaskedPendulum

In [None]:
k = jax.random.PRNGKey(101)

## Initialise Environment

This is a customized version of thee pendulum environment that masks the velocity component of the observation.

In [None]:
env = MaskedPendulum()
env_params = env.default_params

## Hyper Parameters

In [None]:
# Number of policy updates
N_TRAIN = 2_500
# Number of training environments
N_TRAIN_ENV = 32
# Number of test environments
N_TEST_ENV = 5
# Number of enviroment steps
N_ENV_STEPS = env_params.max_steps_in_episode
# Number of training loops per poliy update
N_EPOCHS = 2
# Mini-batch sized used for actual training
MINI_BATCH_SIZE = 512
# Length of input sequence
SEQ_LEN = 8
# Number of LSTM hidden state burn in steps
N_BURN_IN = 8

In [None]:
N_STEPS = N_TRAIN * N_TRAIN_ENV * N_ENV_STEPS * N_EPOCHS // MINI_BATCH_SIZE

In [None]:
params = jax_ppo.default_params._replace(
    gamma=0.95, 
    gae_lambda=0.99, 
    entropy_coeff=0.0001, 
    adam_eps=1e-8, 
    clip_coeff=0.1
)

## Initialise Policy

In [None]:
train_schedule = optax.linear_schedule(2e-3, 2e-6, N_STEPS)

In [None]:
_, agent, hidden_states = jax_ppo.init_lstm_agent(
    k, 
    params,
    env.action_space().shape,
    env.observation_space(env_params).shape,
    train_schedule,
    layer_width=16,
    n_layers=2,
    n_recurrent_layers=1,
    seq_len=SEQ_LEN,
)

## Train Loop

In [None]:
_k, trained_agent, losses, ts, rewards = jax_ppo.train_recurrent(
    k, env, env_params, agent, 
    N_TRAIN, 
    N_TRAIN_ENV, 
    N_EPOCHS, 
    MINI_BATCH_SIZE, 
    N_TEST_ENV, 
    SEQ_LEN,
    1,
    N_BURN_IN,
    params,
    greedy_test_policy=True,
)

In [None]:
plt.plot(jnp.mean(jnp.sum(rewards[:, :, :, 0], axis=2), axis=1));
plt.xlabel("Training Step")
plt.ylabel("Avg Total Rewards");

## Training Data

In [None]:
f, ax = plt.subplots(2, 2, figsize=(10, 6))

ax[0][0].plot(losses["policy_loss"].reshape(-1), drawstyle="steps-mid")
ax[0][0].set_ylabel("Total Policy Loss")
ax[0][0].set_xlabel("Train Step")

ax[0][1].plot(losses["value_loss"].reshape(-1), drawstyle="steps-mid")
ax[0][1].set_ylabel("Total Value Loss")
ax[0][1].set_xlabel("Train Step")

ax[1][0].plot(losses["entropy_loss"].reshape(-1), drawstyle="steps-mid")
ax[1][0].set_ylabel("Total Entropy")
ax[1][0].set_xlabel("Train Step")

ax[1][1].plot(losses["kl_divergence"].reshape(-1), drawstyle="steps-mid")
ax[1][1].set_ylabel("KL-Divergence")
ax[1][1].set_xlabel("Train Step");

In [None]:
plt.plot(rewards[-1, :, :, 0].T, drawstyle="steps-mid");
plt.xlabel("Step");
plt.ylabel("Reward");