In [37]:
import jax
from jax import numpy as jnp, jit
from jaxmarl import make
from jaxmarl.environments.smax import map_name_to_scenario

import numpy as np
from PIL import Image

In [38]:
def traj(env, state, key, max_steps=2000):
    state_seq = []
    returns   = {a: 0 for a in env.agents}
    for i in range(max_steps):
        state_seq             += [state]
        key, key_act, key_step = jax.random.split(key, 3)
        key_act                = jnp.array(jax.random.split(key_act, len(env.agents)))
        actions                = {agent: env.action_space(agent).sample(key_act[i]) for i, agent in enumerate(env.agents)}
        obs, state, reward, done, infos = env.step(key_step, state, actions)
        for agent in env.agents:
            returns[agent] += reward[agent]
        if done["__all__"]:
            break
    return state_seq, returns

In [63]:
def make_vtraj(config):
    env = make(config['env'], **config['env_config'])
    config['n_agents'] = env.num_agents * config['n_envs']

    def init_runner_state(key):
        key, key_reset = jax.random.split(key)
        key_reset      = jax.random.split(key_reset, config['n_envs'])
        obsv, state    = jax.vmap(env.reset)(key_reset)
        return (state, obsv, key)

    def env_step(runner_state, _):
        env_state, last_obs, key = runner_state   # random key for sampling actions
        key, key_act             = jax.random.split(key)

        key_act = jax.random.split(key_act, config['n_agents']).reshape((env.num_agents, config['n_envs'], -1))
        actions = {agent: jax.vmap(env.action_space(agent).sample)(key_act[i]) for i, agent in enumerate(env.agents)}

        key, key_step = jax.random.split(key)
        key_step      = jax.random.split(key_step, config['n_envs'])

        obsv, env_state, _, _, infos = jax.vmap(env.step)(key_step, env_state, actions)
        return (env_state, obsv, key), None  # (state, obsv, reward, done, infos)

    def vtraj(key):
        key, key_init = jax.random.split(key)
        runner_state  = init_runner_state(key_init)
        runner_state  = jax.lax.scan(env_step, runner_state, None, length=config['max_steps'])
        return runner_state[0]

    return vtraj


In [64]:
key = jax.random.PRNGKey(0)
config = {"max_steps": 2000, "n_envs": 100, "env": "HeuristicEnemySMAX", "env_config": {}}
vtraj = jit(make_vtraj(config))
runner_state = vtraj(key)

In [74]:
runner_state[0].state.unit_positions.shape

(100, 10, 2)