In [85]:
import jax
from jax import numpy as jnp, jit, vmap, random
from jaxmarl import make
from jaxmarl.environments.smax import map_name_to_scenario, SMAX
from jaxmarl.viz.visualizer import SMAXVisualizer, Visualizer
from jaxmarl.environments.smax.smax_env import State as SMAXState  # use to unstack state carry from lax scan

from matplotlib import rcParams
from tqdm import tqdm
from functional import partial
import numpy as np
import matplotlib.pyplot as plt
import imageio
import imageio_ffmpeg  # Make sure this is installed
import darkdetect

import esch

In [86]:
# globals
rcParams['font.family'] = 'monospace'
rcParams['font.monospace'] = 'Fira Code'
bg = 'black' if darkdetect.isDark() else 'white'
ink = 'black' if bg == 'white' else 'white'

# Vanilla

In [89]:
# globals
n_envs  = 6
n_steps = 20

In [90]:
# iter
def step_fn(rng, env, obs_v, state_v):
    rng, act_rng, step_rng         = random.split(rng, 3)
    act_keys                       = random.split(act_rng, env.num_agents * n_envs).reshape(env.num_agents, n_envs, -1)
    step_keys                      = random.split(step_rng, n_envs)
    actions                        = {a: vmap(env.action_space(a).sample)(act_keys[i]) % 2 for i, a in enumerate(env.agents)}
    obs_v, state_v, reward_v, _, _ = vmap(env.step)(step_keys, state_v, actions)
    return (rng, env, obs_v, state_v), (step_keys, state_v, actions), reward_v

def traj_fn(rng, env):
    rng, reset_rng        = random.split(random.PRNGKey(0))
    reset_keys            = random.split(reset_rng, n_envs)
    obs_v, state_v        = vmap(env.reset)(reset_keys)
    traj_state            = (rng, env, obs_v, state_v)
    state_seq, reward_seq = [], []
    for step in tqdm(range(n_steps)):
        traj_state, state_v, reward_v = step_fn(*traj_state)
        state_seq, reward_seq         = state_seq + [state_v], reward_seq + [reward_v]
    return state_seq, reward_seq

In [91]:
env                   = make('SMAX')
rng, key              = random.split(random.PRNGKey(0))
state_seq, reward_seq = traj_fn(key, env,) 

100%|██████████| 20/20 [00:04<00:00,  4.37it/s]


In [93]:
def plot_fn(env, state_seq, expand=False):
    state_seq = vmap(env.expand_state_seq)(state_seq) if expand else state_seq
    frames    = []
    for _, state, action in tqdm(state_seq):
        fig, axes = plt.subplots(2, 3, figsize=(17.92, 12.16), facecolor=bg, dpi=100)
        fig_fn(fig, state.time[0])
        for idx, ax in zip(range(n_envs), axes.flatten()):
            x, y = state.unit_positions[idx, :, 0], state.unit_positions[idx, :, 1]
            c, s = None, None
            ax.scatter(x, y, c='white')
            ax_fn(ax, idx)
        frames.append(frame_fn(fig))
    imageio.mimsave(f'docs/figs/out.mp4', frames, fps=24 if expand else 6)
    
def frame_fn(fig):
    fig.canvas.draw()
    frame = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8)
    frame = frame.reshape(fig.canvas.get_width_height()[::-1] + (4,))[:, :, :3]
    plt.close()   # close fig
    return frame

def ax_fn(ax, idx):
    ax.set_xlabel(f"{4} ret. | {2} ret.", color=ink)
    ax.set_title(f"sim. {idx+1}", color=ink)
    ax.set_facecolor(bg)
    ticks = np.arange(2, 31, 4)  # Assuming your grid goes from 0 to 32
    ax.set_xticks(ticks)
    ax.set_yticks(ticks)
    ax.tick_params(colors=ink, direction='in', length=6, width=1, which='both',
        top=True, bottom=True, left=True, right=True, labelleft=False, labelbottom=False)
    ax.spines['top'].set_color(ink)
    ax.spines['bottom'].set_color(ink)
    ax.spines['left'].set_color(ink)
    ax.spines['right'].set_color(ink)
    ax.set_aspect('equal')
    ax.set_xlim(-2, 34)
    ax.set_ylim(-2, 34)

def fig_fn(fig, time):
    title = f"step {time}"
    fig.text(0.1, 0.5, title, va='center', rotation='vertical', fontsize=20, color=ink)
    plt.subplots_adjust(left=0.05, bottom=0.05, right=0.95, top=0.95, wspace=0.3, hspace=0.3)

In [94]:
plot_fn(env, state_seq, expand=True)

 69%|██████▉   | 111/160 [00:13<00:06,  7.80it/s]