In [55]:
# Import necessary libraries
import sys, os
import numpy as np
import pickle
import tensorflow as tf
import multiprocessing
import glob
from tqdm import tqdm
from waymo_open_dataset.protos import scenario_pb2
from waymo_types import object_type, lane_type, road_line_type, road_edge_type, signal_state, polyline_type

# Define the path to the dataset
dataset_path = '/scratch/rksing18/datasets/waymo/scenario/training'  # Replace with your dataset path
dataset_path = '/scratch/rksing18/datasets/waymo/scenario/testing'  # Replace with your dataset path
src_files = glob.glob(os.path.join(dataset_path, '*.tfrecord*'))
src_files.sort()


print(f"Found {len(src_files)} TFRecord files.")

# Function to parse a single TFRecord file
def parse_tfrecord(record):
    scenario = scenario_pb2.Scenario()
    scenario.ParseFromString(record.numpy())  # Deserialize the TFRecord string to a Scenario object
    return scenario.SerializeToString()  # Serialize the Scenario object to a string

# Wrapper function for tf.py_function
def parse_tfrecord_wrapper(record):
    serialized_scenario = tf.py_function(
        func=parse_tfrecord,
        inp=[record],
        Tout=tf.string  # Output is a serialized string
    )
    return serialized_scenario

# Create a TensorFlow dataset from the TFRecord files
raw_dataset = tf.data.TFRecordDataset(src_files, compression_type="")

# Parse the dataset using the wrapper
parsed_dataset = raw_dataset.map(parse_tfrecord_wrapper)

# Iterate through the dataset and print some information
for i, serialized_scenario in enumerate(parsed_dataset):
    scenario = scenario_pb2.Scenario()
    scenario.ParseFromString(serialized_scenario.numpy())  # Deserialize the string back to a Scenario object
    print(f"Scenario {i}: ID = {scenario.scenario_id}")
    if i == 5:  # Limit to 5 scenarios for demonstration
        break

Found 143 TFRecord files.
Scenario 0: ID = 53efd22f9e0bd276
Scenario 1: ID = 4b1f67e58e15e78c
Scenario 2: ID = 81cb7891dfe69679
Scenario 3: ID = 3e4bd1a8eab0cab7
Scenario 4: ID = 339f19cf983ffa68
Scenario 5: ID = 452c20d8f0e303cb


In [None]:
def decode_tracks_from_proto(tracks, tracks_to_predict):
    tracks_to_predict_ids = [x.track_index for x in tracks_to_predict]
    print(f"Tracks to predict: {tracks_to_predict_ids}")
    data = None
    for cur_data in tracks:
        cur_traj = [np.array([
            -1.0, -1.0, -1.0, -1.0,
            x.velocity_x, x.velocity_y,
            x.width, x.length, x.height,
            x.center_x, x.center_y, x.center_z,
            x.heading, 1.0 if cur_data.id in tracks_to_predict_ids else 0.0
        ], dtype=np.float32) for x in cur_data.states if x.valid]
        try:
            cur_traj = np.stack(cur_traj, axis=0)  # (num_timestamps, 10)
        except:
            print(f"Skipping {cur_data.id} as it has no valid states.")
            continue
        _data = np.hstack(( np.arange(1, cur_traj.shape[0] + 1, dtype=np.float32).reshape(-1, 1), 
                           np.full((cur_traj.shape[0], 1), cur_data.id, dtype=np.float32),
                           np.array([object_type.get(cur_data.object_type)]*cur_traj.shape[0]).reshape(-1,1), 
                           cur_traj))
        if data is None:
            data = _data
        else:
            data = np.vstack((data, _data))
    data = data[np.lexsort((data[:, 1].astype(float), data[:, 0].astype(float)))]
    return data

In [68]:
i=0
serialized_scenario = next(iter(parsed_dataset))
scenario = scenario_pb2.Scenario()
scenario.ParseFromString(serialized_scenario.numpy())  # Deserialize the string back to a Scenario object
print(f"\n=== Scenario {i}: ID = {scenario.scenario_id} ===")
try:
    track_infos = decode_tracks_from_proto(scenario.tracks, scenario.tracks_to_predict)
except:
    print("Error decoding tracks.")
    print(f"Scenario ID: {scenario.scenario_id}")
    print(scenario) 


