# L5 Closed-loop Gym-compatible Environment

This notebook demonstrates some of the aspects of our gym-compatible closed-loop environment.

You will understand the inner workings of our L5Kit environment and an RL policy can be used to rollout the environment. 

Note: The training of different RL policies in our environment will be shown in a separate notebook.

![drivergym](../../docs/images/rl/drivergym.png)

In [None]:
#@title Download L5 Sample Dataset and install L5Kit
import os
RunningInCOLAB = 'google.colab' in str(get_ipython())
if RunningInCOLAB:
    !wget https://raw.githubusercontent.com/lyft/l5kit/master/examples/setup_notebook_colab.sh -q
    !sh ./setup_notebook_colab.sh
    os.environ["L5KIT_DATA_FOLDER"] = open("./dataset_dir.txt", "r").read().strip()
else:
    os.environ["L5KIT_DATA_FOLDER"] = "/tmp/level5_data"
    print("Not running in Google Colab.")

In [None]:
import gym
import matplotlib.pyplot as plt
import torch
import numpy as np

import l5kit.environment
from l5kit.configs import load_config_data
from l5kit.environment.envs.l5_env import EpisodeOutputGym, SimulationConfigGym
from l5kit.environment.gym_metric_set import L2DisplacementYawMetricSet
from l5kit.visualization.visualizer.zarr_utils import episode_out_to_visualizer_scene_gym_cle
from l5kit.visualization.visualizer.visualizer import visualize

from bokeh.io import output_notebook, show
from prettytable import PrettyTable


### First, let's configure where our data lives!
The data is expected to live in a folder that can be configured using the `L5KIT_DATA_FOLDER` env variable. Your data folder is expected to contain subfolders for the aerial and semantic maps as well as the scenes (`.zarr` files). 
In this example, the env variable is set to the local data folder. You should make sure the path points to the correct location for you.

We built our code to work with a human-readable `yaml` config. This config file holds much useful information, however, we will only focus on a few functionalities concerning the creation of our gym environment here

In [None]:
# Dataset is assumed to be on the folder specified
# in the L5KIT_DATA_FOLDER environment variable

# get environment config
env_config_path = '../gym_config.yaml'
cfg = load_config_data(env_config_path)
print(cfg)

### We can look into our current configuration for interesting fields

\- when loaded in python, the `yaml`file is converted into a python `dict`. 

`raster_params` contains all the information related to the transformation of the 3D world onto an image plane:
  - `raster_size`: the image plane size
  - `pixel_size`: how many meters correspond to a pixel
  - `ego_center`: our raster is centered around an agent, we can move the agent in the image plane with this param
  - `map_type`: the rasterizer to be employed. We currently support a satellite-based and a semantic-based one. We will look at the differences further down in this script
  
The `raster_params` are used to determine the observation provided by our gym environment to the RL policy.

In [None]:
print(f'current raster_param:\n')
for k,v in cfg["raster_params"].items():
    print(f"{k}:{v}")

## Create L5 Closed-loop Environment

We will now create an instance of the L5Kit gym-compatible environment. As you can see, we need to provide the path to the configuration file of the environment. 

1. The `rescale_action` flag rescales the policy action based on dataset statistics. This argument helps for faster convergence during policy training. 
2. The `return_info` flag informs the environment to return the episode output everytime an episode is rolled out. 

Note: The environment has already been registered with gym during initialization of L5Kit.


In [None]:
env = gym.make("L5-CLE-v0", env_config_path=env_config_path, rescale_action=False, return_info=True)


## Visualize an observation from the environment

Let us visualize the observation from the environment. We will reset the environment and visualize an observation which is provided by the environment.

In [None]:
obs = env.reset()
im = obs["image"].transpose(1, 2, 0)
im = env.dataset.rasterizer.to_rgb(im)

plt.imshow(im)
plt.show()

## Rollout an episode from the environment

### The rollout of an episode in our environment takes place in three steps:

