# LLMIL

In [207]:
import jax
from jax import numpy as jnp, jit, vmap, random
import chex
from jaxmarl import make
from jaxmarl.environments.smax import map_name_to_scenario
from einops import rearrange
from functional import partial

from tqdm import tqdm
import numpy as np
from matplotlib import rcParams
import matplotlib.pyplot as plt
import darkdetect, imageio

In [208]:
# globals
rcParams["font.family"] = "monospace"
rcParams["font.monospace"] = "Fira Code"
bg = "black" if darkdetect.isDark() else "white"
ink = "white" if bg == "black" else "black"
markers = {0: "o", 1: "s", 2: "D", 3: "^", 4: "<", 5: ">", 6: "+"}

## Classes

In [209]:
@chex.dataclass
class Config:
    env_name: str = "smax"
    scenario: str = "simple_wood_and_stone"

# Vanilla

In [210]:
# globals
n_envs = 6
n_steps = 100
scenario = "3s5z_vs_3s6z"

In [211]:
# iter
def step_fn(rng, env, obs_v, old_state_v):
    rng, act_rng, step_rng = random.split(rng, 3)
    act_keys = random.split(act_rng, env.num_agents * n_envs).reshape(
        env.num_agents, n_envs, -1
    )
    step_keys = random.split(step_rng, n_envs)
    actions = {
        a: action_fn(env, act_keys[i], obs_v[a], a) for i, a in enumerate(env.agents)
    }
    obs_v, state_v, reward_v, _, _ = vmap(env.step)(step_keys, old_state_v, actions)
    return (rng, env, obs_v, state_v), (step_keys, old_state_v, actions), reward_v


@partial(vmap, in_axes=(None, 0, 0, None))
def action_fn(env, rng, obs, a):
    # get observation, ignore it and rake random action.
    # next step is to call some bt stuff.
    return env.action_space(a).sample(rng)


def traj_fn(rng, env):
    rng, reset_rng = random.split(random.PRNGKey(0))
    reset_keys = random.split(reset_rng, n_envs)
    obs_v, state_v = vmap(env.reset)(reset_keys)
    traj_state = (rng, env, obs_v, state_v)
    state_seq, reward_seq = [], []
    for step in tqdm(range(n_steps)):
        traj_state, state_v, reward_v = step_fn(*traj_state)
        state_seq, reward_seq = state_seq + [state_v], reward_seq + [reward_v]
    return state_seq, reward_seq

In [212]:
env = make("SMAX", scenario=map_name_to_scenario(scenario))
rng, key = random.split(random.PRNGKey(0))
state_seq, reward_seq = traj_fn(key, env)

100%|██████████| 100/100 [00:04<00:00, 21.77it/s]


