In [2]:
import jax
from jax import numpy as jnp, jit, vmap, random
from jaxmarl import make
from jaxmarl.environments.smax import map_name_to_scenario, SMAX
from jaxmarl.viz.visualizer import SMAXVisualizer, Visualizer
from jaxmarl.environments.smax.smax_env import State as SMAXState  # use to unstack state carry from lax scan

from matplotlib import rcParams
from tqdm import tqdm
from functional import partial
import numpy as np
import matplotlib.pyplot as plt
import imageio
import imageio_ffmpeg  # Make sure this is installed
import darkdetect

import esch

In [3]:
# globals
rcParams['font.family'] = 'monospace'
rcParams['font.monospace'] = 'Fira Code'
bg = 'black' if darkdetect.isDark() else 'white'
ink = 'black' if bg == 'white' else 'white'
markers = {0: 'o', 1: 's', 2: 'D', 3: '^', 4: '<', 5: '>', 6: '+'}

# Vanilla

In [4]:
# globals
n_envs  = 6
n_steps = 100
scenario = '8m'

In [5]:
# iter
def step_fn(rng, env, obs_v, old_state_v):
    rng, act_rng, step_rng         = random.split(rng, 3)
    act_keys                       = random.split(act_rng, env.num_agents * n_envs).reshape(env.num_agents, n_envs, -1)
    step_keys                      = random.split(step_rng, n_envs)
    actions                        = {a: vmap(env.action_space(a).sample)(act_keys[i]) for i, a in enumerate(env.agents)}
    obs_v, state_v, reward_v, _, _ = vmap(env.step)(step_keys, old_state_v, actions)
    return (rng, env, obs_v, state_v), (step_keys, old_state_v, actions), reward_v

def traj_fn(rng, env):
    rng, reset_rng        = random.split(random.PRNGKey(0))
    reset_keys            = random.split(reset_rng, n_envs)
    obs_v, state_v        = vmap(env.reset)(reset_keys)
    traj_state            = (rng, env, obs_v, state_v)
    state_seq, reward_seq = [], []
    for step in tqdm(range(n_steps)):
        traj_state, state_v, reward_v = step_fn(*traj_state)
        state_seq, reward_seq         = state_seq + [state_v], reward_seq + [reward_v]
    return state_seq, reward_seq

In [6]:
env                   = make('SMAX', scenario=map_name_to_scenario(scenario))
rng, key              = random.split(random.PRNGKey(0))
state_seq, reward_seq = traj_fn(key, env,) 

100%|██████████| 100/100 [00:03<00:00, 32.76it/s]


