In [1]:
import jax
from jax import numpy as jnp, jit, vmap, random
from jaxmarl import make
import matplotlib.animation as animation
from jaxmarl.environments.smax import map_name_to_scenario, SMAX
from jaxmarl.environments.smax.smax_env import State as SMAXState  # use to unstack state carry from lax scan
from jaxmarl.environments.smax.heuristic_enemy import HeuristicPolicyState as EnemyState
from jaxmarl.environments.smax.heuristic_enemy_smax_env import State as State


from tqdm import tqdm
from functional import partial
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from copy import deepcopy

import esch

## Environment

In [2]:
def make_vtraj(config):  # returns a function that runs n_envs environments in parallel. Current actions are random.
    env = make(config['env'], **config['env_config'])
    config['n_agents'] = env.num_agents * config['n_envs']

    def init_runner_state(key):
        key, key_reset = random.split(key)
        key_reset      = random.split(key_reset, config['n_envs'])
        obsv, state    = vmap(env.reset)(key_reset)
        return (state, obsv, key)

    def env_step(runner_state, seqs):
        env_state, last_obs, key = runner_state   # random key for sampling actions
        key, key_act             = random.split(key)

        key_act = random.split(key_act, config['n_agents']).reshape((env.num_agents, config['n_envs'], -1))
        # this is the line we wanna inject the action into from.
        actions = {agent: jnp.ones_like(vmap(env.action_space(agent).sample)(key_act[i])) + 1 for i, agent in enumerate(env.agents)}

        key, key_step = random.split(key)
        key_step      = random.split(key_step, config['n_envs'])

        obsv, env_state, reward, done, infos = vmap(env.step)(key_step, env_state, actions)
        return (env_state, obsv, key), (key_step, env_state, actions)  # reward)



    def vtraj_fn(key):
        key, key_init       = random.split(key)
        runner_state        = init_runner_state(key_init)
        # scan :: (c -> a -> (c, b)) -> c -> [a] -> (c, [b])
        runner_state, hist = jax.lax.scan(env_step, runner_state, None, length=config['max_steps'])
        return runner_state, hist

    return vtraj_fn, env

In [3]:
env_config     = {"num_allies": 40,  'num_enemies': 40}
config         = {"max_steps": 30, "n_envs": 6, "env": "SMAX", "env_config": env_config}
vtraj_fn, env  = make_vtraj(config)
vtraj          = jit(vtraj_fn)

In [4]:
rng, key                = random.split(random.PRNGKey(0))
(state, obs, key), hist = vtraj(key)

In [5]:
def expand_state_seq_fn(env, hist):
    xs = []
    for i in range(config['max_steps']):
        state = SMAXState(
            unit_positions=hist[1].unit_positions[i, :],
            unit_types=hist[1].unit_types[i, :],
            unit_teams=hist[1].unit_teams[i, :],
            unit_health=hist[1].unit_health[i, :],
            unit_weapon_cooldowns=hist[1].unit_weapon_cooldowns[i, :],
            prev_actions=hist[1].prev_actions[i, :],
            time=hist[1].time[i],
            terminal=hist[1].terminal[i],
            unit_alive=hist[1].unit_alive[i, :]
        )
        action = {
            ally: actions[i] for ally, actions in hist[2].items()
        }
        xs.append((hist[0][i], state, action))
    return vmap(env.expand_state_seq)(xs)

In [6]:
info = {'agent': 'random', 'scenario': f'{env_config["num_allies"]}v{env_config["num_enemies"]}'}                                                      
seqs = expand_state_seq_fn(env, hist)
esch.worlds_fn(seqs, info)