In [2]:
import os
%load_ext autoreload
%autoreload 2

from brax import envs
from brax.io import html
from brax.training import normalization


import flax
import jax
import jax.numpy as jnp
from brax.envs import create_fn

from IPython.display import HTML, clear_output

import optax

import matplotlib.pyplot as plt
import numpy as np

from controllers import GruController, MlpController
from common import do_local_apg, add_guassian_noise, add_uniform_noise, add_uniform_and_pareto_noise, add_sym_pareto_noise, do_one_rollout

from functools import partial

jax.config.update('jax_platform_name', 'gpu')

def visualize(sys, qps, height=480):
  """Renders a 3D visualization of the environment."""
  return HTML(html.render(sys, qps, height=height))

len(jax.devices())

1

In [3]:
episode_length = 500
action_repeat = 1
batch_size = jax.local_device_count()
#noise_std = 0.2

noise_scale = 5.0
noise_beta = 1.8


apg_epochs = 50
batch_size = 16
truncation_length = 50
learning_rate = 1e-4
clipping = 1e9

normalize_observations=True

env_name = "reacher"  # @param ['ant', 'humanoid', 'fetch', 'grasp', 'halfcheetah', 'walker2d, 'ur5e', 'reacher', bball_1dof]
env_fn = create_fn(env_name = env_name, action_repeat=action_repeat, batch_size=batch_size, auto_reset=False)
env = env_fn()

In [7]:
key = jax.random.PRNGKey(0)
key, reset_key = jax.random.split(key)
key, model_key = jax.random.split(key)
key, noise_key = jax.random.split(key)

In [8]:
policy = GruController(env.observation_size,env.action_size,64)
normalizer_params, obs_normalizer_update_fn, obs_normalizer_apply_fn = normalization.create_observation_normalizer(
          env.observation_size, normalize_observations, num_leading_batch_dims=1)

init_states = env.reset(reset_key)
x0 = init_states.obs
h0 = jnp.zeros(env.observation_size)

policy_params = policy.init(model_key, h0, x0)

best_reward = -float('inf')
meta_rewards_list = []

In [12]:
%%time
for i in range(1):
    key, noise_key, train_key = jax.random.split(key, num=3)

    #policy_params_with_noise, noise = add_noise_pmap(policy_params, noise_std, noise_keys)
    policy_params_with_noise, noise1, noise2 = add_uniform_and_pareto_noise(policy_params, noise_beta, noise_scale, noise_key)
    
    rewards_before, obs, acts, states_before = do_one_rollout(env_fn, policy.apply, normalizer_params, policy_params_with_noise, train_key, episode_length, action_repeat, normalize_observations)
    policy_params_trained, rewards_lists = do_local_apg(apg_epochs, env_fn, policy.apply, normalizer_params, policy_params_with_noise, train_key, learning_rate, episode_length, action_repeat, normalize_observations, batch_size, clipping, truncation_length)
    rewards_after, obs, acts, states_after = do_one_rollout(env_fn, policy.apply, normalizer_params, policy_params_trained, train_key, episode_length, action_repeat, normalize_observations)
            
    print(jnp.any(policy_params_trained['params']['Dense_1']['kernel'] - policy_params_with_noise['params']['Dense_1']['kernel']))
    
    top_idx = sorted(range(len(rewards_lists)), key=lambda k: jnp.mean(rewards_lists[k][-5:]), reverse=True)
    
    normalizer_params = obs_normalizer_update_fn(normalizer_params, obs[top_idx[0],:])
    
    _, params_def = jax.tree_flatten(policy_params)
    params_flat, _ = jax.tree_flatten(policy_params_trained)
    top_params_flat = [param[top_idx[0]] for param in params_flat]
    top_params = jax.tree_unflatten(params_def, top_params_flat)
    
    
#     _, norm_def = jax.tree_flatten(normalizer_params)
#     norm_flat, _ = jax.tree_flatten(normalizer_params_all)
#     top_norm_flat = [param[top_idx[0]] for param in norm_flat]
#     top_norms = jax.tree_unflatten(norm_def, top_norm_flat)
    
    noise_beta -= .1
    
    if rewards_lists[top_idx[0]][-1] > best_reward:
        noise_beta = 2.0
        policy_params = top_params
        top_normalizer_params = normalizer_params
        best_reward = jnp.mean(rewards_lists[top_idx[0]][-5:])
        
    meta_rewards_list.append(best_reward)
    
    print(f" Iteration {i} --------------------------------")
    
    for j in range(len(top_idx)):
        done_idx = jnp.where(states_before.done[top_idx[j], :], size=1)[0].item()
        if done_idx == 0:
            done_idx = rewards_before.shape[-1]
        rewards_sum_before = jnp.sum(rewards_before[top_idx[j],:done_idx])

        done_idx = jnp.where(states_after.done[top_idx[j], :], size=1)[0].item()
        if done_idx == 0:
            done_idx = rewards_after.shape[-1]
        rewards_sum_after = jnp.sum(rewards_after[top_idx[j],:done_idx])
        
        print(f"{j} : reward: {rewards_sum_before} -> {jnp.mean(rewards_lists[top_idx[j]][-5:])}  |  {rewards_sum_after}")

    
    #print(f"{i} : best reward: {rewards_sum_before} -> {rewards_lists[top_idx[0]][-1]}  |  {rewards_sum_after}")

    print("Best reward so far: ", best_reward)
    print('--------------------------------------')

