# DriveNet Close Loop Evaluation

**Note: this notebook assumes you've already run the [training notebook](./drivenet_train.ipynb) and stored your model successfully.**

### What is close loop evaluation?
In close loop evaluation DriveNet is in **full control of the AV**. At each time step, we predict the future trajectory and then we move the AV int he first of the DriveNet's predictions. 


### What is a good close loop metric?
For this setting metrics are particularly challenging. In fact, we would like to penalise some of the drifting (e.g. going off road or in the opposite lane) while at the same time allow others (e.g. different speed profiles)

At Lyft L5, we use a substantial sets of different metrics to capture dangerous manoeuvres and behaviours. 

For the sake of simplicity, in this notebook we will be using a very simple proxy to detect if our model is driving in a sensible way, composed of two different metrics:

#### Collisions
Our AV should avoid collisions with other agents. This sounds trivial, but it's actually more complex than how it looks. In fact, while our AV is fully controlled by Drivenet, other agents are not (a setting we call **log replay**).

If our AV was slower than the recorded one, chasing agents might bump into us. Clearly, this won't happen in a real setting where other agents can react to our behaviours.

Still, this metric is useful to assess how much our DriveNet follows the rules of the road.

However, if we only considered collision, we might under penalise some potentially dangerous situations like driving off-road.

#### Distance from Reference Trajectory
To address the issue presented above, we require our DriveNet to loosely stick to the original trajectory in the data. By setting the right threshold on the distance we can allow for different speed profile and small steerings, while pensalising large deviations like driving off-road.

TODO explain what this is

In [None]:
from tempfile import gettempdir
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
from tqdm import tqdm

from l5kit.configs import load_config_data
from l5kit.data import LocalDataManager, ChunkedDataset, filter_agents_by_frames
from l5kit.dataset import EgoDataset
from l5kit.rasterization import build_rasterizer
from l5kit.geometry import transform_points, angular_distance, yaw_as_rotation33
from l5kit.visualization import TARGET_POINTS_COLOR, PREDICTED_POINTS_COLOR, draw_trajectory
from l5kit.drivenet.model import DriveNetModel
from l5kit.drivenet.utils import detect_collision
from l5kit.kinematic import AckermanPerturbation
from l5kit.random import GaussianRandomGenerator

import os

## Prepare Data path and load cfg

By setting the `L5KIT_DATA_FOLDER` variable, we can point the script to the folder where the data lies.

Then, we load our config file with relative paths and other configurations (rasteriser, training params...).

In [None]:
# set env variable for data
os.environ["L5KIT_DATA_FOLDER"] = "/tmp/l5kit_data"
dm = LocalDataManager(None)
# get config
cfg = load_config_data("./drivenet_config.yaml")

## Load The Model



In [None]:
model_path = "/tmp/drivenet.pt"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = torch.load(model_path).to(device)
model = model.eval()
torch.set_grad_enabled(False)

## Load the Evaluation Data
Differently from training and open loop evaluation, this setting is intrinsically sequential. As such, we won't be using any parallelisation offered by pytorch.

In [None]:
# ===== INIT DATASET
eval_cfg = cfg["val_data_loader"]
rasterizer = build_rasterizer(cfg, dm)
eval_zarr = ChunkedDataset(dm.require(eval_cfg["key"])).open()
eval_dataset = EgoDataset(cfg, eval_zarr, rasterizer)
print(eval_dataset)

# Unroll the first scene

in this cell we unroll the first scene of the dataset

TODO generic function for this

TODO add displacement error

In [None]:
# ==== EVAL LOOP
scene_dataset = eval_dataset.get_scene_dataset(10)

images = []
collisions = []

for frame_idx in tqdm(range(len(scene_dataset) // 5)):
    data = scene_dataset[frame_idx]
    del data["host_id"]
    data_batch = default_collate([data])
    result = model(data_batch)
    predicted_positions = result["positions"].detach().cpu().numpy().squeeze()
    predicted_yaws = result["yaws"].detach().cpu().numpy().squeeze()
    
    ## store image for future plot
    im_ego = rasterizer.to_rgb(data["image"].transpose(1, 2, 0))    
    draw_trajectory(im_ego, transform_points(predicted_positions, data["raster_from_agent"]), PREDICTED_POINTS_COLOR)
    images.append(im_ego[::-1])
    
    ## detect collisions
    agents_frame = filter_agents_by_frames(scene_dataset.dataset.frames[frame_idx], scene_dataset.dataset.agents)[0]
    collision = detect_collision(data["centroid"], data["yaw"], data["extent"],agents_frame)
    collisions.append(collision)
    
    ## mutate the next frame or reset to GT if a collision has occurred
    frame_mutate_idx = frame_idx + 1

    if not collision and frame_mutate_idx < len(scene_dataset):
        pred_positions_m = transform_points(predicted_positions, data["world_from_agent"])
        pred_angles_rad = predicted_yaws + data["yaw"]

        scene_dataset.dataset.frames[frame_mutate_idx]["ego_translation"][:2] = pred_positions_m[0]
        scene_dataset.dataset.frames[frame_mutate_idx]["ego_rotation"] = yaw_as_rotation33(pred_angles_rad[0])

# Qualitative Evaluation: Visualise the Close Loop

We can visualise the frames we have stored in the previous cell. As we are mutating future positions, DriveNet is now in full control of the AV as it moves through the annotated scene.

The playback speed has been decreased here to better show the predicted trajectory at each time step

In [None]:
from IPython.display import display, clear_output
import PIL
from time import sleep

collision_colors = {"front":np.asarray((255, 0, 0)), "side":np.asarray((0, 255, 0)), "rear":np.asarray((0, 0, 255))}
 
for frame, collision in zip(images, collisions):
    clear_output(wait=True)
    if collision:
        frame = np.zeros_like(frame) + collision_colors[collision[0]].astype(np.uint8)
    display(PIL.Image.fromarray(frame))
    sleep(0.1)