In [1]:
best_model = "ensemble_model.pt"

In [2]:
import os

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import random

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch_geometric.data import Data, Batch
import tqdm

In [3]:
train_file = np.load('../cse-251-b-2025/train.npz')

train_data = train_file['data']
print("train_data's shape", train_data.shape)
test_file = np.load('../cse-251-b-2025/test_input.npz')

test_data = test_file['data']
print("test_data's shape", test_data.shape)

train_data's shape (10000, 50, 110, 6)
test_data's shape (2100, 50, 50, 6)


In [4]:
class TrajectoryDatasetTrain_AngularAcc(Dataset):
    def __init__(self, data, scale=10.0, augment=True):
        self.data = data
        self.scale = scale
        self.augment = augment
        self.dt = 0.1  # seconds

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        scene = self.data[idx]
        hist = scene[:, :50, :4].copy()  # (agents=50, time_seq=50, 5)
        future = torch.tensor(scene[0, 50:, 0:4].copy(), dtype=torch.float32)  # (60, 2)

        if self.augment:
            if np.random.rand() < 0.5:
                theta = np.random.uniform(-np.pi, np.pi)
                R = np.array([[np.cos(theta), -np.sin(theta)],
                              [np.sin(theta),  np.cos(theta)]], dtype=np.float32)
                hist[..., :2] = hist[..., :2] @ R
                hist[..., 2:4] = hist[..., 2:4] @ R
                future[..., :2] = future[..., :2] @ R
                future[..., 2:4] = future[..., 2:4] @ R
            if np.random.rand() < 0.5:
                hist[..., 0] *= -1
                hist[..., 2] *= -1
                future[:, 0] *= -1
                future[:, 2] *= -1

        v = hist[..., 2:4]  # velocity
        a = np.zeros_like(v)
        a[..., 1:] = (v[..., 1:] - v[..., :-1]) / self.dt

        theta = np.arctan2(v[..., 1], v[..., 0])[..., np.newaxis]
        omega = np.zeros_like(theta)
        omega[..., 1:] = (theta[..., 1:] - theta[..., :-1]) / self.dt

        alpha = np.zeros_like(omega)
        alpha[..., 1:] = (omega[..., 1:] - omega[..., :-1]) / self.dt

        hist_aug = np.concatenate([hist, a, theta, omega, alpha], axis=-1)  # shape (50, 50, 9)

        origin = hist_aug[0, 49, :2].copy()
        hist_aug[..., :2] -= origin
        hist_aug[..., :4] /= self.scale
        hist_aug[..., 4:6] /= self.scale / self.dt        # ax, ay
        hist_aug[..., 6] /= np.pi                         # θ ∈ [-π, π]
        hist_aug[..., 7] /= (np.pi / self.dt)             # ω
        hist_aug[..., 8] /= (np.pi / (self.dt ** 2))      # α
        future[..., :2] = future[..., :2] - origin
        future = future / self.scale

        return Data(
            x=torch.tensor(hist_aug, dtype=torch.float32),
            y=future,
            origin=torch.tensor(origin, dtype=torch.float32).unsqueeze(0),
            scale=torch.tensor(self.scale, dtype=torch.float32),
        )
        
        
class TrajectoryDatasetTest(Dataset):
    def __init__(self, data, scale=10.0):
        self.data = data
        self.scale = scale
        self.dt = 0.1

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        hist = self.data[idx][:, :50, :4].copy()  # (50, 50, 4: [x, y, vx, vy])

        v = hist[..., 2:4]
        a = np.zeros_like(v)
        a[..., 1:] = (v[..., 1:] - v[..., :-1]) / self.dt

        theta = np.arctan2(v[..., 1], v[..., 0])[..., np.newaxis]
        omega = np.zeros_like(theta)
        omega[..., 1:] = (theta[..., 1:] - theta[..., :-1]) / self.dt

        alpha = np.zeros_like(omega)
        alpha[..., 1:] = (omega[..., 1:] - omega[..., :-1]) / self.dt

        hist_aug = np.concatenate([hist, a, theta, omega, alpha], axis=-1)  # shape (50, 50, 9)

        origin = hist_aug[0, 49, :2].copy()
        hist_aug[..., :2] -= origin
        hist_aug[..., :4] /= self.scale
        hist_aug[..., 4:6] /= self.scale / self.dt
        hist_aug[..., 6] /= np.pi
        hist_aug[..., 7] /= (np.pi / self.dt)
        hist_aug[..., 8] /= (np.pi / (self.dt ** 2))

        return Data(
            x=torch.tensor(hist_aug, dtype=torch.float32),
            origin=torch.tensor(origin, dtype=torch.float32).unsqueeze(0),
            scale=torch.tensor(self.scale, dtype=torch.float32),
        )

In [5]:
torch.manual_seed(251)
np.random.seed(42)

scale = 7.0

N = len(train_data)
val_size = int(0.1 * N)
train_size = N - val_size

train_dataset = TrajectoryDatasetTrain_AngularAcc(train_data[:train_size], scale=scale, augment=True)
val_dataset = TrajectoryDatasetTrain_AngularAcc(train_data[train_size:], scale=scale, augment=False)

train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=lambda x: Batch.from_data_list(x))
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=lambda x: Batch.from_data_list(x))

# Set device for training speedup
if torch.backends.mps.is_available():
    device = torch.device('mps')
    print("Using Apple Silicon GPU")
elif torch.cuda.is_available():
    device = torch.device('cuda')
    print("Using CUDA GPU")
else:
    device = torch.device('cpu')

Using Apple Silicon GPU