TypeError: true_fun and false_fun output must have identical types, got
((State(qp=QP(pos='ShapedArray(float32[16,4,3])', rot='ShapedArray(float32[16,4,4])', vel='ShapedArray(float32[16,4,3])', ang='ShapedArray(float32[16,4,3])'), obs='ShapedArray(float32[16,11])', reward='ShapedArray(float32[16])', done='ShapedArray(float32[16])', metrics={'rewardCtrl': 'ShapedArray(float32[16])', 'rewardDist': 'ShapedArray(float32[16])'}, info={'steps': 'ShapedArray(float32[16])', 'truncation': 'ShapedArray(float32[16])'}), 'ShapedArray(float32[16,11])', FrozenDict({
    params: {
        Dense_0: {
            bias: 'ShapedArray(float32[64])',
            kernel: 'ShapedArray(float32[11,64])',
        },
        Dense_1: {
            bias: 'ShapedArray(float32[64])',
            kernel: 'ShapedArray(float32[64,64])',
        },
        Dense_2: {
            bias: 'ShapedArray(float32[2])',
            kernel: 'ShapedArray(float32[64,2])',
        },
        GRUCell_0: {
            hn: {
                bias: 'ShapedArray(float32[11])',
                kernel: 'ShapedArray(float32[11,11])',
            },
            hr: {
                kernel: 'ShapedArray(float32[11,11])',
            },
            hz: {
                kernel: 'ShapedArray(float32[11,11])',
            },
            in: {
                bias: 'ShapedArray(float32[11])',
                kernel: 'ShapedArray(float32[11,11])',
            },
            ir: {
                bias: 'ShapedArray(float32[11])',
                kernel: 'ShapedArray(float32[11,11])',
            },
            iz: {
                bias: 'ShapedArray(float32[11])',
                kernel: 'ShapedArray(float32[11,11])',
            },
        },
    },
}), ('ShapedArray(float32[])', 'ShapedArray(float32[11])', 'ShapedArray(float32[11])'), 'DIFFERENT ShapedArray(float32[], weak_type=True) vs. ShapedArray(float32[16])', 'ShapedArray(int32[], weak_type=True)'), None).

In [None]:
jnp.any(policy_params_trained['params']['Dense_1']['kernel'] - policy_params_with_noise['params']['Dense_1']['kernel'])

In [None]:
import brax.jumpy as jp
@jax.jit
def do_rnn_rollout(policy_params, normalizer_params, key):
    init_state = env.reset(key)
    h0 = jp.zeros_like(init_state.obs)

    def do_one_rnn_step(carry, step_idx):
        state, h, policy_params, normalizer_params  = carry

        normed_obs = obs_normalizer_apply_fn(normalizer_params, state.obs)
        h1 , actions = policy.apply(policy_params, h, normed_obs)
        #actions = jp.ones_like(actions)*0.0
        nstate = env.step(state, actions)    
        #h1 = jax.lax.cond(nstate.done, lambda x: jnp.zeros_like(h1), lambda x: h1, None)
        return (jax.lax.stop_gradient(nstate), h1, policy_params, normalizer_params), (nstate.reward,state.obs, actions, nstate)


    _, (rewards, obs, acts, states) = jp.scan(
        do_one_rnn_step, (init_state, h0, policy_params, normalizer_params),
        (jnp.array(range(episode_length // action_repeat))),
        length=episode_length // action_repeat)

    return rewards, obs, acts, states

In [None]:
key, reset_key = jax.random.split(key)
(rewards, obs, acts, states) = do_rnn_rollout(policy_params, top_normalizer_params, reset_key)

done_idx = jnp.where(states.done, size=1)[0].item()
rewards_sum = jnp.sum(rewards[:done_idx])

plt.plot(obs[:done_idx,:]);
plt.figure()
plt.plot(acts);
print(rewards_sum)
print(states.done)

In [None]:
rewards, obs, acts, states = do_one_rollout(env_fn, policy.apply, top_normalizer_params, top_params, key, episode_length, action_repeat, normalize_observations)
print(sum(rewards))

In [None]:
qp_flat, qp_def = jax.tree_flatten(states.qp)

qp_list = []

for i in range(qp_flat[0].shape[0]):
    qpc=[]
    for thing in qp_flat:
        qpc.append(thing[i,:])
    qp_list.append(jax.tree_unflatten(qp_def, qpc))
    

visualize(env.sys, qp_list, height=800)

In [None]:
states.done