In [1]:
import os
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=24'

%load_ext autoreload
%autoreload 2

from brax import envs
from brax.io import html
from brax.training import normalization


import flax
import jax
import jax.numpy as jnp
from brax.envs import create_fn

from IPython.display import HTML, clear_output

import optax

import matplotlib.pyplot as plt
import numpy as np

from controllers import GruController, MlpController, LinearController
from common import do_local_apg, add_guassian_noise, add_guassian_noise_mixed_std, do_one_rollout

from functools import partial

jax.config.update('jax_platform_name', 'cpu')

def visualize(sys, qps, height=480):
  """Renders a 3D visualization of the environment."""
  return HTML(html.render(sys, qps, height=height))

len(jax.devices())

24

In [2]:
episode_length = 500
action_repeat = 1
apg_epochs = 75
ars_epochs = 50
batch_size = 1
truncation_length = None
learning_rate = 3e-4
clipping = 1e9

initial_std = 0.03
step_size = 0.02
num_elite = 8
eps = 1e-6

normalize_observations=True

env_name = "inverted_pendulum_swingup"  # @param ['ant', 'humanoid', 'fetch', 'grasp', 'halfcheetah', 'walker2d, 'ur5e', 'reacher', bball_1dof]
env_fn = create_fn(env_name = env_name, action_repeat=action_repeat, batch_size=None, auto_reset=False)
env = env_fn()

#policy = GruController(env.observation_size, env.action_size, 16)
policy = LinearController(env.observation_size,env.action_size)
#policy = GruController(env.observation_size, env.action_size, 32)



In [3]:
num_directions = jax.local_device_count()
key = jax.random.PRNGKey(0)
reset_keys = jax.random.split(key, num=num_directions)
noise_keys = jax.random.split(reset_keys[0], num=num_directions)
_, model_key = jax.random.split(noise_keys[0])


normalizer_params, obs_normalizer_update_fn, obs_normalizer_apply_fn = normalization.create_observation_normalizer(
          env.observation_size, normalize_observations, num_leading_batch_dims=1)

add_noise_pmap = jax.pmap(add_guassian_noise, in_axes=(None,None,0))
do_apg_pmap = jax.pmap(do_local_apg, in_axes = (None,None,None,None,0,0,None,None,None,None,None,None), static_broadcasted_argnums=(0,1,2,6,7,8,9,10,11,12))
do_rollout_pmap = jax.pmap(do_one_rollout, in_axes = (None,None,None,0,0,None,None,None), static_broadcasted_argnums=(0,1,5,6,7))

init_states = jax.pmap(env.reset)(reset_keys)
x0 = init_states.obs
h0 = jnp.zeros(env.observation_size)

policy_params = policy.init(model_key, h0, x0)

new_params_flat = []
policy_params_flat, policy_params_def = jax.tree_flatten(policy_params)

for p in policy_params_flat:
    new_params_flat.append(jnp.zeros_like(p))

policy_params = jax.tree_unflatten(policy_params_def, new_params_flat)   

best_reward = -float('inf')
meta_rewards_list = []

In [4]:
#%%time
import time

policy_params_flat, policy_params_def = jax.tree_flatten(policy_params)
noise_std = initial_std
best_reward_list = []


