# ML Simulation Training
In this notebook you are going to train a ML policy.

However, you won't use examples from the SDV as data, but other agents around it instead.

This may sound like a small difference, but it has profound implications still:
- by using data from multiple sources you're **including much more variability in your training data**;
- two agents may have taken different choices at the same intersection, leading to **multi-modal data**;
- the **quality of the annotated data is expected to be sensibility lower** compared to the SDV, as we're leveraging a perception system.

Still, the final prize is even better than planning as this policy can potentially drive all the agents in the scene and it's not limited to the SDV only.

![simulation-example](https://github.com/lyft/l5kit/blob/master/images/simulation/simulation_example.svg?raw=1)


In [None]:
from tempfile import gettempdir
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from tqdm import tqdm

from l5kit.configs import load_config_data
from l5kit.data import LocalDataManager, ChunkedDataset
from l5kit.dataset import AgentDataset
from l5kit.rasterization import build_rasterizer
from l5kit.geometry import transform_points
from l5kit.visualization import TARGET_POINTS_COLOR, draw_trajectory
from l5kit.planning.model import PlanningModel

import os

## Prepare data path and load cfg

By setting the `L5KIT_DATA_FOLDER` variable, we can point the script to the folder where the data lies.

Then, we load our config file with relative paths and other configurations (rasteriser, training params...).

In [None]:
#@title Download L5 Sample Dataset and install L5Kit
import os
RunningInCOLAB = 'google.colab' in str(get_ipython())
if RunningInCOLAB:
    !wget https://raw.githubusercontent.com/lyft/l5kit/master/examples/setup_notebook_colab.sh -q
    !sh ./setup_notebook_colab.sh
    os.environ["L5KIT_DATA_FOLDER"] = open("./dataset_dir.txt", "r").read().strip()
else:
    print("Not running in Google Colab.")
    os.environ["L5KIT_DATA_FOLDER"] = "/tmp/l5kit_data"

In [None]:
dm = LocalDataManager(None)
# get config
cfg = load_config_data("./config.yaml")

In [None]:
# rasterisation
rasterizer = build_rasterizer(cfg, dm)

# ===== INIT DATASET
train_zarr = ChunkedDataset(dm.require(cfg["train_data_loader"]["key"])).open()
train_dataset = AgentDataset(cfg, train_zarr, rasterizer)

# plot some examples
for idx in range(0, len(train_dataset), len(train_dataset) // 10):
    data = train_dataset[idx]
    im = rasterizer.to_rgb(data["image"].transpose(1, 2, 0))
    target_positions = transform_points(data["target_positions"], data["raster_from_agent"])
    draw_trajectory(im, target_positions, TARGET_POINTS_COLOR)
    plt.imshow(im)
    plt.axis('off')
    plt.show()


In [None]:
model = PlanningModel(
        model_arch=cfg["model_params"]["model_architecture"],
        num_input_channels=rasterizer.num_channels(),
        num_targets=3 * cfg["model_params"]["future_num_frames"],  # X, Y, Yaw * number of future states,
        weights_scaling= [1., 1., 1.],
        criterion=nn.MSELoss(reduction="none")
        )
print(model)

# Prepare for training
Our `AgentDataset` inherits from PyTorch `Dataset`; so we can use it inside a `Dataloader` to enable multi-processing.

In [None]:
train_cfg = cfg["train_data_loader"]
train_dataloader = DataLoader(train_dataset, shuffle=train_cfg["shuffle"], batch_size=train_cfg["batch_size"], 
                             num_workers=train_cfg["num_workers"])
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

print(train_dataset)

# Training loop
Here, we purposely include a barebone training loop. Clearly, many more components can be added to enrich logging and improve performance, such as:
- learning rate drop;
- loss weights tuning;
- importance sampling

To name a few.


Still, the sheer size of our dataset ensures that a reasonable performance can be obtained even with this simple loop.

In [None]:
tr_it = iter(train_dataloader)
progress_bar = tqdm(range(cfg["train_params"]["max_num_steps"]))
losses_train = []
model.train()
torch.set_grad_enabled(True)

for _ in progress_bar:
    try:
        data = next(tr_it)
    except StopIteration:
        tr_it = iter(train_dataloader)
        data = next(tr_it)
    # Forward pass
    data = {k: v.to(device) for k, v in data.items()}
    result = model(data)
    loss = result["loss"]
    # Backward pass
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    losses_train.append(loss.item())
    progress_bar.set_description(f"loss: {loss.item()} loss(avg): {np.mean(losses_train)}")

### Plot the train loss curve
We can plot the train loss against the iterations (batch-wise) to check if our model has converged.

In [None]:
plt.plot(np.arange(len(losses_train)), losses_train, label="train loss")
plt.legend()
plt.show()

# Store the model

Let's store the model as a torchscript. This format allows us to re-load the model and weights without requiring the class definition later.

**Take note of the path, you will use it later to evaluate your planning model!**

In [None]:
to_save = torch.jit.script(model.cpu())
path_to_save = f"{gettempdir()}/simulation_model.pt"
to_save.save(path_to_save)
print(f"MODEL STORED at {path_to_save}")

# Congratulations in training your first ML policy for simulation!
### What's Next

Now that your model is trained and safely stored, you can use it to control the agents around ego. We have a notebook just for that.

### [Simulation evaluation](./simulation_test.ipynb)
In this notebook a `planning_model` will control the SDV, while the `simulation_model` you just trained will be used for all other agents.

Don't worry if you don't have the resources required to train a model, we provide pre-trained models just below.

## Pre-trained models
we provide a collection of pre-trained models for the simulation task:
- [simulation model](https://lyft-l5-datasets-public.s3-us-west-2.amazonaws.com/models/simulation_models/simulation_model_20210416_5steps.pt) trained on agents over the semantic rasteriser with history of 0.5s;
- [planning model](https://lyft-l5-datasets-public.s3-us-west-2.amazonaws.com/models/simulation_models/planning_model_20210421_5steps.pt) trained on the AV over the semantic rasteriser with history of 0.5s;

To use one of the models simply download the corresponding `.pt` file and load it in the evaluation notebooks.