# C2SIM

In [1]:
import jax
from jax import numpy as jnp, jit, vmap, random
import chex
from jaxmarl import make
from jaxmarl.environments.smax import map_name_to_scenario
from einops import rearrange
from functional import partial

import seaborn as sns
from tqdm import tqdm
import numpy as np
from matplotlib import rcParams
import matplotlib.pyplot as plt
import darkdetect, imageio

In [2]:
# globals
rcParams["font.family"] = "monospace"
rcParams["font.monospace"] = "Fira Code"
bg = "black" if darkdetect.isDark() else "white"
ink = "white" if bg == "black" else "black"
markers = {0: "o", 1: "s", 2: "D", 3: "^", 4: "<", 5: ">", 6: "+"}

## Classes

In [3]:
@chex.dataclass
class Config:
    env_name: str = "smax"
    scenario: str = "simple_wood_and_stone"

# Vanilla

In [25]:
# globals
n_envs = 6
n_steps = 100
scenario = "3s5z_vs_3s6z"

In [26]:
# iter
def step_fn(rng, env, obs_v, old_state_v):
    rng, act_rng, step_rng = random.split(rng, 3)
    act_keys = random.split(act_rng, env.num_agents * n_envs).reshape(
        env.num_agents, n_envs, -1
    )
    step_keys = random.split(step_rng, n_envs)
    actions = {
        a: action_fn(env, act_keys[i], obs_v[a], a) for i, a in enumerate(env.agents)
    }
    obs_v, state_v, reward_v, _, _ = vmap(env.step)(step_keys, old_state_v, actions)
    return (rng, env, obs_v, state_v), (step_keys, old_state_v, actions), reward_v


@partial(vmap, in_axes=(None, 0, 0, None))
def action_fn(env, rng, obs, a):
    return env.action_space(a).sample(rng)


def traj_fn(rng, env):
    rng, reset_rng = random.split(random.PRNGKey(0))
    reset_keys = random.split(reset_rng, n_envs)
    obs_v, state_v = vmap(env.reset)(reset_keys)
    traj_state = (rng, env, obs_v, state_v)
    state_seq, reward_seq = [], []
    for _ in tqdm(range(n_steps)):
        traj_state, state_v, reward_v = step_fn(*traj_state)
        state_seq, reward_seq = state_seq + [state_v], reward_seq + [reward_v]
    return state_seq, reward_seq

In [27]:
env = make("SMAX", scenario=map_name_to_scenario(scenario))
rng, key = random.split(random.PRNGKey(0))
state_seq, reward_seq = traj_fn(key, env)

100%|██████████| 100/100 [00:03<00:00, 31.80it/s]


# forensics

In [28]:
def dist_fn(pos):  # computing the distances between all ally and enemy agents
    delta = pos[None, :, :] - pos[:, None, :]
    dist = jnp.sqrt((delta**2).sum(axis=2))
    dist = dist[: env.num_allies, env.num_allies :]
    return {"ally": dist, "enemy": dist.T}


def range_fn(dists, ranges):  # computing what targets are in range
    ally_range = dists["ally"] < ranges[: env.num_allies][:, None]
    enemy_range = dists["enemy"] < ranges[env.num_allies :][:, None]
    return {"ally": ally_range, "enemy": enemy_range}


def target_fn(acts, in_range, team):  # computing the one hot valid targets
    t_acts = jnp.stack([v for k, v in acts.items() if k.startswith(team)]).T
    t_targets = jnp.where(t_acts - 5 < 0, -1, t_acts - 5)  # first 5 are move actions
    t_attacks = jnp.eye(in_range[team].shape[2] + 1)[t_targets][:, :, :-1]
    return t_attacks * in_range[team]  # one hot valid targets


def attack_fn(state_seq, attacks=[]):  # one hot attack list [(team, sim, ally, enemy)]
    for _, state, acts in tqdm(state_seq):
        dists = vmap(dist_fn)(state.unit_positions)
        ranges = env.unit_type_attack_ranges[state.unit_types]
        in_range = vmap(range_fn)(dists, ranges)
        target = partial(target_fn, acts, in_range)
        attack = {"ally": target("ally"), "enemy": target("enemy")}
        attacks.append(attack)
    return attacks


def bullet_fn(state_seq, bullet_seq=[]):
    attack_seq = attack_fn(state_seq)

    def aux_fn(team):
        bullets = jnp.stack(jnp.where(one_hot[team] == 1)).T
        bullets = bullets.at[:, 2 if team == "ally" else 1].add(env.num_allies)
        return bullets

    for idx, (_, state, _) in tqdm(enumerate(state_seq), total=len(state_seq)):
        one_hot = attack_seq[idx]
        bullets = jnp.concatenate(list(map(aux_fn, ["ally", "enemy"])), axis=0)
        bullets_source = state.unit_positions[bullets[:, 0], bullets[:, 1], ...]
        bullets_target = state.unit_positions[bullets[:, 0], bullets[:, 2], ...]
        bullets = jnp.concatenate([bullets, bullets_source, bullets_target], axis=1)
        bullet_seq.append(bullets)
    return bullet_seq