In [46]:
def convert_xy(pred_vel, origin=None):
    dt = 0.1  # seconds per step
    if origin:
        pred_pos = [origin]  # list of (B, 1, 2)
        for t in range(60):
            next_pos = pred_pos[-1] + pred_vel[:, t:t+1, :] * dt  # (B, 1, 2)
            pred_pos.append(next_pos)
        
        # Concatenate positions across time steps
        pred_xy = torch.cat(pred_pos[1:], dim=1)  # skip initial origin, get (B, 60, 2)
    else:
        pred_pos = [0]
        
        for t in range(60):
            next_vel = (pred_vel[:, t:t+1, :] - pred_pos[-1]) /dt
            # next_pos = pred_pos[-1] + pred_vel[:, t:t+1, :] * dt  # (B, 1, 2)
            pred_pos.append(next_vel)
        pred_vxvy = torch.cat(pred_pos[1:], dim=1)  # skip initial origin, get (B, 60, 2)
    return pred_vxvy

In [57]:
def convert_VXVY(pred_pos, origin=None):
    dt = 0.1
    pred_pos = pred_pos.reshape(-1, 60, 2)
    B = pred_pos.shape[0]
    zero_pos = torch.zeros((B, 1, 2), device=pred_pos.device, dtype=pred_pos.dtype)
    pos = torch.cat([zero_pos, pred_pos], dim=1)
    vel = torch.zeros_like(pos)                     # shape (B, 61, 2), with vel[:,0,:]=0
    vel[:, 1:, :] = (pos[:, 1:, :] - pos[:, :-1, :]) / dt
    pred_vxvy = vel[:, 1:, :]     # skip initial origin, get (B, 60, 2)
    return pred_vxvy

In [58]:
# Example of basic model that should work
class SimpleLSTM(nn.Module):
    def __init__(self, input_dim=5, hidden_dim=512, output_dim=60*2):
        super(SimpleLSTM, self).__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True)
        
        # Add multi-layer prediction head for better results
        self.fc1 = nn.Linear(hidden_dim, hidden_dim)
        self.dropout = nn.Dropout(0.1)  # Add dropout for regularization
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        
        # Initialize weights properly
        for name, param in self.named_parameters():
            if 'weight' in name:
                nn.init.xavier_normal_(param)
            elif 'bias' in name:
                nn.init.constant_(param, 0.0)
        
    def forward(self, x):
        # x = data.x[..., :5]
        # x = x.reshape(-1, 50, 50, 5)  # (batch_size, num_agents, seq_len, input_dim)
        x = x[:, 0, :, :]  # Only consider ego agent (index 0)
        
        # Process through LSTM
        lstm_out, _ = self.lstm(x)
        
        # Extract final hidden state
        features = lstm_out[:, -1, :]
        
        # Process through prediction head
        features = self.relu(self.fc1(features))
        features = self.dropout(features)
        out = self.fc2(features)
        
        # Reshape to (batch_size, 60, 2)
        out_vel = convert_VXVY(out.view(-1, 60, 2))
        return out_vel

In [59]:
# Example of basic model that should work
class VelocityModel(nn.Module):
    # 2 layers 9.77 val
    def __init__(self, input_dim=5, hidden_dim=512, output_dim=60*2, num_layers=1):
        super(VelocityModel, self).__init__()
        self.ego_encoder = nn.LSTM(input_size=input_dim, hidden_size=hidden_dim, num_layers=num_layers, batch_first=True)
        self.neighbor_encoder = nn.LSTM(input_size=input_dim, hidden_size=hidden_dim, num_layers=num_layers, batch_first=True)
        
        # Add multi-layer prediction head for better results
        self.fc1 = nn.Linear(hidden_dim*2, hidden_dim*2)
        # TODO: remove dropout?
        self.dropout = nn.Dropout(0.1)  # Add dropout for regularization
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim*2, output_dim)
        
        # Initialize weights properly
        # for name, param in self.named_parameters():
        #     if 'weight' in name:
        #         nn.init.xavier_normal_(param)
        #     elif 'bias' in name:
        #         nn.init.constant_(param, 0.0)
        
    def forward(self, data):
        # x = data.x[..., :5]
        x = data.reshape(-1, 50, 50, 5)  # (batch_size, num_agents, seq_len, input_dim)
        batch_size = x.size(0)
        # x = x[:, 0, :, :]  # Only consider ego agent (index 0)
        
        # EGO AGENT
        ego_traj = x[:, 0, :, :]  # (batch, 50, 5)
        # Process through LSTM
        ego_lstm_out, _ = self.ego_encoder(ego_traj)
        # Extract final hidden state
        ego_features = ego_lstm_out[:, -1, :]
        
        # CLOSEST NEIGHBOR
        # ---- DISTANCES TO OTHER AGENTS ----
        ego_pos = x[:, 0, 49, :2].unsqueeze(1)  # (batch, 1, 2)
        agent_pos = x[:, :, 49, :2]  # (batch, 50, 2)
        dists = torch.norm(agent_pos - ego_pos, dim=-1)  # (batch, 50)
        dists[:, 0] = float('inf')  # mask out ego
        
        _, neighbor_ids = torch.topk(dists, k=1, dim=1, largest=False)  # (batch, 3)
        
        # ---- ENCODE NEIGHBORS ----
        neighbor_out_list = []

        for i in range(1):
            idx = neighbor_ids[:, i]  # (batch,)
            neighbor_trajs = torch.stack([x[b, idx[b]] for b in range(batch_size)], dim=0)  # (batch, 50, 5)

            neighbor_lstm_out, _ = self.neighbor_encoder(neighbor_trajs)  # both: (num_layers, batch, hidden_dim)

            neighbor_out_list.append(neighbor_lstm_out[:, -1, :])

        # ---- CONCATENATE HIDDEN AND CELL STATES ----
        all_features = torch.cat([ego_features] + neighbor_out_list, dim=1)  # (num_layers, batch, hidden_dim * 4)
        
        # Process through prediction head
        features = self.relu(self.fc1(all_features))
        features = self.dropout(features)
        out = self.fc2(features)
        
        # Reshape to (batch_size, 60, 2)
        return out.view(-1, 60, 2)