for i in range(50):
    noise_keys = jax.random.split(noise_keys[0], num=num_directions)
    train_keys = jax.random.split(noise_keys[0], num=num_directions)
    
    base_rewards, _, _, _ = do_one_rollout(env_fn, policy.apply, normalizer_params, policy_params, train_keys[0], episode_length, action_repeat, normalize_observations)
    base_reward_sum = sum(base_rewards)
    print("before apg", base_reward_sum)
    policy_params_with_noise, rewards_lists = do_apg_pmap(apg_epochs, env_fn, policy.apply, normalizer_params, policy_params_with_noise, train_keys, learning_rate, episode_length, action_repeat, normalize_observations, batch_size, clipping, truncation_length)
    reward_sums = [jnp.mean(rew[-5:]) for rew in rewards_lists]
    top_idx = sorted(range(len(reward_sums)), key=lambda k: reward_sums[k], reverse=True)

    print("after apg, before ars: ", reward_sums[top_idx[0]])

    
    for j in range(ars_epochs):


        params_with_noise_flat, params_with_noise_def = jax.tree_flatten(policy_params_with_noise)
        noise_flat, noise_def = jax.tree_flatten(noise)
        
        params_elite_flat = [p[top_idx[:num_elite], :] for p in params_with_noise_flat]
        noise_elite_flat = [p[top_idx[:num_elite], :] for p in noise_flat]
        
        reward_sums_elite = jnp.array([reward_sums[t] for t in top_idx][:num_elite])
        
        reward_std = jnp.std(reward_sums_elite)
            
        new_params_flat = []
        for k, (old_params, elite_noise) in enumerate(zip(policy_params_flat, noise_elite_flat)):
            reward_weight = reward_sums_elite - base_reward_sum
            reward_weight = reward_weight.reshape((num_elite, *[1 for _ in range(len(elite_noise.shape)-1)]))

            new_params_flat.append(old_params + step_size/(num_elite*reward_std) * jnp.sum(reward_weight*elite_noise, axis=0))
        
        policy_params = jax.tree_unflatten(policy_params_def, new_params_flat)   
        policy_params_flat, policy_params_def = jax.tree_flatten(policy_params)
        
        base_rewards, _, _, _ = do_one_rollout(env_fn, policy.apply, normalizer_params, policy_params, train_keys[0], episode_length, action_repeat, normalize_observations)
        base_reward_sum = sum(base_rewards)
        
        policy_params_with_noise, policy_params_with_anti_noise, noise = add_noise_pmap(policy_params, noise_std, noise_keys)
        rewards, obs, acts, states_before = do_rollout_pmap(env_fn, policy.apply, normalizer_params, policy_params_with_noise, train_keys, episode_length, action_repeat, normalize_observations)
        reward_sums = jnp.sum(rewards, axis=1)
        top_idx = sorted(range(len(reward_sums)), key=lambda k: reward_sums[k], reverse=True)

        
    print("after ars", reward_sums[top_idx[0]])

    

before apg -4929.81


NameError: name 'policy_params_with_noise' is not defined

In [None]:
#%%time
import time

policy_params_flat, policy_params_def = jax.tree_flatten(policy_params)
noise_std = initial_std
best_reward_list = []


