In [None]:
from collections import OrderedDict
import os
import torch
from l5kit.simulation.unroll import ClosedLoopSimulator
from l5kit.data import LocalDataManager, ChunkedDataset, filter_agents_by_frames
from l5kit.dataset import EgoDataset
from l5kit.rasterization import build_rasterizer
from tbsim.algos.l5kit_algos import L5TrafficModel
from tbsim.configs.base import ExperimentConfig
from tbsim.configs.l5kit_config import L5KitTrainConfig, L5KitEnvConfig, L5RasterizedPlanningConfig
from tbsim.envs.env_l5kit import EnvL5KitSimulation
from tbsim.utils.config_utils import translate_l5kit_cfg
from tbsim.utils.env_utils import rollout_episodes

from l5kit.visualization.visualizer.zarr_utils import simulation_out_to_visualizer_scene
from l5kit.visualization.visualizer.visualizer import visualize
from bokeh.io import output_notebook, show
from l5kit.data import MapAPI

In [None]:
cfg = ExperimentConfig(
    train_config=L5KitTrainConfig(),
    env_config=L5KitEnvConfig(),
    algo_config=L5RasterizedPlanningConfig()
)

# set env variable for data
os.environ["L5KIT_DATA_FOLDER"] = os.path.abspath(cfg.train.dataset_path)
dm = LocalDataManager(None)
l5_config = translate_l5kit_cfg(cfg)
rasterizer = build_rasterizer(l5_config, dm)
mapAPI = MapAPI.from_cfg(dm, l5_config)


eval_zarr = ChunkedDataset(dm.require(cfg.train.dataset_valid_key)).open()
env_dataset = EgoDataset(l5_config, eval_zarr, rasterizer)

env = EnvL5KitSimulation(cfg.env, dataset=env_dataset, seed=cfg.seed, num_scenes=10)

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

modality_shapes = OrderedDict(image=(rasterizer.num_channels(), 224, 224))
model = L5TrafficModel.load_from_checkpoint(
    "/home/danfeix/workspace/tbsim/l5_rasterized_trained_models/vanilla1/20211214235039/models/iter35999_ep0_simADE1.97.ckpt",
    algo_config=cfg.algo,
    modality_shapes=modality_shapes,
)

model.eval()

In [None]:
stats, info = rollout_episodes(env, model, 1)

In [None]:
output_notebook()
for sim_out in info["l5_sim_states"]: # for each scene
    vis_in = simulation_out_to_visualizer_scene(sim_out, mapAPI)
    show(visualize(sim_out.scene_id, vis_in))