In [None]:
# brax imports
import functools
import jax
import os

from datetime import datetime
from jax import numpy as jp
import matplotlib.pyplot as plt

import brax
import flax
from brax import envs
from brax.io import model
from brax.io import json
from brax.io import html
from brax.training.agents.ppo import train as ppo
from brax.training.agents.sac import train as sac

In [None]:
# custom imports
from envs.cart_pole_env import CartPoleEnv, CartPoleConfig

In [None]:
# import the environment
env = CartPoleEnv(CartPoleConfig()) 

# select some hyperparameters (for ppo.train)
num_timesteps = 1_000_000  # total number of training timesteps
num_evals = 20             # number of times to evaluate the policy
reward_scaling = 10.0      # scaling factor on the reward
episode_length = 1000     # max length of each episode
normalize_observations = True # whether to normalize observations
action_repeat = 1         # number of env steps per action
unroll_length = 32        # number of steps to unroll the policy
num_minibatches = 32      # number of minibatches for ppo
num_updates_per_batch = 8 # number of updates per batch
discounting = 0.97        # discounting factor for future rewards
learning_rate = 3e-4      # learning rate for the optimizer
entropy_cost = 1e-2       # cost for the policy entropy
num_envs = 2048           # number of parallel environments
batch_size = 1024         # number of samples per gradient update
seed = 0                 # random seed for reproducibility

# make the train function
train_fn = functools.partial(
    ppo.train,
    env=env,
    num_timesteps=num_timesteps,
    num_evals=num_evals,
    reward_scaling=reward_scaling,
    episode_length=episode_length,
    normalize_observations=normalize_observations,
    action_repeat=action_repeat,
    unroll_length=unroll_length,
    num_minibatches=num_minibatches,
    num_updates_per_batch=num_updates_per_batch,
    discounting=discounting,
    learning_rate=learning_rate,
    entropy_cost=entropy_cost,
    num_envs=num_envs,
    batch_size=batch_size,
    seed=seed
)

