In [None]:
import os
import glob
import math

import numpy as np
import tensorflow as tf 
import matplotlib.pyplot as plt

In [None]:
files = glob.glob(os.path.join(os.getcwd(), "data", "*.tfrecord-*"))
print(f"Found {len(files)} tfrecord files...")

In [None]:
import tensorflow as tf

num_map_samples = 30000

# Example field definition
roadgraph_features = {
    'roadgraph_samples/dir': tf.io.FixedLenFeature(
        [num_map_samples, 3], tf.float32, default_value=None
    ),
    'roadgraph_samples/id': tf.io.FixedLenFeature(
        [num_map_samples, 1], tf.int64, default_value=None
    ),
    'roadgraph_samples/type': tf.io.FixedLenFeature(
        [num_map_samples, 1], tf.int64, default_value=None
    ),
    'roadgraph_samples/valid': tf.io.FixedLenFeature(
        [num_map_samples, 1], tf.int64, default_value=None
    ),
    'roadgraph_samples/xyz': tf.io.FixedLenFeature(
        [num_map_samples, 3], tf.float32, default_value=None
    ),
}
# Features of other agents.
state_features = {
    'state/id':
        tf.io.FixedLenFeature([128], tf.float32, default_value=None),
    'state/type':
        tf.io.FixedLenFeature([128], tf.float32, default_value=None),
    'state/is_sdc':
        tf.io.FixedLenFeature([128], tf.int64, default_value=None),
    'state/tracks_to_predict':
        tf.io.FixedLenFeature([128], tf.int64, default_value=None),
    'state/current/bbox_yaw':
        tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None),
    'state/current/height':
        tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None),
    'state/current/length':
        tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None),
    'state/current/timestamp_micros':
        tf.io.FixedLenFeature([128, 1], tf.int64, default_value=None),
    'state/current/valid':
        tf.io.FixedLenFeature([128, 1], tf.int64, default_value=None),
    'state/current/vel_yaw':
        tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None),
    'state/current/velocity_x':
        tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None),
    'state/current/velocity_y':
        tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None),
    'state/current/width':
        tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None),
    'state/current/x':
        tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None),
    'state/current/y':
        tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None),
    'state/current/z':
        tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None),
    'state/future/bbox_yaw':
        tf.io.FixedLenFeature([128, 80], tf.float32, default_value=None),
    'state/future/height':
        tf.io.FixedLenFeature([128, 80], tf.float32, default_value=None),
    'state/future/length':
        tf.io.FixedLenFeature([128, 80], tf.float32, default_value=None),
    'state/future/timestamp_micros':
        tf.io.FixedLenFeature([128, 80], tf.int64, default_value=None),
    'state/future/valid':
        tf.io.FixedLenFeature([128, 80], tf.int64, default_value=None),
    'state/future/vel_yaw':
        tf.io.FixedLenFeature([128, 80], tf.float32, default_value=None),
    'state/future/velocity_x':
        tf.io.FixedLenFeature([128, 80], tf.float32, default_value=None),
    'state/future/velocity_y':
        tf.io.FixedLenFeature([128, 80], tf.float32, default_value=None),
    'state/future/width':
        tf.io.FixedLenFeature([128, 80], tf.float32, default_value=None),
    'state/future/x':
        tf.io.FixedLenFeature([128, 80], tf.float32, default_value=None),
    'state/future/y':
        tf.io.FixedLenFeature([128, 80], tf.float32, default_value=None),
    'state/future/z':
        tf.io.FixedLenFeature([128, 80], tf.float32, default_value=None),
    'state/past/bbox_yaw':
        tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None),
    'state/past/height':
        tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None),
    'state/past/length':
        tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None),
    'state/past/timestamp_micros':
        tf.io.FixedLenFeature([128, 10], tf.int64, default_value=None),
    'state/past/valid':
        tf.io.FixedLenFeature([128, 10], tf.int64, default_value=None),
    'state/past/vel_yaw':
        tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None),
    'state/past/velocity_x':
        tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None),
    'state/past/velocity_y':
        tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None),
    'state/past/width':
        tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None),
    'state/past/x':
        tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None),
    'state/past/y':
        tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None),
    'state/past/z':
        tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None),
}

