In [1]:
%load_ext autoreload
%autoreload 2

import numpy as np
import os
import pickle
import lzma
from metadrive import MetaDriveEnv
from concurrent.futures import ProcessPoolExecutor
from utils.env import State, Transition, Action, next_state, normalize_angle


In [2]:
def generate_2d_velocity(rng: np.random.Generator, v:float) -> tuple[float, float]:
    """
    Generates a random 2D velocity vector with a magnitude between 0 and v
    """
    mag = rng.uniform(0, v)
    heading = rng.uniform(-np.pi, np.pi)
    return mag * np.cos(heading), mag * np.sin(heading)

def gen_random_state_action(rng: np.random.Generator) -> tuple[State, Action]:
    """
    Generates a random state with probabilities that are similar to that are found in the waymo dataset
    """
    if rng.uniform(0, 1) < 0.5:
        # parked 
        vel = generate_2d_velocity(rng, 1)
        heading = rng.uniform(-np.pi, np.pi)
        action = Action(rng.uniform(-1, 1), rng.normal(0, 0.1))
    else:
        # moving
        vel = generate_2d_velocity(rng, 35)
        heading = normalize_angle(np.arctan2(vel[1], vel[0]) + rng.normal(0, 0.2))
        action = Action(rng.normal(0, 0.3), rng.normal(0.2, 0.3))

    return State(heading=heading, velocity=vel), action

def generate_data(n: int, thread_id: int) -> list[Transition]:
    rng = np.random.default_rng(thread_id)
    env = MetaDriveEnv(config={"on_continuous_line_done": False, "use_render": False})
    dataset: list[Transition] = []
    for _ in range(n):
        s0, a = gen_random_state_action(rng)
        s1 = next_state(env, s0, a)
        dataset.append(Transition(s0, a, s1))
    env.close()
    return dataset

In [3]:
NUM_SCENARIOS = 2_000_000
MAX_WORKERS = 16

transition_data: list[Transition] = []

seed = 16

with ProcessPoolExecutor(max_workers=MAX_WORKERS) as executor:
    batch_size, leftover_size = divmod(NUM_SCENARIOS, MAX_WORKERS)
    
    # Distribute the data evenly among workers
    n_scenarios_per_worker = [batch_size]*MAX_WORKERS
    for i in range(leftover_size):
        n_scenarios_per_worker[i] += 1

    # generate thread ids
    thread_ids = [seed + i for i in range(MAX_WORKERS)]

    # Generate the data in parallel
    for batch in executor.map(generate_data, n_scenarios_per_worker, thread_ids):
        transition_data.extend(batch)

[38;20m[INFO] MetaDrive version: 0.4.1.2[0m
[38;20m[INFO] MetaDrive version: 0.4.1.2[0m
[38;20m[INFO] MetaDrive version: 0.4.1.2[0m
[38;20m[INFO] MetaDrive version: 0.4.1.2[0m
[38;20m[INFO] MetaDrive version: 0.4.1.2[0m
[38;20m[INFO] MetaDrive version: 0.4.1.2[0m
[38;20m[INFO] Sensors: [][0m
[38;20m[INFO] MetaDrive version: 0.4.1.2[0m
[38;20m[INFO] Sensors: [][0m
[38;20m[INFO] MetaDrive version: 0.4.1.2[0m
[38;20m[INFO] Sensors: [][0m
[38;20m[INFO] MetaDrive version: 0.4.1.2[0m
[38;20m[INFO] Sensors: [][0m
[38;20m[INFO] MetaDrive version: 0.4.1.2[0m
[38;20m[INFO] Render Mode: none[0m
[38;20m[INFO] Sensors: [][0m
[38;20m[INFO] Render Mode: none[0m
[38;20m[INFO] Sensors: [][0m
[38;20m[INFO] Sensors: [][0m
[38;20m[INFO] MetaDrive version: 0.4.1.2[0m
[38;20m[INFO] Render Mode: none[0m
[38;20m[INFO] MetaDrive version: 0.4.1.2[0m
[38;20m[INFO] Sensors: [][0m
[38;20m[INFO] Render Mode: none[0m
[38;20m[INFO] Sensors: [][0m
[38;20m[INFO] MetaDr

In [4]:
# save data
if not os.path.exists("./data/transition_data.pkl.xz"):
    with lzma.open("./data/transition_data.pkl.xz", "wb") as f:
        pickle.dump(transition_data, f)