In [60]:
# Example of basic model that should work
class AccelearationModel(nn.Module):
    # 2 layers 9.77 val
    def __init__(self, input_dim=8, hidden_dim=512, output_dim=60*2, num_layers=1):
        super(AccelearationModel, self).__init__()
        self.ego_encoder = nn.LSTM(input_size=input_dim, hidden_size=hidden_dim, num_layers=num_layers, batch_first=True)
        self.neighbor_encoder = nn.LSTM(input_size=input_dim, hidden_size=hidden_dim, num_layers=num_layers, batch_first=True)
        
        # Add multi-layer prediction head for better results
        self.fc1 = nn.Linear(hidden_dim*2, hidden_dim*2)
        # TODO: remove dropout?
        self.dropout = nn.Dropout(0.1)  # Add dropout for regularization
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim*2, output_dim)
        
        self.input_dim = input_dim
        
        # Initialize weights properly
        # for name, param in self.named_parameters():
        #     if 'weight' in name:
        #         nn.init.xavier_normal_(param)
        #     elif 'bias' in name:
        #         nn.init.constant_(param, 0.0)
        
    def forward(self, data):
        # x = data.x[..., :5]
        x = data.reshape(-1, 50, 50, self.input_dim)  # (batch_size, num_agents, seq_len, input_dim)
        batch_size = x.size(0)
        # x = x[:, 0, :, :]  # Only consider ego agent (index 0)
        
        # EGO AGENT
        ego_traj = x[:, 0, :, :]  # (batch, 50, 5)
        # Process through LSTM
        ego_lstm_out, _ = self.ego_encoder(ego_traj)
        # Extract final hidden state
        ego_features = ego_lstm_out[:, -1, :]
        
        # CLOSEST NEIGHBOR
        # ---- DISTANCES TO OTHER AGENTS ----
        ego_pos = x[:, 0, 49, :2].unsqueeze(1)  # (batch, 1, 2)
        agent_pos = x[:, :, 49, :2]  # (batch, 50, 2)
        dists = torch.norm(agent_pos - ego_pos, dim=-1)  # (batch, 50)
        dists[:, 0] = float('inf')  # mask out ego
        
        _, neighbor_ids = torch.topk(dists, k=1, dim=1, largest=False)  # (batch, 3)
        
        # ---- ENCODE NEIGHBORS ----
        neighbor_out_list = []

        for i in range(1):
            idx = neighbor_ids[:, i]  # (batch,)
            neighbor_trajs = torch.stack([x[b, idx[b]] for b in range(batch_size)], dim=0)  # (batch, 50, 5)

            neighbor_lstm_out, _ = self.neighbor_encoder(neighbor_trajs)  # both: (num_layers, batch, hidden_dim)

            neighbor_out_list.append(neighbor_lstm_out[:, -1, :])

        # ---- CONCATENATE HIDDEN AND CELL STATES ----
        all_features = torch.cat([ego_features] + neighbor_out_list, dim=1)  # (num_layers, batch, hidden_dim * 4)
        
        # Process through prediction head
        features = self.relu(self.fc1(all_features))
        features = self.dropout(features)
        out = self.fc2(features)
        
        # Reshape to (batch_size, 60, 2)
        return out.view(-1, 60, 2)

In [61]:
# Example of basic model that should work
class AngularAccelerationModel(nn.Module):
    # 2 layers 9.77 val
    def __init__(self, input_dim=9, hidden_dim=512, output_dim=60*2, num_layers=1):
        super(AngularAccelerationModel, self).__init__()
        self.ego_encoder = nn.LSTM(input_size=input_dim, hidden_size=hidden_dim, num_layers=num_layers, batch_first=True)
        self.neighbor_encoder = nn.LSTM(input_size=input_dim, hidden_size=hidden_dim, num_layers=num_layers, batch_first=True)
        
        # Add multi-layer prediction head for better results
        self.fc1 = nn.Linear(hidden_dim*2, hidden_dim*2)
        # TODO: remove dropout?
        self.dropout = nn.Dropout(0.1)  # Add dropout for regularization
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim*2, output_dim)
        
        self.input_dim = input_dim
        
        # Initialize weights properly
        # for name, param in self.named_parameters():
        #     if 'weight' in name:
        #         nn.init.xavier_normal_(param)
        #     elif 'bias' in name:
        #         nn.init.constant_(param, 0.0)
        
        
    def forward(self, data):
        # x = data.x[..., :5]
        x = data.reshape(-1, 50, 50, self.input_dim)  # (batch_size, num_agents, seq_len, input_dim)
        batch_size = x.size(0)
        # x = x[:, 0, :, :]  # Only consider ego agent (index 0)
        
        # EGO AGENT
        ego_traj = x[:, 0, :, :]  # (batch, 50, 5)
        # Process through LSTM
        ego_lstm_out, _ = self.ego_encoder(ego_traj)
        # Extract final hidden state
        ego_features = ego_lstm_out[:, -1, :]
        
        # CLOSEST NEIGHBOR
        # ---- DISTANCES TO OTHER AGENTS ----
        ego_pos = x[:, 0, 49, :2].unsqueeze(1)  # (batch, 1, 2)
        agent_pos = x[:, :, 49, :2]  # (batch, 50, 2)
        dists = torch.norm(agent_pos - ego_pos, dim=-1)  # (batch, 50)
        dists[:, 0] = float('inf')  # mask out ego
        
        _, neighbor_ids = torch.topk(dists, k=1, dim=1, largest=False)  # (batch, 3)
        
        # ---- ENCODE NEIGHBORS ----
        neighbor_out_list = []

        for i in range(1):
            idx = neighbor_ids[:, i]  # (batch,)
            neighbor_trajs = torch.stack([x[b, idx[b]] for b in range(batch_size)], dim=0)  # (batch, 50, 5)

            neighbor_lstm_out, _ = self.neighbor_encoder(neighbor_trajs)  # both: (num_layers, batch, hidden_dim)

            neighbor_out_list.append(neighbor_lstm_out[:, -1, :])

        # ---- CONCATENATE HIDDEN AND CELL STATES ----
        all_features = torch.cat([ego_features] + neighbor_out_list, dim=1)  # (num_layers, batch, hidden_dim * 4)
        
        # Process through prediction head
        features = self.relu(self.fc1(all_features))
        features = self.dropout(features)
        out = self.fc2(features)
        
        # Reshape to (batch_size, 60, 2)
        return out.view(-1, 60, 2)

