# Introduction to SafePathNet

In this notebook you are going to train the multimodal prediction model presented in [Safe Real-World Autonomous Driving by Learning to Predict and Plan with a Mixture of Experts](https://arxiv.org/abs/2211.02131).
Please note that we are releasing the prediction part of our approach only, but the code can be easily extended to planning too.

You will train and test your model using the Woven by Toyota Prediction Dataset and [L5Kit](https://github.com/woven-planet/l5kit).
**Before starting, please download the [Woven by Toyota Prediction Dataset 2020](https://woven.toyota/en/prediction-dataset) and follow [the instructions](https://github.com/woven-planet/l5kit#download-the-datasets) to correctly organise it.**

### Model

From the paper:
```
The architecture of SafePathNet is similar to those of VectorNet [11] and DETR [7], combining an element-wise point encoder [23] and a Transformer [31]. The element-wise point encoder consists of two PointNet-like modules that are used to compress each input element from a set of points to a single feature vector of the same size. A series of Transformer Encoder layers are used to model the relationships between all input elements (SDV, road agents, static and dynamic map, route), encoded by the point encoder. Then, a series of Transformer Decoders are used to query agents features. We make use of a set of learnable embeddings to construct the queries of the Transformer Decoders. M learnable query embeddings are used to obtain a variable number of M different queries for each road agent. An agent-specific MLP decoder converts each agent feature to a future trajectory. In addition to trajectories, the decoder predicts a logit for each agent trajectory. For each element, the corresponding logits are converted to a probability distribution over the future trajectories by applying a softmax function. All road agents are modeled independently, but predicted jointly in parallel.
```
This is a diagram of the full model:

![model](../../docs/images/safepathnet/safepathnet_model.svg)


#### Inputs
Following previous works, SafePathNet is based on a vectorized representation of the world, centered on the ego location.
Please refer to the paper for more details.


#### Outputs
SafePathNet outputs a trajectory distribution (in the form of a set of trajectories and a probability distribution over them) for each road agent (including ego).
All road agents are modeled independently, but predicted jointly in parallel.
Each timestep is a tuple consisting of `(X, Y, yaw)`.

### Training

Our model represents a mixture of experts, comprised of a set of experts and an expert selection function.
We train them jointly while avoiding mode collapse using a winner-takes-all approach.

From the paper:
```
Our model represents a MoE and predicts multiple trajectories for the SDV and each road agent, corresponding to N/M experts, and a probability distribution over each trajectory set, corresponding to expert selection. To train the experts and expert selection jointly while avoiding mode collapse, we use a winner-takes-all approach, somewhat similar to previous methods [10]. Similarly to DETR [7], we formulate a matching cost between predicted and target trajectories and probabilities, making the expert with minimal cost the winner.
```

We define our training objective as minimizing the distance between predicted and ground truth agents’ future trajectories (imitation loss) and the negative log likelihood of the selected trajectory (matching loss).
Please refer to the paper for more information.

# Import packages
Import packages (requires a working installation of l5kit) and set random seeds.

In [None]:
import os
from collections import defaultdict
from pathlib import Path

import numpy as np
import torch
from matplotlib import pyplot as plt
from tempfile import gettempdir
from torch import nn, optim
from torch.utils.data import DataLoader
from tqdm import tqdm

from l5kit.configs import load_config_data
from l5kit.data import ChunkedDataset, LocalDataManager
from l5kit.dataset import EgoAgentDatasetVectorized
from l5kit.evaluation import write_pred_csv, compute_metrics_csv, read_gt_csv, create_chopped_dataset
from l5kit.evaluation.chop_dataset import MIN_FUTURE_STEPS
from l5kit.evaluation.metrics import average_displacement_error_oracle, final_displacement_error_oracle
from l5kit.planning.vectorized.common import build_matrix, transform_points
from l5kit.prediction.vectorized.safepathnet_model import SafePathNetModel
from l5kit.vectorization.vectorizer_builder import build_vectorizer

torch.manual_seed(123)
np.random.seed(123)

# Prepare data path and load cfg

By setting the `L5KIT_DATA_FOLDER` variable, we can point the script to the folder where the data lies.

Then, we load our config file with relative paths and other configurations (vectorizer, training params...).

In [None]:
# Download L5 Sample Dataset
import os
RunningInCOLAB = 'google.colab' in str(get_ipython())
if RunningInCOLAB:
    !wget https://raw.githubusercontent.com/lyft/l5kit/master/examples/setup_notebook_colab.sh -q
    !sh ./setup_notebook_colab.sh
    os.environ["L5KIT_DATA_FOLDER"] = open("./dataset_dir.txt", "r").read().strip()
else:
    print("Not running in Google Colab.")
    os.environ["L5KIT_DATA_FOLDER"] = "PATH_TO_DATASET"

In [None]:
# define local data manager
dm = LocalDataManager(None)

# load the experiment config
cfg = load_config_data("./config.yaml")
print("Configuration loaded.")

# Initialize the training dataset

In [None]:
# INIT DATASET
train_zarr = ChunkedDataset(dm.require(cfg["train_data_loader"]["key"])).open()

vectorizer = build_vectorizer(cfg, dm)
train_dataset = EgoAgentDatasetVectorized(cfg, train_zarr, vectorizer)
print(train_dataset)

# Define the model
Let's define the SafePathNet model and move it to GPU, if available.

In [None]:
model = SafePathNetModel(
    history_num_frames_ego=cfg["model_params"]["history_num_frames_ego"],
    history_num_frames_agents=cfg["model_params"]["history_num_frames_agents"],
    num_timesteps=cfg["model_params"]["future_num_frames"],
    weights_scaling=cfg["model_params"]["weights_scaling"],
    criterion=nn.L1Loss(reduction="none"),
    disable_other_agents=cfg["model_params"]["disable_other_agents"],
    disable_map=cfg["model_params"]["disable_map"],
    disable_lane_boundaries=cfg["model_params"]["disable_lane_boundaries"],
    agent_num_trajectories=cfg["model_params"]["agent_num_trajectories"],
    max_num_agents=cfg["data_generation_params"]["other_agents_num"],
    cost_prob_coeff=cfg["model_params"]["cost_prob_coeff"] * 2.5,
)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Model created and loaded on device:", device)

model = model.to(device)

# Prepare for training
Our `EgoAgentDatasetVectorized` inherits from PyTorch `Dataset`; so we can use it inside a `Dataloader` to enable multi-processing.
It extends the dataset `EgoDatasetVectorized` to include ego as a road agent and to support agent prediction evaluation, while keeping the scene ego-centric.

In [None]:
train_cfg = cfg["train_data_loader"]
train_dataloader = DataLoader(train_dataset, shuffle=train_cfg["shuffle"], batch_size=train_cfg["batch_size"],
                              num_workers=train_cfg["num_workers"])
optimizer = optim.Adam(model.parameters(), lr=1e-3)

num_epochs = cfg["train_params"]["num_epochs"]
num_steps_per_epoch = len(train_dataset)
max_num_steps = min(num_steps_per_epoch, cfg["train_params"]["max_num_steps"])
num_steps_per_log = max(1, max_num_steps // 100)
checkpoint_every_n_epochs = cfg["train_params"]["checkpoint_every_n_epochs"]
num_warmup_epochs = cfg["train_params"]["num_epochs"] // 5

def lr_lambda_warmup_cosine(step: int) -> float:
    steps_per_epoch = max_num_steps
    total_steps = num_epochs * steps_per_epoch
    warmup_steps = num_warmup_epochs * steps_per_epoch

    if step < warmup_steps:  # warmup
        return step / warmup_steps
    else:
        steps_since_warmup = step - warmup_steps
        anneal_steps = total_steps - warmup_steps
        completion = steps_since_warmup / anneal_steps
        cosine_rate = float(0.5 * (1 + np.cos(completion * np.pi)))
        return cosine_rate

lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda_warmup_cosine)


# Training loop
Here, we purposely include a basic training loop. Clearly, many more components can be added to enrich logging and improve performance. Still, a reasonable performance can be obtained even with this simple loop.
Please adapt the training length changing the "train_params" in the config file.

In [None]:
%matplotlib inline

model.train()
torch.set_grad_enabled(True)

loss_log = defaultdict(list)
lr_log = list()
progress_bar = tqdm(total=max_num_steps)

print(f"Starting training - {num_epochs} epochs")
print(f"An epoch is composed of {len(train_dataset)} steps")

for epoch in range(num_epochs):
    print(f"Epoch {epoch + 1} - Starting")
    progress_bar.reset()
    for idx, data in enumerate(train_dataloader):
        if idx == max_num_steps:
            break

        # Forward pass
        data = {k: v.to(device) for k, v in data.items()}
        result = model(data)
        loss = result["loss"]

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        # logging
        if idx % num_steps_per_log == 0:
            for key, res in result.items():
                loss_log[key].append(res.item())
            lr_log.append(lr_scheduler.get_last_lr())
        
        progress_bar.update()
        progress_bar.set_description(f"loss: {loss.item():.5f} - loss(avg): {np.mean(loss_log['loss'][-idx:]):.5f}")
    
    if epoch % checkpoint_every_n_epochs == 0 or epoch + 1 == num_epochs:
        save_path = f"{gettempdir()}/safepathnet_model.{epoch}.pth"
        torch.save(model.state_dict(), save_path)
        print(f"Model saved at {save_path}.")
    
    for key, loss in loss_log.items():
        loss_last_epoch = loss[-idx // num_steps_per_log:]
        plt.plot(np.arange(len(loss_last_epoch)), loss_last_epoch, label=key)
    plt.legend()
    plt.show()
    
    lr_log_last_epoch = lr_log[-idx // num_steps_per_log:]
    plt.plot(np.arange(len(lr_log_last_epoch)), lr_log_last_epoch, label='learning rate')
    plt.legend()
    plt.show()
    
    print(f"Epoch {epoch + 1} - Ended")


## Plot the final train loss curve
We can plot the train loss against the iterations (max 100 values per epoch) to check if our model has converged.
We also plot the learning rate values used across the training.

In [None]:
%matplotlib inline

for key, loss in loss_log.items():
    plt.plot(np.arange(len(loss)), loss, label=key)
plt.legend()
plt.show()

plt.plot(np.arange(len(lr_log)), lr_log, label='learning rate')
plt.legend()
plt.show()

## Store the model in torchscript format

Let's store the model as a torchscript too. This format allows us to re-load the model and weights without requiring the class definition.

In [None]:
model.eval()
jit_model = torch.jit.script(model.cpu())
path_to_save = f"{gettempdir()}/safepathnet_script.pth"
jit_model.save(path_to_save)
print(f"MODEL STORED at {path_to_save}")

# Evaluation

Following the challenge evaluation protocol, **the test set for the competition is "chopped" using the `chop_dataset` function**.

In [None]:
# GENERATE AND LOAD CHOPPED DATASET
num_frames_to_chop = 100
eval_cfg = cfg["val_data_loader"]
eval_base_path = os.path.join(os.environ["L5KIT_DATA_FOLDER"], f"{eval_cfg['key'].split('.')[0]}_chopped_100")
if not os.path.exists(eval_base_path):
    eval_base_path = create_chopped_dataset(
        dm.require(eval_cfg["key"]), 
        cfg["raster_params"]["filter_agents_threshold"], 
        num_frames_to_chop, 
        cfg["model_params"]["future_num_frames"], 
        MIN_FUTURE_STEPS)

The result is that **each scene has been reduced to only 100 frames**, and **only valid agents in the 100th frame will be used to compute the metrics**. Because following frames in the scene have been chopped off, we can't just look ahead to get the future of those agents.

In this example, we simulate this pipeline by running `chop_dataset` on the validation set. The function stores:
- a new chopped `.zarr` dataset, in which each scene has only the first 100 frames;
- a numpy mask array where only valid agents in the 100th frame are True;
- a ground-truth file with the future coordinates of those agents;

Please note how the total number of frames is now equal to the number of scenes multipled by `num_frames_to_chop`. 

The remaining frames in the scene have been sucessfully chopped off from the data

**Note:** SafePathNet is able to predict future trajectories of all the agents in a scene in a single pass of the model, using the ego-centric reference frame.
Thus, we use a modified version of the EgoDataset that additionally returns the ids of the agents that are used for evaluation in the Prediction Challenge.

In [None]:
eval_zarr_path = str(Path(eval_base_path) / Path(dm.require(eval_cfg["key"])).name)
eval_mask_path = str(Path(eval_base_path) / "mask.npz")
eval_gt_path = str(Path(eval_base_path) / "gt.csv")

eval_zarr = ChunkedDataset(eval_zarr_path).open()
eval_mask = np.load(eval_mask_path)["arr_0"]

vectorizer = build_vectorizer(cfg, dm)

# INIT DATASET AND LOAD MASK
eval_dataset = EgoAgentDatasetVectorized(cfg, eval_zarr, vectorizer, agents_mask=eval_mask, eval_mode=True)
print(eval_dataset)

eval_cfg = cfg["val_data_loader"]
eval_dataloader = DataLoader(eval_dataset, shuffle=eval_cfg["shuffle"], batch_size=eval_cfg["batch_size"],
                             num_workers=eval_cfg["num_workers"])


## Storing Predictions
There is a small catch to be aware of when saving the model predictions. The output of the models are coordinates in `ego` space and we need to convert them into displacements in `world` space.

To do so, we first convert them back into the `world` space and we then subtract to each agent their own `world` centroid coordinates.

In [None]:
# EVAL LOOP
model.eval()
model.to(device)
torch.set_grad_enabled(False)

# store information for evaluation
future_coords_offsets_pd = []
future_traj_confidence = []
timestamps = []
agent_ids = []
agent_of_interest_ids = []
missing_agent_of_interest_ids = []
missing_agent_of_interest_timestamp = []

# torch.isin is available only form pytorch 1.10 - defining a simple alternative
def torch_isin(ar1, ar2):
    return (ar1[..., None] == ar2).any(-1)

# iterate over validation dataset
progress_bar = tqdm(eval_dataloader)
for data in progress_bar:
    data = {k: v.to(device) for k, v in data.items()}
    outputs = model(data)

    # [batch_size, max_num_agents, num_trajectories, num_timesteps, 2]
    agent_xy = outputs["all_agent_positions"]
    # [batch_size, max_num_agents, num_trajectories, num_timesteps, 1]
    agent_yaw = outputs["all_agent_yaws"]
    # [batch_size, max_num_agents, num_trajectories]
    agent_logits = outputs["agent_traj_logits"]
    
    # [batch_size, max_num_agents, num_trajectories, num_timesteps, 3]
    agent_pos = torch.cat((agent_xy, agent_yaw), dim=-1)

    # ego-centric agent coords must be converted to world frame first
    # [batch_size, 3, 3]
    world_from_agents = data["world_from_agent"].float()
    # [batch_size]
    world_from_agents_yaw = data["yaw"].float()
    # shape of data["all_other_agents_history_positions"]: [batch_size, num_agents, num_history_frames, 2]
    # [batch_size, num_agents, 1, 3]
    agent_t0_pos_yaw = torch.cat((data["all_other_agents_history_positions"][:, :, :1],
                                  data["all_other_agents_history_yaws"][:, :, :1]), dim=-1)
    agent_t0_avail = data["all_other_agents_history_availability"][:, :, :1]
    # [batch_size, num_agents, 1, 3]
    world_agent_t0_pos_yaw = transform_points(agent_t0_pos_yaw, world_from_agents, avail=agent_t0_avail,
                                              yaw=world_from_agents_yaw)
    world_agent_pos = transform_points(agent_pos.flatten(2,3), world_from_agents, avail=agent_t0_avail).view_as(agent_pos)

    # then can be converted to agent-relative
    world_agents_t0_pos_exp = world_agent_t0_pos_yaw[..., :2]
    world_agents_t0_yaw_exp = world_agent_t0_pos_yaw[..., 2]
    # [batch_size * max_num_agents, 3, 3]
    _, matrix = build_matrix(world_agents_t0_pos_exp.reshape(-1, 2), world_agents_t0_yaw_exp.reshape(-1))
    # [batch_size, max_num_agents, 3, 3]
    matrix = matrix.view(list(world_agent_t0_pos_yaw.shape[:2]) + [3, 3])
    # [batch_size * max_num_agents * num_trajectories * num_timesteps, 3, 3]
    matrix = matrix.unsqueeze(2).unsqueeze(2).expand(list(agent_pos.shape[:-1]) + [3, 3]).reshape(-1, 3, 3)
    coords_offset = transform_points(world_agent_pos.reshape(-1, 1, 1, 3), matrix, 
                                     avail=torch.ones_like(world_agent_pos.reshape(-1, 1, 1, 3)[..., 0]))
    coords_offset = coords_offset.view_as(world_agent_pos)
    
    # need to filter per agents of interest (from original prediction evaluation)
    agents_track_ids = data["all_other_agents_track_ids"]
    agents_of_interest = data["all_valid_agents_track_ids"]
    agents_track_ids_mask = torch.zeros_like(agents_track_ids, dtype=torch.bool)
    missing_agents_mask = torch.zeros_like(agents_of_interest, dtype=torch.bool)
    for batch_idx in range(agents_track_ids.shape[0]):
        agents_track_ids_mask[batch_idx] = torch_isin(agents_track_ids[batch_idx], agents_of_interest[batch_idx]) * \
                                           agents_track_ids[batch_idx] != 0
        missing_agents_mask[batch_idx] = ~torch_isin(agents_of_interest[batch_idx], agents_track_ids[batch_idx]) * \
                                         agents_of_interest[batch_idx] != 0
    # we may miss some agents due to the limit cfg["data_generation_params"]["other_agents_num"], we will consider them stationary
    missing_agents_ids = agents_of_interest[missing_agents_mask]
    if torch.any(missing_agents_mask):
        # print(len(missing_agents_ids), missing_agents_ids[missing_agents_ids != 0])
        missing_agents_ids = missing_agents_ids[missing_agents_ids != 0]
        missing_agent_of_interest_ids.append(missing_agents_ids.cpu())
        missing_timestamps = []
        for batch_idx, num_missing_agents in enumerate(missing_agents_mask.sum(-1)):
            missing_timestamps.extend([data["timestamp"][batch_idx]] * num_missing_agents)
        missing_agent_of_interest_timestamp.append(torch.tensor(missing_timestamps))
    
    # move the valida data to CPU
    relevant_coords_offset = coords_offset[agents_track_ids_mask].cpu()
    traj_confidence = agent_logits[agents_track_ids_mask].cpu()
    relevant_agent_track_ids = agents_track_ids[agents_track_ids_mask].cpu()
    relevant_timestamps = data["timestamp"].unsqueeze(1).expand(agents_track_ids.shape)[agents_track_ids_mask].cpu()

    # add them to the result lists
    future_coords_offsets_pd.append(relevant_coords_offset)
    future_traj_confidence.append(traj_confidence)
    timestamps.append(relevant_timestamps)
    agent_ids.append(relevant_agent_track_ids)

# add the missing agents as stationary
missing_agent_of_interest_ids = torch.cat(missing_agent_of_interest_ids, dim=0)
missing_agent_of_interest_timestamp = torch.cat(missing_agent_of_interest_timestamp, dim=0)
stationary_trajectories = torch.zeros(list(missing_agent_of_interest_ids.shape[:1]) + list(future_coords_offsets_pd[0].shape[1:]))
uniform_probabilities = torch.ones(list(missing_agent_of_interest_ids.shape[:1]) + list(future_traj_confidence[0].shape[1:]))
agent_ids.append(missing_agent_of_interest_ids)
future_coords_offsets_pd.append(stationary_trajectories)
future_traj_confidence.append(uniform_probabilities)
timestamps.append(missing_agent_of_interest_timestamp)

# concatenate all the results in a single np array    
future_coords_offsets_pd = torch.cat(future_coords_offsets_pd, dim=0).numpy()
future_traj_confidence = torch.cat(future_traj_confidence, dim=0).softmax(-1).numpy()
timestamps = torch.cat(timestamps, dim=0).numpy().astype(np.int64)
agent_ids = torch.cat(agent_ids, dim=0).numpy().astype(np.int64)


In [None]:
# let's verify the number of coordinates corresponds to the number of coordinates in the original 
assert len(future_coords_offsets_pd == 94694)

In [None]:
print(f"Overall, we missed {len(missing_agent_of_interest_ids)} agents over a total of {94694} agents "
      f"(~{len(missing_agent_of_interest_ids)/94694:.5f}%)")

## Save results in csv format
After the model has predicted trajectories for our evaluation set, we can save them in a `csv` file compatible with the l5kit evaluation tool.

In [None]:
pred_path = f"{gettempdir()}/pred.csv"

write_pred_csv(pred_path,
               timestamps=timestamps,
               track_ids=agent_ids,
               coords=future_coords_offsets_pd[..., :2],
               confs=future_traj_confidence,
               max_modes=cfg["model_params"]["agent_num_trajectories"])

## Perform Evaluation
We can evaluate the model predictions with the existing metrics from l5kit, supporting multimodal predictions. In our case, we're interested in the minimum Average Distance Error (minADE) and minimum final distance error (minFDE). Other metrics can be added from `l5kit.evaluation.metrics`.

In [None]:
# if you restart the notebook and want to evaluate an existing csv file, uncomment this cell using your csv path
# pred_path = 'PATH_TO_CSV'

In [None]:
# COMPUTE AND PRINT METRICS
metrics = compute_metrics_csv(eval_gt_path, pred_path, 
                              [average_displacement_error_oracle, final_displacement_error_oracle],
                              max_modes=cfg["model_params"]["agent_num_trajectories"])
for metric_name, metric_mean in metrics.items():
    print(metric_name, metric_mean)


### Visualise Results
Coming soon.


# Congratulations in training and evaluating your SafePathNet model!

For more information on SafePathNet, please have a look at our paper  
[Safe Real-World Autonomous Driving by Learning to Predict and Plan with a Mixture of Experts](https://arxiv.org/abs/2211.02131).