# C2SIM

In [1]:
import jax
import chex
import ollama
from jax import numpy as jnp, jit, vmap, random
from jaxmarl import make
from jaxmarl.environments.smax import map_name_to_scenario

from einops import rearrange
from functional import partial
from tqdm import tqdm

from src import plot_fn

## trajectories

In [2]:
# globals
n_envs = 10
n_steps = 100
scenario = "3s5z_vs_3s6z"
num_allies = 10
num_enemies = 10

In [3]:
def step_fn(rng, old_state_v, env, obs_v):
    rng, act_rng, step_rng = random.split(rng, 3)
    act_keys = random.split(act_rng, env.num_agents * n_envs).reshape(
        env.num_agents, n_envs, -1
    )
    step_keys = random.split(step_rng, n_envs)
    actions = {
        a: action_fn(env, act_keys[i], obs_v[a], a) for i, a in enumerate(env.agents)
    }
    obs_v, state_v, reward_v, _, _ = vmap(env.step)(step_keys, old_state_v, actions)
    # old_state_v = old_state_v if env.name == "SMAX" else old_state_v.state
    return obs_v, (rng, state_v), (step_keys, old_state_v, actions), reward_v


@partial(vmap, in_axes=(None, 0, 0, None))
def action_fn(env, rng, obs, a):
    return env.action_space(a).sample(rng)


def traj_fn(rng, env):
    rng, reset_rng = random.split(random.PRNGKey(0))
    reset_keys = random.split(reset_rng, n_envs)
    obs_v, state_v = vmap(env.reset)(reset_keys)
    traj_state = (rng, state_v)
    state_seq, reward_seq = [], []
    for _ in tqdm(range(n_steps)):
        step = partial(step_fn, env=env, obs_v=obs_v)
        obs_v, traj_state, state_v, reward_v = step(*traj_state)
        state_seq, reward_seq = state_seq + [state_v], reward_seq + [reward_v]
    return state_seq, reward_seq

In [4]:
# scenario=map_name_to_scenario(scenario))  # TODO: allow HeuristicSMAX
env = make("SMAX", num_allies=num_allies, num_enemies=num_enemies)
rng, key = random.split(random.PRNGKey(0))
state_seq, reward_seq = traj_fn(key, env)
plot_fn(env, state_seq, reward_seq, expand=True)

100%|██████████| 100/100 [00:03<00:00, 29.46it/s]
100%|██████████| 100/100 [00:00<00:00, 133.46it/s]
100%|██████████| 800/800 [01:45<00:00,  7.58it/s]


## language model

In [5]:
model = "mistral:7b-instruct-q2_K"
template = """"
return a .yaml behavior tree for a StarCraft II agent.
Return no other text.
"""
response = ollama.chat(
    model=model,
    messages=[
        {
            "role": "user",
            "content": template,
        },
    ],
)
print(response["message"]["content"])

```yaml
version: "1.0"
actions:
- name: Select_Unit
  description: Selects a unit of the specified type.
  parameters:
    - name: unitType
      in: int
      required: true
      desc: The type of unit to select.
- name: Move_Unit
  description: Moves a unit to a specified location.
  parameters:
    - name: x
      in: float
      required: true
      desc: The x coordinate of the destination.
    - name: y
      in: float
      required: true
      desc: The y coordinate of the destination.
- name: Attack_Unit
  description: Attacks a specified enemy unit.
  parameters:
    - name: targetType
      in: int
      required: true
      desc: The type of the enemy unit to attack.
    - name: targetID
      in: int
      required: true
      desc: The ID of the enemy unit to attack.

tasks:
- name: Select_And_Move_Unit
  description: Selects a unit of the specified type and moves it to a specified location.
  requirements:
    - Select_Unit
    - Move_Unit
  parameters:
    - name: unit