# EGO Vehicle Trajectory Endpoint Clustering

Extract normalized trajectory endpoints from the last 6 seconds of EGO vehicle trajectories and perform clustering analysis

In [None]:
from pathlib import Path
import sys

project_root = Path.cwd().resolve()
if not (project_root / "src").exists():
    for parent in project_root.parents:
        if (parent / "src").exists():
            project_root = parent
            break

if (project_root / "src").exists() and str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))

import torch
import numpy as np
from sklearn.cluster import KMeans
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
import seaborn as sns
from src.utils.data_visualization import get_available_files, DataLoader

sns.set_style('whitegrid')
print("âœ… Libraries loaded")

## 1. Load data and extract EGO features

In [None]:
# Configuration
DATA_PATH = Path('/data1/xiaowei/code/DeMo/data/DeMo_processed')
N_FILES = 1000
TOTAL_TIMESTEPS = 110  # 11 seconds * 10 Hz
HISTORY_TIMESTEPS = 50  # First 5 seconds
FUTURE_TIMESTEPS = 60   # Last 6 seconds

# Get files
train_files = get_available_files(DATA_PATH, 'train')[:N_FILES]
print(f"ðŸ“‚ Loaded {len(train_files)} files")

In [None]:
def normalize_trajectory(positions, heading_angle):
    """
    Normalize trajectory: place the 5s position at the origin and align heading to +Y.

    Args:
        positions: numpy array of shape (60, 2) - [x, y] positions
        heading_angle: float heading (radians) at the normalization timestamp

    Returns:
        numpy array of shape (60, 2) with normalized coordinates
    """
    # Translation: move 5s position to origin
    start_pos = positions[0]
    positions_translated = positions - start_pos

    # Convert to radians if the dataset stores angles in degrees
    if np.abs(heading_angle) > np.pi:
        heading_angle = np.deg2rad(heading_angle)

    # Align heading to positive Y axis
    rotation_angle = np.pi / 2 - heading_angle
    cos_theta = np.cos(rotation_angle)
    sin_theta = np.sin(rotation_angle)
    rotation_matrix = np.array([
        [cos_theta, -sin_theta],
        [sin_theta, cos_theta]
    ])

    positions_aligned = positions_translated @ rotation_matrix.T
    return positions_aligned


def extract_ego_trajectory(file_path):
    """
    Extract normalized full trajectory from the last 6 seconds of EGO vehicle

    Args:
        file_path: Path to the data file

    Returns:
        trajectory: numpy array of shape (60, 2) - [normalized_x, normalized_y] or None if invalid
    """
    # Use DataLoader to load and extract ego data
    loader = DataLoader()
    data = loader.load_scenario(file_path)

    if data is None:
        return None

    # Get focal agent index
    focal_idx = loader.current_metadata['focal_agent_idx']

    # Use plot_ego_velocity_analysis to get comprehensive ego data
    from src.utils.data_visualization import plot_ego_velocity_analysis

    analysis_data = plot_ego_velocity_analysis(loader, show_acceleration=False, time_window=None)

    if analysis_data is None:
        return None

    # Get full trajectory data
    positions_full = analysis_data['positions']
    timesteps_full = analysis_data['timesteps']
    angles_full = analysis_data['angles']

    # Filter for future timesteps (timestep 50-110, i.e., last 6 seconds)
    future_mask = timesteps_full >= HISTORY_TIMESTEPS

    if not future_mask.any():
        return None

    positions = positions_full[future_mask]
    angles = angles_full[future_mask]

    # Check if we have exactly 60 timesteps
    if len(positions) != FUTURE_TIMESTEPS or len(angles) != FUTURE_TIMESTEPS:
        return None

    heading_start = float(angles[0])

    # Normalize trajectory and return full trajectory
    trajectory = normalize_trajectory(positions, heading_start)

    return trajectory




In [None]:
# Extract all trajectories (we'll save full trajectories for faster visualization)
all_trajectories = []
valid_files = []

print("Extracting normalized trajectories...")
for i, file_path in enumerate(train_files):
    if (i + 1) % 100 == 0:
        print(f"  Processing {i + 1}/{len(train_files)}...")
    
    trajectory = extract_ego_trajectory(file_path)
    if trajectory is not None:
        all_trajectories.append(trajectory)
        valid_files.append(file_path.name)

# Convert to numpy array
all_trajectories = np.array(all_trajectories)  # Shape: (n_samples, 60, 2)

