In [7]:
# automatically reload the package when it is modified
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Test Multiagent Simulation    

In [8]:
import jax
from jax import numpy as jnp
import numpy as np
import mediapy
from tqdm import tqdm
import dataclasses
import random

from waymax import config as _config
from waymax import dataloader
from waymax import datatypes
from waymax import dynamics
from waymax import env as _env
from waymax import agents
from waymax import visualization
from rl_env.waymax_env import womd_loader, MultiAgentEnvironment
max_num_objects = 32
from rl_env.sim_agent_actor import SimAgentMTR

from mtr.config import cfg, cfg_from_yaml_file
cfg = cfg_from_yaml_file('tools/cfgs/waymo/rl_finetune.yaml', cfg)

In [9]:
# create a dataset
WOMD_1_2_0_VAL_LOCAL = _config.DatasetConfig(
    path='/Data/Dataset/Waymo/V1_2_tf/validation_interactive/validation_interactive_tfexample.tfrecord@150',
    # path='/Data/Dataset/Waymo/V1_2_tf/validation/validation_tfexample.tfrecord@150',
    max_num_rg_points=30000,
    data_format=_config.DataFormat.TFRECORD,
    max_num_objects=32,
    shuffle_seed = 0,
)

data_iter = womd_loader(data_config=WOMD_1_2_0_VAL_LOCAL)


In [10]:

# Config the multi-agent environment:
init_steps = 11

# Set the dynamics model the environment is using.
# Note each actor interacting with the environment needs to provide action
# compatible with this dynamics model.
dynamics_model = dynamics.DeltaLocal()

# Expect users to control all valid object in the scene.
env = MultiAgentEnvironment(
    dynamics_model=dynamics_model,
    config=dataclasses.replace(
        _config.EnvironmentConfig(),
        init_steps = init_steps,
        max_num_objects=max_num_objects,
        controlled_object=_config.ObjectType.MODELED,
        rewards = _config.LinearCombinationRewardConfig(
            rewards={
                'overlap': 1.0, # Positive is good.  
                'offroad': 1.0, # Negative is good.
                'kinematics': 1.0, # Negative is good.
            }
        )
    ),
)

In [11]:
@jax.jit
def check_controlled(state: datatypes.SimulatorState):
    current_stamp = state.timestep
    is_valid = state.sim_trajectory.valid[..., current_stamp]
    
    is_modeled = jnp.logical_or(
        state.object_metadata.is_modeled,
        state.object_metadata.is_sdc
    )
    
    is_modeled = is_valid
    is_vehicle = (state.object_metadata.object_types == 1)
    
    # print(is_valid)
    # print(is_modeled)
    # print(is_vehicle)
    
    return jnp.logical_and(is_valid, jnp.logical_and(is_modeled, is_vehicle))

In [12]:
# Setup a few actors, see visualization below for how each actor behaves.

# An actor that doesn't move, controlling all objects with index > 4
obj_idx = jnp.arange(max_num_objects)
actor_gt = agents.create_expert_actor(
    dynamics_model=dynamics_model,
    is_controlled_func=lambda state: ~check_controlled(state),
)

actor_policy = SimAgentMTR(
    model_config= cfg.MODEL,
    # model_path = 'output/bc_atten_4_freeze/epoch=9-step=302840.ckpt',
    model_path= 'output/bc_atten_4_unfreeze_state/epoch=19-step=1211340.ckpt',
    # model_path= 'output/bc_atten_2_freeze_bicycle/epoch=12-step=393692.ckpt',
    # model_path = 'output/bc_atten_2_unfreeze_bicycle/epoch=10-step=333124.ckpt',
    # model_path = 'output/bc_atten_2_freeze/epoch=49-step=152600.ckpt',
    # model_path = 'output/bc_atten_0_freeze/epoch=31-step=97664.ckpt',
    # model_path = 'output/bc_atten_0_unfreeze/epoch=31-step=97664.ckpt',
    is_controlled_func = check_controlled,
)

actors = [actor_policy, actor_gt]


jit_step = jax.jit(env.step)
jit_select_action_list = [actor.select_action for actor in actors]

==> Loading parameters from checkpoint output/bc_atten_4_unfreeze_state/epoch=19-step=1211340.ckpt to CPU
==> Checkpoint trained from version: None
The number of disk ckpt keys: 487


In [13]:
scenario_id, scenario = next(data_iter)
scenario : datatypes.SimulatorState
img = visualization.plot_simulator_state(scenario, use_log_traj=True, highlight_obj = _config.ObjectType.MODELED)
mediapy.show_image(img)
print(jnp.where(scenario.object_metadata.is_modeled))


Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089


(Array([0, 1], dtype=int32),)


In [14]:
# get a scenario from the dataset
# scenario : datatypes.SimulatorState = next(data_iter)
states = [env.reset(scenario)]
for _ in range(states[0].remaining_timesteps):
  current_state = states[-1]

  gt_output = actor_gt.select_action({}, current_state, None, None)
  policy_output = actor_policy.select_action({}, current_state, None, None)
  action = agents.merge_actions([policy_output, gt_output])
  next_state = jit_step(current_state, action)
  
  # print(env.reward(current_state, action))
  states.append(next_state)
  
print("Sim Agent Ids:", jnp.where(policy_output.is_controlled))

Sim Agent Ids: (Array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
       17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],      dtype=int32),)


In [15]:
imgs = []
for state in states:
  imgs.append(visualization.plot_simulator_state(state, use_log_traj=False, highlight_obj = _config.ObjectType.MODELED))
mediapy.show_video(imgs, fps=10)

0
This browser does not support the video tag.


In [16]:
# mediapy.show_image(visualization.plot_simulator_state(states[60], use_log_traj=False, highlight_obj = _config.ObjectType.MODELED))