In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = 'retina'

In [2]:
import os
import os.path as osp

# Set the number of (emulated) host devices
num_devices = 4
os.environ['XLA_FLAGS'] = f"--xla_force_host_platform_device_count={num_devices}"

import numpy as np

import jax
import jax.numpy as jnp
from jax import lax
import gymnax

jax.device_count(), jax.devices()

(4, [CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3)])

In [69]:
from utils.models import get_model_ready
from utils.helpers import load_config, save_pkl_object
from src.control.dynamics import cartpole_step,CartPoleEnvState,cartpole_kinematics

"""Run training with ES or PPO. Store logs and agent ckpt."""
ppo_yaml = osp.join("expert_agents","CartPole-v1","ppo.yaml")
seed_id = 123
lrate = 5e-04

config = load_config(ppo_yaml, seed_id, lrate)

rng = jax.random.PRNGKey(config.train_config.seed_id)
# Setup the model architecture
rng, rng_init = jax.random.split(rng)
model, policy_params = get_model_ready(rng_init, config.train_config)

In [27]:
def batch_step(x,x_dot,theta,theta_dot, action):
    return jax.vmap(cartpole_step, in_axes=(0, 0,0,0,0))(
        x,x_dot,theta,theta_dot, action)
    

In [28]:
state = CartPoleEnvState(jnp.array([0.1]),jnp.array([0.1]),jnp.array([0.1]),jnp.array([0.1]))
action = jnp.array([0])

In [29]:
state = cartpole_step(state.x,state.x_dot,state.theta,state.theta_dot,action)
state

Array([[ 0.102     ],
       [-0.09640235],
       [ 0.102     ],
       [ 0.42248276]], dtype=float64)

In [30]:
state_batch = CartPoleEnvState(jnp.array(np.random.randn(5,)),
                       jnp.array(np.random.randn(5,)),
                       jnp.array(np.random.randn(5,)),
                       jnp.array(np.random.randn(5,)))
action_batch = jnp.array(np.random.randint(0,2,5))

In [31]:
state_batch

CartPoleEnvState(x=Array([0.16887683, 0.25342474, 1.43330185, 0.66684305, 0.1711656 ],      dtype=float64), x_dot=Array([-0.62648993,  0.64402032,  1.04853453, -2.29503117, -1.10335393],      dtype=float64), theta=Array([-0.04119768,  0.94428266, -0.43576316,  0.09866285, -0.54856945],      dtype=float64), theta_dot=Array([ 1.03033543, -0.89861142,  0.07671684,  1.30767577,  1.20013876],      dtype=float64))

In [32]:
next_state  = batch_step(
  state_batch.x,state_batch.x_dot,state_batch.theta,state_batch.theta_dot, action_batch
)

In [35]:
next_state

Array([[ 0.15634703, -0.82104023, -0.02059098,  1.30980456],
       [ 0.26630514,  0.45194747,  0.92631043, -0.49152349],
       [ 1.45427254,  0.86133846, -0.43422882,  0.20717214],
       [ 0.62094243, -2.49125528,  0.12481636,  1.62954035],
       [ 0.14909852, -0.9065003 , -0.52456668,  0.79487308]],      dtype=float64)

In [70]:
# Jit-Compiled Episode Rollout
from functools import partial

jit_rollout = jax.jit(partial(cartpole_kinematics,cartpole_step=cartpole_step))

In [71]:
state = CartPoleEnvState(jnp.array([0.1]),jnp.array([0.1]),jnp.array([0.1]),jnp.array([0.1]))
action = jnp.array(np.random.randint(0,2,10))

In [72]:
x_rollout, x_dot_rollout, theta_rollout, theta_dot_rollout = jit_rollout(state.x,state.x_dot,state.theta,state.theta_dot,action)

In [73]:
x_rollout

Array([[0.102     ],
       [0.10007195],
       [0.09421575],
       [0.08443037],
       [0.07071426],
       [0.06085865],
       [0.05485238],
       [0.05268567],
       [0.04657173],
       [0.03651157]], dtype=float64)