# C2SIM

In [41]:
import jax
import chex
import ollama
from jax import numpy as jnp, jit, vmap, random
from jaxmarl import make
from jaxmarl.environments.smax import map_name_to_scenario

from einops import rearrange
from functional import partial
from tqdm import tqdm
from typing import Callable, Any, Tuple, Dict, List
import yaml

from src import plot_fn

In [42]:
# globals
model = "mistral:7b-instruct-q2_K"
n_envs = 100
n_steps = 100
scenario = "3s5z_vs_3s6z"
num_allies = 10
num_enemies = 10

## behavior tree

In [44]:
with open("data/bt.yaml", "r") as f:
    bt = yaml.load(f, Loader=yaml.FullLoader)

In [45]:
@chex.dataclass
class Status:
    SUCCESS: int = 0
    FAILURE: int = 1
    RUNNING: int = 2

In [46]:
# types
NodeFunc = Callable[[Any], Tuple[Status, Any]]

In [47]:
# Sequence node
def sequence(children: List[NodeFunc]) -> NodeFunc:
    def tick(state: Any) -> Tuple[Status, Any]:
        for child in children:  # iterate over children (which are nodes themselves)
            status, state = child(state)
            if status != Status.SUCCESS:
                return status, state
        return Status.SUCCESS, state

    return tick


# Selector node
def selector(children: List[NodeFunc]) -> NodeFunc:
    def tick(state: Any) -> Tuple[Status, Any]:
        for child in children:
            status, state = child(state)
            if status == Status.SUCCESS:
                return status, state
        return Status.FAILURE, state

    return tick


# Action node
def action(fn: Callable) -> NodeFunc:
    def tick(state: Any) -> Tuple[Status, Any]:
        return fn(state), state

    return tick

In [48]:
# Tree


def tree(f_name: str) -> NodeFunc:
    with open(f"data/{f_name}.yaml", "r") as f:
        tree = yaml.load(f, Loader=yaml.FullLoader)

    def build(node: Dict) -> NodeFunc:
        if node["type"] == "sequence":
            return sequence([build(child) for child in node["children"]])
        elif node["type"] == "selector":
            return selector([build(child) for child in node["children"]])
        elif node["type"] == "action" or node["type"] == "condition":
            return action(globals()[node["name"]])  # a little hacky
        elif node["type"] == "condition":
            return condition(globals()[node["name"]])
        else:
            raise ValueError(f"Unknown node type: {node['type']}")

    return build(tree)


# Run
tree("bt")

{'type': 'selector',
 'children': [{'type': 'sequence',
   'children': [{'type': 'action', 'name': 'go_left'}]},
  {'type': 'action', 'name': 'wander'}]}

## language model

## smax trajectories

In [49]:
def step_fn(rng, old_state_v, env, obs_v):  # take action and return state / reward info
    rng, act_rng, step_rng = random.split(rng, 3)
    act_keys = random.split(act_rng, env.num_agents * n_envs).reshape(-1, n_envs, 2)
    step_keys = random.split(step_rng, n_envs)
    # TODO: currently actions are taken in paralel across agents, but not with envs. fix?
    acts = {a: act_fn(env, act_keys[i], obs_v[a], a) for i, a in enumerate(env.agents)}
    obs_v, state_v, reward_v, _, _ = vmap(env.step)(step_keys, old_state_v, acts)
    return obs_v, (rng, state_v), (step_keys, old_state_v, acts), reward_v


@partial(vmap, in_axes=(None, 0, 0, None))  # actions for parallel copies of agent
def act_fn(env, rng, obs, agent):  # take action for a given agent
    return env.action_space(agent).sample(rng)


def traj_fn(rng, env):  # runs parallel trajectories and returns state seqs.
    rng, reset_rng = random.split(random.PRNGKey(0))
    reset_keys = random.split(reset_rng, n_envs)
    obs_v, state_v = vmap(env.reset)(reset_keys)
    traj_state = (rng, state_v)
    state_seq, reward_seq = [], []
    for _ in tqdm(range(n_steps)):  # take n steps in env and append to lists
        step = partial(step_fn, env=env, obs_v=obs_v)
        obs_v, traj_state, state_v, reward_v = step(*traj_state)
        state_seq, reward_seq = state_seq + [state_v], reward_seq + [reward_v]
        # probe_fn(state_seq)
    return state_seq, reward_seq

In [50]:
# scenario=map_name_to_scenario(scenario))  # TODO: allow HeuristicSMAX
env = make("SMAX", num_allies=num_allies, num_enemies=num_enemies)
rng, key = random.split(random.PRNGKey(0))
state_seq, reward_seq = traj_fn(key, env)
# plot_fn(env, state_seq, reward_seq, expand=True)

100%|██████████| 100/100 [00:04<00:00, 23.49it/s]


## language model