In [72]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class GatingNetwork(nn.Module):
    def __init__(self, input_dim, num_experts, hidden_dim=64):
        super(GatingNetwork, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3),
            # nn.LayerNorm(hidden_dim),
            nn.Linear(hidden_dim, num_experts)
        )

    def forward(self, context):
        logits = self.net(context)  # [batch_size, num_experts]
        weights = F.softmax(logits, dim=-1)
        return weights

class MoEModel(nn.Module):
    def __init__(self, expert_models, gating_input_dim, future_steps):
        super(MoEModel, self).__init__()
        self.expert_models = nn.ModuleList(expert_models)
        self.gating_net = GatingNetwork(gating_input_dim, len(expert_models))
        self.future_steps = future_steps
        self.context_encoder = nn.LSTM(input_size=9, hidden_size=gating_input_dim, batch_first=True)


    def forward(self, input_dict):
        """
        input_dict:
            - context: [batch_size, gating_input_dim] --> for gating net
            - inputs_per_model: list of input tensors, one per model
        Returns:
            - [batch_size, future_steps, 2]  # future positions
        """
        model_inputs = input_dict['inputs_per_model']
        batch_size = model_inputs[0].size(0)
        # context = input_dict['context']
        model_input =model_inputs[-1]
        ego_pos = model_input[:, 0, 49, :2].unsqueeze(1)  # (batch, 1, 2)
        ego_features = model_input[:, 0, :, :]
        agent_pos = model_input[:, :, 49, :2]  # (batch, 50, 2)
        dists = torch.norm(agent_pos - ego_pos, dim=-1)  # (batch, 50)
        dists[:, 0] = float('inf')  # mask out ego
        
        _, neighbor_ids = torch.topk(dists, k=1, dim=1, largest=False)  # (batch, 3)
        
        # ---- ENCODE NEIGHBORS ----
        neighbor_out_list = []

        for i in range(1):
            idx = neighbor_ids[:, i]  # (batch,)
            neighbor_trajs = torch.stack([model_input[b, idx[b]] for b in range(batch_size)], dim=0)  # (batch, 50, 5)
        
        ego_features[..., :4] = ego_features[..., :4] - neighbor_trajs[..., :4]

        context, (hidden, cell)= self.context_encoder(ego_features)
        # print('encoder done')
        weights = self.gating_net(hidden[:, -1, :])  # [batch_size, num_experts]
        # print('gating')
        expert_preds = []
        for i, model in enumerate(self.expert_models):
            # print(i)
            pred = model(model_inputs[i])  # [batch_size, future_steps, 2]
            expert_preds.append(pred.unsqueeze(1))  # [B, 1, T, 2]

        expert_preds = torch.cat(expert_preds, dim=1)  # [B, E, T, 2]
        weights = weights.unsqueeze(-1).unsqueeze(-1)  # [B, E, 1, 1]

        fused = torch.sum(expert_preds * weights, dim=1)  # [B, T, 2]
        return fused.view(-1, 60, 2)