# plotting

In [23]:
def plot_fn(env, state_seq, reward_seq, expand=False):
    bullet_seq = bullet_fn(state_seq) if expand else None
    state_seq = state_seq if not expand else vmap(env.expand_state_seq)(state_seq)
    frames, returns = [], return_fn(reward_seq)
    unit_types = np.unique(np.array(state_seq[0][1].unit_types))
    fills = np.where(np.array(state_seq[0][1].unit_teams) == 1, ink, "None")
    for i, (_, state, action) in tqdm(enumerate(state_seq), total=len(state_seq)):
        fig, axes = plt.subplots(2, 3, figsize=(18, 12), facecolor=bg, dpi=100)
        # [[ 1.        1.       13.       16.139389 19.965769 21.064068 21.61848 ]]  # sim, source, target, source_pos, target_pos
        # [[ 1.        1.       13.       16.139389 19.965769 21.064068 21.61848 ]]  # sim, source, target, source_pos, target_pos
        bullets = bullet_seq[i // 8] if expand else None
        for j, ax in zip(range(n_envs), axes.flatten()):
            ax_fn(ax, state, returns, bullets, i, j)  # setup axis theme
            for unit_type in unit_types:
                idx = state.unit_types[j, :] == unit_type  # unit_type idxs
                x = state.unit_positions[j, idx, 0]  # x coord
                y = state.unit_positions[j, idx, 1]  # y coord
                c = fills[j, idx]  # color
                s = state.unit_health[j, idx] ** 1.5 * 0.1  # size
                ax.scatter(x, y, s=s, c=c, edgecolor=ink, marker=markers[unit_type])
        frames.append(frame_fn(fig, i // 8 if expand else i))
    fname = f"docs/figs/worlds_{bg}{'_laggy' if not expand else ''}.mp4"
    imageio.mimsave(fname, frames, fps=24 if expand else 3)


def return_fn(reward_seq):
    reward = [reward_fn(reward) for reward in reward_seq]
    ally = jnp.stack([v[0] for v in reward]).cumsum(axis=0)
    enemy = jnp.stack([v[1] for v in reward]).cumsum(axis=0)
    return {"ally": ally, "enemy": enemy}


def reward_fn(reward):
    ally_rewards = jnp.stack([v for k, v in reward.items() if k.startswith("ally")])
    enemy_rewards = jnp.stack([v for k, v in reward.items() if k.startswith("enemy")])
    return ally_rewards.sum(axis=0), enemy_rewards.sum(axis=0)


def frame_fn(fig, idx):
    title = f"step : {str(idx).zfill(len(str(n_steps - 1)))}     model : random     env : {scenario}"
    fig.text(0.01, 0.5, title, va="center", rotation="vertical", fontsize=20, color=ink)
    sublot_params = {"hspace": 0.3, "wspace": 0.3, "left": 0.05, "right": 0.95}
    plt.subplots_adjust(**sublot_params)
    fig.canvas.draw()
    frame = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8)
    frame = frame.reshape(fig.canvas.get_width_height()[::-1] + (4,))[:, :, :3]
    if idx == n_steps - 1:
        plt.savefig(f"docs/figs/worlds_{bg}.jpg", dpi=200)
    plt.close()  # close fig
    return frame


tick_params = {
    "colors": ink,
    "direction": "in",
    "length": 6,
    "width": 1,
    "which": "both",
    "top": True,
    "bottom": True,
    "left": True,
    "right": True,
    "labelleft": False,
    "labelbottom": False,
}


def ax_fn(ax, state, returns, bullets, i, j):
    if bullets is not None:
        idx = bullets[:, 0] == j
        alpha = i % 8 / 8
        # if there are any bullets pring (yay)
        if idx.any():
            print(9)
        pos = (1 - alpha) * bullets[idx, 3:5] + alpha * bullets[idx, 5:]
        ax.scatter(pos[:, 0], pos[:, 1], s=10, c=ink, marker="o")
    ally_return = returns["ally"][i, j]
    enemy_return = returns["enemy"][i, j]
    ax.set_xlabel("{:.3f} | {:.3f}".format(ally_return, enemy_return), color=ink)
    ax.set_title(f"simulation {j+1}", color=ink)
    ax.set_facecolor(bg)
    ticks = np.arange(2, 31, 4)  # Assuming your grid goes from 0 to 32
    ax.set_xticks(ticks)
    ax.set_yticks(ticks)
    ax.tick_params(**tick_params)
    ax.spines["top"].set_color(ink)
    ax.spines["bottom"].set_color(ink)
    ax.spines["left"].set_color(ink)
    ax.spines["right"].set_color(ink)
    ax.set_aspect("equal")
    ax.set_xlim(-2, 34)
    ax.set_ylim(-2, 34)

In [24]:
plot_fn(env, state_seq, reward_seq, expand=True)

100%|██████████| 48/48 [00:00<00:00, 211.39it/s]
100%|██████████| 48/48 [00:00<00:00, 412.97it/s]
 59%|█████▊    | 225/384 [00:35<01:02,  2.55it/s]

9
9


 59%|█████▉    | 227/384 [00:36<00:41,  3.80it/s]

9
9


 60%|█████▉    | 229/384 [00:36<00:30,  5.02it/s]

9
9


 60%|██████    | 231/384 [00:36<00:25,  5.94it/s]

9
9


 61%|██████    | 233/384 [00:37<00:23,  6.53it/s]

9
9


 61%|██████    | 235/384 [00:37<00:21,  6.78it/s]

9
9


 62%|██████▏   | 237/384 [00:37<00:21,  6.90it/s]

9
9


 62%|██████▏   | 239/384 [00:37<00:21,  6.78it/s]

9
9


 69%|██████▉   | 265/384 [00:41<00:16,  7.41it/s]

9
9


 70%|██████▉   | 267/384 [00:41<00:15,  7.42it/s]

9
9


 70%|███████   | 269/384 [00:42<00:15,  7.36it/s]

9
9


 71%|███████   | 271/384 [00:42<00:15,  7.28it/s]

9
9


 88%|████████▊ | 337/384 [00:52<00:06,  7.07it/s]

9
9


 88%|████████▊ | 339/384 [00:52<00:06,  7.10it/s]

9
9


 89%|████████▉ | 341/384 [00:52<00:06,  7.15it/s]

9
9


 89%|████████▉ | 343/384 [00:53<00:05,  6.93it/s]

9
9


 92%|█████████▏| 353/384 [00:54<00:04,  7.40it/s]

9
9


 92%|█████████▏| 355/384 [00:54<00:04,  7.24it/s]

9


 93%|█████████▎| 356/384 [00:56<00:14,  1.99it/s]

9
9


 93%|█████████▎| 358/384 [00:56<00:08,  3.08it/s]

9
9


 94%|█████████▍| 360/384 [00:56<00:05,  4.23it/s]

9
9


 94%|█████████▍| 362/384 [00:57<00:04,  5.16it/s]

9
9


 95%|█████████▍| 364/384 [00:57<00:03,  6.01it/s]

9
9


 95%|█████████▌| 366/384 [00:57<00:02,  6.56it/s]

9
9


 96%|█████████▌| 368/384 [00:58<00:02,  6.83it/s]

9
9
9


 96%|█████████▋| 370/384 [00:58<00:02,  6.84it/s]

9
9
9
9


 97%|█████████▋| 372/384 [00:58<00:01,  7.08it/s]

9
9
9
9


 97%|█████████▋| 374/384 [00:58<00:01,  7.34it/s]

9
9
9
9


 98%|█████████▊| 376/384 [00:59<00:01,  7.47it/s]

9
9


100%|██████████| 384/384 [01:01<00:00,  6.23it/s]


# Bullet proof

In [10]:
bullet_fn(state_seq)

100%|██████████| 48/48 [00:00<00:00, 66.63it/s]
100%|██████████| 48/48 [00:00<00:00, 67.13it/s]


[Array([], dtype=float32),
 Array([], dtype=float32),
 Array([], dtype=float32),
 Array([], dtype=float32),
 Array([], dtype=float32),
 Array([], dtype=float32),
 Array([], dtype=float32),
 Array([], dtype=float32),
 Array([], dtype=float32),
 Array([], dtype=float32),
 Array([], dtype=float32),
 Array([], dtype=float32),
 Array([], dtype=float32),
 Array([], dtype=float32),
 Array([], dtype=float32),
 Array([], dtype=float32),
 Array([], dtype=float32),
 Array([], dtype=float32),
 Array([], dtype=float32),
 Array([], dtype=float32),
 Array([], dtype=float32),
 Array([], dtype=float32),
 Array([], dtype=float32),
 Array([], dtype=float32),
 Array([], dtype=float32),
 Array([], dtype=float32),
 Array([], dtype=float32),
 Array([], dtype=float32),
 Array([[ 1.      ,  1.      , 13.      , 16.139389, 19.965769, 21.064068,
         21.61848 ]], dtype=float32),
 Array([[ 2.      , 10.      ,  6.      , 18.466928, 18.552805, 13.438275,
         15.700323]], dtype=float32),
 Array([], dtype=f