# LLMIL

In [19]:
import jax
from jax import numpy as jnp, jit, vmap, random
import chex
from jaxmarl import make
from jaxmarl.environments.smax import map_name_to_scenario
from einops import rearrange
from functional import partial

from tqdm import tqdm
import numpy as np
from matplotlib import rcParams
import matplotlib.pyplot as plt
import darkdetect, imageio

In [20]:
# globals
rcParams["font.family"] = "monospace"
rcParams["font.monospace"] = "Fira Code"
bg = "black" if darkdetect.isDark() else "white"
ink = "white" if bg == "black" else "black"
markers = {0: "o", 1: "s", 2: "D", 3: "^", 4: "<", 5: ">", 6: "+"}

## Classes

In [25]:
@chex.dataclass
class Config:
    env_name: str = "smax"
    scenario: str = "simple_wood_and_stone"

# Vanilla

In [26]:
# globals
n_envs = 6
n_steps = 100
scenario = "3s5z_vs_3s6z"

In [27]:
# iter
def step_fn(rng, env, obs_v, old_state_v):
    rng, act_rng, step_rng = random.split(rng, 3)
    act_keys = random.split(act_rng, env.num_agents * n_envs).reshape(
        env.num_agents, n_envs, -1
    )
    step_keys = random.split(step_rng, n_envs)
    actions = {
        a: action_fn(env, act_keys[i], obs_v[a], a) for i, a in enumerate(env.agents)
    }
    obs_v, state_v, reward_v, _, _ = vmap(env.step)(step_keys, old_state_v, actions)
    return (rng, env, obs_v, state_v), (step_keys, old_state_v, actions), reward_v


@partial(vmap, in_axes=(None, 0, 0, None))
def action_fn(env, rng, obs, a):
    # get observation, ignore it and rake random action.
    # next step is to call some bt stuff.
    return env.action_space(a).sample(rng)


def traj_fn(rng, env):
    rng, reset_rng = random.split(random.PRNGKey(0))
    reset_keys = random.split(reset_rng, n_envs)
    obs_v, state_v = vmap(env.reset)(reset_keys)
    traj_state = (rng, env, obs_v, state_v)
    state_seq, reward_seq = [], []
    for step in tqdm(range(n_steps)):
        traj_state, state_v, reward_v = step_fn(*traj_state)
        state_seq, reward_seq = state_seq + [state_v], reward_seq + [reward_v]
    return state_seq, reward_seq

In [28]:
env = make("SMAX", scenario=map_name_to_scenario(scenario))
rng, key = random.split(random.PRNGKey(0))
state_seq, reward_seq = traj_fn(key, env)

100%|██████████| 100/100 [00:03<00:00, 31.16it/s]


