In [None]:
# brax imports
import functools
import jax
import os

from datetime import datetime
from jax import numpy as jp
import matplotlib.pyplot as plt

from IPython.display import HTML, clear_output

import brax
import flax
from brax import envs
from brax.io import model
from brax.io import json
from brax.io import html
from brax.training.agents.ppo import train as ppo
from brax.training.agents.sac import train as sac

# custom imports
from envs.cart_pole_env import CartPoleEnv, CartPoleConfig

In [None]:
# print the device being used (gpu or cpu)
print(jax.devices())

In [None]:
# import the environment
env = CartPoleEnv(CartPoleConfig()) 

# select some hyperparameters (for ppo.train)
num_timesteps = 10_00_000  # total number of training timesteps
num_evals = 20             # number of times to evaluate the policy
reward_scaling = 10.0      # scaling factor on the reward
episode_length = 750     # max length of each episode
normalize_observations = True # whether to normalize observations
action_repeat = 1         # number of env steps per action
unroll_length = 8        # number of steps to unroll the policy
num_minibatches = 32      # number of minibatches for ppo
num_updates_per_batch = 8 # number of updates per batch
discounting = 0.97        # discounting factor for future rewards
learning_rate = 3e-4      # learning rate for the optimizer
entropy_cost = 1e-2       # cost for the policy entropy
num_envs = 2048           # number of parallel environments
batch_size = 1024         # number of samples per gradient update
seed = 1                 # random seed for reproducibility

# make the train function
train_fn = functools.partial(
    ppo.train,
    environment=env,
    num_timesteps=num_timesteps,
    num_evals=num_evals,
    reward_scaling=reward_scaling,
    episode_length=episode_length,
    normalize_observations=normalize_observations,
    action_repeat=action_repeat,
    unroll_length=unroll_length,
    num_minibatches=num_minibatches,
    num_updates_per_batch=num_updates_per_batch,
    discounting=discounting,
    learning_rate=learning_rate,
    entropy_cost=entropy_cost,
    num_envs=num_envs,
    batch_size=batch_size,
    seed=seed
)


In [None]:
# plot settings
max_y = 8000
min_y = -100

# containers for the data
xdata, ydata = [], []
times = [datetime.now()]

def progress(num_steps, metrics):
  times.append(datetime.now())
  xdata.append(num_steps)
  
  # select a reward key that exists
  if 'reward_total' in metrics:
      ydata.append(metrics['reward_total'])
      print("found rewrad_total")
  elif 'eval/episode_reward' in metrics:  # fallback key for eval metrics
      ydata.append(metrics['eval/episode_reward'])
      print("found eval/episode_reward")
  else:
      # default if no reward info yet
      ydata.append(0.0)
      print("no reward found")

  clear_output(wait=True)
  plt.xlim([0, train_fn.keywords['num_timesteps']])
  plt.ylim([min_y, max_y])
  plt.xlabel('# environment steps')
  plt.ylabel('reward per episode')
  plt.plot(xdata, ydata)
  plt.show()

# train
make_inference_fn, params, metrics = train_fn(environment=env, progress_fn=progress)

# print timing info
print(f'time to jit: {times[1] - times[0]}')
print(f'time to train: {times[-1] - times[1]}')

In [None]:
# show the metrics
for key, val in metrics.items():
    print(f"{key}: {val}")