for i in range(500):
    
    noise_keys = jax.random.split(noise_keys[0], num=num_directions)
    train_keys = jax.random.split(noise_keys[0], num=num_directions)

    base_rewards, _, _, _ = do_one_rollout(env_fn, policy.apply, normalizer_params, policy_params, train_keys[0], episode_length, action_repeat, normalize_observations)
    base_reward_sum = sum(base_rewards)
    policy_params_with_noise, policy_params_with_anti_noise, noise = add_noise_pmap(policy_params, noise_std, noise_keys)
    
    rewards_before, obs, acts, states_before = do_rollout_pmap(env_fn, policy.apply, normalizer_params, policy_params_with_noise, train_keys, episode_length, action_repeat, normalize_observations)
    normalizer_params = obs_normalizer_update_fn(normalizer_params, obs[0,:])

    #policy_params_with_noise, rewards_lists = do_apg_pmap(apg_epochs, env_fn, policy.apply, normalizer_params, policy_params_with_noise, train_keys, learning_rate, episode_length, action_repeat, normalize_observations, batch_size, clipping, truncation_length)
    reward_sums = [rew[-1] for rew in rewards_lists]
    top_idx = sorted(range(len(reward_sums)), key=lambda k: reward_sums[k], reverse=True)
    
    best_reward_list.append(jnp.mean(rewards_lists[top_idx[0]][-5:]))    
    
    if i % 5 == 0:
        print(f" Iteration {i} --------------------------------")
        #print(f" Time: {time.time() - start}")

        for j in range(len(top_idx)):
            done_idx = jnp.where(states_before.done[top_idx[j], :], size=1)[0].item()
            if done_idx == 0:
                done_idx = rewards_before.shape[-1]
            rewards_sum_before = jnp.sum(rewards_before[top_idx[j],:done_idx])

            print(f"{j} : reward: {rewards_sum_before} -> {reward_sums[top_idx[j]]}")
            if j == num_elite-1:
                print("---")
        print("-------------------------------------------------")
        print()
    
    for j in range(ars_epochs):
        top_idx = sorted(range(len(reward_sums)), key=lambda k: reward_sums[k], reverse=True)

        params_with_noise_flat, params_with_noise_def = jax.tree_flatten(policy_params_with_noise)
        noise_flat, noise_def = jax.tree_flatten(noise)
        
        params_elite_flat = [p[top_idx[:num_elite], :] for p in params_with_noise_flat]
        noise_elite_flat = [p[top_idx[:num_elite], :] for p in noise_flat]
        reward_sums_elite = jnp.array([reward_sums[t] for t in top_idx][:num_elite])
        reward_std = jnp.std(reward_sums_elite)
            
        new_params_flat = []
        for k, (old_params, elite_noise) in enumerate(zip(policy_params_flat, noise_elite_flat)):
            reward_weight = reward_sums_elite - base_reward_sum
            reward_weight = reward_weight.reshape((num_elite, *[1 for _ in range(len(elite_noise.shape)-1)]))

            new_params_flat.append(old_params + step_size/(n_elite*reward_std) * jnp.sum(reward_weight*elite_noise, axis=0))
        
        policy_params = jax.tree_unflatten(policy_params_def, new_params_flat)   
        policy_params_flat, policy_params_def = jax.tree_flatten(policy_params)

        policy_params_with_noise, _,  noise = add_noise_pmap(policy_params, noise_std, noise_keys)
        rewards, obs, acts, states_before = do_rollout_pmap(env_fn, policy.apply, normalizer_params, policy_params_with_noise, train_keys, episode_length, action_repeat, normalize_observations)
        reward_sums = jnp.sum(rewards, axis=1)
    
    

In [None]:
plt.plot(best_reward_list)

In [None]:
import brax.jumpy as jp
@jax.jit
def do_rnn_rollout(policy_params, normalizer_params, key):
    init_state = env.reset(key)
    h0 = jp.zeros_like(init_state.obs)

    def do_one_rnn_step(carry, step_idx):
        state, h, policy_params, normalizer_params  = carry

        normed_obs = obs_normalizer_apply_fn(normalizer_params, state.obs)
        h1 , actions = policy.apply(policy_params, h, normed_obs)
        #actions = jp.ones_like(actions)*0.0
        nstate = env.step(state, actions)    
        #h1 = jax.lax.cond(nstate.done, lambda x: jnp.zeros_like(h1), lambda x: h1, None)
        return (jax.lax.stop_gradient(nstate), h1, policy_params, normalizer_params), (nstate.reward,state.obs, actions, nstate)


    _, (rewards, obs, acts, states) = jp.scan(
        do_one_rnn_step, (init_state, h0, policy_params, normalizer_params),
        (jnp.array(range(episode_length // action_repeat))),
        length=episode_length // action_repeat)

    return rewards, obs, acts, states

In [None]:
key, reset_key = jax.random.split(key)
(rewards, obs, acts, states) = do_rnn_rollout(policy_params, normalizer_params, reset_key)

done_idx = jnp.where(states.done, size=1)[0].item()
if done_idx == 0:
    done_idx = states.done.shape[0]
rewards_sum = jnp.sum(rewards[:done_idx])

plt.plot(obs[:,0]);
plt.figure()

plt.plot(acts);
print(rewards_sum)
print(states.done)

In [None]:
qp_flat, qp_def = jax.tree_flatten(states.qp)

qp_list = []

for i in range(qp_flat[0].shape[0]):
    qpc=[]
    for thing in qp_flat:
        qpc.append(thing[i,:])
    qp_list.append(jax.tree_unflatten(qp_def, qpc))
    

visualize(env.sys, qp_list, height=800)

In [None]:
plt.plot(best_reward_list)