In [1]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

In [14]:
from camar import camar_v0
from camar.render import SVG_Visualizer
import jax
import jax.numpy as jnp

In [18]:
env = camar_v0(grain_factor=8)

In [19]:
obs, state = env.reset(jax.random.key(0))

In [None]:
SVG_Visualizer(env, state).save_svg("test2.svg")

In [21]:
env.max_obs, env.agent_rad, env.landmark_rad

(49, 0.08571428571428572, 0.028571428571428574)

In [33]:
for seed in range(100):
    obs, state = jax.block_until_ready(env.reset(jax.random.key(seed)))

    agent_pos = state.agent_pos
    goal_pos = state.goal_pos
    landmark_pos = state.landmark_pos

    objects = jnp.vstack((agent_pos, landmark_pos)) # (num_objects, 2)

    # (1, num_objects, 2) - (num_agents, 1, 2) -> (num_agents, num_objecst, 2)
    ego_objects = objects[None, :, :] - agent_pos[:, None, :]

    # (num_agents, num_objecst, 2) -> (num_agents, num_objecst)
    dists = jnp.linalg.norm(ego_objects, axis=-1)
    nearest_dists, nearest_ids = jax.lax.top_k(- dists, env.max_obs + 1) # (num_agents, self.max_obs + 1)

    if not jnp.allclose(nearest_dists[:, 0], 0):
        print(seed)

In [48]:
jnp.all(jnp.sort(nearest_dists, axis=1, descending=True) == nearest_dists)

Array(True, dtype=bool)