traffic_light_features = {
    'traffic_light_state/current/state':
        tf.io.FixedLenFeature([1, 16], tf.int64, default_value=None),
    'traffic_light_state/current/valid':
        tf.io.FixedLenFeature([1, 16], tf.int64, default_value=None),
    'traffic_light_state/current/x':
        tf.io.FixedLenFeature([1, 16], tf.float32, default_value=None),
    'traffic_light_state/current/y':
        tf.io.FixedLenFeature([1, 16], tf.float32, default_value=None),
    'traffic_light_state/current/z':
        tf.io.FixedLenFeature([1, 16], tf.float32, default_value=None),
    'traffic_light_state/past/state':
        tf.io.FixedLenFeature([10, 16], tf.int64, default_value=None),
    'traffic_light_state/past/valid':
        tf.io.FixedLenFeature([10, 16], tf.int64, default_value=None),
    'traffic_light_state/past/x':
        tf.io.FixedLenFeature([10, 16], tf.float32, default_value=None),
    'traffic_light_state/past/y':
        tf.io.FixedLenFeature([10, 16], tf.float32, default_value=None),
    'traffic_light_state/past/z':
        tf.io.FixedLenFeature([10, 16], tf.float32, default_value=None),
}

features_description = {}
features_description.update(roadgraph_features)
features_description.update(state_features)
features_description.update(traffic_light_features)

In [None]:
def filter_stationary_trajs(trajs: np.ndarray, stol: float = 1e-6) -> np.ndarray:
    disps = np.diff(trajs[:, :, :2], axis=1)
    stationary_mask = np.all(np.sum(disps**2, axis=2) < stol, axis=1)
    
    return trajs[~stationary_mask]

In [None]:
def parse_single_tfrecord(record: bytes, fields: list[str]) -> np.ndarray:
    """Parses a single tfrecord file into a numpy array."""
    parsed = tf.io.parse_single_example(record, features_description)
    vehicle_mask = (parsed["state/type"] == 1)
    concat_features = []

    for field in fields:
        timeranges = (parsed[f"state/past/{field}"].numpy(), parsed[f"state/current/{field}"].numpy(), parsed[f"state/future/{field}"].numpy())
        feat = np.concatenate(timeranges, axis=1)
        concat_features.append(feat[vehicle_mask])

    out_trajs = np.stack(concat_features, axis=2)

    # Process trajectories here.

    return out_trajs

In [None]:
from scipy.interpolate import interp1d

def resample_traj_2d(traj: np.ndarray, num_samples: int) -> np.ndarray:
    """Resamples a 2D trajectory to have a fixed number of samples."""
    valid_timestamps = np.array([0.1, 0.3, 0.5, 0.8])  # Non-continuous valid timestamps
    features = np.array([[1, 2, 3], [2, 3, 4], [3, 4, 5], [5, 6, 7]])  # Shape: (n_valid, n_features)

    # Target fixed-length timestamps
    fixed_length = 10
    t_min, t_max = valid_timestamps[0], valid_timestamps[-1]
    target_timestamps = np.linspace(t_min, t_max, num=fixed_length)

    # Interpolate each feature
    interpolated_features = []
    for i in range(features.shape[1]):
        interp_func = interp1d(valid_timestamps, features[:, i], kind='linear', fill_value="extrapolate")
        interpolated_features.append(interp_func(target_timestamps))

    # Stack interpolated features back into a single array
    interpolated_features = np.stack(interpolated_features, axis=1)

    # Output: interpolated_features has shape (fixed_length, n_features)
    print("Interpolated Features:")
    print(interpolated_features)

In [None]:
fields = ["timestamp_micros", "x", "y", "velocity_x", "velocity_y", "bbox_yaw", "valid"]
dataset = tf.data.TFRecordDataset(files[0], compression_type="")

parsed = parse_single_tfrecord(next(dataset.as_numpy_iterator()), fields)
print(parsed.shape)

In [None]:
traj_index = 0
traj = parsed[traj_index]

# Plot the past trajectory and the future trajectory
plt.figure(figsize=(6, 6))
plt.plot(traj[:80, 0], traj[:80, 1], label='Past Trajectory', color='blue', linestyle='--')
plt.plot(traj[80:, 0], traj[80:, 1], label='Actual Future Trajectory', color='red')
plt.xlabel('X Coordinate')
plt.ylabel('Y Coordinate')
plt.title(f"Predicted vs. Actual Future Trajectory (Example {traj_index})")
plt.axis('equal')
plt.legend()
plt.grid()
plt.show()