In [3]:
import jax
from jax import jit, numpy as jnp
import numpy as np
import mediapy
from tqdm import tqdm
import dataclasses

from waymax import config as _config
from waymax import dataloader
from waymax import datatypes
from waymax import dynamics
from waymax import env as _env
from waymax import agents
from waymax import visualization

path2 = "gs://waymo_open_dataset_motion_v_1_3_0/uncompressed/tf_example/training/training_tfexample.tfrecord@1000"
max_num_objects = 32
data_loader_config = dataclasses.replace(
    _config.WOD_1_1_0_TRAINING,
    max_num_objects=max_num_objects,
    max_num_rg_points=30000,
    path="./data/training_tfexample.tfrecord@5",
)
sim_agent_config = _config.SimAgentConfig(
    agent_type=_config.SimAgentType.IDM, controlled_objects=_config.ObjectType.NON_SDC
)
env_config = dataclasses.replace(
    _config.EnvironmentConfig(),
    max_num_objects=max_num_objects,
    sim_agents=[sim_agent_config],
)


data_iter = dataloader.simulator_state_generator(config=data_loader_config)
dynamics_model = dynamics.InvertibleBicycleModel(normalize_actions=True)
sim_agent_actors = agents.create_sim_agents_from_config(sim_agent_config)
env = _env.PlanningAgentEnvironment(dynamics_model=dynamics_model, config=env_config, sim_agent_actors=[sim_agent_actors], sim_agent_params=[{}])

obj_idx = jnp.arange(max_num_objects)
actor = agents.create_expert_actor(dynamics_model=dynamics_model)

# Get next example from data iterator
scenario = next(data_iter)
scenario = next(data_iter)
scenario = next(data_iter)
scenario = next(data_iter)
scenario = next(data_iter)

# Reset environment with new state
state = env.reset(scenario)
rng = jax.random.key(0)
actor_state = actor.init(rng, state)
states = [state]

jit_step = jax.jit(env.step)
jit_select_action = jax.jit(actor.select_action)


for _ in tqdm(range(80)):  # Run for 8 seconds (80 steps)
    # Random actions for planning agent
    action = jit_select_action({}, state, actor_state, rng)
    action = datatypes.select_by_onehot(action, state.object_metadata.is_sdc)

    # Step environment forward
    state = jit_step(state, action.action)
    states.append(state)

# Visualize the scenario
imgs = []
for state in states:
    imgs.append(visualization.plot_simulator_state(state, use_log_traj=False))
mediapy.show_video(imgs, fps=10)


100%|██████████| 80/80 [00:00<00:00, 80.69it/s]


0
This browser does not support the video tag.
