In [None]:
from functools import partial

import jax
import mediapy
from utils import expert_step, run_and_log_scenario
from waymax import dynamics

from vmax import PATH_TO_PROJECT
from vmax.simulator import datasets, make_data_generator, make_env, vizualization


%load_ext autoreload
%autoreload 2

In [None]:
MAX_NUM_OBJECTS = 64
INCLUDE_SDC_PATHS = True
SEED = 0
path_dataset = PATH_TO_PROJECT + "/" + datasets.LOCAL_WOMD


data_generator = make_data_generator(
    path=path_dataset,
    max_num_objects=MAX_NUM_OBJECTS,
    include_sdc_paths=INCLUDE_SDC_PATHS,
    seed=SEED,
)


env = make_env(
    max_num_objects=MAX_NUM_OBJECTS,
    dynamics_model=dynamics.InvertibleBicycleModel(normalize_actions=True),
    observation_type="road",
    observation_config={
        "obs_past_num_steps": 5,
        "objects": {
            "features": ["waypoints", "velocity", "yaw", "size", "valid"],
            "num_closest_objects": 8,
        },
        "roadgraphs": {
            "features": ["waypoints", "direction", "types", "valid"],
            "meters_box": {"front": 50, "back": 10, "left": 15, "right": 15},
            "max_meters": 50,
            "roadgraph_top_k": 256,
            "interval": 1,
        },
        "traffic_lights": {
            "features": ["waypoints", "state", "valid"],
            "num_closest_traffic_lights": 3,
        },
        "path_target": {
            "features": ["waypoints"],
            "num_points": 10,
            "points_gap": 5,
        },
    },
)

In [None]:
scenario = next(data_generator)
simulator_state = env.reset(scenario)

In [None]:
img = vizualization.plot_input_agent(simulator_state, env)
mediapy.show_image(img)

In [None]:
_step = partial(expert_step, env)
_jitted_expert_step = jax.jit(_step)

In [None]:
for _i in range(5):
    scenario = next(data_generator)
    imgs = run_and_log_scenario(env, scenario, _jitted_expert_step)
    mediapy.show_video(imgs, fps=10)