In [None]:
from absl.testing import absltest
from absl.testing import parameterized
from brax import envs
from brax.training import sac

episode_length = 500
action_repeat = 1
env_name = 'acrobot'

inference_fn, params, metrics = sac.train(
    environment_fn=envs.create_fn(env_name, auto_reset=True),
    num_timesteps = 10_000_000,
    episode_length = episode_length, normalize_observations = True,
    action_repeat = action_repeat,
    discounting = 0.99, learning_rate = 3e-4,
    num_envs = 128, batch_size = 256, seed = 1
)

In [None]:
import jax
import jax.numpy as jnp
env = envs.create_fn(env_name, auto_reset=True)()
key = jax.random.PRNGKey(0)
reset_key, inference_key = jax.random.split(key)
state = env.reset(reset_key)
obs = state.obs 

In [None]:
import brax.jumpy as jp
import matplotlib.pyplot as plt


@jax.jit
def do_rnn_rollout(policy_params, key):
    init_state = env.reset(key)

    def do_one_step(carry, step_idx):
        state, policy_params, key = carry
        act_key, key = jax.random.split(key)
        actions = inference_fn(params, state.obs, act_key)
        nstate = env.step(state, actions)    
        return (nstate, policy_params, key), (nstate.reward,state.obs, actions, nstate)


    _, (rewards, obs, acts, states) = jp.scan(
        do_one_step, (init_state, policy_params, key),
        (jnp.array(range(episode_length // action_repeat))),
        length=episode_length // action_repeat)

    return rewards, obs, acts, states

In [None]:
key, reset_key = jax.random.split(key)
(rewards, obs, acts, states) = do_rnn_rollout(params, reset_key)

done_idx = jnp.where(states.done, size=1)[0].item()
if done_idx == 0:
    done_idx = rewards.shape[0]
rewards_sum = jnp.sum(rewards[:done_idx])

plt.plot(obs);
plt.figure()
plt.plot(acts);
print(rewards_sum)
print(states.done)

In [None]:
from IPython.display import HTML
from brax.io import html



def visualize(sys, qps, height=480):
  """Renders a 3D visualization of the environment."""
  return HTML(html.render(sys, 
                          qps, height=height))

qp_flat, qp_def = jax.tree_flatten(states.qp)

qp_list = []

for i in range(qp_flat[0].shape[0]):
    qpc=[]
    for thing in qp_flat:
        qpc.append(thing[i,:])
    qp_list.append(jax.tree_unflatten(qp_def, qpc))
    

visualize(env.sys, qp_list, height=800)