In [None]:
import os
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=12'

%load_ext autoreload
%autoreload 2
%matplotlib inline

from brax import envs
from brax.io import html
from brax.training import normalization


import flax
import jax
import jax.numpy as jnp
from brax.envs import create_fn

from IPython.display import HTML, clear_output

import optax

import matplotlib.pyplot as plt
import numpy as np

from controllers import GruController, MlpController, LinearController
from common import do_local_apg, add_guassian_noise, add_uniform_noise, add_uniform_and_pareto_noise, add_sym_pareto_noise, do_one_rollout

from functools import partial
from matplotlib import cm


jax.config.update('jax_platform_name', 'cpu')

def visualize(sys, qps, height=480):
  """Renders a 3D visualization of the environment."""
  return HTML(html.render(sys, qps, height=height))

len(jax.devices())

In [None]:
episode_length = 1000
action_repeat = 1

env_name = "acrobot"  # @param ['ant', 'humanoid', 'fetch', 'grasp', 'halfcheetah', 'walker2d, 'ur5e', 'reacher', bball_1dof]
env_fn = create_fn(env_name = env_name, action_repeat=action_repeat, batch_size=None, auto_reset=False)
env = env_fn()
policy = LinearController(env.observation_size,env.action_size)

key = jax.random.PRNGKey(0)
reset_keys = jax.random.split(key, num=jax.local_device_count())
key, model_key = jax.random.split(key)
noise_keys = jax.random.split(model_key, num=jax.local_device_count())

init_states = jax.pmap(env.reset)(reset_keys)
x0 = init_states.obs
h0 = jnp.zeros(env.observation_size)

policy_params = policy.init(model_key, h0, x0)


import brax.jumpy as jp
@jax.jit
def do_rnn_rollout(policy_params, key):
    init_state = env.reset(key)
    h0 = jp.zeros_like(init_state.obs)

    def do_one_rnn_step(carry, step_idx):
        state, h, policy_params  = carry

        h1 , actions = policy.apply(policy_params, h, state.obs)
        #actions = jp.ones_like(actions)*0.0
        nstate = env.step(state, actions)    
        #h1 = jax.lax.cond(nstate.done, lambda x: jnp.zeros_like(h1), lambda x: h1, None)
        return (jax.lax.stop_gradient(nstate), h1, policy_params), (nstate.reward,state.obs, actions, nstate)


    _, (rewards, obs, acts, states) = jp.scan(
        do_one_rnn_step, (init_state, h0, policy_params),
        (jnp.array(range(episode_length // action_repeat))),
        length=episode_length // action_repeat)

    return rewards, obs, acts, states

In [None]:
policy_params = policy_params.unfreeze()
policy_params['params']['Dense_0']['kernel'] = policy_params['params']['Dense_0']['kernel'].at[0].set(1649.0)
policy_params['params']['Dense_0']['kernel'] = policy_params['params']['Dense_0']['kernel'].at[1].set(460.2)
policy_params['params']['Dense_0']['kernel'] = policy_params['params']['Dense_0']['kernel'].at[2].set(716.1)
policy_params['params']['Dense_0']['kernel'] = policy_params['params']['Dense_0']['kernel'].at[3].set(278.2)
policy_params = flax.core.frozen_dict.FrozenDict(policy_params)

key, reset_key = jax.random.split(key)
reward,obs,acts,states = do_rnn_rollout(policy_params, reset_key)

plt.plot(obs)
plt.figure()
plt.plot(acts)
print(sum(reward))

In [None]:
for i in range(10):
    key, model_key = jax.random.split(model_key)
    policy_params = policy.init(model_key, h0, x0)

    size = 100
    rewards = jnp.zeros((size,size))
    for i,x in enumerate(jnp.linspace(-10,10,size)):
        for j,y in enumerate(jnp.linspace(-10,10,size)):
            key, reset_key = jax.random.split(key)

            policy_params = policy_params.unfreeze()
            policy_params['params']['Dense_0']['kernel'] = policy_params['params']['Dense_0']['kernel'].at[0].set(x)
            policy_params['params']['Dense_0']['kernel'] = policy_params['params']['Dense_0']['kernel'].at[1].set(y)
            policy_params = flax.core.frozen_dict.FrozenDict(policy_params)

            reward,obs,acts,states = do_rnn_rollout(policy_params, reset_key)
            reward_sum = jnp.sum(reward)
            rewards = rewards.at[i,j].set(reward_sum)

            #print(x,y,reward_sum)
    rewards = np.array(rewards)
    plt.figure(figsize=(7, 6))
    plt.pcolormesh(rewards)
    plt.colorbar()
    plt.figure();
 

In [None]:
rewards = np.array(rewards)

In [None]:
X = jnp.linspace(-10,10,size)
Y = jnp.linspace(-10,10,size)
X, Y = np.meshgrid(X, Y)

fig, ax = plt.subplots(subplot_kw={"projection": "3d"}, figsize=(16,16))
# Plot the surface.
surf = ax.plot_surface(X, Y, rewards, cmap=cm.coolwarm,
                       linewidth=0, antialiased=False)

In [None]:
plt.figure(figsize=(7, 6))
plt.pcolormesh(rewards)
plt.colorbar()

In [None]:
plt.figure(figsize=(7, 6))
plt.pcolormesh(rewards)
plt.colorbar()