In [73]:
def train_improved_model(model, train_dataloader, val_dataloader, 
                         device, criterion=nn.MSELoss(), 
                         lr=0.001, epochs=100, patience=15):
    """
    Improved training function with better debugging and early stopping
    """
    # Initialize optimizer with smaller learning rate
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    
    # Exponential decay scheduler
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95)
    
    early_stopping_patience = patience
    best_val_loss = float('inf')
    no_improvement = 0
    
    # Save initial state for comparison
    initial_state_dict = {k: v.clone() for k, v in model.state_dict().items()}
    
    for epoch in tqdm.tqdm(range(epochs), desc="Epoch", unit="epoch"):
        # ---- Training ----
        model.train()
        train_loss = 0
        num_train_batches = 0
        
        for batch in train_dataloader:
            batch = batch.to(device)
            x = batch.x
            # x = x.reshape(-1, 50, 50, 5)  # (batch_size, num_agents, seq_len, input_dim)
            input_dict = {
                    "inputs_per_model": [
                        # x[..., :5].reshape(-1,50,50,5),
                        x[..., :5].reshape(-1,50,50,5),  # input for PositionModel
                        x[..., :8].reshape(-1,50,50,8),  # input for VelocityModel
                        x[..., :9].reshape(-1,50,50,9)  # input for AccelAngularModel
                    ]}
            pred = model(input_dict)
            y_all  = batch.y.view(batch.num_graphs, 60, 4)
            y = y_all[..., 2:4]
            
            # Check for NaN predictions
            if torch.isnan(pred).any():
                print(f"WARNING: NaN detected in predictions during training")
                continue
                
            loss = criterion(pred, y)
            
            # Check if loss is valid
            if torch.isnan(loss) or torch.isinf(loss):
                print(f"WARNING: Invalid loss value: {loss.item()}")
                continue
                
            optimizer.zero_grad()
            loss.backward()
            
            # More conservative gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            
            optimizer.step()
            train_loss += loss.item()
            num_train_batches += 1
        
        # Skip epoch if no valid batches
        if num_train_batches == 0:
            print("WARNING: No valid training batches in this epoch")
            continue
            
        train_loss /= num_train_batches
        
        # ---- Validation ----
        model.eval()
        val_loss = 0
        val_mae = 0
        val_mse = 0
        num_val_batches = 0
        
        # Sample predictions for debugging
        sample_input = None
        sample_pred = None
        sample_target = None
        
        with torch.no_grad():
            for batch_idx, batch in enumerate(val_dataloader):
                batch = batch.to(device)
                x = batch.x
                # x = x.reshape(-1, 50, 50, 5)  # (batch_size, num_agents, seq_len, input_dim)
                input_dict = {
                    "inputs_per_model": [
                        # x[..., :5].reshape(-1,50,50,5),
                        x[..., :5].reshape(-1,50,50,5),  # input for PositionModel
                        x[..., :8].reshape(-1,50,50,8),  # input for VelocityModel
                        x[..., :9].reshape(-1,50,50,9)  # input for AccelAngularModel
                    ]}
                
                pred = model(input_dict)
                y_all  = batch.y.view(batch.num_graphs, 60, 4)
                y = y_all[..., 2:4]
                
                # Store sample for debugging
                if batch_idx == 0 and sample_input is None:
                    sample_input = batch.x[0].cpu().numpy()
                    sample_pred = pred[0].cpu().numpy()
                    sample_target = y[0].cpu().numpy()
                
                # Skip invalid predictions
                if torch.isnan(pred).any():
                    print(f"WARNING: NaN detected in predictions during validation")
                    continue
                    
                batch_loss = criterion(pred, y).item()
                val_loss += batch_loss
                
                # Unnormalize for real-world metrics
                pred_unnorm = pred * batch.scale.view(-1, 1, 1)
                y_unnorm = y * batch.scale.view(-1, 1, 1)
                
                val_mae += nn.L1Loss()(pred_unnorm, y_unnorm).item()
                val_mse += nn.MSELoss()(pred_unnorm, y_unnorm).item()
                
                num_val_batches += 1
        
        # Skip epoch if no valid validation batches
        if num_val_batches == 0:
            print("WARNING: No valid validation batches in this epoch")
            continue
            
        val_loss /= num_val_batches
        val_mae /= num_val_batches
        val_mse /= num_val_batches
        
        # Update learning rate
        scheduler.step()
        
        # Print with more details
        tqdm.tqdm.write(
            f"Epoch {epoch:03d} | LR {optimizer.param_groups[0]['lr']:.6f} | "
            f"Train MSE {train_loss:.4f} | Val MSE {val_loss:.4f} | "
            f"Val MAE {val_mae:.4f} | Val MSE {val_mse:.4f}"
        )
        
        # Debug output - first 3 predictions vs targets
        if epoch % 5 == 0:
            tqdm.tqdm.write(f"Sample pred first 3 steps: {sample_pred[:3]}")
            tqdm.tqdm.write(f"Sample target first 3 steps: {sample_target[:3]}")
            
            # Check if model weights are changing
            if epoch > 0:
                weight_change = False
                for name, param in model.named_parameters():
                    if param.requires_grad:
                        initial_param = initial_state_dict[name]
                        if not torch.allclose(param, initial_param, rtol=1e-4):
                            weight_change = True
                            break
                if not weight_change:
                    tqdm.tqdm.write("WARNING: Model weights barely changing!")
        
        # Relaxed improvement criterion - consider any improvement
        if val_loss < best_val_loss:
            tqdm.tqdm.write(f"Validation improved: {best_val_loss:.6f} -> {val_loss:.6f}")
            best_val_loss = val_loss
            no_improvement = 0
            torch.save(model.state_dict(), best_model)
        else:
            no_improvement += 1
            if no_improvement >= early_stopping_patience:
                print(f"Early stopping after {epoch+1} epochs without improvement")
                break
    
    # Load best model before returning
    model.load_state_dict(torch.load(best_model))
    return model

In [74]:
# Example usage
def train_and_evaluate_model(model_list):
    # Create model
    expert_models = model_list
    moe = MoEModel(expert_models, gating_input_dim=64, future_steps=30)
    moe = moe.to(device)
    
    # Train with improved function
    train_improved_model(
        model=moe,
        train_dataloader=train_dataloader,
        val_dataloader=val_dataloader,
        device=device,
        # lr = 0.007 => 8.946
        lr=0.001,  # Lower learning rate
        patience=20,  # More patience
        epochs=150
    )
    
    # # Evaluate
    # moe.eval()
    # test_mse = 0
    # test_mse_xy = 0
    # with torch.no_grad():
    #     for batch in val_dataloader:
    #         batch = batch.to(device)
    #         pred = moe(batch)
    #         y_all  = batch.y.view(batch.num_graphs, 60, 4)
    #         y = y_all[..., 2:4]
    #         y_xy = y_all[..., 0:2]
    #         origin = batch.origin.unsqueeze(1)
    #         # Unnormalize
    #         pred = pred * batch.scale.view(-1, 1, 1)
    #         y = y * batch.scale.view(-1, 1, 1)
    #         y_xy = y_xy * batch.scale.view(-1, 1, 1)
            
    #         test_mse += nn.MSELoss()(pred, y).item()
    #         dt = 0.1  # seconds per step
    #         pred_pos = [origin]  # list of (B, 1, 2)

    #         for t in range(60):
    #             next_pos = pred_pos[-1] + pred[:, t:t+1, :] * dt  # (B, 1, 2)
    #             pred_pos.append(next_pos)

    #         # Concatenate positions across time steps
    #         pred_xy = torch.cat(pred_pos[1:], dim=1)
    #         test_mse_xy +=nn.MSELoss()(pred_xy, y_xy).item()
    # test_mse /= len(val_dataloader)
    # test_mse_xy /= len(val_dataloader)
    # print(f"Val MSE: {test_mse:.4f}")

    
    return moe

In [75]:
def val(model, laoder):
    model.eval()
    test_mse = 0
    test_mse_xy = 0
    with torch.no_grad():
        for batch in laoder:
            batch = batch.to(device)
            pred = model(batch)
            y_all  = batch.y.view(batch.num_graphs, 60, 4)
            y = y_all[..., 2:4]
            y_xy = y_all[..., 0:2]
            origin = batch.origin.unsqueeze(1)
            # Unnormalize
            pred = pred * batch.scale.view(-1, 1, 1)
            y = y * batch.scale.view(-1, 1, 1)
            y_xy = y_xy * batch.scale.view(-1, 1, 1)
            
            test_mse += nn.MSELoss()(pred, y).item()
            dt = 0.1  # seconds per step
            pred_pos = [0]  # list of (B, 1, 2)

            for t in range(60):
                next_pos = pred_pos[-1] + pred[:, t:t+1, :] * dt  # (B, 1, 2)
                pred_pos.append(next_pos)

            # Concatenate positions across time steps
            pred_xy = torch.cat(pred_pos[1:], dim=1)
            # print(test_mse)
            test_mse_xy += nn.MSELoss()(pred_xy, y_xy).item()
            # print(test_mse_xy)
    test_mse /= len(val_dataloader)
    test_mse_xy /= len(val_dataloader)
    print(f"Val MSE: {test_mse:.4f}")
    print(f"Val MSE: {test_mse_xy:.4f}")



