# C2SIM

In [2]:
import jax
from jax import numpy as jnp, jit, vmap, random
import chex
from jaxmarl import make
from jaxmarl.environments.smax import map_name_to_scenario
from einops import rearrange
from functional import partial

import seaborn as sns
from tqdm import tqdm
import numpy as np
from matplotlib import rcParams
import matplotlib.pyplot as plt
import darkdetect, imageio

from src import bullet_fn, plot_fn

## trajectories

In [3]:
# globals
n_envs = 6
n_steps = 100
scenario = "3s5z_vs_3s6z"

In [4]:
# iter
def step_fn(rng, env, obs_v, old_state_v):
    rng, act_rng, step_rng = random.split(rng, 3)
    act_keys = random.split(act_rng, env.num_agents * n_envs).reshape(
        env.num_agents, n_envs, -1
    )
    step_keys = random.split(step_rng, n_envs)
    actions = {
        a: action_fn(env, act_keys[i], obs_v[a], a) for i, a in enumerate(env.agents)
    }
    obs_v, state_v, reward_v, _, _ = vmap(env.step)(step_keys, old_state_v, actions)
    return (rng, env, obs_v, state_v), (step_keys, old_state_v, actions), reward_v


@partial(vmap, in_axes=(None, 0, 0, None))
def action_fn(env, rng, obs, a):
    return env.action_space(a).sample(rng)


def traj_fn(rng, env):
    rng, reset_rng = random.split(random.PRNGKey(0))
    reset_keys = random.split(reset_rng, n_envs)
    obs_v, state_v = vmap(env.reset)(reset_keys)
    traj_state = (rng, env, obs_v, state_v)
    state_seq, reward_seq = [], []
    for _ in tqdm(range(n_steps)):
        traj_state, state_v, reward_v = step_fn(*traj_state)
        state_seq, reward_seq = state_seq + [state_v], reward_seq + [reward_v]
    return state_seq, reward_seq

In [6]:
env = make("SMAX", scenario=map_name_to_scenario(scenario))
rng, key = random.split(random.PRNGKey(0))
state_seq, reward_seq = traj_fn(key, env)
plot_fn(env, state_seq, reward_seq, expand=True)

100%|██████████| 100/100 [00:03<00:00, 31.85it/s]
100%|██████████| 100/100 [00:00<00:00, 101.07it/s]
 74%|███████▍  | 590/800 [01:26<00:28,  7.32it/s]

## plotting