In [None]:
import jax
import jax.numpy as jnp
import gymnax as gym
import optax
import matplotlib.pyplot as plt

In [None]:
import jax_ppo
from gym_runner import train

In [None]:
k = jax.random.PRNGKey(101)

## Initialise Environment

In [None]:
# env, env_params = gym.make("MountainCarContinuous-v0")
env, env_params = gym.make("Pendulum-v1")

## Hyper Parameters

In [None]:
# Number of policy updates
N_TRAIN = 2_500
# Number of training samples per poliy update
N_SAMPLES = 2048
# Number of training loops per poliy update
N_EPOCHS = 2
# Mini-batch sized used for actual training
MINI_BATCH_SIZE = 256
# Number of test steps
N_TEST = 2_000

The total number of policy updates is the total number of trajectory samples divided by the size of mini batches

In [None]:
N_STEPS = N_TRAIN * N_SAMPLES * N_EPOCHS // MINI_BATCH_SIZE

In [None]:
params = jax_ppo.default_params._replace(
    gamma=0.95, gae_lambda=0.9, entropy_coeff=0.0001, adam_eps=1e-8, clip_coeff=0.2
)

## Initialise Policy

Initialise a linear step-size schedule for the number of steps calculated above

In [None]:
train_schedule = optax.linear_schedule(2e-3, 2e-5, N_STEPS)

In [None]:
k, agent = jax_ppo.init_agent(
    k, 
    params,
    env.action_space().shape,
    env.observation_space(env_params).shape,
    train_schedule,
    layer_width=16,
)

## Train Loop

In [None]:
_k, trained_agent, losses, ts, rewards = train(
    k, env, env_params, agent,
    N_TRAIN, 
    N_SAMPLES, 
    N_EPOCHS, 
    MINI_BATCH_SIZE, 
    N_TEST, 
    params, 
    greedy_test_policy=True
)

In [None]:
plt.plot(jnp.sum(rewards, axis=1), drawstyle="steps-mid");
plt.xlabel("Training Step")
plt.ylabel("Total Rewards");

## Training Data

In [None]:
f, ax = plt.subplots(2, 2, figsize=(10, 6))

ax[0][0].plot(losses["policy_loss"].reshape(-1), drawstyle="steps-mid")
ax[0][0].set_ylabel("Total Policy Loss")
ax[0][0].set_xlabel("Policy Update")

ax[0][1].plot(losses["value_loss"].reshape(-1), drawstyle="steps-mid")
ax[0][1].set_ylabel("Total Value Loss")
ax[0][1].set_xlabel("Policy Update")

ax[1][0].plot(losses["entropy_loss"].reshape(-1), drawstyle="steps-mid")
ax[1][0].set_ylabel("Total Entropy")
ax[1][0].set_xlabel("Policy Update")

ax[1][1].plot(losses["kl_divergence"].reshape(-1), drawstyle="steps-mid")
ax[1][1].set_ylabel("KL-Divergence")
ax[1][1].set_xlabel("Policy Update");

In [None]:
plt.plot(rewards[-1], drawstyle="steps-mid");
plt.xlabel("Step");
plt.ylabel("Reward");