In [None]:
import jax
import jax.numpy as jnp
import gymnasium as gym
import optax
import numpy as np
import flax
import matplotlib.pyplot as plt

In [None]:
import jax_ppo
from gym_runner import train, test

In [None]:
k = jax.random.PRNGKey(101)

## Initialise Environment

In [None]:
env = gym.make('Pendulum-v1', g=9.81)

## Hyper Parameters

In [None]:
# Number of policy updates
N_TRAIN = 300
# Number of training samples per poliy update
N_SAMPLES = 512
# Number of training loops per poliy update
N_EPOCHS = 4
# Mini-batch sized used for actual training
MINI_BATCH_SIZE = 32

The total number of policy updates is the total number of trajectory samples divided by the size of mini batches

In [None]:
N_STEPS = N_TRAIN * N_SAMPLES * N_EPOCHS // MINI_BATCH_SIZE

## Initialise Policy

Initialise a linear step-size schedule for the number of steps calculated above

In [None]:
train_schedule = optax.linear_schedule(3e-2, 3e-4, N_STEPS)

In [None]:
k, agent = jax_ppo.init_agent(
    k, 
    jax_ppo.default_params,
    env.action_space.shape,
    env.observation_space.shape,
    train_schedule,
    layer_width=16,
    activation=flax.linen.tanh,
)

## Train Loop

In [None]:
k, agent, losses = train(
    k, env, agent, N_TRAIN, N_SAMPLES, N_EPOCHS, MINI_BATCH_SIZE,
)

In [None]:
env.close()

## Training Data

In [None]:
f, ax = plt.subplots(2, 2, figsize=(10, 6))

ax[0][0].plot(losses["policy_loss"].reshape(N_TRAIN, -1).sum(axis=1), drawstyle="steps-mid")
ax[0][0].set_ylabel("Total Policy Loss")
ax[0][0].set_xlabel("Train Step")

ax[0][1].plot(losses["value_loss"].reshape(N_TRAIN, -1).sum(axis=1), drawstyle="steps-mid")
ax[0][1].set_ylabel("Total Value Loss")
ax[0][1].set_xlabel("Train Step")

ax[1][0].plot(losses["entropy"].reshape(N_TRAIN, -1).sum(axis=1), drawstyle="steps-mid")
ax[1][0].set_ylabel("Total Entropy")
ax[1][0].set_xlabel("Train Step")

ax[1][1].plot(losses["learning_rate"], drawstyle="steps-mid")
ax[1][1].set_ylabel("Learning Rate")
ax[1][1].set_xlabel("Policy Update");

## Test Optimal Policy

In [None]:
env = gym.make('Pendulum-v1', g=9.81, render_mode="human"
)
test_rewards = test(env, agent, 1000)
env.close()

In [None]:
f, ax = plt.subplots(figsize=(7, 3))
ax.plot(test_rewards, drawstyle="steps-mid");
ax.set_ylabel("Reward")
ax.set_xlabel("Step");