In [1]:
import os
import jax
import pickle

os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=24'
jax.config.update('jax_platform_name', 'cpu')

#from jax.config import config; config.update("jax_enable_x64", True)

%load_ext autoreload
%autoreload 2

from brax import envs
from brax.io import html
from brax.training import normalization

import flax
import jax.numpy as jnp
from brax.envs import create_fn

from IPython.display import HTML, clear_output

import optax

import matplotlib.pyplot as plt
import numpy as np

from controllers import GruController, MlpController, LinearController

from ce_apg import do_one_rollout, cem_apg

from functools import partial



def visualize(sys, qps, height=480):
  """Renders a 3D visualization of the environment."""
  return HTML(html.render(sys, qps, height=height))

len(jax.devices())

save_dir = "save_reacher"

In [None]:
env_name = "reacher"  # @param ['ant', 'humanoid', 'fetch', 'grasp', 'halfcheetah', 'walker2d, 'ur5e', 'reacher', bball_1dof]

episode_length = 500
action_repeat = 1
env_fn = create_fn(env_name = env_name, action_repeat=action_repeat, batch_size=None, auto_reset=False)
env = env_fn()

policy_size = int(2**jnp.ceil(jnp.log2(env.observation_size*4)))
print(policy_size)
policy = GruController(env.observation_size, env.action_size, 128)

for i in range(1):
    inference_fn, params, rewards = cem_apg(env_fn,
                                            20,
                                            key=jax.random.PRNGKey(i),
                                            episode_length = episode_length,
                                            action_repeat = action_repeat,
                                            apg_epochs = 30,
                                            cem_epochs = 1,
                                            batch_size = 1,
                                            truncation_length = None,
                                            learning_rate = 5e-4,
                                            clipping = 1e9,
                                            initial_std = 0.01,
                                            num_elite = 8,
                                            eps = 0.0,
                                            normalize_observations=True,
                                            policy = policy
                                           )



    pickle.dump(params, open(f"{save_dir}/{env_name}_policy{i}.pkl", 'wb'))
    #pickle.dump(no, open(f"{save_dir}/{env_name}_normalize{i}.pkl", 'wb'))
    pickle.dump(rewards, open(f"{save_dir}/{env_name}_rewards.pkl{i}", 'wb'))

64


In [None]:
import brax.model

model.save_params('/tmp/params', params)
inference_fn = ce_apg.make_inference_fn(
    env.observation_size, env.action_size, True, policy)
params = model.load_params('/tmp/params')

jit_env_reset = jax.jit(env.reset)
jit_env_step = jax.jit(env.step)
jit_inference_fn = jax.jit(inference_fn)

rollout = []
rng = jax.random.PRNGKey(seed=0)
state = jit_env_reset(rng=rng)
while not state.done:
  rollout.append(state)
  act_rng, rng = jax.random.split(rng)
  act = jit_inference_fn(params, state.obs, act_rng)
  state = jit_env_step(state, act)

HTML(html.render(env.sys, [s.qp for s in rollout]))

In [None]:
env_name = "acrobot"  # @param ['ant', 'humanoid', 'fetch', 'grasp', 'halfcheetah', 'walker2d, 'ur5e', 'reacher', bball_1dof]

episode_length = 500
action_repeat = 1
env_fn = create_fn(env_name = env_name, action_repeat=action_repeat, batch_size=None, auto_reset=False)
env = env_fn()



policy_size = int(2**jnp.ceil(jnp.log2(env.observation_size*4)))
print(policy_size)
policy = GruController(env.observation_size, env.action_size, 32)

for i in range(8):
    normalizer_params, policy_params, rewards = cem_apg(env_fn,
                                                        200,
                                                        key=jax.random.PRNGKey(i),
                                                        episode_length = episode_length,
                                                        action_repeat = action_repeat,
                                                        apg_epochs = 75,
                                                        cem_epochs = 1,
                                                        batch_size = 1,
                                                        truncation_length = None,
                                                        learning_rate = 5e-4,
                                                        clipping = 1e9,
                                                        initial_std = 0.01,
                                                        num_elite = 8,
                                                        eps = 0.0,
                                                        normalize_observations=True,
                                                        policy = policy
                                                       )



    pickle.dump(policy_params, open(f"{save_dir}/{env_name}_policy{i}.pkl", 'wb'))
    pickle.dump(normalizer_params, open(f"{save_dir}/{env_name}_normalize{i}.pkl", 'wb'))
    pickle.dump(rewards, open(f"{save_dir}/{env_name}_rewards.pkl{i}", 'wb'))

In [None]:
env_name = "inverted_double_pendulum_swingup"  # @param ['ant', 'humanoid', 'fetch', 'grasp', 'halfcheetah', 'walker2d, 'ur5e', 'reacher', bball_1dof]

episode_length = 500
action_repeat = 1
env_fn = create_fn(env_name = env_name, action_repeat=action_repeat, batch_size=None, auto_reset=False)
env = env_fn()



policy_size = int(2**jnp.ceil(jnp.log2(env.observation_size*4)))
print(policy_size)
policy = GruController(env.observation_size, env.action_size, 128)

for i in range(8):
    normalizer_params, policy_params, rewards = cem_apg(env_fn,
                                                        200,
                                                        key=jax.random.PRNGKey(i),
                                                        episode_length = episode_length,
                                                        action_repeat = action_repeat,
                                                        apg_epochs = 75,
                                                        cem_epochs = 1,
                                                        batch_size = 1,
                                                        truncation_length = None,
                                                        learning_rate = 5e-4,
                                                        clipping = 1e9,
                                                        initial_std = 0.01,
                                                        num_elite = 8,
                                                        eps = 0.0,
                                                        normalize_observations=True,
                                                        policy = policy
                                                       )



    pickle.dump(policy_params, open(f"{save_dir}/{env_name}_policy{i}.pkl", 'wb'))
    pickle.dump(normalizer_params, open(f"{save_dir}/{env_name}_normalize{i}.pkl", 'wb'))
    pickle.dump(rewards, open(f"{save_dir}/{env_name}_rewards.pkl{i}", 'wb'))

In [None]:
import pickle
pickle.dump(policy_params2, open("inverted_double_pendulum_swingup.pickle", 'wb'))

In [None]:
env_name = "inverted_double_pendulum_swingup"
pickle.dump(policy_params2, open(f"{env_name}_policy.pkl", 'wb'))
pickle.dump(normalizer_params2, open(f"{env_name}_normalize.pkl", 'wb'))
pickle.dump(rewards2, open(f"{env_name}_rewards.pkl", 'wb'))