In [2]:
import jax
from jax import numpy as jnp, jit, vmap, random
from jaxmarl import make
from jaxmarl.environments.smax import map_name_to_scenario

import ollama

import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

## Environment

In [46]:
def make_vtraj(config):  # returns a function that runs n_envs environments in parallel. Current actions are random.
    env = make(config['env'], **config['env_config'])
    config['n_agents'] = env.num_agents * config['n_envs']

    def init_runner_state(key):
        key, key_reset = random.split(key)
        key_reset      = random.split(key_reset, config['n_envs'])
        obsv, state    = vmap(env.reset)(key_reset)
        return (state, obsv, key)

    def env_step(runner_state, seqs):
        env_state, last_obs, key = runner_state   # random key for sampling actions
        key, key_act             = random.split(key)

        key_act = random.split(key_act, config['n_agents']).reshape((env.num_agents, config['n_envs'], -1))
        actions = {agent: vmap(env.action_space(agent).sample)(key_act[i]) for i, agent in enumerate(env.agents)}

        key, key_step = random.split(key)
        key_step      = random.split(key_step, config['n_envs'])

        obsv, env_state, _, _, infos = vmap(env.step)(key_step, env_state, actions)

        return (env_state, obsv, key), seqs + [env_state] # + [env_state]  # (state, obsv, reward, done, infos)

    def vtraj(key):
        key, key_init      = random.split(key)
        runner_state       = init_runner_state(key_init)
        runner_state, seqs = jax.lax.scan(env_step, runner_state, [], length=config['max_steps'])
        return runner_state, seqs

    return vtraj

In [47]:
config          = {"max_steps": 20, "n_envs": 10, "env": "HeuristicEnemySMAX", "env_config": {'scenario': map_name_to_scenario('8m')}}
vtraj           = jit(make_vtraj(config))

In [48]:
key                     = random.PRNGKey(0)
(state, obs, key), seqs = vtraj(key)
seqs

[State(state=State(unit_positions=Array([[[[ 7.0332026, 14.865866 ],
          [ 9.087245 , 17.137732 ],
          [ 5.3754845, 17.132877 ],
          ...,
          [22.892427 , 15.713569 ],
          [22.60216  , 14.782666 ],
          [22.462015 , 16.318407 ]],
 
         [[ 6.3874907, 15.461169 ],
          [ 6.7988763, 14.433478 ],
          [ 7.1029334, 15.535973 ],
          ...,
          [22.689745 , 17.088522 ],
          [21.59261  , 16.708757 ],
          [23.542906 , 16.27849  ]],
 
         [[ 9.354386 , 16.190022 ],
          [ 9.695629 , 14.113747 ],
          [ 9.8478985, 14.928266 ],
          ...,
          [22.933716 , 17.70464  ],
          [23.62261  , 14.534169 ],
          [20.70372  , 14.21932  ]],
 
         ...,
 
         [[ 7.4901304, 16.644552 ],
          [ 7.666708 , 14.93498  ],
          [ 5.7572713, 15.727257 ],
          ...,
          [21.097435 , 14.836933 ],
          [22.925898 , 17.180439 ],
          [20.74834  , 15.519462 ]],
 
         [[ 6.2

In [36]:
state.state.unit_positions.shape

(10, 16, 2)

## behavior tree

In [6]:
""" 
# Define a basic action node
def leaf_fn(state, condition):
    return jnp.array(state[0] > state[1]), state

# Define a selector node (returns success if any child succeeds)
def node_fn(nodes, state, kind='selector'):
    # success, state = [leaf(s) for leaf, s in zip(nodes, state)], []
    # vmap over nodes and states
    return success, state  # jnp.any(success), state if kind == 'selector' else jnp.all(success), state

# Example usage
def tree_fn(state):
    return node_fn([leaf_fn for _ in range(config['n_agents'])], state)
    
key, key_init  = random.split(key)
key_init       = random.split(key_init, config['n_agents'])
success, state = tree_fn(key_init)
success """

" \n# Define a basic action node\ndef leaf_fn(state, condition):\n    return jnp.array(state[0] > state[1]), state\n\n# Define a selector node (returns success if any child succeeds)\ndef node_fn(nodes, state, kind='selector'):\n    # success, state = [leaf(s) for leaf, s in zip(nodes, state)], []\n    # vmap over nodes and states\n    return success, state  # jnp.any(success), state if kind == 'selector' else jnp.all(success), state\n\n# Example usage\ndef tree_fn(state):\n    return node_fn([leaf_fn for _ in range(config['n_agents'])], state)\n    \nkey, key_init  = random.split(key)\nkey_init       = random.split(key_init, config['n_agents'])\nsuccess, state = tree_fn(key_init)\nsuccess "

In [19]:
seqs = jnp.zeros((1, 10, 10))
seq  = jnp.zeros((10, 10))
jnp.concatenate((seqs, seq[None, :]), axis=0).shape

(2, 10, 10)