In [2]:
# CHANGE ME: Set this to the path to the NuPlan data directory
NUPLAN_DATA_PATH = "~/nuplan-v1.1/splits/mini/"

In [128]:
import typing
from dataclasses import dataclass
import numpy as np
import numpy.typing as npt

@dataclass
class State:
    heading: float
    velocity: npt.NDArray[np.float64]


Observation: typing.TypeAlias = tuple[State, State]
Action: typing.TypeAlias = tuple[float, float]

In [147]:
import os
trajectories: list[list[State]] = []

def getFiles(path: str) -> list[str]:
    path = os.path.expanduser(path)
    files = [os.path.join(path, f) for f in os.listdir(path)]
    return [f for f in files if os.path.isfile(f)]

file_iter = iter(getFiles(NUPLAN_DATA_PATH))

In [149]:
import sqlite3
from scipy.spatial.transform import Rotation
from scipy.spatial.transform import Slerp
from scipy.interpolate import UnivariateSpline


for file_path in file_iter:
    time_micros = []
    quaternions = []
    xs = []
    ys = []

    # gather headings and positions from sqlite3 database
    with sqlite3.connect(file_path) as conn:
        for (timestamp, qw, qx, qy, qz, x, y) in conn.cursor().execute("SELECT timestamp, qw, qx, qy, qz, x, y FROM ego_pose"):
            time_micros.append(timestamp)
            quaternions.append([qx, qy, qz, qw])
            xs.append(x)
            ys.append(y)

    if len(time_micros) == 0:
        print(f"Skipping {file_path} because it has no data")
        continue

    # convert time to seconds
    times = np.array(time_micros) / 1e6

    # create interpolation objects
    rotation_interpolator = Slerp(times, Rotation.from_quat(quaternions))
    x_interpolator = UnivariateSpline(times, xs, k=3, s=1)
    y_interpolator = UnivariateSpline(times, ys, k=3, s=1)    

    xvel_interpolator = x_interpolator.derivative()
    yvel_interpolator = y_interpolator.derivative()

    # sample at 10Hz
    sample_times = np.arange(times[0], times[-1], 0.1)

    # create heading and velocity arrays
    headings = rotation_interpolator(sample_times).as_euler('xyz')[:, 2]
    velocities = np.stack([xvel_interpolator(sample_times), yvel_interpolator(sample_times)], axis=1)

    # create trajectory
    trajectory = [State(heading=h, velocity=v) for h, v in zip(headings, velocities)]
    trajectories.append(trajectory)
    print(f"Loaded trajectory of len {len(trajectory)} from {file_path}")

Loaded trajectory of len 3620 from /home/fidgetsinner/nuplan-v1.1/splits/mini/2021.10.05.07.10.04_veh-52_01442_01802.db
Loaded trajectory of len 3844 from /home/fidgetsinner/nuplan-v1.1/splits/mini/2021.06.09.12.39.51_veh-26_05620_06003.db
Loaded trajectory of len 3980 from /home/fidgetsinner/nuplan-v1.1/splits/mini/2021.08.30.14.54.34_veh-40_00439_00835.db
Loaded trajectory of len 1850 from /home/fidgetsinner/nuplan-v1.1/splits/mini/2021.10.11.02.57.41_veh-50_00352_00535.db
Loaded trajectory of len 4520 from /home/fidgetsinner/nuplan-v1.1/splits/mini/2021.06.28.15.02.02_veh-38_02398_02848.db
Loaded trajectory of len 3940 from /home/fidgetsinner/nuplan-v1.1/splits/mini/2021.10.06.07.26.10_veh-52_00006_00398.db
Loaded trajectory of len 4620 from /home/fidgetsinner/nuplan-v1.1/splits/mini/2021.06.09.17.37.09_veh-12_00404_00864.db
Loaded trajectory of len 5510 from /home/fidgetsinner/nuplan-v1.1/splits/mini/2021.07.16.00.51.05_veh-17_01352_01901.db
Loaded trajectory of len 4200 from /home

In [150]:
import pickle
import os
import lzma

# pickle the trajectories
if not os.path.exists('nuplan_data/trajectories.pkl.xz'):
    with lzma.open('nuplan_data/trajectories.pkl.xz', 'wb') as f:
        pickle.dump(trajectories, f)

In [None]:
import pickle
import lzma

# unpickle and decompress
if "trajectories" not in locals():
    with lzma.open('nuplan_data/trajectories.pkl.xz', 'rb') as f:
        trajectories:list[list[State]] = pickle.load(f)

In [151]:
import random

idm_data: list[Observation] = []
for states in trajectories:
    for i in range(len(states)-1):
        idm_data.append((states[i], states[i+1]))

# 90:10 train-validation split
random.seed(0)
random.shuffle(idm_data)
idm_train_data = idm_data[:int(len(idm_data)*0.9)]
idm_validation_data = idm_data[int(len(idm_data)*0.9):]

In [152]:
print("train data:", len(idm_train_data))
print("validation data:", len(idm_validation_data))

train data: 230236
validation data: 25582


In [None]:
import matplotlib.pyplot as plt
import metadrive
from metadrive import MetaDriveEnv
import gymnasium as gym
import torch
import torch.nn as nn
import torch.nn.functional as F



def deviceof(m: nn.Module) -> torch.device:
    """
    Get the device of the given module
    """
    return next(m.parameters()).device