In [9]:
def plot_fn(env, state_seq, reward_seq, expand=False):
    """
    given an environment, a state_seq, and a reward seq
    this function makes a small multiples plot of the trejectory.
    It save it as a .mp4 film in docs/figs along with an image of the last frame.
    """
    state_seq  = state_seq if not expand else vmap(env.expand_state_seq)(state_seq)  # expand states for smooth animation
    frames     = []                                                                  # frames to make .mp4
    unit_types = np.unique(np.array(state_seq[0][1].unit_types))                     # unit types for plotting markers
    for jdx, (_, state, action) in tqdm(enumerate(state_seq), total=len(state_seq)):
        reward    = reward_seq[jdx // 8 if expand else jdx]
        fills     = np.where(np.array(state.unit_teams) == 1, ink, 'None')             # team colors
        fig, axes = plt.subplots(2, 3, figsize=(17.92, 12.16), facecolor=bg, dpi=100)  # frame in .mp4 film
        for idx, ax in zip(range(n_envs), axes.flatten()):
            ax_fn(ax, state, reward, idx)  # setup axis theme
            for unit_type in unit_types:
                kdx = state.unit_types[idx, :] == unit_type     # unit_type idxs
                x   = state.unit_positions[idx, kdx, 0]         # x coord
                y   = state.unit_positions[idx, kdx, 1]         # y coord
                c   = fills[idx, kdx]                           # color 
                s   = state.unit_health[idx, kdx] ** 1.5 * 0.2  # size 
                ax.scatter(x, y, s=s, c=c, edgecolor=ink, marker=markers[unit_type.item()])
        frames.append(frame_fn(fig, jdx // 8 if expand else jdx))
    imageio.mimsave(f'docs/figs/worlds_{bg}{"_laggy" if not expand else ""}.mp4', frames, fps=24 if expand else 3)
    
def frame_fn(fig, idx):
    
    title = f"step : {str(idx).zfill(len(str(n_steps - 1)))}     model : random     env : {scenario}"
    fig.text(0.01, 0.5, title, va='center', rotation='vertical', fontsize=20, color=ink)
    plt.subplots_adjust(left=0.05, bottom=0.05, right=0.95, top=0.95, wspace=0.3, hspace=0.3)
    fig.canvas.draw()
    frame = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8)
    frame = frame.reshape(fig.canvas.get_width_height()[::-1] + (4,))[:, :, :3]
    if idx == n_steps - 1: plt.savefig(f'docs/figs/worlds_{bg}.jpg', dpi=200)
    plt.close()   # close fig
    return frame

def ax_fn(ax, state, reward, idx):
    ally_rewards = sum([v[idx] for k, v in reward.items() if k.startswith('ally')])
    enemy_rewards = sum([v[idx] for k, v in reward.items() if k.startswith('enemy')])
    ax.set_xlabel("{:.3f} | {:.3f}".format(ally_rewards, enemy_rewards), color=ink)
    ax.set_title(f"simulation {idx+1}", color=ink)
    ax.set_facecolor(bg)
    ticks = np.arange(2, 31, 4)  # Assuming your grid goes from 0 to 32
    ax.set_xticks(ticks)
    ax.set_yticks(ticks)
    ax.tick_params(colors=ink, direction='in', length=6, width=1, which='both',
        top=True, bottom=True, left=True, right=True, labelleft=False, labelbottom=False)
    ax.spines['top'].set_color(ink)
    ax.spines['bottom'].set_color(ink)
    ax.spines['left'].set_color(ink)
    ax.spines['right'].set_color(ink)
    ax.set_aspect('equal')
    ax.set_xlim(-2, 34)
    ax.set_ylim(-2, 34)

In [8]:
plot_fn(env, state_seq, reward_seq, expand=False)

100%|██████████| 100/100 [00:12<00:00,  8.03it/s]
[rawvideo @ 0x7f8970f2dd40] Stream #0: not enough frames to estimate rate; consider increasing probesize


# Bullet proof

In [30]:
def dist_fn(pos):
    delta = (pos[None, :, :] - pos[:, None, :])
    dists = jnp.sqrt((delta ** 2).sum(axis=2))
    return dists

def target_fn(action, in_range):
    print(action.shape, in_range.shape, action, in_range[0])
    mask = in_range[action[None, :], jnp.arange(action.shape[0])]
    # print(mask)
    return mask 

def attacked_fn(action, in_range): # return vector of attacked agents
    pass

# def bullet_fn() returns bullet to render

for _, state, actions in state_seq:
    # IGNORE ACTIONS <= 4 as these are movement, NOT attack actions
    # which agents are close enough to shoot?
    dist     = vmap(dist_fn)(state.unit_positions)            # what are the distances between units
    ranges   = env.unit_type_attack_ranges[state.unit_types]  # what are the ranges of each unit
    in_range = dist < ranges[:, None, :]                      # mask of agents in range
    # targets  = {agent: target_fn(action, in_range[:, idx, :]) for idx, (agent, action) in enumerate(actions.items())}
    target_fn(actions['ally_0'], in_range[:, 0, :])
    break
    # targets  = {k: target_fn(action[k], in_range[i]) for i, (k, v) in enumerate(action.items())}

(6,) (6, 16) [12  4  1  2  4  4] [ True  True  True  True  True  True  True  True False False False False
 False False False False]