# Extract endpoints for clustering
all_endpoints = all_trajectories[:, -1, :]  # Shape: (n_samples, 2)

print(f"\nâœ… Successfully extracted {len(all_trajectories)} trajectories")
print(f"Trajectory dimensions: {all_trajectories.shape}")
print(f"Endpoint dimensions: {all_endpoints.shape}")
print(f"ðŸ’¡ Clustering will be based on endpoints: [normalized_x, normalized_y]")
print(f"ðŸ’¡ Normalization: 5s position at origin (0,0), 5s heading aligned to +Y axis")

In [None]:
# K-Means clustering on trajectory endpoints
N_CLUSTERS = 6

# Standardize the endpoints
scaler = StandardScaler()
endpoints_scaled = scaler.fit_transform(all_endpoints)

# Apply K-Means clustering
kmeans = KMeans(n_clusters=N_CLUSTERS, random_state=42, n_init=10)
cluster_labels = kmeans.fit_predict(endpoints_scaled)

print(f"âœ… K-Means clustering completed with {N_CLUSTERS} clusters")
print(f"\nðŸ“Š Cluster Distribution:")
for i in range(N_CLUSTERS):
    count = np.sum(cluster_labels == i)
    print(f"  Cluster {i}: {count} samples ({count/len(cluster_labels)*100:.1f}%)")

## 2. Visualize endpoint clustering results

In [None]:
# 1. Endpoint scatter plot with cluster colors
plt.figure(figsize=(12, 10))

# Downsample for faster rendering if needed
MAX_POINTS = 4000
rng = np.random.default_rng(42)
if len(all_endpoints) > MAX_POINTS:
    sample_idx = rng.choice(len(all_endpoints), size=MAX_POINTS, replace=False)
    endpoints_plot = all_endpoints[sample_idx]
    labels_plot = cluster_labels[sample_idx]
else:
    endpoints_plot = all_endpoints
    labels_plot = cluster_labels

colors = plt.cm.tab10(labels_plot)
plt.scatter(endpoints_plot[:, 0], endpoints_plot[:, 1],
            c=colors, s=36, alpha=0.6, linewidth=0)

cluster_centers_original = scaler.inverse_transform(kmeans.cluster_centers_)
for i, center in enumerate(cluster_centers_original):
    plt.scatter(center[0], center[1],
                marker='X', s=400, c=[plt.cm.tab10(i)],
                edgecolors='black', linewidth=1.5, zorder=10)
    plt.text(center[0], center[1], f'C{i}',
             fontsize=14, fontweight='bold', ha='center', va='center',
             bbox=dict(boxstyle='round,pad=0.4', facecolor='white', alpha=0.9))

plt.scatter(0, 0, marker='o', s=200, c='green',
            edgecolors='black', linewidth=1.5, zorder=10, label='Start (Origin)')

plt.xlabel('Normalized X Position (m)', fontsize=14)
plt.ylabel('Normalized Y Position (m)', fontsize=14)
plt.title('Trajectory Endpoint Clustering (Normalized to Origin)', fontsize=16, fontweight='bold')
plt.legend(fontsize=12)
plt.grid(True, alpha=0.3)
plt.axis('equal')
plt.tight_layout()
plt.show()

if len(endpoints_plot) < len(all_endpoints):
    print(f"\nShowing {len(endpoints_plot)} of {len(all_endpoints)} endpoints (random sample)")
else:
    print(f"\nShowing all {len(all_endpoints)} endpoints")
print("\nCluster centers (in normalized coordinates):")
for i, center in enumerate(cluster_centers_original):
    print(f"  Cluster {i}: x={center[0]:6.2f}m, y={center[1]:6.2f}m")

In [None]:
# 2. Cluster distribution bar chart
fig, ax = plt.subplots(1, 1, figsize=(10, 6))

cluster_counts = [np.sum(cluster_labels == i) for i in range(N_CLUSTERS)]
colors_bar = [plt.cm.tab10(i) for i in range(N_CLUSTERS)]

ax.bar(range(N_CLUSTERS), cluster_counts, color=colors_bar, alpha=0.7, edgecolor='black')
ax.set_xticks(range(N_CLUSTERS))
ax.set_xticklabels([f'Cluster {i}' for i in range(N_CLUSTERS)])
ax.set_xlabel('Cluster ID', fontsize=12)
ax.set_ylabel('Sample Count', fontsize=12)
ax.set_title('Sample Distribution Across Clusters', fontsize=14, fontweight='bold')
ax.grid(axis='y', alpha=0.3)