### Gym Environment Update:
1. Reward Calculation (CLE): Given an action from the policy, the environment will calculate the reward received as a consequence of the action.
2. Internal State Update: Since we are rolling out the environment in closed-loop, the internal state of the ego is updated based on the action.
3. Raster rendering: A new raster image is rendered based on the predicted ego position and returned as the observation of next time-step.

### Policy Forward Pass
The policy takes as input the observation provided by the environment and outputs the action via a forward pass.

### Inter-process communication
Usually, we deploy different subprocesses to rollout parallel environments to speed up rollout time during training. Each subprocess rolls out one environemnt. In such scenarios, there is an additional component called inter-process communication: The subprocess outputs (observations) are aggregated and passed to the main process and vice versa (for the actions)

![rollout](../../docs/images/rl/policy_rollout.png)

### Dummy Policy

For this notebook, we will not train the policy but use a dummy policy. Our dummy policy that will move the ego by 10 m/s along the direction of orientation.

In [None]:
class DummyPolicy(torch.nn.Module):
    """A policy that advances the ego by constant speed along x-direction.

    :param advance_x: the distance to advance per time-step
    """
    def __init__(self, advance_x: float = 0.0):
        super(DummyPolicy, self).__init__()
        self.advance_x = advance_x

    def forward(self, x):
        positions_and_yaws = torch.zeros(3,)
        positions_and_yaws[..., 0] = self.advance_x

        return positions_and_yaws.cpu().numpy()

# We multiple the desired speed by the step-time (inverse of frequency) of data collection
desired_speed = 10.0
dummy_policy = DummyPolicy(cfg["model_params"]["step_time"] * desired_speed)

Let us now rollout the environment using the dummy policy. 

In [None]:
def rollout_episode(env, idx = 0):
    """Rollout a particular scene index and return the simulation output.

    :param env: the gym environment
    :param idx: the scene index to be rolled out
    :return: the episode output of the rolled out scene
    """

    # Set the reset_scene_id to 'idx'
    env.reset_scene_id = idx
    
    # Rollout step-by-step
    obs = env.reset()
    while True:
        action = dummy_policy(obs)
        obs, _, done, info = env.step(action)
        if done:
            break
    
    # The episode outputs are present in the key "sim_outs"
    sim_out = info["sim_outs"][0]
    return sim_out

# Rollout one episode
sim_out = rollout_episode(env)

## Visualize the episode from the environment

We can easily visualize the outputs obtained by rolling out episodes in the L5Kit using the Bokeh visualizer.

In [None]:
# might change with different rasterizer
map_API = env.dataset.rasterizer.sem_rast.mapAPI

def visualize_outputs(sim_outs, map_API):
    for sim_out in sim_outs: # for each scene
        vis_in = episode_out_to_visualizer_scene_gym_cle(sim_out, map_API)
        show(visualize(sim_out.scene_id, vis_in))

output_notebook()
visualize_outputs([sim_out], map_API)

## Calculate the performance metrics from the episode outputs

We can also calculate the various quantitative metrics on the rolled out episode output. 

In [None]:
def quantify_outputs(sim_outs, metric_set=None):
    metric_set = metric_set if metric_set is not None else L2DisplacementYawMetricSet()

    metric_set.evaluate(sim_outs)
    scene_results = metric_set.evaluator.scene_metric_results
    fields = ["scene_id", "FDE", "ADE"]
    table = PrettyTable(field_names=fields)
    tot_fde = 0.0
    tot_ade = 0.0
    for scene_id in scene_results:
        scene_metrics = scene_results[scene_id]
        ade_error = scene_metrics["displacement_error_l2"][1:].mean()
        fde_error = scene_metrics['displacement_error_l2'][-1]
        table.add_row([scene_id, round(fde_error.item(), 4), round(ade_error.item(), 4)])
        tot_fde += fde_error.item()
        tot_ade += ade_error.item()

    ave_fde = tot_fde / len(scene_results)
    ave_ade = tot_ade / len(scene_results)
    table.add_row(["Overall", round(ave_fde, 4), round(ave_ade, 4)])
    print(table)


quantify_outputs([sim_out])