In [1]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

In [2]:
import jax
import jax.numpy as jnp
from camar import camar_v0
from rrt import RRT
import pandas as pd
import yaml
from tqdm import trange

from camar.render import SVG_Visualizer

In [3]:
yaml_name = "random_grid_h20_w20_a8_o60"
with open(f"BenchMARL/benchmarl/conf/task/camar/{yaml_name}.yaml") as stream:
    try:
        env_kwargs = yaml.safe_load(stream)
    except yaml.YAMLError as exc:
        print(exc)
env_kwargs["map_generator"] = "_".join(yaml_name.split("_")[:2])

In [4]:
num_envs = 100
rrt_iters = 50_000
visualize = False

In [5]:
Kp = 1.7
Kd = 0.05

@jax.jit
def pd_control(target_pos, pos, vel):
    error = target_pos - pos
    # Assuming desired velocity is zero at intermediate points,
    # except for final point if you want a nonzero velocity.

    # target_vel = target_pos - pos
    # target_vel = target_vel / (jnp.linalg.norm(target_vel) + 1e-8)

    target_vel = 0.0

    error = error / (jnp.linalg.norm(error) + 1e-8)

    error_vel = target_vel - vel
    # Simple PD control law
    force = Kp * error + Kd * error_vel
    return force

In [6]:
key = jax.random.key(6)

In [7]:
env = camar_v0(**env_kwargs)
rrt_ = RRT(env, rrt_iters, env.agent_rad)

success_rate_s = []
flowtime_s = []
makespan_s = []
coordination_s = []
for env_id in trange(num_envs):
    key, _key = jax.random.split(key, 2)
    obs, state = env.reset(_key)

    # print(env_id, "run rrt")
    rrt_state = rrt_.run(_key, state.agent_pos, state.goal_pos, state.landmark_pos)
    last_idx = rrt_.find_last_idx(rrt_state)

    path_ids = []
    for agent in range(env.num_agents):
        agent_path = []
        i = int(last_idx[agent])
        if i == -1:
            agent_path.append(0)
        else:
            while i != -1:
                agent_path.append(i)
                i = int(rrt_state.parent[i, agent])
        path_ids.append(agent_path[::-1])

    current_waypoint_idx = [0 for _ in range(env.num_agents)]
    tolerance = env.goal_rad

    if visualize:
        state_seq = [state]
    # print("run steps")
    for _ in range(env.max_steps + 1):
        target_pos = []
        for agent, cur_idx in enumerate(current_waypoint_idx):
            target_pos.append(rrt_state.pos[path_ids[agent][cur_idx], agent])
        target_pos = jnp.stack(target_pos)

        # Compute the control force (this could be LQR instead)
        force = pd_control(target_pos, state.agent_pos, state.agent_vel)
        # print(force)

        _key, key_s = jax.random.split(_key)

        obs, state, reward, done, _ = env.step(key_s, state, force)

        if visualize:
            state_seq.append(state)

        if done:
            break

        cur_dist = jnp.linalg.norm(target_pos - state.agent_pos, axis=-1)

        for agent, dist in enumerate(cur_dist):
            if dist < tolerance:
                if current_waypoint_idx[agent] < len(path_ids[agent]) - 1:
                    current_waypoint_idx[agent] += 1

    if visualize:
        break

    success_rate = state.on_goal.mean()
    flowtime = state.time_to_reach_goal.sum()
    makespan = state.time_to_reach_goal.max()
    coordination =  1 - (state.num_collisions / state.step).mean()

    success_rate_s.append(float(success_rate))
    flowtime_s.append(float(flowtime))
    makespan_s.append(float(makespan))
    coordination_s.append(float(coordination))

if not visualize:
    df_metrics = pd.DataFrame({
    "success_rate": success_rate_s,
    "flowtime": flowtime_s,
    "makespan": makespan_s,
    "coordination": coordination_s,
    })
    df_metrics.to_csv(f"metrics_rrt/{yaml_name}.csv")

  0%|          | 0/100 [00:00<?, ?it/s]

100%|██████████| 100/100 [39:27<00:00, 23.67s/it]


In [8]:
rrt_state.goal_reached

Array([ True, False,  True,  True, False,  True,  True, False], dtype=bool)

In [9]:
if visualize:
    SVG_Visualizer(env, state_seq, animate_landmarks=False).save_svg("test_rrt.svg")