In [1]:
import jax
import jax.numpy as jnp
import gymnax

In [2]:
jax.devices()

[gpu(id=0), gpu(id=1), gpu(id=2), gpu(id=3)]

In [3]:
gymnax.registered_envs

['CartPole-v1',
 'Pendulum-v1',
 'Acrobot-v1',
 'MountainCar-v0',
 'MountainCarContinuous-v0',
 'Asterix-MinAtar',
 'Breakout-MinAtar',
 'Freeway-MinAtar',
 'SpaceInvaders-MinAtar',
 'Catch-bsuite',
 'DeepSea-bsuite',
 'MemoryChain-bsuite',
 'UmbrellaChain-bsuite',
 'DiscountingChain-bsuite',
 'MNISTBandit-bsuite',
 'SimpleBandit-bsuite',
 'FourRooms-misc',
 'MetaMaze-misc',
 'PointRobot-misc',
 'BernoulliBandit-misc',
 'GaussianBandit-misc',
 'Reacher-misc',
 'Swimmer-misc',
 'Pong-misc']

In [4]:
env, params = gymnax.make("CartPole-v1")

In [5]:
rng, step_rng, reset_rng, action_key = jax.random.split(jax.random.PRNGKey(42), 4)

In [8]:
vmap_reset = jax.vmap(env.reset, in_axes=(0, None))
vmap_step = jax.vmap(env.step, in_axes=(0, 0, 0))
vmap_action = jax.vmap(env.action_space(params).sample)(jax.random.split(step_rng, 8))
num_envs = 8
vmap_keys = jax.random.split(rng, num_envs)

obs, state = vmap_reset(vmap_keys, params)
n_obs, n_state, reward, done, _ = vmap_step(vmap_keys, state, vmap_action)
print(n_obs.shape)

(8, 4)


In [9]:
pmap_reset = jax.pmap(vmap_reset, in_axes=(0, None))
pmap_step = jax.pmap(vmap_step, in_axes=(0, 0, 0))
pmap_action = jax.pmap(jax.vmap(env.action_space(params).sample))(jnp.tile(jax.random.split(step_rng, 8), (4, 1, 1)))

map_keys = jnp.tile(vmap_keys, (4, 1, 1))
obs, state = pmap_reset(map_keys, params)
n_obs, n_state, reward, done, _ = pmap_step(map_keys, state, pmap_action)
print(n_obs.shape)

(4, 8, 4)


In [11]:
pmap_action

Array([[1, 1, 0, 0, 0, 0, 0, 0],
       [1, 1, 0, 0, 0, 0, 0, 0],
       [1, 1, 0, 0, 0, 0, 0, 0],
       [1, 1, 0, 0, 0, 0, 0, 0]], dtype=int32)