In [76]:
train_and_evaluate_model(expert_models)

  future[..., :2] = future[..., :2] - origin
  future[..., :2] = future[..., :2] @ R
  future[..., 2:4] = future[..., 2:4] @ R
Epoch:   1%|          | 1/150 [00:28<1:10:29, 28.39s/epoch]

Epoch 000 | LR 0.000950 | Train MSE 0.1724 | Val MSE 0.0900 | Val MAE 1.4060 | Val MSE 4.4104
Sample pred first 3 steps: [[-0.01059493 -0.0205432 ]
 [-0.01287978 -0.01940733]
 [-0.01416496 -0.01945248]]
Sample target first 3 steps: [[-1.0263214e-05  5.1069769e-06]
 [-8.2099714e-06  2.6287930e-06]
 [ 1.4959690e-05 -1.9491419e-05]]
Validation improved: inf -> 0.090008


Epoch:   1%|▏         | 2/150 [00:56<1:09:01, 27.98s/epoch]

Epoch 001 | LR 0.000902 | Train MSE 0.0980 | Val MSE 0.0750 | Val MAE 1.2175 | Val MSE 3.6774
Validation improved: 0.090008 -> 0.075048


Epoch:   2%|▏         | 3/150 [01:23<1:07:59, 27.75s/epoch]

Epoch 002 | LR 0.000857 | Train MSE 0.0811 | Val MSE 0.0724 | Val MAE 1.1601 | Val MSE 3.5458
Validation improved: 0.075048 -> 0.072364


Epoch:   3%|▎         | 4/150 [01:51<1:07:22, 27.69s/epoch]

Epoch 003 | LR 0.000815 | Train MSE 0.0778 | Val MSE 0.0617 | Val MAE 1.1120 | Val MSE 3.0236
Validation improved: 0.072364 -> 0.061706


Epoch:   3%|▎         | 5/150 [02:19<1:07:08, 27.78s/epoch]

Epoch 004 | LR 0.000774 | Train MSE 0.0819 | Val MSE 0.0586 | Val MAE 1.0650 | Val MSE 2.8716
Validation improved: 0.061706 -> 0.058603


Epoch:   4%|▍         | 6/150 [02:47<1:06:48, 27.84s/epoch]

Epoch 005 | LR 0.000735 | Train MSE 0.0641 | Val MSE 0.0619 | Val MAE 1.0373 | Val MSE 3.0316
Sample pred first 3 steps: [[-0.03051857 -0.01122168]
 [-0.02775079 -0.01010021]
 [-0.02435523 -0.0092591 ]]
Sample target first 3 steps: [[-1.0263214e-05  5.1069769e-06]
 [-8.2099714e-06  2.6287930e-06]
 [ 1.4959690e-05 -1.9491419e-05]]


Epoch:   5%|▍         | 7/150 [03:14<1:06:22, 27.85s/epoch]

Epoch 006 | LR 0.000698 | Train MSE 0.0620 | Val MSE 0.0575 | Val MAE 1.0385 | Val MSE 2.8188
Validation improved: 0.058603 -> 0.057526


Epoch:   5%|▌         | 8/150 [03:42<1:06:03, 27.91s/epoch]

Epoch 007 | LR 0.000663 | Train MSE 0.0596 | Val MSE 0.0561 | Val MAE 1.0154 | Val MSE 2.7507
Validation improved: 0.057526 -> 0.056136


Epoch:   6%|▌         | 9/150 [04:11<1:05:50, 28.01s/epoch]

Epoch 008 | LR 0.000630 | Train MSE 0.0598 | Val MSE 0.0505 | Val MAE 0.9620 | Val MSE 2.4734
Validation improved: 0.056136 -> 0.050479


Epoch:   7%|▋         | 10/150 [04:39<1:05:36, 28.12s/epoch]

Epoch 009 | LR 0.000599 | Train MSE 0.0568 | Val MSE 0.0509 | Val MAE 0.9587 | Val MSE 2.4917


Epoch:   7%|▋         | 11/150 [05:07<1:05:03, 28.08s/epoch]

Epoch 010 | LR 0.000569 | Train MSE 0.0616 | Val MSE 0.0711 | Val MAE 1.0172 | Val MSE 3.4845
Sample pred first 3 steps: [[-0.02117983  0.00184315]
 [-0.01958773  0.00051857]
 [-0.01797159 -0.00179574]]
Sample target first 3 steps: [[-1.0263214e-05  5.1069769e-06]
 [-8.2099714e-06  2.6287930e-06]
 [ 1.4959690e-05 -1.9491419e-05]]


Epoch:   8%|▊         | 12/150 [05:38<1:06:22, 28.86s/epoch]

Epoch 011 | LR 0.000540 | Train MSE 0.0606 | Val MSE 0.0600 | Val MAE 1.0009 | Val MSE 2.9403


Epoch:   9%|▊         | 13/150 [06:08<1:06:47, 29.25s/epoch]

Epoch 012 | LR 0.000513 | Train MSE 0.0564 | Val MSE 0.0493 | Val MAE 0.9568 | Val MSE 2.4156
Validation improved: 0.050479 -> 0.049299


Epoch:   9%|▉         | 14/150 [06:38<1:07:11, 29.65s/epoch]

Epoch 013 | LR 0.000488 | Train MSE 0.0549 | Val MSE 0.0568 | Val MAE 0.9722 | Val MSE 2.7828


Epoch:  10%|█         | 15/150 [07:14<1:10:37, 31.39s/epoch]

