In [None]:
from typing import Optional

import torch
from l5kit.cle.composite_metrics import CompositeMetricAggregator
from l5kit.cle.scene_type_agg import compute_cle_scene_type_aggregations, compute_scene_type_ade_fde
from l5kit.cle.validators import ValidationCountingAggregator
from l5kit.dataset import EgoDataset
from l5kit.environment.callbacks import L5KitEvalCallback
from l5kit.environment.gym_metric_set import CLEMetricSet
from l5kit.environment.utils import get_scene_types
from l5kit.simulation.dataset import SimulationConfig
from l5kit.simulation.unroll import ClosedLoopSimulator
from stable_baselines3.common.logger import Logger

from l5kit.configs import load_config_data
from l5kit.data import ChunkedDataset, LocalDataManager
from l5kit.dataset import EgoDataset
from l5kit.rasterization import build_rasterizer
from stable_baselines3.common import utils

from l5kit.visualization.visualizer.zarr_utils import simulation_out_to_visualizer_scene, episode_out_to_visualizer_scene_gym_cle
from l5kit.visualization.visualizer.visualizer import visualize
from bokeh.io import output_notebook, show
from l5kit.data import MapAPI

scene_id_to_type_path = '../../dataset_metadata/validate_turns_metadata.csv'

dm = LocalDataManager(None)
# get config
cfg = load_config_data("./drivenet_config.yaml")
rasterizer = build_rasterizer(cfg, dm)

# Validation Dataset
eval_cfg = cfg["val_data_loader"]
eval_zarr = ChunkedDataset(dm.require(eval_cfg["key"])).open()
mapAPI = MapAPI.from_cfg(dm ,cfg)
eval_dataset = EgoDataset(cfg, eval_zarr, rasterizer)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
def eval_model(model: torch.nn.Module, dataset: EgoDataset, logger: Logger, d_set: str, iter_num: int,
               num_scenes_to_unroll: int, num_simulation_steps: int = None,
               enable_scene_type_aggregation: Optional[bool] = False,
               scene_id_to_type_path: Optional[str] = None) -> None:
    """Evaluator function for the drivenet model. Evaluate the model using the CLEMetricSet
    of L5Kit. Logging is performed in the Tensorboard logger.

    :param model: the trained model to evaluate
    :param dataset: the dataset on which the models is evaluated
    :param logger: tensorboard logger to log the evaluation results
    :param d_set: the type of dataset being evaluated ("train" or "eval")
    :param iter_num: iteration number of training (to log in tensorboard)
    :param num_scenes_to_unroll: Number of scenes to evaluate in the dataset
    :param num_simulation_steps: Number of steps to unroll the model for.
    :param enable_scene_type_aggregation: enable evaluation according to scene type
    :param scene_id_to_type_path: path to the csv file mapping scene id to scene type
    """

    model.eval()
    torch.set_grad_enabled(False)

    # Close Loop Simulation
    sim_cfg = SimulationConfig(use_ego_gt=False, use_agents_gt=True, disable_new_agents=False,
                               distance_th_far=30, distance_th_close=15, num_simulation_steps=num_simulation_steps,
                               start_frame_index=0, show_info=False)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    sim_loop = ClosedLoopSimulator(sim_cfg, dataset, device, model_ego=model, model_agents=None)

    # metric set
    metric_set = CLEMetricSet()

    # unroll
    batch_unroll = 100
    for start_idx in range(0, num_scenes_to_unroll, batch_unroll):
        end_idx = min(num_scenes_to_unroll, start_idx + batch_unroll)
        scenes_to_unroll = list(range(start_idx, end_idx))
        sim_outs = sim_loop.unroll(scenes_to_unroll)
        metric_set.evaluator.evaluate(sim_outs)

    # Aggregate metrics (ADE, FDE)
    ade, fde = L5KitEvalCallback.compute_ade_fde(metric_set)
    print(f'{d_set}/ade', round(ade, 3))
    print(f'{d_set}/fde', round(fde, 3))
    
    return sim_outs


In [None]:
# Give model path and make sure config.yaml respects the model
# model_path = "./checkpoints/drivenet_h0_p05_default_schedule_step5_full_757425_steps.pt"
model_path = "./checkpoints/drivenet_h0_p05_onecycle_schedule_step5_wt_rew_full_757425_steps.pt"
num_scenes_to_unroll = 5

model = torch.load(model_path).to(device)
model = model.eval()

import time
st = time.time()
sim_outs = eval_model(model, eval_dataset, None, "eval", 2000000, num_scenes_to_unroll, num_simulation_steps=None)
print("Time: ", time.time() - st)

In [None]:
output_notebook()
for sim_out in sim_outs: # for each scene
    vis_in = simulation_out_to_visualizer_scene(sim_out, mapAPI)
    show(visualize(sim_out.scene_id, vis_in))

In [None]:
# Give model path and make sure config.yaml respects the model
model_path = "./checkpoints/drivenet_h0_p05_default_schedule_step5_full_757425_steps.pt"
# model_path = "./checkpoints/drivenet_h0_p05_onecycle_schedule_step5_wt_rew_full_757425_steps.pt"
num_scenes_to_unroll = 5

model = torch.load(model_path).to(device)
model = model.eval()

import time
st = time.time()
sim_outs = eval_model(model, eval_dataset, None, "eval", 2000000, num_scenes_to_unroll, num_simulation_steps=None)
print("Time: ", time.time() - st)

In [None]:
output_notebook()
for sim_out in sim_outs: # for each scene
    vis_in = simulation_out_to_visualizer_scene(sim_out, mapAPI)
    show(visualize(sim_out.scene_id, vis_in))