In [216]:
def plot_fn(env, state_seq, reward_seq, expand=False):
    """
    given an environment, a state_seq, and a reward seq
    this function makes a small multiples plot of the trejectory.
    It save it as a .mp4 film in docs/figs along with an image of the last frame.
    """
    state_seq = state_seq if not expand else vmap(env.expand_state_seq)(state_seq)
    frames, returns = [], return_fn(reward_seq)
    unit_types = np.unique(np.array(state_seq[0][1].unit_types))
    fills = np.where(np.array(state_seq[0][1].unit_teams) == 1, ink, "None")
    for i, (_, state, action) in tqdm(enumerate(state_seq), total=len(state_seq)):
        fig, axes = plt.subplots(2, 3, figsize=(18, 12), facecolor=bg, dpi=100)
        for j, ax in zip(range(n_envs), axes.flatten()):
            ax_fn(ax, state, returns, i, j)  # setup axis theme
            for unit_type in unit_types:
                idx = state.unit_types[j, :] == unit_type  # unit_type idxs
                x = state.unit_positions[j, idx, 0]  # x coord
                y = state.unit_positions[j, idx, 1]  # y coord
                c = fills[j, idx]  # color
                s = state.unit_health[j, idx] ** 1.5 * 0.2  # size
                ax.scatter(x, y, s=s, c=c, edgecolor=ink, marker=markers[unit_type])
        frames.append(frame_fn(fig, i // 8 if expand else i))
    fname = f"docs/figs/worlds_{bg}{'_laggy' if not expand else ''}.mp4"
    imageio.mimsave(fname, frames, fps=24 if expand else 3)


def return_fn(reward_seq):
    reward = [reward_fn(reward) for reward in reward_seq]
    ally = jnp.stack([v[0] for v in reward]).cumsum(axis=0)
    enemy = jnp.stack([v[1] for v in reward]).cumsum(axis=0)
    return {"ally": ally, "enemy": enemy}


def reward_fn(reward):
    ally_rewards = jnp.stack([v for k, v in reward.items() if k.startswith("ally")])
    enemy_rewards = jnp.stack([v for k, v in reward.items() if k.startswith("enemy")])
    return ally_rewards.sum(axis=0), enemy_rewards.sum(axis=0)


def frame_fn(fig, idx):
    title = f"step : {str(idx).zfill(len(str(n_steps - 1)))}     model : random     env : {scenario}"
    fig.text(0.01, 0.5, title, va="center", rotation="vertical", fontsize=20, color=ink)
    sublot_params = {"hspace": 0.3, "wspace": 0.3, "left": 0.05, "right": 0.95}
    plt.subplots_adjust(**sublot_params)
    fig.canvas.draw()
    frame = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8)
    frame = frame.reshape(fig.canvas.get_width_height()[::-1] + (4,))[:, :, :3]
    if idx == n_steps - 1:
        plt.savefig(f"docs/figs/worlds_{bg}.jpg", dpi=200)
    plt.close()  # close fig
    return frame


tick_params = {
    "colors": ink,
    "direction": "in",
    "length": 6,
    "width": 1,
    "which": "both",
    "top": True,
    "bottom": True,
    "left": True,
    "right": True,
    "labelleft": False,
    "labelbottom": False,
}


def ax_fn(ax, state, returns, i, j):
    ally_return = returns["ally"][i, j]
    enemy_return = returns["enemy"][i, j]
    ax.set_xlabel("{:.3f} | {:.3f}".format(ally_return, enemy_return), color=ink)
    ax.set_title(f"simulation {i+1}", color=ink)
    ax.set_facecolor(bg)
    ticks = np.arange(2, 31, 4)  # Assuming your grid goes from 0 to 32
    ax.set_xticks(ticks)
    ax.set_yticks(ticks)
    ax.tick_params(**tick_params)
    ax.spines["top"].set_color(ink)
    ax.spines["bottom"].set_color(ink)
    ax.spines["left"].set_color(ink)
    ax.spines["right"].set_color(ink)
    ax.set_aspect("equal")
    ax.set_xlim(-2, 34)
    ax.set_ylim(-2, 34)

In [217]:
plot_fn(env, state_seq, reward_seq, expand=False)

100%|██████████| 100/100 [00:15<00:00,  6.57it/s]
[rawvideo @ 0x158639640] Stream #0: not enough frames to estimate rate; consider increasing probesize


# Bullet proof

In [282]:
def dist_fn(pos):  # computing the distances between all ally and enemy agents
    delta = pos[None, :, :] - pos[:, None, :]
    dist = jnp.sqrt((delta**2).sum(axis=2))
    dist = dist[: env.num_allies, env.num_allies :]
    return {"ally": dist, "enemy": dist.T}


def range_fn(dists, ranges):  # computing what targets are in range
    ally_range = dists["ally"] < ranges[: env.num_allies][:, None]
    enemy_range = dists["enemy"] < ranges[env.num_allies :][:, None]
    return {"ally": ally_range, "enemy": enemy_range}


def target_fn(acts, in_range, team):  # computing the one hot valid targets
    t_acts = jnp.stack([v for k, v in acts.items() if k.startswith(team)]).T
    t_targets = jnp.where(t_acts - 5 < 0, -1, t_acts - 5)  # 5 is the attack range
    t_attacks = jnp.eye(in_range[team].shape[2] + 1)[t_targets][:, :, :-1]
    return t_attacks


@jit
def attack_fn(state_seq, attacks=[]):  # one hot attack list [(team, sim, ally, enemy)]
    for _, state, acts in tqdm(state_seq):
        dists = vmap(dist_fn)(state.unit_positions)
        ranges = env.unit_type_attack_ranges[state.unit_types]
        in_range = vmap(range_fn)(dists, ranges)
        target = partial(target_fn, acts, in_range)
        attack = {"ally": target("ally"), "enemy": target("enemy")}
        attacks.append(attack)
    return attacks


attacks = attack_fn(state_seq)

100%|██████████| 100/100 [00:00<00:00, 160.56it/s]


In [286]:
def bullet_fn(state_seq, attacks, bullet_seq=[]):  # MUST BE EXPANDED STATE_SEQ!!!
    for attack in tqdm(attacks):
        ally = jnp.stack(jnp.where(attack["ally"] == 1)).T
        enemy = jnp.stack(jnp.where(attack["enemy"] == 1)).T
        bullets = jnp.concatenate([ally, enemy], axis=0)
        # get bullet positions from ally and enemy positions
        print(bullets.shape)


bullets = bullet_fn(state_seq, attacks)

100%|██████████| 100/100 [00:00<00:00, 647.57it/s]

(65, 3)
(55, 3)
(63, 3)
(63, 3)
(69, 3)
(65, 3)
(62, 3)
(56, 3)
(69, 3)
(73, 3)
(64, 3)
(76, 3)
(62, 3)
(61, 3)
(62, 3)
(64, 3)
(64, 3)
(67, 3)
(72, 3)
(58, 3)
(62, 3)
(71, 3)
(61, 3)
(64, 3)
(66, 3)
(59, 3)
(63, 3)
(69, 3)
(63, 3)
(61, 3)
(68, 3)
(63, 3)
(65, 3)
(64, 3)
(65, 3)
(58, 3)
(68, 3)
(62, 3)
(62, 3)
(61, 3)
(66, 3)
(65, 3)
(64, 3)
(65, 3)
(57, 3)
(65, 3)
(63, 3)
(64, 3)
(65, 3)
(59, 3)
(65, 3)
(68, 3)
(69, 3)
(63, 3)
(68, 3)
(67, 3)
(67, 3)
(59, 3)
(69, 3)
(70, 3)
(60, 3)
(58, 3)
(65, 3)
(58, 3)
(60, 3)
(59, 3)
(54, 3)
(63, 3)
(56, 3)
(65, 3)
(67, 3)
(62, 3)
(70, 3)
(57, 3)
(58, 3)
(66, 3)
(59, 3)
(74, 3)
(57, 3)
(67, 3)
(65, 3)
(58, 3)
(65, 3)
(63, 3)
(67, 3)
(72, 3)
(60, 3)
(68, 3)
(66, 3)
(62, 3)
(64, 3)
(61, 3)
(66, 3)
(73, 3)
(60, 3)
(66, 3)
(60, 3)
(62, 3)
(64, 3)
(71, 3)