Epoch 014 | LR 0.000463 | Train MSE 0.0543 | Val MSE 0.0563 | Val MAE 0.9728 | Val MSE 2.7565


Epoch:  11%|█         | 16/150 [07:47<1:11:16, 31.91s/epoch]

Epoch 015 | LR 0.000440 | Train MSE 0.0572 | Val MSE 0.0530 | Val MAE 0.9402 | Val MSE 2.5974
Sample pred first 3 steps: [[0.00641353 0.02593149]
 [0.00584396 0.02451914]
 [0.00558462 0.02436241]]
Sample target first 3 steps: [[-1.0263214e-05  5.1069769e-06]
 [-8.2099714e-06  2.6287930e-06]
 [ 1.4959690e-05 -1.9491419e-05]]


Epoch:  11%|█▏        | 17/150 [08:19<1:10:44, 31.92s/epoch]

Epoch 016 | LR 0.000418 | Train MSE 0.0536 | Val MSE 0.0530 | Val MAE 0.9381 | Val MSE 2.5962


Epoch:  12%|█▏        | 18/150 [08:53<1:11:40, 32.58s/epoch]

Epoch 017 | LR 0.000397 | Train MSE 0.0504 | Val MSE 0.0519 | Val MAE 0.9485 | Val MSE 2.5454


Epoch:  13%|█▎        | 19/150 [09:26<1:11:14, 32.63s/epoch]

Epoch 018 | LR 0.000377 | Train MSE 0.0515 | Val MSE 0.0519 | Val MAE 0.9406 | Val MSE 2.5424


Epoch:  13%|█▎        | 20/150 [09:59<1:11:05, 32.81s/epoch]

Epoch 019 | LR 0.000358 | Train MSE 0.0497 | Val MSE 0.0522 | Val MAE 0.9281 | Val MSE 2.5559


Epoch:  14%|█▍        | 21/150 [10:33<1:11:22, 33.20s/epoch]

Epoch 020 | LR 0.000341 | Train MSE 0.0512 | Val MSE 0.0493 | Val MAE 0.9207 | Val MSE 2.4181
Sample pred first 3 steps: [[0.00877132 0.0434375 ]
 [0.00695176 0.0412857 ]
 [0.00529175 0.0409074 ]]
Sample target first 3 steps: [[-1.0263214e-05  5.1069769e-06]
 [-8.2099714e-06  2.6287930e-06]
 [ 1.4959690e-05 -1.9491419e-05]]


Epoch:  15%|█▍        | 22/150 [11:07<1:11:32, 33.53s/epoch]

Epoch 021 | LR 0.000324 | Train MSE 0.0483 | Val MSE 0.0485 | Val MAE 0.9129 | Val MSE 2.3757
Validation improved: 0.049299 -> 0.048483


Epoch:  15%|█▌        | 23/150 [11:41<1:10:54, 33.50s/epoch]

Epoch 022 | LR 0.000307 | Train MSE 0.0488 | Val MSE 0.0486 | Val MAE 0.9100 | Val MSE 2.3813


Epoch:  16%|█▌        | 24/150 [12:15<1:10:59, 33.81s/epoch]

Epoch 023 | LR 0.000292 | Train MSE 0.0469 | Val MSE 0.0477 | Val MAE 0.8998 | Val MSE 2.3361
Validation improved: 0.048483 -> 0.047676


Epoch:  17%|█▋        | 25/150 [12:48<1:09:56, 33.57s/epoch]

Epoch 024 | LR 0.000277 | Train MSE 0.0492 | Val MSE 0.0528 | Val MAE 0.9232 | Val MSE 2.5878


Epoch:  17%|█▋        | 26/150 [13:22<1:09:25, 33.60s/epoch]

Epoch 025 | LR 0.000264 | Train MSE 0.0464 | Val MSE 0.0507 | Val MAE 0.9070 | Val MSE 2.4842
Sample pred first 3 steps: [[0.02188403 0.0196551 ]
 [0.0209751  0.01736747]
 [0.01835225 0.01692602]]
Sample target first 3 steps: [[-1.0263214e-05  5.1069769e-06]
 [-8.2099714e-06  2.6287930e-06]
 [ 1.4959690e-05 -1.9491419e-05]]


Epoch:  18%|█▊        | 27/150 [13:56<1:09:09, 33.74s/epoch]

Epoch 026 | LR 0.000250 | Train MSE 0.0476 | Val MSE 0.0518 | Val MAE 0.9200 | Val MSE 2.5365


Epoch:  19%|█▊        | 28/150 [14:30<1:08:37, 33.75s/epoch]

Epoch 027 | LR 0.000238 | Train MSE 0.0461 | Val MSE 0.0517 | Val MAE 0.9172 | Val MSE 2.5314


Epoch:  19%|█▉        | 29/150 [15:04<1:08:04, 33.75s/epoch]

Epoch 028 | LR 0.000226 | Train MSE 0.0455 | Val MSE 0.0506 | Val MAE 0.9019 | Val MSE 2.4772


Epoch:  20%|██        | 30/150 [15:37<1:07:14, 33.62s/epoch]

Epoch 029 | LR 0.000215 | Train MSE 0.0468 | Val MSE 0.0535 | Val MAE 0.9371 | Val MSE 2.6233


Epoch:  21%|██        | 31/150 [16:10<1:06:21, 33.45s/epoch]

Epoch 030 | LR 0.000204 | Train MSE 0.0450 | Val MSE 0.0502 | Val MAE 0.9090 | Val MSE 2.4622
Sample pred first 3 steps: [[0.02345403 0.0513161 ]
 [0.02187377 0.04915054]
 [0.01950306 0.04672467]]
Sample target first 3 steps: [[-1.0263214e-05  5.1069769e-06]
 [-8.2099714e-06  2.6287930e-06]
 [ 1.4959690e-05 -1.9491419e-05]]


Epoch:  21%|██▏       | 32/150 [16:44<1:06:05, 33.60s/epoch]