=== Scenario 0: ID = 53efd22f9e0bd276 ===
Tracks to predict: [15, 8, 18]
[]
Skipping 294 as it has no valid states.
Error decoding tracks.
Scenario ID: 53efd22f9e0bd276
timestamps_seconds: 0.0
timestamps_seconds: 0.09956
timestamps_seconds: 0.1991
timestamps_seconds: 0.29862
timestamps_seconds: 0.39811
timestamps_seconds: 0.49753
timestamps_seconds: 0.59696
timestamps_seconds: 0.69635
timestamps_seconds: 0.7957
timestamps_seconds: 0.89505
timestamps_seconds: 0.99435
tracks {
  id: 259
  object_type: TYPE_VEHICLE
  states {
    center_x: 316.8645935058594
    center_y: 7953.30615234375
    center_z: 121.39199676987984
    length: 12.407690048217773
    width: 3.0261433124542236
    height: 3.3666744232177734
    heading: 1.554057240486145
    velocity_x: -0.0439453125
    velocity_y: -0.72265625
    valid: true
  }
  states {
    center_x: 316.8626708984375
    center_y: 7952.96142578125
    center_z: 121.43236809897911
    length: 11.541391372680664
    width: 3.012834310531616
    he

In [52]:
np.savetxt("output.txt", track_infos, fmt='%s')

In [49]:
scenario.tracks_to_predict[0].track_index

0

In [None]:
import numpy as np
import cv2
import os

def plot_map_features(map_infos, save_path="map.png", meta_path="meta.txt"):
    color_map = {
        'lane': [219, 225, 200],
        'road_line': [40, 40, 40],
        'road_edge': [0, 0, 0],
        'stop_sign': [0, 0, 255],
        'crosswalk': [116, 116, 116],
        'speed_bump': [53, 209, 243],
        'rest': [243, 240, 255]
    }

    dark_color_map = {
        'lane': [246, 9, 0],
        'road_line': [219, 225, 200],
        # 'road_line': [20, 20, 20],
        # 'road_edge': [179, 185, 160],
        'road_edge': [20, 20, 20],
        'stop_sign': [0, 0, 255],
        'crosswalk': [116, 116, 116],
        'speed_bump': [53, 209, 243],
        'rest': [20, 20, 20]
    }

    thickness_map = {
        'lane': 12,
        'road_line': 3,
        'road_edge': 2,
        'stop_sign': 4,
        'crosswalk': 5,
        'speed_bump': 4,
    }

    all_polylines = map_infos.get("all_polylines", None)
    if all_polylines is None or len(all_polylines) == 0:
        print("No polylines found.")
        return

    # Calculate bounds with margin
    margin = 0
    scale = 4
    x = all_polylines[:, 0]
    y = all_polylines[:, 1]
    x_min, x_max = np.round(x.min() - margin), np.round(x.max() + margin)
    y_min, y_max = np.round(y.min() - margin), np.round(y.max() + margin)
    x_size = int(np.round((x_max - x_min) * scale))
    y_size = int(np.round((y_max - y_min) * scale))

    def draw_map(canvas, cmap):
        for feature_type, color in cmap.items():
            if feature_type == 'rest':
                continue
            thickness = thickness_map.get(feature_type, 1)
            for item in map_infos[feature_type]:
                start, end = item['polyline_index']
                polyline = all_polylines[start:end]
                pts = np.round((polyline[:, :2] - [x_min, y_min]) * scale).astype(int)
                for i in range(len(pts) - 1):
                    cv2.line(canvas, tuple(pts[i]), tuple(pts[i + 1]), color=color, thickness=thickness)
        return canvas

    # Generate visual canvas
    vis_canvas = np.ones((y_size, x_size, 3), dtype=np.uint8) * np.array(color_map['rest'], dtype=np.uint8)
    vis_canvas = draw_map(vis_canvas, color_map)

    # Generate dark canvas
    dark_canvas = np.ones((y_size, x_size, 3), dtype=np.uint8) * np.array(dark_color_map['rest'], dtype=np.uint8)
    dark_canvas = draw_map(dark_canvas, dark_color_map)

    # Save metadata
    meta = np.array([x_min, y_min, scale])
    np.savetxt(meta_path, meta, fmt='%.2f')

    # Create separate filenames
    folder, base = os.path.split(save_path)
    name = os.path.splitext(base)[0]
    vis_path = os.path.join(folder, f"vis_{name}.png")
    dark_path = os.path.join(folder, f"{name}.png")

    # Save images
    cv2.imwrite(vis_path, vis_canvas)
    cv2.imwrite(dark_path, dark_canvas)
    print(f"Saved visual map to {vis_path}")
    print(f"Saved dark map to {dark_path}")