In [29]:
def plot_fn(env, state_seq, reward_seq, expand=False):
    """
    given an environment, a state_seq, and a reward seq
    this function makes a small multiples plot of the trejectory.
    It save it as a .mp4 film in docs/figs along with an image of the last frame.
    """
    state_seq = (
        state_seq if not expand else vmap(env.expand_state_seq)(state_seq)
    )  # expand states for smooth animation
    frames = []  # frames to make .mp4
    unit_types = np.unique(
        np.array(state_seq[0][1].unit_types)
    )  # unit types for plotting markers
    for jdx, (_, state, action) in tqdm(enumerate(state_seq), total=len(state_seq)):
        reward = reward_seq[jdx // 8 if expand else jdx]
        fills = np.where(np.array(state.unit_teams) == 1, ink, "None")  # team colors
        fig, axes = plt.subplots(
            2, 3, figsize=(17.92, 12.16), facecolor=bg, dpi=100
        )  # frame in .mp4 film
        for idx, ax in zip(range(n_envs), axes.flatten()):
            ax_fn(ax, state, reward, idx)  # setup axis theme
            for unit_type in unit_types:
                kdx = state.unit_types[idx, :] == unit_type  # unit_type idxs
                x = state.unit_positions[idx, kdx, 0]  # x coord
                y = state.unit_positions[idx, kdx, 1]  # y coord
                c = fills[idx, kdx]  # color
                s = state.unit_health[idx, kdx] ** 1.5 * 0.2  # size
                ax.scatter(x, y, s=s, c=c, edgecolor=ink, marker=markers[unit_type])
        frames.append(frame_fn(fig, jdx // 8 if expand else jdx))
    imageio.mimsave(
        f'docs/figs/worlds_{bg}{"_laggy" if not expand else ""}.mp4',
        frames,
        fps=24 if expand else 3,
    )


def frame_fn(fig, idx):
    title = f"step : {str(idx).zfill(len(str(n_steps - 1)))}     model : random     env : {scenario}"
    fig.text(0.01, 0.5, title, va="center", rotation="vertical", fontsize=20, color=ink)
    plt.subplots_adjust(
        left=0.05, bottom=0.05, right=0.95, top=0.95, wspace=0.3, hspace=0.3
    )
    fig.canvas.draw()
    frame = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8)
    frame = frame.reshape(fig.canvas.get_width_height()[::-1] + (4,))[:, :, :3]
    if idx == n_steps - 1:
        plt.savefig(f"docs/figs/worlds_{bg}.jpg", dpi=200)
    plt.close()  # close fig
    return frame


def ax_fn(
    ax, state, reward, idx
):  # TODO: add augmentation functions to put dynamic info/metrics around subplots
    ally_rewards = sum([v[idx] for k, v in reward.items() if k.startswith("ally")])
    enemy_rewards = sum([v[idx] for k, v in reward.items() if k.startswith("enemy")])
    ax.set_xlabel("{:.3f} | {:.3f}".format(ally_rewards, enemy_rewards), color=ink)
    ax.set_title(f"simulation {idx+1}", color=ink)
    ax.set_facecolor(bg)
    ticks = np.arange(2, 31, 4)  # Assuming your grid goes from 0 to 32
    ax.set_xticks(ticks)
    ax.set_yticks(ticks)
    ax.tick_params(
        colors=ink,
        direction="in",
        length=6,
        width=1,
        which="both",
        top=True,
        bottom=True,
        left=True,
        right=True,
        labelleft=False,
        labelbottom=False,
    )
    ax.spines["top"].set_color(ink)
    ax.spines["bottom"].set_color(ink)
    ax.spines["left"].set_color(ink)
    ax.spines["right"].set_color(ink)
    ax.set_aspect("equal")
    ax.set_xlim(-2, 34)
    ax.set_ylim(-2, 34)

In [30]:
plot_fn(env, state_seq, reward_seq, expand=True)

100%|██████████| 800/800 [02:05<00:00,  6.37it/s]


# Bullet proof

In [144]:
def dist_fn(pos):  # computing the distances between all ally and enemy agents
    delta = pos[None, :, :] - pos[:, None, :]  # get all the delta vectors
    dist = jnp.sqrt((delta**2).sum(axis=2))  # dists between every agent
    dist = dist[: env.num_allies, env.num_allies :]  # row is ally, col is enemy
    return {"ally": dist, "enemy": dist.T}  # rember dist is associativ


def range_fn(dists, ranges):
    ally_range = dists["ally"] < ranges[: env.num_allies][:, None]
    enemy_range = dists["enemy"] < ranges[env.num_allies :][:, None]
    return {"ally": ally_range, "enemy": enemy_range}


def target_fn(acts, in_range, team):
    t_acts = jnp.stack([v for k, v in acts.items() if k.startswith(team)]).T
    t_targets = jnp.where(t_acts - 5 < 0, -1, t_acts - 5)  # 5 is the attack range
    t_attacks = jnp.eye(in_range[team].shape[2] + 1)[t_targets][:, :, :-1]
    return t_attacks


def bullet_fn(acts, in_range):
    target = partial(target_fn, acts, in_range)
    return {"ally": target("ally"), "enemy": target("enemy")}


@jit
def attack_fn(state, acts):
    dists = vmap(dist_fn)(state.unit_positions)  # compute distances
    ranges = env.unit_type_attack_ranges[state.unit_types]  # get attack ranges
    in_range = vmap(range_fn)(dists, ranges)  # compute in range
    attacks = bullet_fn(acts, in_range)  # compute attacks
    return attacks


for _, state, acts in tqdm(state_seq):  # TODO: precompute
    bullets = attack_fn(state, acts)

100%|██████████| 100/100 [00:00<00:00, 2107.10it/s]