# Add count labels on bars
for i, count in enumerate(cluster_counts):
    ax.text(i, count + max(cluster_counts)*0.01, str(count), 
            ha='center', va='bottom', fontsize=11, fontweight='bold')

plt.tight_layout()
plt.show()

In [None]:
# 3. Visualize full trajectories for each cluster (OPTIMIZED for speed)
from matplotlib.collections import LineCollection

fig, axes = plt.subplots(2, 3, figsize=(18, 12))
axes = axes.flatten()

print("\nVisualizing trajectories from pre-loaded data...")
print("ðŸ’¡ Using LineCollection for 10x faster rendering")

for cluster_id in range(N_CLUSTERS):
    ax = axes[cluster_id]
    
    # Get trajectories for this cluster
    cluster_mask = cluster_labels == cluster_id
    cluster_trajectories = all_trajectories[cluster_mask]
    
    # Limit number of trajectories to plot
    n_plot = min(50, len(cluster_trajectories))
    
    # Prepare line segments for LineCollection (MUCH faster than individual plot() calls)
    segments = []
    for trajectory in cluster_trajectories[:n_plot]:
        # Downsample to every 3rd point: 60 -> 20 points
        downsampled = trajectory[::3, :]
        # Create segments: list of (x,y) points
        segments.append(downsampled)
    
    # Create LineCollection - this draws ALL lines in one operation!
    lc = LineCollection(segments, colors=plt.cm.tab10(cluster_id), 
                       alpha=0.3, linewidths=1)
    ax.add_collection(lc)
    
    # Set plot limits based on data
    all_points = np.vstack(segments)
    ax.set_xlim(all_points[:, 0].min() - 5, all_points[:, 0].max() + 5)
    ax.set_ylim(all_points[:, 1].min() - 5, all_points[:, 1].max() + 5)
    
    # Plot cluster center endpoint
    center = cluster_centers_original[cluster_id]
    ax.scatter(center[0], center[1], 
              marker='X', s=400, c=[plt.cm.tab10(cluster_id)],
              edgecolors='black', linewidth=2, zorder=10)
    
    # Plot origin
    ax.scatter(0, 0, marker='o', s=150, c='green',
              edgecolors='black', linewidth=1.5, zorder=10)
    
    ax.set_xlabel('Normalized X (m)', fontsize=10)
    ax.set_ylabel('Normalized Y (m)', fontsize=10)
    ax.set_title(f'Cluster {cluster_id}\n({np.sum(cluster_mask)} samples, showing {n_plot})', 
                fontsize=12, fontweight='bold')
    ax.grid(True, alpha=0.3)
    ax.set_aspect('equal', adjustable='box')

plt.tight_layout()
plt.show()
print("âœ… Trajectory visualization completed (optimized rendering)")

In [None]:
# 4. Endpoint clustering statistics
print("\n" + "="*60)
print("ðŸ“Š Endpoint Clustering Statistics")
print("="*60)

for cluster_id in range(N_CLUSTERS):
    cluster_mask = cluster_labels == cluster_id
    cluster_endpoints = all_endpoints[cluster_mask]
    
    center = cluster_centers_original[cluster_id]
    mean_x = np.mean(cluster_endpoints[:, 0])
    mean_y = np.mean(cluster_endpoints[:, 1])
    std_x = np.std(cluster_endpoints[:, 0])
    std_y = np.std(cluster_endpoints[:, 1])
    
    # Calculate distance from origin
    distances = np.sqrt(cluster_endpoints[:, 0]**2 + cluster_endpoints[:, 1]**2)
    mean_distance = np.mean(distances)
    
    # Calculate angle from y-axis (main direction)
    angles = np.arctan2(cluster_endpoints[:, 0], cluster_endpoints[:, 1]) * 180 / np.pi
    mean_angle = np.mean(angles)
    
    print(f"\nCluster {cluster_id}:")
    print(f"  Count: {np.sum(cluster_mask)} samples")
    print(f"  Center: ({center[0]:6.2f}, {center[1]:6.2f}) m")
    print(f"  Mean endpoint: ({mean_x:6.2f}, {mean_y:6.2f}) m")
    print(f"  Std X/Y: ({std_x:5.2f}, {std_y:5.2f}) m")
    print(f"  Avg distance from origin: {mean_distance:6.2f} m")
    print(f"  Avg angle from forward: {mean_angle:6.1f}Â°")

print("="*60)