This notebook renders some of the policies found by BGPBT.
- ```env_seed{seed}_Agent{agent}```: the best agent found by BGPBT.
- ```hopper_seed3_pendulum``` and ```humanoid_seed0_stiff```: the "inverted pendulum" mode found by BGPBT on Hopper and "stiff humanoid without knee joint movement" mode on Humanoid.

In [1]:
import os
os.chdir('../..')

import functools
from brax import envs
from brax import jumpy as jp
from custom_brax_train import ppo_torch as ppo
import torch

device = "cpu"

In [2]:
from IPython.display import HTML, clear_output
from brax.io import html
from tqdm import tqdm

In [3]:
all_envs = ['ant', 'halfcheetah', 'humanoid', 'hopper', 'reacher', 'fetch', 'ur5e']

In [4]:
# Make environment
ckpt_name = 'reacher_seed3_Agent5'   # enter the full name of the checkpoint, less the extension
if ckpt_name in all_envs:
    env_name = ckpt_name
else:
    env_name = ckpt_name.split('_')[0]

env_fn = envs.create_fn(env_name=env_name)
env = env_fn()

if 'seed' in ckpt_name:
    seed = int(ckpt_name.split('seed')[1].split('_')[0])
else:
    seed = 0

In [5]:
arch_dict = {
    # numbers of neurons and number of hidden layers in order of (V-net, policy-net)
    'ant_seed300_Agent4': ['256,256,256,256','256,256'],
    'halfcheetah_seed2_Agent6':['256,256,256,256,256','256,256,256,256'],
    'humanoid_seed3_Agent2':['256,256,256,256,256', '32,32,32,32',],
    'hopper_seed200_Agent1':['256,256,256,256,256', '32,32,32,32'],
    'fetch_seed2_Agent6':['256,256', '128,128'],
    'reacher_seed3_Agent5': ['128,128,128,128','32,32,32'],
    'ur5e_seed1_Agent6': ['256,256,256,256,256', '128,128,128'],
    'ur5e_seed300_default': ['256,256,256,256,256', '32,32,32,32'],
    'hopper_seed3_pendulum': ['256,256,256,256,256', '32,32,32,32'],
    'humanoid_seed0_stiff': ['128,128,128,128','64,64,64'],
}

In [6]:
sd = torch.load(os.path.join('./data/ckpts', ckpt_name+'.pt'), map_location=torch.device('cpu'))

In [34]:
agent = ppo.Agent(
    obs_dim = env.observation_size,
    act_dim = env.action_size,
    policy_hidden_layer_sizes = arch_dict[ckpt_name][1],
    policy_activation = torch.nn.SiLU(),
    v_hidden_layer_sizes = arch_dict[ckpt_name][0],
    value_activation = torch.nn.SiLU(),
    entropy_cost = 1e-2,
    discounting = 0.97,
    reward_scaling = 10,
    lambda_ = 0.95,
    ppo_epsilon = 0.3,
    unroll_length = 5,
    batch_size = 1024,
    num_minibatches = 32,
    num_update_epochs = 4,
    device = device,
)

Policy Network: [32, 32, 32]
Value Network: [128, 128, 128, 128]


In [35]:
try:
    agent.load_state_dict(sd['agent'])
    agent.num_steps = sd['num_steps']
    agent.running_mean = sd['running_mean']
    agent.running_variance = sd['running_variance']

except:
    ns, rm, rv = sd['agent']['num_steps'], sd['agent']['running_mean'], sd['agent']['running_variance']
    del sd['agent']['num_steps'], sd['agent']['running_mean'], sd['agent']['running_variance']
    agent.load_state_dict(sd['agent'])
    agent.num_steps = ns
    agent.running_mean = rm
    agent.running_variacne = rv


In [36]:
# Load policy
rollout = []
state = env.reset(rng=jp.random_prngkey(seed=seed))
print(f'seed={seed}')
for idx in tqdm(range(1000)):
    rollout.append(state)
    _, act = agent.get_logits_action(torch.from_numpy(state.obs).to(device).float())
    act = agent.dist_postprocess(act)
    state = env.step(state, act.cpu().data.numpy())

seed=3


100%|██████████| 1000/1000 [00:09<00:00, 111.10it/s]


In [37]:
print(sum([s.reward for s in rollout]))

-13.756391609535529


In [38]:
HTML(html.render(env.sys, [s.qp for s in rollout]))