Epoch 031 | LR 0.000194 | Train MSE 0.0453 | Val MSE 0.0494 | Val MAE 0.8971 | Val MSE 2.4203


Epoch:  22%|██▏       | 33/150 [17:17<1:05:11, 33.44s/epoch]

Epoch 032 | LR 0.000184 | Train MSE 0.0456 | Val MSE 0.0489 | Val MAE 0.8878 | Val MSE 2.3952


Epoch:  23%|██▎       | 34/150 [17:51<1:04:58, 33.61s/epoch]

Epoch 033 | LR 0.000175 | Train MSE 0.0443 | Val MSE 0.0490 | Val MAE 0.8981 | Val MSE 2.4009


Epoch:  23%|██▎       | 35/150 [18:24<1:04:17, 33.55s/epoch]

Epoch 034 | LR 0.000166 | Train MSE 0.0435 | Val MSE 0.0484 | Val MAE 0.8719 | Val MSE 2.3717


Epoch:  24%|██▍       | 36/150 [18:58<1:03:43, 33.54s/epoch]

Epoch 035 | LR 0.000158 | Train MSE 0.0437 | Val MSE 0.0500 | Val MAE 0.8937 | Val MSE 2.4496
Sample pred first 3 steps: [[0.00356904 0.03973006]
 [0.00322617 0.03677312]
 [0.00273846 0.03568399]]
Sample target first 3 steps: [[-1.0263214e-05  5.1069769e-06]
 [-8.2099714e-06  2.6287930e-06]
 [ 1.4959690e-05 -1.9491419e-05]]


Epoch:  25%|██▍       | 37/150 [19:31<1:02:58, 33.44s/epoch]

Epoch 036 | LR 0.000150 | Train MSE 0.0430 | Val MSE 0.0501 | Val MAE 0.8924 | Val MSE 2.4528


Epoch:  25%|██▌       | 38/150 [20:04<1:02:05, 33.26s/epoch]

Epoch 037 | LR 0.000142 | Train MSE 0.0466 | Val MSE 0.0480 | Val MAE 0.8837 | Val MSE 2.3496


Epoch:  26%|██▌       | 39/150 [20:37<1:01:20, 33.16s/epoch]

Epoch 038 | LR 0.000135 | Train MSE 0.0425 | Val MSE 0.0482 | Val MAE 0.8877 | Val MSE 2.3611


Epoch:  27%|██▋       | 40/150 [21:10<1:00:34, 33.04s/epoch]

Epoch 039 | LR 0.000129 | Train MSE 0.0426 | Val MSE 0.0495 | Val MAE 0.8874 | Val MSE 2.4261


Epoch:  27%|██▋       | 41/150 [21:43<1:00:09, 33.11s/epoch]

Epoch 040 | LR 0.000122 | Train MSE 0.0428 | Val MSE 0.0489 | Val MAE 0.8817 | Val MSE 2.3943
Sample pred first 3 steps: [[-0.00234975  0.0416768 ]
 [-0.00169034  0.03899324]
 [-0.00185749  0.03612714]]
Sample target first 3 steps: [[-1.0263214e-05  5.1069769e-06]
 [-8.2099714e-06  2.6287930e-06]
 [ 1.4959690e-05 -1.9491419e-05]]


Epoch:  28%|██▊       | 42/150 [22:17<59:54, 33.28s/epoch]  

Epoch 041 | LR 0.000116 | Train MSE 0.0426 | Val MSE 0.0498 | Val MAE 0.8959 | Val MSE 2.4420


Epoch:  29%|██▊       | 43/150 [22:50<59:27, 33.34s/epoch]

Epoch 042 | LR 0.000110 | Train MSE 0.0419 | Val MSE 0.0485 | Val MAE 0.8866 | Val MSE 2.3752


Epoch:  29%|██▊       | 43/150 [23:24<58:13, 32.65s/epoch]

Epoch 043 | LR 0.000105 | Train MSE 0.0431 | Val MSE 0.0486 | Val MAE 0.8874 | Val MSE 2.3828
Early stopping after 44 epochs without improvement





MoEModel(
  (expert_models): ModuleList(
    (0): VelocityModel(
      (ego_encoder): LSTM(5, 512, batch_first=True)
      (neighbor_encoder): LSTM(5, 512, batch_first=True)
      (fc1): Linear(in_features=1024, out_features=1024, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
      (relu): ReLU()
      (fc2): Linear(in_features=1024, out_features=120, bias=True)
    )
    (1): AccelearationModel(
      (ego_encoder): LSTM(8, 512, batch_first=True)
      (neighbor_encoder): LSTM(8, 512, batch_first=True)
      (fc1): Linear(in_features=1024, out_features=1024, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
      (relu): ReLU()
      (fc2): Linear(in_features=1024, out_features=120, bias=True)
    )
    (2): AngularAccelerationModel(
      (ego_encoder): LSTM(9, 512, batch_first=True)
      (neighbor_encoder): LSTM(9, 512, batch_first=True)
      (fc1): Linear(in_features=1024, out_features=1024, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
      (relu

In [None]:
simple_model_n = "best_model_xy.pt"
velocity_model_n = 'best_model0_retry.pt'
acceleration_model_n = 'best_model0_feat_eng.pt'
angular_model_n = 'best_model0_feat_eng_add_ang_acc.pt'


simple_model2 = torch.load(simple_model_n)
simple_model = SimpleLSTM().to(device)
simple_model.load_state_dict(simple_model2)

velocity_model2 = torch.load(velocity_model_n)
velocity_model = VelocityModel().to(device)
velocity_model.load_state_dict(velocity_model2)

acceleration_model2 = torch.load(acceleration_model_n)
acceleration_model = AccelearationModel().to(device)
acceleration_model.load_state_dict(acceleration_model2)

angular_model2 = torch.load(angular_model_n)
angular_model = AngularAccelerationModel().to(device)
angular_model.load_state_dict(angular_model2)


expert_models = [
                #    simple_model,
                   velocity_model,
                   acceleration_model,
                    angular_model
                     ]