def normalize_angle(angle: float) -> float:
    """
    Normalize the angle to [-pi, pi)
    """
    return (angle + np.pi) % (2 * np.pi) - np.pi

def get_metadrive_state(env: MetaDriveEnv) -> State:
    return State(heading=env.vehicle.heading_theta, velocity=env.vehicle.velocity[:2])

def next_state(env: MetaDriveEnv, s: State, a: Action) -> State:
    """
    runs the policy and returns the total reward
    """
    # reset
    env.reset()
    env.vehicle.set_position(env.vehicle.position, height=0.49)

    # allow car to settle
    for _ in range(5):
        env.step([0,0])

    # set the initial state
    env.vehicle.set_velocity(s.velocity)
    env.vehicle.set_heading_theta(s.heading)
    
    # run the simulator
    env.step(a)

    # get the new state
    s_prime = get_metadrive_state(env)

    # allow car to settle (if rendering)
    if env.config.use_render:
        for _ in range(10):
            env.step([0,0])

    return s_prime

def gen_random_action() -> Action:
    """
    Generates a random action with probabilities that are similar to that are found in the waymo dataset
    """
    a = tuple(np.random.normal(0, 0.5, 2))
    return a

def state_batch_to_tensor(states: list[State], device: torch.device) -> torch.Tensor:
    """
    Reshape the state from State to a tensor of shape (batch_size, 4)
    """
    velocities = torch.tensor(np.stack([st.velocity for st in states]), dtype=torch.float32, device=device)
    heading = torch.tensor([st.heading for st in states], dtype=torch.float32, device=device)
    return torch.cat([velocities, torch.cos(heading).unsqueeze(1), torch.sin(heading).unsqueeze(1)], dim=1)

def action_batch_to_tensor(actions: list[Action], device: torch.device) -> torch.Tensor:
    """
    Reshape the action from Action to a tensor of shape (batch_size, 2)
    """
    return torch.tensor(np.stack(actions), dtype=torch.float32, device=device)

def obs_batch_to_tensor(obs: list[Observation], device: torch.device) -> torch.Tensor:
    """
    Reshape the observation from tuple[State, State] to a tensor of shape (batch_size, 4, 2)
    """

    observations = []

    for st0, st1 in obs:
        observations.append(np.array([
            [st0.velocity[0], st1.velocity[0]], 
            [st0.velocity[1], st1.velocity[1]],
            [np.cos(st0.heading), np.cos(st1.heading)],
            [np.sin(st0.heading), np.sin(st1.heading)],
        ]))

    return torch.tensor(np.stack(observations), dtype=torch.float32, device=device)

In [None]:
def generate_data(s0_batch: list[State]) -> list[tuple[State, Action, State]]:
    env = MetaDriveEnv(config={"on_continuous_line_done": False, "use_render": False})
    dataset: list[tuple[State, Action, State]] = []
    for s0 in s0_batch:
        a = gen_random_action()
        s1 = next_state(env, s0, a)
        dataset.append((s0, a, s1))
    env.close()
    return dataset

In [None]:
from concurrent.futures import ProcessPoolExecutor
from metadrive import MetaDriveEnv

MAX_WORKERS = 16

mm_train_data: list[tuple[State, Action, State]] = []

with ProcessPoolExecutor(max_workers=MAX_WORKERS) as executor:
    batch_size, leftover_size = divmod(len(idm_train_data), MAX_WORKERS)
    
    # Distribute the data evenly among workers
    n_scenarios_per_worker = [batch_size]*MAX_WORKERS
    n_scenarios_per_worker[0] += leftover_size

    # Distribute the initial states among workers
    idm_train_data_iter = iter(idm_train_data)
    s0_batch_per_worker = [[next(idm_train_data_iter)[0] for _ in range(n_scenarios)] for n_scenarios in n_scenarios_per_worker]

    # Generate the data in parallel
    for batch in executor.map(generate_data, s0_batch_per_worker):
        mm_train_data.extend(batch)

In [None]:
from concurrent.futures import ProcessPoolExecutor
from metadrive import MetaDriveEnv

MAX_WORKERS = 16

mm_validation_data: list[tuple[State, Action, State]] = []

with ProcessPoolExecutor(max_workers=MAX_WORKERS) as executor:
    batch_size, leftover_size = divmod(len(idm_validation_data), MAX_WORKERS)
    
    # Distribute the data evenly among workers
    n_scenarios_per_worker = [batch_size]*MAX_WORKERS
    n_scenarios_per_worker[0] += leftover_size

    # Distribute the initial states among workers
    idm_train_data_iter = iter(idm_train_data)
    s0_batch_per_worker = [[next(idm_train_data_iter)[0] for _ in range(n_scenarios)] for n_scenarios in n_scenarios_per_worker]

    # Generate the data in parallel
    for batch in executor.map(generate_data, s0_batch_per_worker):
        mm_validation_data.extend(batch)

In [None]:
import pickle
import os
import lzma

# pickle the data
if not os.path.exists('waymo_data/mm_train_data.pkl.xz'):
    with lzma.open('waymo_data/mm_train_data.pkl.xz', 'wb') as f:
        pickle.dump(mm_train_data, f)

if not os.path.exists('waymo_data/mm_validation_data.pkl.xz'):
    with lzma.open('waymo_data/mm_validation_data.pkl.xz', 'wb') as f:
        pickle.dump(mm_validation_data, f)

