In [10]:
best_model = "ensemble_model_4.pt"

In [11]:
import os

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import random

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch_geometric.data import Data, Batch
import tqdm

In [12]:
train_file = np.load('../cse-251-b-2025/train.npz')

train_data = train_file['data']
print("train_data's shape", train_data.shape)
test_file = np.load('../cse-251-b-2025/test_input.npz')

test_data = test_file['data']
print("test_data's shape", test_data.shape)

train_data's shape (10000, 50, 110, 6)
test_data's shape (2100, 50, 50, 6)


In [13]:
class TrajectoryDatasetTrain_AngularAcc(Dataset):
    def __init__(self, data, scale=10.0, augment=True):
        self.data = data
        self.scale = scale
        self.augment = augment
        self.dt = 0.1  # seconds

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        scene = self.data[idx]
        hist = scene[:, :50, :4].copy()  # (agents=50, time_seq=50, 5)
        future = torch.tensor(scene[0, 50:, 0:4].copy(), dtype=torch.float32)  # (60, 2)

        if self.augment:
            if np.random.rand() < 0.5:
                theta = np.random.uniform(-np.pi, np.pi)
                R = np.array([[np.cos(theta), -np.sin(theta)],
                              [np.sin(theta),  np.cos(theta)]], dtype=np.float32)
                hist[..., :2] = hist[..., :2] @ R
                hist[..., 2:4] = hist[..., 2:4] @ R
                future[..., :2] = future[..., :2] @ R
                future[..., 2:4] = future[..., 2:4] @ R
            if np.random.rand() < 0.5:
                hist[..., 0] *= -1
                hist[..., 2] *= -1
                future[:, 0] *= -1
                future[:, 2] *= -1

        v = hist[..., 2:4]  # velocity
        a = np.zeros_like(v)
        a[..., 1:] = (v[..., 1:] - v[..., :-1]) / self.dt

        theta = np.arctan2(v[..., 1], v[..., 0])[..., np.newaxis]
        omega = np.zeros_like(theta)
        omega[..., 1:] = (theta[..., 1:] - theta[..., :-1]) / self.dt

        alpha = np.zeros_like(omega)
        alpha[..., 1:] = (omega[..., 1:] - omega[..., :-1]) / self.dt

        hist_aug = np.concatenate([hist, a, theta, omega, alpha], axis=-1)  # shape (50, 50, 9)

        origin = hist_aug[0, 49, :2].copy()
        hist_aug[..., :2] -= origin
        hist_aug[..., :4] /= self.scale
        hist_aug[..., 4:6] /= self.scale / self.dt        # ax, ay
        hist_aug[..., 6] /= np.pi                         # θ ∈ [-π, π]
        hist_aug[..., 7] /= (np.pi / self.dt)             # ω
        hist_aug[..., 8] /= (np.pi / (self.dt ** 2))      # α
        future[..., :2] = future[..., :2] - origin
        future = future / self.scale

        return Data(
            x=torch.tensor(hist_aug, dtype=torch.float32),
            y=future,
            origin=torch.tensor(origin, dtype=torch.float32).unsqueeze(0),
            scale=torch.tensor(self.scale, dtype=torch.float32),
        )
        
        
class TrajectoryDatasetTest(Dataset):
    def __init__(self, data, scale=10.0):
        self.data = data
        self.scale = scale
        self.dt = 0.1

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        hist = self.data[idx][:, :50, :4].copy()  # (50, 50, 4: [x, y, vx, vy])

        v = hist[..., 2:4]
        a = np.zeros_like(v)
        a[..., 1:] = (v[..., 1:] - v[..., :-1]) / self.dt

        theta = np.arctan2(v[..., 1], v[..., 0])[..., np.newaxis]
        omega = np.zeros_like(theta)
        omega[..., 1:] = (theta[..., 1:] - theta[..., :-1]) / self.dt

        alpha = np.zeros_like(omega)
        alpha[..., 1:] = (omega[..., 1:] - omega[..., :-1]) / self.dt

        hist_aug = np.concatenate([hist, a, theta, omega, alpha], axis=-1)  # shape (50, 50, 9)

        origin = hist_aug[0, 49, :2].copy()
        hist_aug[..., :2] -= origin
        hist_aug[..., :4] /= self.scale
        hist_aug[..., 4:6] /= self.scale / self.dt
        hist_aug[..., 6] /= np.pi
        hist_aug[..., 7] /= (np.pi / self.dt)
        hist_aug[..., 8] /= (np.pi / (self.dt ** 2))

        return Data(
            x=torch.tensor(hist_aug, dtype=torch.float32),
            origin=torch.tensor(origin, dtype=torch.float32).unsqueeze(0),
            scale=torch.tensor(self.scale, dtype=torch.float32),
        )

In [14]:
torch.manual_seed(251)
np.random.seed(42)

scale = 7.0

N = len(train_data)
val_size = int(0.1 * N)
train_size = N - val_size

train_dataset = TrajectoryDatasetTrain_AngularAcc(train_data[:train_size], scale=scale, augment=True)
val_dataset = TrajectoryDatasetTrain_AngularAcc(train_data[train_size:], scale=scale, augment=False)

train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=lambda x: Batch.from_data_list(x))
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=lambda x: Batch.from_data_list(x))

# Set device for training speedup
if torch.backends.mps.is_available():
    device = torch.device('mps')
    print("Using Apple Silicon GPU")
elif torch.cuda.is_available():
    device = torch.device('cuda')
    print("Using CUDA GPU")
else:
    device = torch.device('cpu')

Using Apple Silicon GPU


In [15]:
# def convert_xy(pred_vel, origin=None):
#     dt = 0.1  # seconds per step
#     if origin:
#         pred_pos = [origin]  # list of (B, 1, 2)
#         for t in range(60):
#             next_pos = pred_pos[-1] + pred_vel[:, t:t+1, :] * dt  # (B, 1, 2)
#             pred_pos.append(next_pos)
        
#         # Concatenate positions across time steps
#         pred_xy = torch.cat(pred_pos[1:], dim=1)  # skip initial origin, get (B, 60, 2)
#     else:
#         pred_pos = [0]
        
#         for t in range(60):
#             next_vel = (pred_vel[:, t:t+1, :] - pred_pos[-1]) /dt
#             # next_pos = pred_pos[-1] + pred_vel[:, t:t+1, :] * dt  # (B, 1, 2)
#             pred_pos.append(next_vel)
#         pred_vxvy = torch.cat(pred_pos[1:], dim=1)  # skip initial origin, get (B, 60, 2)
#     return pred_vxvy

In [27]:
def convert_VXVY(pred_pos):
    dt = 0.1
    pred_pos = pred_pos.reshape(-1, 60, 2)
    B = pred_pos.shape[0]
    zero_pos = torch.zeros((B, 1, 2), device=pred_pos.device, dtype=pred_pos.dtype)
    pos = torch.cat([zero_pos, pred_pos], dim=1)
    vel = torch.zeros_like(pos)                     # shape (B, 61, 2), with vel[:,0,:]=0
    vel[:, 1:, :] = (pos[:, 1:, :] - pos[:, :-1, :]) / dt
    pred_vxvy = vel[:, 1:, :]     # skip initial origin, get (B, 60, 2)
    # print(pred_vxvy[:4])
    # print(pred_vxvy.shape)
    return pred_vxvy

In [17]:
# Example of basic model that should work
class SimpleLSTM(nn.Module):
    def __init__(self, input_dim=5, hidden_dim=512, output_dim=60*2):
        super(SimpleLSTM, self).__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True)
        
        # Add multi-layer prediction head for better results
        self.fc1 = nn.Linear(hidden_dim, hidden_dim)
        self.dropout = nn.Dropout(0.1)  # Add dropout for regularization
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        
        # Initialize weights properly
        for name, param in self.named_parameters():
            if 'weight' in name:
                nn.init.xavier_normal_(param)
            elif 'bias' in name:
                nn.init.constant_(param, 0.0)
        
    def forward(self, x):
        # x = data.x[..., :5]
        # x = x.reshape(-1, 50, 50, 5)  # (batch_size, num_agents, seq_len, input_dim)
        x = x[:, 0, :, :]  # Only consider ego agent (index 0)
        
        # Process through LSTM
        lstm_out, _ = self.lstm(x)
        
        # Extract final hidden state
        features = lstm_out[:, -1, :]
        
        # Process through prediction head
        features = self.relu(self.fc1(features))
        features = self.dropout(features)
        out = self.fc2(features)
        
        # Reshape to (batch_size, 60, 2)
        out_vel = convert_VXVY(out.view(-1, 60, 2))
        return out_vel

In [18]:
# Example of basic model that should work
class VelocityModel(nn.Module):
    # 2 layers 9.77 val
    def __init__(self, input_dim=5, hidden_dim=512, output_dim=60*2, num_layers=1):
        super(VelocityModel, self).__init__()
        self.ego_encoder = nn.LSTM(input_size=input_dim, hidden_size=hidden_dim, num_layers=num_layers, batch_first=True)
        self.neighbor_encoder = nn.LSTM(input_size=input_dim, hidden_size=hidden_dim, num_layers=num_layers, batch_first=True)
        
        # Add multi-layer prediction head for better results
        self.fc1 = nn.Linear(hidden_dim*2, hidden_dim*2)
        # TODO: remove dropout?
        self.dropout = nn.Dropout(0.1)  # Add dropout for regularization
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim*2, output_dim)
        
        # Initialize weights properly
        # for name, param in self.named_parameters():
        #     if 'weight' in name:
        #         nn.init.xavier_normal_(param)
        #     elif 'bias' in name:
        #         nn.init.constant_(param, 0.0)
        
    def forward(self, data):
        # x = data.x[..., :5]
        x = data.reshape(-1, 50, 50, 5)  # (batch_size, num_agents, seq_len, input_dim)
        batch_size = x.size(0)
        # x = x[:, 0, :, :]  # Only consider ego agent (index 0)
        
        # EGO AGENT
        ego_traj = x[:, 0, :, :]  # (batch, 50, 5)
        # Process through LSTM
        ego_lstm_out, _ = self.ego_encoder(ego_traj)
        # Extract final hidden state
        ego_features = ego_lstm_out[:, -1, :]
        
        # CLOSEST NEIGHBOR
        # ---- DISTANCES TO OTHER AGENTS ----
        ego_pos = x[:, 0, 49, :2].unsqueeze(1)  # (batch, 1, 2)
        agent_pos = x[:, :, 49, :2]  # (batch, 50, 2)
        dists = torch.norm(agent_pos - ego_pos, dim=-1)  # (batch, 50)
        dists[:, 0] = float('inf')  # mask out ego
        
        _, neighbor_ids = torch.topk(dists, k=1, dim=1, largest=False)  # (batch, 3)
        
        # ---- ENCODE NEIGHBORS ----
        neighbor_out_list = []

        for i in range(1):
            idx = neighbor_ids[:, i]  # (batch,)
            neighbor_trajs = torch.stack([x[b, idx[b]] for b in range(batch_size)], dim=0)  # (batch, 50, 5)

            neighbor_lstm_out, _ = self.neighbor_encoder(neighbor_trajs)  # both: (num_layers, batch, hidden_dim)

            neighbor_out_list.append(neighbor_lstm_out[:, -1, :])

        # ---- CONCATENATE HIDDEN AND CELL STATES ----
        all_features = torch.cat([ego_features] + neighbor_out_list, dim=1)  # (num_layers, batch, hidden_dim * 4)
        
        # Process through prediction head
        features = self.relu(self.fc1(all_features))
        features = self.dropout(features)
        out = self.fc2(features)
        
        # Reshape to (batch_size, 60, 2)
        return out.view(-1, 60, 2)

In [19]:
# Example of basic model that should work
class AccelearationModel(nn.Module):
    # 2 layers 9.77 val
    def __init__(self, input_dim=8, hidden_dim=512, output_dim=60*2, num_layers=1):
        super(AccelearationModel, self).__init__()
        self.ego_encoder = nn.LSTM(input_size=input_dim, hidden_size=hidden_dim, num_layers=num_layers, batch_first=True)
        self.neighbor_encoder = nn.LSTM(input_size=input_dim, hidden_size=hidden_dim, num_layers=num_layers, batch_first=True)
        
        # Add multi-layer prediction head for better results
        self.fc1 = nn.Linear(hidden_dim*2, hidden_dim*2)
        # TODO: remove dropout?
        self.dropout = nn.Dropout(0.1)  # Add dropout for regularization
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim*2, output_dim)
        
        self.input_dim = input_dim
        
        # Initialize weights properly
        # for name, param in self.named_parameters():
        #     if 'weight' in name:
        #         nn.init.xavier_normal_(param)
        #     elif 'bias' in name:
        #         nn.init.constant_(param, 0.0)
        
    def forward(self, data):
        # x = data.x[..., :5]
        x = data.reshape(-1, 50, 50, self.input_dim)  # (batch_size, num_agents, seq_len, input_dim)
        batch_size = x.size(0)
        # x = x[:, 0, :, :]  # Only consider ego agent (index 0)
        
        # EGO AGENT
        ego_traj = x[:, 0, :, :]  # (batch, 50, 5)
        # Process through LSTM
        ego_lstm_out, _ = self.ego_encoder(ego_traj)
        # Extract final hidden state
        ego_features = ego_lstm_out[:, -1, :]
        
        # CLOSEST NEIGHBOR
        # ---- DISTANCES TO OTHER AGENTS ----
        ego_pos = x[:, 0, 49, :2].unsqueeze(1)  # (batch, 1, 2)
        agent_pos = x[:, :, 49, :2]  # (batch, 50, 2)
        dists = torch.norm(agent_pos - ego_pos, dim=-1)  # (batch, 50)
        dists[:, 0] = float('inf')  # mask out ego
        
        _, neighbor_ids = torch.topk(dists, k=1, dim=1, largest=False)  # (batch, 3)
        
        # ---- ENCODE NEIGHBORS ----
        neighbor_out_list = []

        for i in range(1):
            idx = neighbor_ids[:, i]  # (batch,)
            neighbor_trajs = torch.stack([x[b, idx[b]] for b in range(batch_size)], dim=0)  # (batch, 50, 5)

            neighbor_lstm_out, _ = self.neighbor_encoder(neighbor_trajs)  # both: (num_layers, batch, hidden_dim)

            neighbor_out_list.append(neighbor_lstm_out[:, -1, :])

        # ---- CONCATENATE HIDDEN AND CELL STATES ----
        all_features = torch.cat([ego_features] + neighbor_out_list, dim=1)  # (num_layers, batch, hidden_dim * 4)
        
        # Process through prediction head
        features = self.relu(self.fc1(all_features))
        features = self.dropout(features)
        out = self.fc2(features)
        
        # Reshape to (batch_size, 60, 2)
        return out.view(-1, 60, 2)

In [20]:
# Example of basic model that should work
class AngularAccelerationModel(nn.Module):
    # 2 layers 9.77 val
    def __init__(self, input_dim=9, hidden_dim=512, output_dim=60*2, num_layers=1):
        super(AngularAccelerationModel, self).__init__()
        self.ego_encoder = nn.LSTM(input_size=input_dim, hidden_size=hidden_dim, num_layers=num_layers, batch_first=True)
        self.neighbor_encoder = nn.LSTM(input_size=input_dim, hidden_size=hidden_dim, num_layers=num_layers, batch_first=True)
        
        # Add multi-layer prediction head for better results
        self.fc1 = nn.Linear(hidden_dim*2, hidden_dim*2)
        # TODO: remove dropout?
        self.dropout = nn.Dropout(0.1)  # Add dropout for regularization
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim*2, output_dim)
        
        self.input_dim = input_dim
        
        # Initialize weights properly
        # for name, param in self.named_parameters():
        #     if 'weight' in name:
        #         nn.init.xavier_normal_(param)
        #     elif 'bias' in name:
        #         nn.init.constant_(param, 0.0)
        
        
    def forward(self, data):
        # x = data.x[..., :5]
        x = data.reshape(-1, 50, 50, self.input_dim)  # (batch_size, num_agents, seq_len, input_dim)
        batch_size = x.size(0)
        # x = x[:, 0, :, :]  # Only consider ego agent (index 0)
        
        # EGO AGENT
        ego_traj = x[:, 0, :, :]  # (batch, 50, 5)
        # Process through LSTM
        ego_lstm_out, _ = self.ego_encoder(ego_traj)
        # Extract final hidden state
        ego_features = ego_lstm_out[:, -1, :]
        
        # CLOSEST NEIGHBOR
        # ---- DISTANCES TO OTHER AGENTS ----
        ego_pos = x[:, 0, 49, :2].unsqueeze(1)  # (batch, 1, 2)
        agent_pos = x[:, :, 49, :2]  # (batch, 50, 2)
        dists = torch.norm(agent_pos - ego_pos, dim=-1)  # (batch, 50)
        dists[:, 0] = float('inf')  # mask out ego
        
        _, neighbor_ids = torch.topk(dists, k=1, dim=1, largest=False)  # (batch, 3)
        
        # ---- ENCODE NEIGHBORS ----
        neighbor_out_list = []

        for i in range(1):
            idx = neighbor_ids[:, i]  # (batch,)
            neighbor_trajs = torch.stack([x[b, idx[b]] for b in range(batch_size)], dim=0)  # (batch, 50, 5)

            neighbor_lstm_out, _ = self.neighbor_encoder(neighbor_trajs)  # both: (num_layers, batch, hidden_dim)

            neighbor_out_list.append(neighbor_lstm_out[:, -1, :])

        # ---- CONCATENATE HIDDEN AND CELL STATES ----
        all_features = torch.cat([ego_features] + neighbor_out_list, dim=1)  # (num_layers, batch, hidden_dim * 4)
        
        # Process through prediction head
        features = self.relu(self.fc1(all_features))
        features = self.dropout(features)
        out = self.fc2(features)
        
        # Reshape to (batch_size, 60, 2)
        return out.view(-1, 60, 2)

In [56]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class GatingNetwork(nn.Module):
    def __init__(self, input_dim, num_experts, hidden_dim=64):
        super(GatingNetwork, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3),
            # nn.LayerNorm(hidden_dim),
            nn.Linear(hidden_dim, num_experts)
        )

    def forward(self, context):
        logits = self.net(context)  # [batch_size, num_experts]
        weights = F.softmax(logits, dim=-1)
        return weights

class MoEModel(nn.Module):
    def __init__(self, expert_models, gating_input_dim, future_steps):
        super(MoEModel, self).__init__()
        self.expert_models = nn.ModuleList(expert_models)
        self.gating_net = GatingNetwork(gating_input_dim, len(expert_models))
        self.future_steps = future_steps
        self.context_encoder = nn.LSTM(input_size=9, hidden_size=gating_input_dim, batch_first=True)


    def forward(self, input_dict):
        """
        input_dict:
            - context: [batch_size, gating_input_dim] --> for gating net
            - inputs_per_model: list of input tensors, one per model
        Returns:
            - [batch_size, future_steps, 2]  # future positions
        """
        model_inputs = input_dict['inputs_per_model']
        batch_size = model_inputs[0].size(0)
        # context = input_dict['context']
        model_input =model_inputs[-1]
        ego_pos = model_input[:, 0, 49, :2].unsqueeze(1)  # (batch, 1, 2)
        ego_features = model_input[:, 0, :, :]
        agent_pos = model_input[:, :, 49, :2]  # (batch, 50, 2)
        dists = torch.norm(agent_pos - ego_pos, dim=-1)  # (batch, 50)
        dists[:, 0] = float('inf')  # mask out ego
        
        _, neighbor_ids = torch.topk(dists, k=1, dim=1, largest=False)  # (batch, 3)
        
        # ---- ENCODE NEIGHBORS ----
        neighbor_out_list = []

        for i in range(1):
            idx = neighbor_ids[:, i]  # (batch,)
            neighbor_trajs = torch.stack([model_input[b, idx[b]] for b in range(batch_size)], dim=0)  # (batch, 50, 5)
        
        ego_features[..., :4] = ego_features[..., :4] - neighbor_trajs[..., :4]

        context, (hidden, cell)= self.context_encoder(ego_features)
        # print('encoder done')
        weights = self.gating_net(hidden[:, -1, :])  # [batch_size, num_experts]
        # print('gating')
        expert_preds = []
        for i, model in enumerate(self.expert_models):
            # print(i)
            pred = model(model_inputs[i])  # [batch_size, future_steps, 2]
            expert_preds.append(pred.unsqueeze(1))  # [B, 1, T, 2]

        expert_preds = torch.cat(expert_preds, dim=1)  # [B, E, T, 2]
        weights = weights.unsqueeze(-1).unsqueeze(-1)  # [B, E, 1, 1]

        fused = torch.sum(expert_preds * weights, dim=1)  # [B, T, 2]
        return fused.view(-1, 60, 2)


In [57]:
def train_improved_model(model, train_dataloader, val_dataloader, 
                         device, criterion=nn.MSELoss(), 
                         lr=0.001, epochs=100, patience=15):
    """
    Improved training function with better debugging and early stopping
    """
    # Initialize optimizer with smaller learning rate
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    
    # Exponential decay scheduler
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95)
    
    early_stopping_patience = patience
    best_val_loss = float('inf')
    no_improvement = 0
    
    # Save initial state for comparison
    initial_state_dict = {k: v.clone() for k, v in model.state_dict().items()}
    
    for epoch in tqdm.tqdm(range(epochs), desc="Epoch", unit="epoch"):
        # ---- Training ----
        model.train()
        train_loss = 0
        num_train_batches = 0
        
        for batch in train_dataloader:
            batch = batch.to(device)
            x = batch.x
            # x = x.reshape(-1, 50, 50, 5)  # (batch_size, num_agents, seq_len, input_dim)
            input_dict = {
                    "inputs_per_model": [
                        x[..., :5].reshape(-1,50,50,5),
                        # x[..., :5].reshape(-1,50,50,5),  # input for PositionModel
                        # x[..., :8].reshape(-1,50,50,8),  # input for VelocityModel
                        x[..., :9].reshape(-1,50,50,9)  # input for AccelAngularModel
                    ]}
            pred = model(input_dict)
            y_all  = batch.y.view(batch.num_graphs, 60, 4)
            y = y_all[..., 2:4]
            
            # Check for NaN predictions
            if torch.isnan(pred).any():
                print(f"WARNING: NaN detected in predictions during training")
                continue
                
            loss = criterion(pred, y)
            
            # Check if loss is valid
            if torch.isnan(loss) or torch.isinf(loss):
                print(f"WARNING: Invalid loss value: {loss.item()}")
                continue
                
            optimizer.zero_grad()
            loss.backward()
            
            # More conservative gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            
            optimizer.step()
            train_loss += loss.item()
            num_train_batches += 1
        
        # Skip epoch if no valid batches
        if num_train_batches == 0:
            print("WARNING: No valid training batches in this epoch")
            continue
            
        train_loss /= num_train_batches
        
        # ---- Validation ----
        model.eval()
        val_loss = 0
        val_mae = 0
        val_mse = 0
        num_val_batches = 0
        
        # Sample predictions for debugging
        sample_input = None
        sample_pred = None
        sample_target = None
        
        with torch.no_grad():
            for batch_idx, batch in enumerate(val_dataloader):
                batch = batch.to(device)
                x = batch.x
                # x = x.reshape(-1, 50, 50, 5)  # (batch_size, num_agents, seq_len, input_dim)
                input_dict = {
                    "inputs_per_model": [
                        x[..., :5].reshape(-1,50,50,5),
                        # x[..., :5].reshape(-1,50,50,5),  # input for PositionModel
                        # x[..., :8].reshape(-1,50,50,8),  # input for VelocityModel
                        x[..., :9].reshape(-1,50,50,9)  # input for AccelAngularModel
                    ]}
                
                pred = model(input_dict)
                y_all  = batch.y.view(batch.num_graphs, 60, 4)
                y = y_all[..., 2:4]
                
                # Store sample for debugging
                if batch_idx == 0 and sample_input is None:
                    sample_input = batch.x[0].cpu().numpy()
                    sample_pred = pred[0].cpu().numpy()
                    sample_target = y[0].cpu().numpy()
                
                # Skip invalid predictions
                if torch.isnan(pred).any():
                    print(f"WARNING: NaN detected in predictions during validation")
                    continue
                    
                batch_loss = criterion(pred, y).item()
                val_loss += batch_loss
                
                # Unnormalize for real-world metrics
                pred_unnorm = pred * batch.scale.view(-1, 1, 1)
                y_unnorm = y * batch.scale.view(-1, 1, 1)
                
                val_mae += nn.L1Loss()(pred_unnorm, y_unnorm).item()
                val_mse += nn.MSELoss()(pred_unnorm, y_unnorm).item()
                
                num_val_batches += 1
        
        # Skip epoch if no valid validation batches
        if num_val_batches == 0:
            print("WARNING: No valid validation batches in this epoch")
            continue
            
        val_loss /= num_val_batches
        val_mae /= num_val_batches
        val_mse /= num_val_batches
        
        # Update learning rate
        scheduler.step()
        
        # Print with more details
        tqdm.tqdm.write(
            f"Epoch {epoch:03d} | LR {optimizer.param_groups[0]['lr']:.6f} | "
            f"Train MSE {train_loss:.4f} | Val MSE {val_loss:.4f} | "
            f"Val MAE {val_mae:.4f} | Val MSE {val_mse:.4f}"
        )
        
        # Debug output - first 3 predictions vs targets
        if epoch % 5 == 0:
            tqdm.tqdm.write(f"Sample pred first 3 steps: {sample_pred[:3]}")
            tqdm.tqdm.write(f"Sample target first 3 steps: {sample_target[:3]}")
            
            # Check if model weights are changing
            if epoch > 0:
                weight_change = False
                for name, param in model.named_parameters():
                    if param.requires_grad:
                        initial_param = initial_state_dict[name]
                        if not torch.allclose(param, initial_param, rtol=1e-4):
                            weight_change = True
                            break
                if not weight_change:
                    tqdm.tqdm.write("WARNING: Model weights barely changing!")
        
        # Relaxed improvement criterion - consider any improvement
        if val_loss < best_val_loss:
            tqdm.tqdm.write(f"Validation improved: {best_val_loss:.6f} -> {val_loss:.6f}")
            best_val_loss = val_loss
            no_improvement = 0
            torch.save(model.state_dict(), best_model)
        else:
            no_improvement += 1
            if no_improvement >= early_stopping_patience:
                print(f"Early stopping after {epoch+1} epochs without improvement")
                break
    
    # Load best model before returning
    model.load_state_dict(torch.load(best_model))
    return model

In [58]:
# Example usage
def train_and_evaluate_model(model_list):
    # Create model
    expert_models = model_list
    moe = MoEModel(expert_models, gating_input_dim=64, future_steps=30)
    moe = moe.to(device)
    
    # Train with improved function
    train_improved_model(
        model=moe,
        train_dataloader=train_dataloader,
        val_dataloader=val_dataloader,
        device=device,
        # lr = 0.007 => 8.946
        lr=0.004,  # Lower learning rate
        patience=20,  # More patience
        epochs=150
    )
    
    
    return moe

In [69]:
def val(model):
    model.eval()
    test_mse = 0
    test_mse_xy = 0
    with torch.no_grad():
        for batch_idx, batch in enumerate(val_dataloader):
            batch = batch.to(device)
            x = batch.x
            # x = x.reshape(-1, 50, 50, 5)  # (batch_size, num_agents, seq_len, input_dim)
            input_dict = {
                "inputs_per_model": [
                    x[..., :5].reshape(-1,50,50,5),
                    # x[..., :5].reshape(-1,50,50,5),  # input for PositionModel
                    # x[..., :8].reshape(-1,50,50,8),  # input for VelocityModel
                    x[..., :9].reshape(-1,50,50,9)  # input for AccelAngularModel
                ]}
            
            pred = model(input_dict)
            y_all  = batch.y.view(batch.num_graphs, 60, 4)
            y = y_all[..., 2:4]
            y_xy = y_all[..., :2]
            # Unnormalize
            pred = pred * batch.scale.view(-1, 1, 1)
            y = y * batch.scale.view(-1, 1, 1)
            y_xy = y_xy * batch.scale.view(-1, 1, 1)
            
            test_mse += nn.MSELoss()(pred, y).item()
            dt = 0.1  # seconds per step
            pred_pos = [0]  # list of (B, 1, 2)

            for t in range(60):
                next_pos = pred_pos[-1] + pred[:, t:t+1, :] * dt  # (B, 1, 2)
                pred_pos.append(next_pos)

            # Concatenate positions across time steps
            pred_xy = torch.cat(pred_pos[1:], dim=1)
            # print(test_mse)
            test_mse_xy += nn.MSELoss()(pred_xy, y_xy).item()
            # print(test_mse_xy)
    test_mse /= len(val_dataloader)
    test_mse_xy /= len(val_dataloader)
    print(f"Val MSE: {test_mse:.4f}")
    print(f"Val MSE: {test_mse_xy:.4f}")



In [60]:
simple_model_n = "best_model_xy.pt"
velocity_model_n = 'best_model0_retry.pt'
acceleration_model_n = 'best_model0_feat_eng.pt'
angular_model_n = 'best_model0_feat_eng_add_ang_acc.pt'


simple_model2 = torch.load(simple_model_n)
simple_model = SimpleLSTM().to(device)
simple_model.load_state_dict(simple_model2)

velocity_model2 = torch.load(velocity_model_n)
velocity_model = VelocityModel().to(device)
velocity_model.load_state_dict(velocity_model2)

acceleration_model2 = torch.load(acceleration_model_n)
acceleration_model = AccelearationModel().to(device)
acceleration_model.load_state_dict(acceleration_model2)

angular_model2 = torch.load(angular_model_n)
angular_model = AngularAccelerationModel().to(device)
angular_model.load_state_dict(angular_model2)


expert_models = [
                   simple_model,
                  #  velocity_model,
                  #  acceleration_model,
                    angular_model
                     ]

In [61]:
train_and_evaluate_model(expert_models)

  future[..., :2] = future[..., :2] - origin
  future[..., :2] = future[..., :2] @ R
  future[..., 2:4] = future[..., 2:4] @ R
Epoch:   1%|          | 1/150 [00:20<50:48, 20.46s/epoch]

Epoch 000 | LR 0.003800 | Train MSE 0.2295 | Val MSE 0.1015 | Val MAE 1.4632 | Val MSE 4.9728
Sample pred first 3 steps: [[-0.15537104  0.10191669]
 [-0.149544    0.08190986]
 [-0.15243182  0.0846969 ]]
Sample target first 3 steps: [[-1.0263214e-05  5.1069769e-06]
 [-8.2099714e-06  2.6287930e-06]
 [ 1.4959690e-05 -1.9491419e-05]]
Validation improved: inf -> 0.101485


Epoch:   1%|▏         | 2/150 [00:41<51:13, 20.77s/epoch]

Epoch 001 | LR 0.003610 | Train MSE 0.1094 | Val MSE 0.0847 | Val MAE 1.3308 | Val MSE 4.1493
Validation improved: 0.101485 -> 0.084680


Epoch:   2%|▏         | 3/150 [01:00<49:27, 20.19s/epoch]

Epoch 002 | LR 0.003429 | Train MSE 0.0892 | Val MSE 0.0664 | Val MAE 1.1737 | Val MSE 3.2519
Validation improved: 0.084680 -> 0.066365


Epoch:   3%|▎         | 4/150 [01:20<48:08, 19.78s/epoch]

Epoch 003 | LR 0.003258 | Train MSE 0.0772 | Val MSE 0.0850 | Val MAE 1.2307 | Val MSE 4.1664


Epoch:   3%|▎         | 5/150 [01:39<47:14, 19.55s/epoch]

Epoch 004 | LR 0.003095 | Train MSE 0.0868 | Val MSE 0.0624 | Val MAE 1.1333 | Val MSE 3.0585
Validation improved: 0.066365 -> 0.062419


Epoch:   4%|▍         | 6/150 [01:58<46:46, 19.49s/epoch]

Epoch 005 | LR 0.002940 | Train MSE 0.0809 | Val MSE 0.0602 | Val MAE 1.0897 | Val MSE 2.9499
Sample pred first 3 steps: [[0.01058621 0.00269172]
 [0.00960363 0.00412056]
 [0.01000907 0.00528579]]
Sample target first 3 steps: [[-1.0263214e-05  5.1069769e-06]
 [-8.2099714e-06  2.6287930e-06]
 [ 1.4959690e-05 -1.9491419e-05]]
Validation improved: 0.062419 -> 0.060202


Epoch:   5%|▍         | 7/150 [02:17<46:16, 19.41s/epoch]

Epoch 006 | LR 0.002793 | Train MSE 0.0708 | Val MSE 0.0596 | Val MAE 1.0682 | Val MSE 2.9203
Validation improved: 0.060202 -> 0.059598


Epoch:   5%|▌         | 8/150 [02:37<45:45, 19.34s/epoch]

Epoch 007 | LR 0.002654 | Train MSE 0.0702 | Val MSE 0.0594 | Val MAE 1.0178 | Val MSE 2.9124
Validation improved: 0.059598 -> 0.059437


Epoch:   6%|▌         | 9/150 [02:56<45:33, 19.39s/epoch]

Epoch 008 | LR 0.002521 | Train MSE 0.0665 | Val MSE 0.0593 | Val MAE 1.0304 | Val MSE 2.9053
Validation improved: 0.059437 -> 0.059291


Epoch:   7%|▋         | 10/150 [03:15<45:06, 19.33s/epoch]

Epoch 009 | LR 0.002395 | Train MSE 0.0635 | Val MSE 0.0524 | Val MAE 0.9801 | Val MSE 2.5699
Validation improved: 0.059291 -> 0.052447


Epoch:   7%|▋         | 11/150 [03:35<44:46, 19.33s/epoch]

Epoch 010 | LR 0.002275 | Train MSE 0.0860 | Val MSE 0.0547 | Val MAE 0.9846 | Val MSE 2.6811
Sample pred first 3 steps: [[0.01088215 0.0059285 ]
 [0.00995551 0.00355508]
 [0.01241987 0.0049829 ]]
Sample target first 3 steps: [[-1.0263214e-05  5.1069769e-06]
 [-8.2099714e-06  2.6287930e-06]
 [ 1.4959690e-05 -1.9491419e-05]]


Epoch:   8%|▊         | 12/150 [03:54<44:19, 19.27s/epoch]

Epoch 011 | LR 0.002161 | Train MSE 0.0598 | Val MSE 0.0536 | Val MAE 1.0037 | Val MSE 2.6288


Epoch:   9%|▊         | 13/150 [04:13<43:52, 19.22s/epoch]

Epoch 012 | LR 0.002053 | Train MSE 0.0599 | Val MSE 0.0513 | Val MAE 0.9780 | Val MSE 2.5155
Validation improved: 0.052447 -> 0.051336


Epoch:   9%|▉         | 14/150 [04:32<43:28, 19.18s/epoch]

Epoch 013 | LR 0.001951 | Train MSE 0.0614 | Val MSE 0.0512 | Val MAE 0.9663 | Val MSE 2.5078
Validation improved: 0.051336 -> 0.051180


Epoch:  10%|█         | 15/150 [04:51<43:16, 19.23s/epoch]

Epoch 014 | LR 0.001853 | Train MSE 0.0585 | Val MSE 0.0548 | Val MAE 0.9651 | Val MSE 2.6871


Epoch:  11%|█         | 16/150 [05:10<42:54, 19.21s/epoch]

Epoch 015 | LR 0.001761 | Train MSE 0.0829 | Val MSE 0.0490 | Val MAE 0.9509 | Val MSE 2.3992
Sample pred first 3 steps: [[-0.02954382  0.01754522]
 [-0.02897093  0.01671624]
 [-0.0278479   0.01632467]]
Sample target first 3 steps: [[-1.0263214e-05  5.1069769e-06]
 [-8.2099714e-06  2.6287930e-06]
 [ 1.4959690e-05 -1.9491419e-05]]
Validation improved: 0.051180 -> 0.048963


Epoch:  11%|█▏        | 17/150 [05:30<42:31, 19.18s/epoch]

Epoch 016 | LR 0.001672 | Train MSE 0.0636 | Val MSE 0.0531 | Val MAE 0.9881 | Val MSE 2.6005


Epoch:  12%|█▏        | 18/150 [05:49<42:20, 19.25s/epoch]

Epoch 017 | LR 0.001589 | Train MSE 0.0573 | Val MSE 0.0487 | Val MAE 0.9083 | Val MSE 2.3862
Validation improved: 0.048963 -> 0.048698


Epoch:  13%|█▎        | 19/150 [06:08<42:08, 19.30s/epoch]

Epoch 018 | LR 0.001509 | Train MSE 0.0542 | Val MSE 0.0473 | Val MAE 0.9105 | Val MSE 2.3156
Validation improved: 0.048698 -> 0.047257


Epoch:  13%|█▎        | 20/150 [06:28<41:52, 19.33s/epoch]

Epoch 019 | LR 0.001434 | Train MSE 0.0592 | Val MSE 0.0508 | Val MAE 0.9532 | Val MSE 2.4870


Epoch:  14%|█▍        | 21/150 [06:47<41:28, 19.29s/epoch]

Epoch 020 | LR 0.001362 | Train MSE 0.0543 | Val MSE 0.0477 | Val MAE 0.8953 | Val MSE 2.3363
Sample pred first 3 steps: [[-0.00819267  0.01534182]
 [-0.00774683  0.01448792]
 [-0.00657799  0.01353488]]
Sample target first 3 steps: [[-1.0263214e-05  5.1069769e-06]
 [-8.2099714e-06  2.6287930e-06]
 [ 1.4959690e-05 -1.9491419e-05]]


Epoch:  15%|█▍        | 22/150 [07:06<41:05, 19.26s/epoch]

Epoch 021 | LR 0.001294 | Train MSE 0.0524 | Val MSE 0.0512 | Val MAE 0.9326 | Val MSE 2.5068


Epoch:  15%|█▌        | 23/150 [07:25<40:39, 19.21s/epoch]

Epoch 022 | LR 0.001229 | Train MSE 0.0520 | Val MSE 0.0485 | Val MAE 0.9064 | Val MSE 2.3786


Epoch:  16%|█▌        | 24/150 [07:45<40:28, 19.28s/epoch]

Epoch 023 | LR 0.001168 | Train MSE 0.0506 | Val MSE 0.0519 | Val MAE 0.9177 | Val MSE 2.5436


Epoch:  17%|█▋        | 25/150 [08:04<39:54, 19.15s/epoch]

Epoch 024 | LR 0.001110 | Train MSE 0.0517 | Val MSE 0.0489 | Val MAE 0.8849 | Val MSE 2.3966


Epoch:  17%|█▋        | 26/150 [08:23<39:37, 19.18s/epoch]

Epoch 025 | LR 0.001054 | Train MSE 0.0487 | Val MSE 0.0496 | Val MAE 0.8790 | Val MSE 2.4313
Sample pred first 3 steps: [[-0.03678217  0.02879348]
 [-0.03600746  0.02728501]
 [-0.03615708  0.02573924]]
Sample target first 3 steps: [[-1.0263214e-05  5.1069769e-06]
 [-8.2099714e-06  2.6287930e-06]
 [ 1.4959690e-05 -1.9491419e-05]]


Epoch:  18%|█▊        | 27/150 [08:42<39:20, 19.19s/epoch]

Epoch 026 | LR 0.001001 | Train MSE 0.0589 | Val MSE 0.0495 | Val MAE 0.9033 | Val MSE 2.4240


Epoch:  19%|█▊        | 28/150 [09:01<38:56, 19.15s/epoch]

Epoch 027 | LR 0.000951 | Train MSE 0.0485 | Val MSE 0.0479 | Val MAE 0.8731 | Val MSE 2.3476


Epoch:  19%|█▉        | 29/150 [09:20<38:43, 19.20s/epoch]

Epoch 028 | LR 0.000904 | Train MSE 0.0493 | Val MSE 0.0488 | Val MAE 0.8989 | Val MSE 2.3928


Epoch:  20%|██        | 30/150 [09:40<38:24, 19.21s/epoch]

Epoch 029 | LR 0.000859 | Train MSE 0.0472 | Val MSE 0.0494 | Val MAE 0.8727 | Val MSE 2.4199


Epoch:  21%|██        | 31/150 [09:59<38:03, 19.19s/epoch]

Epoch 030 | LR 0.000816 | Train MSE 0.0493 | Val MSE 0.0484 | Val MAE 0.8815 | Val MSE 2.3740
Sample pred first 3 steps: [[-0.01530728  0.01720907]
 [-0.01464726  0.01604044]
 [-0.01380106  0.01559897]]
Sample target first 3 steps: [[-1.0263214e-05  5.1069769e-06]
 [-8.2099714e-06  2.6287930e-06]
 [ 1.4959690e-05 -1.9491419e-05]]


Epoch:  21%|██▏       | 32/150 [10:18<37:37, 19.13s/epoch]

Epoch 031 | LR 0.000775 | Train MSE 0.0480 | Val MSE 0.0450 | Val MAE 0.8666 | Val MSE 2.2060
Validation improved: 0.047257 -> 0.045020


Epoch:  22%|██▏       | 33/150 [10:37<37:21, 19.16s/epoch]

Epoch 032 | LR 0.000736 | Train MSE 0.0486 | Val MSE 0.0492 | Val MAE 0.8889 | Val MSE 2.4092


Epoch:  23%|██▎       | 34/150 [10:56<37:04, 19.17s/epoch]

Epoch 033 | LR 0.000699 | Train MSE 0.0475 | Val MSE 0.0444 | Val MAE 0.8461 | Val MSE 2.1759
Validation improved: 0.045020 -> 0.044406


Epoch:  23%|██▎       | 35/150 [11:15<36:42, 19.15s/epoch]

Epoch 034 | LR 0.000664 | Train MSE 0.0464 | Val MSE 0.0461 | Val MAE 0.8711 | Val MSE 2.2586


Epoch:  24%|██▍       | 36/150 [11:39<39:02, 20.55s/epoch]

Epoch 035 | LR 0.000631 | Train MSE 0.0515 | Val MSE 0.0445 | Val MAE 0.8646 | Val MSE 2.1825
Sample pred first 3 steps: [[-0.00275062  0.01670329]
 [-0.00197171  0.01606143]
 [-0.00162905  0.01569856]]
Sample target first 3 steps: [[-1.0263214e-05  5.1069769e-06]
 [-8.2099714e-06  2.6287930e-06]
 [ 1.4959690e-05 -1.9491419e-05]]


Epoch:  25%|██▍       | 37/150 [12:00<38:53, 20.65s/epoch]

Epoch 036 | LR 0.000600 | Train MSE 0.0467 | Val MSE 0.0460 | Val MAE 0.8536 | Val MSE 2.2549


Epoch:  25%|██▌       | 38/150 [12:19<37:36, 20.15s/epoch]

Epoch 037 | LR 0.000570 | Train MSE 0.0493 | Val MSE 0.0449 | Val MAE 0.8489 | Val MSE 2.1998


Epoch:  26%|██▌       | 39/150 [27:55<9:05:36, 294.92s/epoch]

Epoch 038 | LR 0.000541 | Train MSE 0.0450 | Val MSE 0.0450 | Val MAE 0.8576 | Val MSE 2.2050


Epoch:  27%|██▋       | 40/150 [28:14<6:28:48, 212.08s/epoch]

Epoch 039 | LR 0.000514 | Train MSE 0.0449 | Val MSE 0.0453 | Val MAE 0.8515 | Val MSE 2.2217


Epoch:  27%|██▋       | 41/150 [28:35<4:41:15, 154.82s/epoch]

Epoch 040 | LR 0.000488 | Train MSE 0.0443 | Val MSE 0.0461 | Val MAE 0.8543 | Val MSE 2.2570
Sample pred first 3 steps: [[-0.01371375  0.0194    ]
 [-0.01290339  0.01819277]
 [-0.01228211  0.01727289]]
Sample target first 3 steps: [[-1.0263214e-05  5.1069769e-06]
 [-8.2099714e-06  2.6287930e-06]
 [ 1.4959690e-05 -1.9491419e-05]]


Epoch:  28%|██▊       | 42/150 [28:54<3:25:27, 114.14s/epoch]

Epoch 041 | LR 0.000464 | Train MSE 0.0442 | Val MSE 0.0448 | Val MAE 0.8499 | Val MSE 2.1958


Epoch:  29%|██▊       | 43/150 [29:13<2:32:45, 85.66s/epoch] 

Epoch 042 | LR 0.000441 | Train MSE 0.0434 | Val MSE 0.0457 | Val MAE 0.8570 | Val MSE 2.2385


Epoch:  29%|██▉       | 44/150 [29:33<1:56:09, 65.75s/epoch]

Epoch 043 | LR 0.000419 | Train MSE 0.0439 | Val MSE 0.0455 | Val MAE 0.8464 | Val MSE 2.2299


Epoch:  30%|███       | 45/150 [29:52<1:30:38, 51.79s/epoch]

Epoch 044 | LR 0.000398 | Train MSE 0.0433 | Val MSE 0.0449 | Val MAE 0.8405 | Val MSE 2.1993


Epoch:  31%|███       | 46/150 [30:11<1:12:58, 42.10s/epoch]

Epoch 045 | LR 0.000378 | Train MSE 0.0445 | Val MSE 0.0439 | Val MAE 0.8383 | Val MSE 2.1503
Sample pred first 3 steps: [[-0.01276849  0.01766177]
 [-0.01260876  0.01677577]
 [-0.01225994  0.01617947]]
Sample target first 3 steps: [[-1.0263214e-05  5.1069769e-06]
 [-8.2099714e-06  2.6287930e-06]
 [ 1.4959690e-05 -1.9491419e-05]]
Validation improved: 0.044406 -> 0.043883


Epoch:  31%|███▏      | 47/150 [30:31<1:00:31, 35.26s/epoch]

Epoch 046 | LR 0.000359 | Train MSE 0.0441 | Val MSE 0.0428 | Val MAE 0.8397 | Val MSE 2.0965
Validation improved: 0.043883 -> 0.042785


Epoch:  32%|███▏      | 48/150 [30:50<51:47, 30.47s/epoch]  

Epoch 047 | LR 0.000341 | Train MSE 0.0428 | Val MSE 0.0433 | Val MAE 0.8265 | Val MSE 2.1221


Epoch:  33%|███▎      | 49/150 [31:09<45:35, 27.09s/epoch]

Epoch 048 | LR 0.000324 | Train MSE 0.0420 | Val MSE 0.0430 | Val MAE 0.8246 | Val MSE 2.1078


Epoch:  33%|███▎      | 50/150 [31:29<41:18, 24.78s/epoch]

Epoch 049 | LR 0.000308 | Train MSE 0.0425 | Val MSE 0.0428 | Val MAE 0.8166 | Val MSE 2.0956
Validation improved: 0.042785 -> 0.042767


Epoch:  34%|███▍      | 51/150 [31:48<38:13, 23.16s/epoch]

Epoch 050 | LR 0.000292 | Train MSE 0.0420 | Val MSE 0.0433 | Val MAE 0.8308 | Val MSE 2.1194
Sample pred first 3 steps: [[-0.00512389  0.01589687]
 [-0.00487048  0.01521665]
 [-0.00464458  0.01483158]]
Sample target first 3 steps: [[-1.0263214e-05  5.1069769e-06]
 [-8.2099714e-06  2.6287930e-06]
 [ 1.4959690e-05 -1.9491419e-05]]


Epoch:  35%|███▍      | 52/150 [32:07<35:47, 21.92s/epoch]

Epoch 051 | LR 0.000278 | Train MSE 0.0427 | Val MSE 0.0422 | Val MAE 0.8193 | Val MSE 2.0687
Validation improved: 0.042767 -> 0.042218


Epoch:  35%|███▌      | 53/150 [32:26<34:12, 21.16s/epoch]

Epoch 052 | LR 0.000264 | Train MSE 0.0414 | Val MSE 0.0422 | Val MAE 0.8218 | Val MSE 2.0660
Validation improved: 0.042218 -> 0.042163


Epoch:  36%|███▌      | 54/150 [32:46<32:56, 20.59s/epoch]

Epoch 053 | LR 0.000251 | Train MSE 0.0431 | Val MSE 0.0447 | Val MAE 0.8508 | Val MSE 2.1911


Epoch:  37%|███▋      | 55/150 [33:05<32:03, 20.25s/epoch]

Epoch 054 | LR 0.000238 | Train MSE 0.0422 | Val MSE 0.0432 | Val MAE 0.8330 | Val MSE 2.1164


Epoch:  37%|███▋      | 56/150 [33:25<31:19, 20.00s/epoch]

Epoch 055 | LR 0.000226 | Train MSE 0.0414 | Val MSE 0.0429 | Val MAE 0.8234 | Val MSE 2.1020
Sample pred first 3 steps: [[-0.00217451  0.01056751]
 [-0.00174761  0.0104585 ]
 [-0.00163846  0.0100283 ]]
Sample target first 3 steps: [[-1.0263214e-05  5.1069769e-06]
 [-8.2099714e-06  2.6287930e-06]
 [ 1.4959690e-05 -1.9491419e-05]]


Epoch:  38%|███▊      | 57/150 [33:44<30:51, 19.91s/epoch]

Epoch 056 | LR 0.000215 | Train MSE 0.0409 | Val MSE 0.0427 | Val MAE 0.8258 | Val MSE 2.0903


Epoch:  39%|███▊      | 58/150 [34:04<30:31, 19.90s/epoch]

Epoch 057 | LR 0.000204 | Train MSE 0.0412 | Val MSE 0.0431 | Val MAE 0.8241 | Val MSE 2.1129


Epoch:  39%|███▉      | 59/150 [34:24<30:05, 19.84s/epoch]

Epoch 058 | LR 0.000194 | Train MSE 0.0408 | Val MSE 0.0434 | Val MAE 0.8267 | Val MSE 2.1264


Epoch:  40%|████      | 60/150 [34:43<29:40, 19.78s/epoch]

Epoch 059 | LR 0.000184 | Train MSE 0.0410 | Val MSE 0.0431 | Val MAE 0.8271 | Val MSE 2.1142


Epoch:  41%|████      | 61/150 [35:03<29:26, 19.85s/epoch]

Epoch 060 | LR 0.000175 | Train MSE 0.0413 | Val MSE 0.0429 | Val MAE 0.8234 | Val MSE 2.1036
Sample pred first 3 steps: [[0.00031498 0.01535163]
 [0.00068092 0.01470681]
 [0.00079019 0.0143331 ]]
Sample target first 3 steps: [[-1.0263214e-05  5.1069769e-06]
 [-8.2099714e-06  2.6287930e-06]
 [ 1.4959690e-05 -1.9491419e-05]]


Epoch:  41%|████▏     | 62/150 [35:23<28:55, 19.72s/epoch]

Epoch 061 | LR 0.000166 | Train MSE 0.0417 | Val MSE 0.0427 | Val MAE 0.8262 | Val MSE 2.0947


Epoch:  42%|████▏     | 63/150 [35:43<28:34, 19.70s/epoch]

Epoch 062 | LR 0.000158 | Train MSE 0.0412 | Val MSE 0.0427 | Val MAE 0.8184 | Val MSE 2.0904


Epoch:  43%|████▎     | 64/150 [36:02<28:20, 19.77s/epoch]

Epoch 063 | LR 0.000150 | Train MSE 0.0402 | Val MSE 0.0429 | Val MAE 0.8178 | Val MSE 2.1019


Epoch:  43%|████▎     | 65/150 [36:23<28:11, 19.89s/epoch]

Epoch 064 | LR 0.000143 | Train MSE 0.0417 | Val MSE 0.0428 | Val MAE 0.8194 | Val MSE 2.0975


Epoch:  44%|████▍     | 66/150 [36:43<27:56, 19.96s/epoch]

Epoch 065 | LR 0.000135 | Train MSE 0.0399 | Val MSE 0.0430 | Val MAE 0.8208 | Val MSE 2.1066
Sample pred first 3 steps: [[-0.00286125  0.01311345]
 [-0.00251129  0.01271136]
 [-0.00228604  0.01240343]]
Sample target first 3 steps: [[-1.0263214e-05  5.1069769e-06]
 [-8.2099714e-06  2.6287930e-06]
 [ 1.4959690e-05 -1.9491419e-05]]


Epoch:  45%|████▍     | 67/150 [37:03<27:39, 20.00s/epoch]

Epoch 066 | LR 0.000129 | Train MSE 0.0401 | Val MSE 0.0424 | Val MAE 0.8109 | Val MSE 2.0798


Epoch:  45%|████▌     | 68/150 [37:22<27:04, 19.81s/epoch]

Epoch 067 | LR 0.000122 | Train MSE 0.0407 | Val MSE 0.0431 | Val MAE 0.8279 | Val MSE 2.1124


Epoch:  46%|████▌     | 69/150 [37:42<26:51, 19.89s/epoch]

Epoch 068 | LR 0.000116 | Train MSE 0.0405 | Val MSE 0.0430 | Val MAE 0.8225 | Val MSE 2.1062


Epoch:  47%|████▋     | 70/150 [38:02<26:16, 19.70s/epoch]

Epoch 069 | LR 0.000110 | Train MSE 0.0404 | Val MSE 0.0431 | Val MAE 0.8218 | Val MSE 2.1098


Epoch:  47%|████▋     | 71/150 [38:21<25:59, 19.73s/epoch]

Epoch 070 | LR 0.000105 | Train MSE 0.0399 | Val MSE 0.0426 | Val MAE 0.8138 | Val MSE 2.0862
Sample pred first 3 steps: [[-0.00109431  0.01034892]
 [-0.00084768  0.01009248]
 [-0.00069709  0.00964039]]
Sample target first 3 steps: [[-1.0263214e-05  5.1069769e-06]
 [-8.2099714e-06  2.6287930e-06]
 [ 1.4959690e-05 -1.9491419e-05]]


Epoch:  48%|████▊     | 72/150 [38:41<25:35, 19.68s/epoch]

Epoch 071 | LR 0.000100 | Train MSE 0.0398 | Val MSE 0.0439 | Val MAE 0.8244 | Val MSE 2.1507


Epoch:  48%|████▊     | 72/150 [39:01<42:16, 32.52s/epoch]

Epoch 072 | LR 0.000095 | Train MSE 0.0402 | Val MSE 0.0427 | Val MAE 0.8152 | Val MSE 2.0945
Early stopping after 73 epochs without improvement





MoEModel(
  (expert_models): ModuleList(
    (0): SimpleLSTM(
      (lstm): LSTM(5, 512, batch_first=True)
      (fc1): Linear(in_features=512, out_features=512, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
      (relu): ReLU()
      (fc2): Linear(in_features=512, out_features=120, bias=True)
    )
    (1): AngularAccelerationModel(
      (ego_encoder): LSTM(9, 512, batch_first=True)
      (neighbor_encoder): LSTM(9, 512, batch_first=True)
      (fc1): Linear(in_features=1024, out_features=1024, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
      (relu): ReLU()
      (fc2): Linear(in_features=1024, out_features=120, bias=True)
    )
  )
  (gating_net): GatingNetwork(
    (net): Sequential(
      (0): Linear(in_features=64, out_features=64, bias=True)
      (1): ReLU()
      (2): Dropout(p=0.3, inplace=False)
      (3): Linear(in_features=64, out_features=2, bias=True)
    )
  )
  (context_encoder): LSTM(9, 64, batch_first=True)
)

In [66]:
best_model_nn

MoEModel(
  (expert_models): ModuleList(
    (0): SimpleLSTM(
      (lstm): LSTM(5, 512, batch_first=True)
      (fc1): Linear(in_features=512, out_features=512, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
      (relu): ReLU()
      (fc2): Linear(in_features=512, out_features=120, bias=True)
    )
    (1): AngularAccelerationModel(
      (ego_encoder): LSTM(9, 512, batch_first=True)
      (neighbor_encoder): LSTM(9, 512, batch_first=True)
      (fc1): Linear(in_features=1024, out_features=1024, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
      (relu): ReLU()
      (fc2): Linear(in_features=1024, out_features=120, bias=True)
    )
  )
  (gating_net): GatingNetwork(
    (net): Sequential(
      (0): Linear(in_features=64, out_features=64, bias=True)
      (1): ReLU()
      (2): Dropout(p=0.3, inplace=False)
      (3): Linear(in_features=64, out_features=2, bias=True)
    )
  )
  (context_encoder): LSTM(9, 64, batch_first=True)
)

In [70]:
best_model_nn = "ensemble_model_4.pt"
best_model2 = torch.load(best_model_nn)
best_model = MoEModel(expert_models, gating_input_dim=64, future_steps=30).to(device)
best_model.load_state_dict(best_model2)
val(best_model)


  future[..., :2] = future[..., :2] - origin


Val MSE: 2.0660
Val MSE: 12.1586


In [33]:
# val(model, val_dataloader)
train_and_evaluate_model(expert_models)

  future[..., :2] = future[..., :2] - origin
  future[..., :2] = future[..., :2] @ R
  future[..., 2:4] = future[..., 2:4] @ R
Epoch:   1%|          | 1/150 [00:30<1:16:31, 30.82s/epoch]

Epoch 000 | LR 0.000950 | Train MSE 0.1605 | Val MSE 0.0888 | Val MAE 1.4146 | Val MSE 4.3534
Sample pred first 3 steps: [[0.02323055 0.06422044]
 [0.02527423 0.06352042]
 [0.02442436 0.06019098]]
Sample target first 3 steps: [[-1.0263214e-05  5.1069769e-06]
 [-8.2099714e-06  2.6287930e-06]
 [ 1.4959690e-05 -1.9491419e-05]]
Validation improved: inf -> 0.088845


Epoch:   1%|▏         | 2/150 [00:59<1:12:55, 29.56s/epoch]

Epoch 001 | LR 0.000902 | Train MSE 0.1034 | Val MSE 0.0754 | Val MAE 1.2532 | Val MSE 3.6958
Validation improved: 0.088845 -> 0.075424


Epoch:   2%|▏         | 3/150 [01:27<1:10:57, 28.97s/epoch]

Epoch 002 | LR 0.000857 | Train MSE 0.0844 | Val MSE 0.0651 | Val MAE 1.1327 | Val MSE 3.1886
Validation improved: 0.075424 -> 0.065073


Epoch:   3%|▎         | 4/150 [01:56<1:09:53, 28.72s/epoch]

Epoch 003 | LR 0.000815 | Train MSE 0.0805 | Val MSE 0.0627 | Val MAE 1.1104 | Val MSE 3.0707
Validation improved: 0.065073 -> 0.062668


Epoch:   3%|▎         | 5/150 [02:25<1:09:58, 28.96s/epoch]

Epoch 004 | LR 0.000774 | Train MSE 0.0732 | Val MSE 0.0613 | Val MAE 1.0860 | Val MSE 3.0041
Validation improved: 0.062668 -> 0.061308


Epoch:   3%|▎         | 5/150 [02:54<1:09:58, 28.96s/epoch]

Epoch 005 | LR 0.000735 | Train MSE 0.0693 | Val MSE 0.0603 | Val MAE 1.0607 | Val MSE 2.9544
Sample pred first 3 steps: [[-0.02051372  0.02869587]
 [-0.01995567  0.02716169]
 [-0.019794    0.02499215]]
Sample target first 3 steps: [[-1.0263214e-05  5.1069769e-06]
 [-8.2099714e-06  2.6287930e-06]
 [ 1.4959690e-05 -1.9491419e-05]]


Epoch:   4%|▍         | 6/150 [02:55<1:10:00, 29.17s/epoch]

Validation improved: 0.061308 -> 0.060294


Epoch:   5%|▍         | 7/150 [03:24<1:09:25, 29.13s/epoch]

Epoch 006 | LR 0.000698 | Train MSE 0.0689 | Val MSE 0.0560 | Val MAE 1.0288 | Val MSE 2.7416
Validation improved: 0.060294 -> 0.055952


Epoch:   5%|▌         | 8/150 [03:53<1:08:56, 29.13s/epoch]

Epoch 007 | LR 0.000663 | Train MSE 0.0675 | Val MSE 0.0541 | Val MAE 0.9977 | Val MSE 2.6523
Validation improved: 0.055952 -> 0.054128


Epoch:   6%|▌         | 9/150 [04:22<1:08:24, 29.11s/epoch]

Epoch 008 | LR 0.000630 | Train MSE 0.0604 | Val MSE 0.0534 | Val MAE 0.9970 | Val MSE 2.6189
Validation improved: 0.054128 -> 0.053446


Epoch:   7%|▋         | 10/150 [04:50<1:07:21, 28.87s/epoch]

Epoch 009 | LR 0.000599 | Train MSE 0.0617 | Val MSE 0.0521 | Val MAE 0.9744 | Val MSE 2.5527
Validation improved: 0.053446 -> 0.052096


Epoch:   7%|▋         | 11/150 [05:19<1:06:56, 28.90s/epoch]

Epoch 010 | LR 0.000569 | Train MSE 0.0616 | Val MSE 0.0526 | Val MAE 0.9556 | Val MSE 2.5772
Sample pred first 3 steps: [[-0.00066602  0.0168576 ]
 [-0.00211234  0.01877681]
 [-0.00065202  0.01713165]]
Sample target first 3 steps: [[-1.0263214e-05  5.1069769e-06]
 [-8.2099714e-06  2.6287930e-06]
 [ 1.4959690e-05 -1.9491419e-05]]


Epoch:   8%|▊         | 12/150 [05:48<1:06:42, 29.00s/epoch]

Epoch 011 | LR 0.000540 | Train MSE 0.0592 | Val MSE 0.0508 | Val MAE 0.9620 | Val MSE 2.4910
Validation improved: 0.052096 -> 0.050837


Epoch:   9%|▊         | 13/150 [06:18<1:06:23, 29.08s/epoch]

Epoch 012 | LR 0.000513 | Train MSE 0.0562 | Val MSE 0.0527 | Val MAE 0.9860 | Val MSE 2.5809


Epoch:   9%|▉         | 14/150 [06:46<1:05:31, 28.91s/epoch]

Epoch 013 | LR 0.000488 | Train MSE 0.0595 | Val MSE 0.0568 | Val MAE 0.9719 | Val MSE 2.7829


Epoch:  10%|█         | 15/150 [07:15<1:05:05, 28.93s/epoch]

Epoch 014 | LR 0.000463 | Train MSE 0.0554 | Val MSE 0.0531 | Val MAE 0.9495 | Val MSE 2.6025


Epoch:  11%|█         | 16/150 [07:43<1:04:07, 28.71s/epoch]

Epoch 015 | LR 0.000440 | Train MSE 0.0557 | Val MSE 0.0516 | Val MAE 0.9248 | Val MSE 2.5296
Sample pred first 3 steps: [[-0.00915118  0.02023512]
 [-0.01177842  0.02204712]
 [-0.00880585  0.02283294]]
Sample target first 3 steps: [[-1.0263214e-05  5.1069769e-06]
 [-8.2099714e-06  2.6287930e-06]
 [ 1.4959690e-05 -1.9491419e-05]]


Epoch:  11%|█▏        | 17/150 [08:12<1:03:23, 28.60s/epoch]

Epoch 016 | LR 0.000418 | Train MSE 0.0547 | Val MSE 0.0522 | Val MAE 0.9112 | Val MSE 2.5580


Epoch:  12%|█▏        | 18/150 [08:41<1:03:21, 28.80s/epoch]

Epoch 017 | LR 0.000397 | Train MSE 0.0543 | Val MSE 0.0515 | Val MAE 0.9366 | Val MSE 2.5251


Epoch:  13%|█▎        | 19/150 [09:09<1:02:32, 28.65s/epoch]

Epoch 018 | LR 0.000377 | Train MSE 0.0537 | Val MSE 0.0490 | Val MAE 0.9341 | Val MSE 2.4033
Validation improved: 0.050837 -> 0.049047


Epoch:  13%|█▎        | 20/150 [09:38<1:02:09, 28.69s/epoch]

Epoch 019 | LR 0.000358 | Train MSE 0.0516 | Val MSE 0.0488 | Val MAE 0.9124 | Val MSE 2.3908
Validation improved: 0.049047 -> 0.048792


Epoch:  14%|█▍        | 21/150 [10:07<1:01:38, 28.67s/epoch]

Epoch 020 | LR 0.000341 | Train MSE 0.0517 | Val MSE 0.0483 | Val MAE 0.9019 | Val MSE 2.3650
Sample pred first 3 steps: [[-0.01446448  0.02343647]
 [-0.01430107  0.02225365]
 [-0.01381221  0.02183931]]
Sample target first 3 steps: [[-1.0263214e-05  5.1069769e-06]
 [-8.2099714e-06  2.6287930e-06]
 [ 1.4959690e-05 -1.9491419e-05]]
Validation improved: 0.048792 -> 0.048265


Epoch:  15%|█▍        | 22/150 [10:35<1:01:02, 28.62s/epoch]

Epoch 021 | LR 0.000324 | Train MSE 0.0533 | Val MSE 0.0467 | Val MAE 0.8971 | Val MSE 2.2884
Validation improved: 0.048265 -> 0.046702


Epoch:  15%|█▌        | 23/150 [11:04<1:00:37, 28.64s/epoch]

Epoch 022 | LR 0.000307 | Train MSE 0.0523 | Val MSE 0.0478 | Val MAE 0.9064 | Val MSE 2.3436


Epoch:  16%|█▌        | 24/150 [11:33<1:00:12, 28.67s/epoch]

Epoch 023 | LR 0.000292 | Train MSE 0.0514 | Val MSE 0.0509 | Val MAE 0.9100 | Val MSE 2.4940


Epoch:  17%|█▋        | 25/150 [12:01<59:40, 28.64s/epoch]  

Epoch 024 | LR 0.000277 | Train MSE 0.0520 | Val MSE 0.0513 | Val MAE 0.9122 | Val MSE 2.5123


Epoch:  17%|█▋        | 26/150 [12:29<58:49, 28.46s/epoch]

Epoch 025 | LR 0.000264 | Train MSE 0.0494 | Val MSE 0.0476 | Val MAE 0.9057 | Val MSE 2.3344
Sample pred first 3 steps: [[0.00506167 0.02080267]
 [0.00452659 0.02338865]
 [0.00669341 0.02334145]]
Sample target first 3 steps: [[-1.0263214e-05  5.1069769e-06]
 [-8.2099714e-06  2.6287930e-06]
 [ 1.4959690e-05 -1.9491419e-05]]


Epoch:  18%|█▊        | 27/150 [12:57<58:10, 28.38s/epoch]

Epoch 026 | LR 0.000250 | Train MSE 0.0535 | Val MSE 0.0485 | Val MAE 0.8971 | Val MSE 2.3774


Epoch:  19%|█▊        | 28/150 [13:26<57:37, 28.34s/epoch]

Epoch 027 | LR 0.000238 | Train MSE 0.0495 | Val MSE 0.0483 | Val MAE 0.8915 | Val MSE 2.3679


Epoch:  19%|█▉        | 29/150 [13:55<57:31, 28.53s/epoch]

Epoch 028 | LR 0.000226 | Train MSE 0.0490 | Val MSE 0.0470 | Val MAE 0.8969 | Val MSE 2.3020


Epoch:  20%|██        | 30/150 [14:23<57:01, 28.51s/epoch]

Epoch 029 | LR 0.000215 | Train MSE 0.0488 | Val MSE 0.0475 | Val MAE 0.8871 | Val MSE 2.3283


Epoch:  21%|██        | 31/150 [14:52<56:56, 28.71s/epoch]

Epoch 030 | LR 0.000204 | Train MSE 0.0490 | Val MSE 0.0475 | Val MAE 0.8971 | Val MSE 2.3294
Sample pred first 3 steps: [[0.00206064 0.00948643]
 [0.00164194 0.00883075]
 [0.00291455 0.00983321]]
Sample target first 3 steps: [[-1.0263214e-05  5.1069769e-06]
 [-8.2099714e-06  2.6287930e-06]
 [ 1.4959690e-05 -1.9491419e-05]]


Epoch:  21%|██▏       | 32/150 [15:21<56:33, 28.76s/epoch]

Epoch 031 | LR 0.000194 | Train MSE 0.0498 | Val MSE 0.0460 | Val MAE 0.8840 | Val MSE 2.2561
Validation improved: 0.046702 -> 0.046043


Epoch:  22%|██▏       | 33/150 [15:50<55:56, 28.69s/epoch]

Epoch 032 | LR 0.000184 | Train MSE 0.0491 | Val MSE 0.0449 | Val MAE 0.8817 | Val MSE 2.1998
Validation improved: 0.046043 -> 0.044894


Epoch:  23%|██▎       | 34/150 [16:18<55:35, 28.76s/epoch]

Epoch 033 | LR 0.000175 | Train MSE 0.0465 | Val MSE 0.0462 | Val MAE 0.8818 | Val MSE 2.2662


Epoch:  23%|██▎       | 35/150 [16:48<55:32, 28.98s/epoch]

Epoch 034 | LR 0.000166 | Train MSE 0.0456 | Val MSE 0.0467 | Val MAE 0.8801 | Val MSE 2.2867


Epoch:  24%|██▍       | 36/150 [17:18<55:25, 29.17s/epoch]

Epoch 035 | LR 0.000158 | Train MSE 0.0467 | Val MSE 0.0491 | Val MAE 0.8948 | Val MSE 2.4039
Sample pred first 3 steps: [[0.00645726 0.01411067]
 [0.01016316 0.01386602]
 [0.00725953 0.01486521]]
Sample target first 3 steps: [[-1.0263214e-05  5.1069769e-06]
 [-8.2099714e-06  2.6287930e-06]
 [ 1.4959690e-05 -1.9491419e-05]]


Epoch:  25%|██▍       | 37/150 [17:47<55:16, 29.35s/epoch]

Epoch 036 | LR 0.000150 | Train MSE 0.0486 | Val MSE 0.0471 | Val MAE 0.8836 | Val MSE 2.3087


Epoch:  25%|██▌       | 38/150 [18:17<54:40, 29.29s/epoch]

Epoch 037 | LR 0.000142 | Train MSE 0.0475 | Val MSE 0.0482 | Val MAE 0.8846 | Val MSE 2.3604


Epoch:  26%|██▌       | 39/150 [18:46<54:21, 29.38s/epoch]

Epoch 038 | LR 0.000135 | Train MSE 0.0464 | Val MSE 0.0467 | Val MAE 0.8684 | Val MSE 2.2885


Epoch:  27%|██▋       | 40/150 [19:17<54:28, 29.72s/epoch]

Epoch 039 | LR 0.000129 | Train MSE 0.0470 | Val MSE 0.0468 | Val MAE 0.8725 | Val MSE 2.2956


Epoch:  27%|██▋       | 41/150 [19:46<53:44, 29.59s/epoch]

Epoch 040 | LR 0.000122 | Train MSE 0.0463 | Val MSE 0.0467 | Val MAE 0.8729 | Val MSE 2.2892
Sample pred first 3 steps: [[-0.00219257  0.00896662]
 [-0.00244757  0.00797808]
 [-0.0015401   0.00724611]]
Sample target first 3 steps: [[-1.0263214e-05  5.1069769e-06]
 [-8.2099714e-06  2.6287930e-06]
 [ 1.4959690e-05 -1.9491419e-05]]


Epoch:  28%|██▊       | 42/150 [20:16<53:21, 29.65s/epoch]

Epoch 041 | LR 0.000116 | Train MSE 0.0454 | Val MSE 0.0460 | Val MAE 0.8585 | Val MSE 2.2518


Epoch:  29%|██▊       | 43/150 [20:45<52:28, 29.42s/epoch]

Epoch 042 | LR 0.000110 | Train MSE 0.0454 | Val MSE 0.0473 | Val MAE 0.8703 | Val MSE 2.3164


Epoch:  29%|██▉       | 44/150 [21:14<52:05, 29.48s/epoch]

Epoch 043 | LR 0.000105 | Train MSE 0.0455 | Val MSE 0.0471 | Val MAE 0.8702 | Val MSE 2.3086


Epoch:  30%|███       | 45/150 [21:43<51:17, 29.31s/epoch]

Epoch 044 | LR 0.000099 | Train MSE 0.0452 | Val MSE 0.0466 | Val MAE 0.8635 | Val MSE 2.2844


Epoch:  31%|███       | 46/150 [22:12<50:28, 29.12s/epoch]

Epoch 045 | LR 0.000094 | Train MSE 0.0453 | Val MSE 0.0481 | Val MAE 0.8690 | Val MSE 2.3572
Sample pred first 3 steps: [[0.00351993 0.0037683 ]
 [0.00342034 0.00357878]
 [0.00537532 0.00557312]]
Sample target first 3 steps: [[-1.0263214e-05  5.1069769e-06]
 [-8.2099714e-06  2.6287930e-06]
 [ 1.4959690e-05 -1.9491419e-05]]


Epoch:  31%|███▏      | 47/150 [22:40<49:43, 28.97s/epoch]

Epoch 046 | LR 0.000090 | Train MSE 0.0458 | Val MSE 0.0466 | Val MAE 0.8654 | Val MSE 2.2833


Epoch:  32%|███▏      | 48/150 [23:09<49:10, 28.92s/epoch]

Epoch 047 | LR 0.000085 | Train MSE 0.0450 | Val MSE 0.0462 | Val MAE 0.8677 | Val MSE 2.2651


Epoch:  33%|███▎      | 49/150 [23:39<49:05, 29.17s/epoch]

Epoch 048 | LR 0.000081 | Train MSE 0.0448 | Val MSE 0.0464 | Val MAE 0.8695 | Val MSE 2.2742


Epoch:  33%|███▎      | 50/150 [24:10<49:41, 29.82s/epoch]

Epoch 049 | LR 0.000077 | Train MSE 0.0445 | Val MSE 0.0452 | Val MAE 0.8622 | Val MSE 2.2151


Epoch:  34%|███▍      | 51/150 [24:46<52:08, 31.60s/epoch]

Epoch 050 | LR 0.000073 | Train MSE 0.0439 | Val MSE 0.0448 | Val MAE 0.8658 | Val MSE 2.1931
Sample pred first 3 steps: [[0.00551333 0.00920374]
 [0.00919386 0.00797803]
 [0.00618015 0.00748086]]
Sample target first 3 steps: [[-1.0263214e-05  5.1069769e-06]
 [-8.2099714e-06  2.6287930e-06]
 [ 1.4959690e-05 -1.9491419e-05]]
Validation improved: 0.044894 -> 0.044757


Epoch:  35%|███▍      | 52/150 [25:22<53:56, 33.03s/epoch]

Epoch 051 | LR 0.000069 | Train MSE 0.0450 | Val MSE 0.0446 | Val MAE 0.8610 | Val MSE 2.1871
Validation improved: 0.044757 -> 0.044634


Epoch:  35%|███▌      | 53/150 [25:59<55:20, 34.24s/epoch]

Epoch 052 | LR 0.000066 | Train MSE 0.0449 | Val MSE 0.0458 | Val MAE 0.8618 | Val MSE 2.2447


Epoch:  36%|███▌      | 54/150 [26:37<56:09, 35.10s/epoch]

Epoch 053 | LR 0.000063 | Train MSE 0.0448 | Val MSE 0.0452 | Val MAE 0.8640 | Val MSE 2.2140


Epoch:  37%|███▋      | 55/150 [27:08<53:55, 34.06s/epoch]

Epoch 054 | LR 0.000060 | Train MSE 0.0434 | Val MSE 0.0458 | Val MAE 0.8623 | Val MSE 2.2466


Epoch:  37%|███▋      | 56/150 [27:38<51:07, 32.63s/epoch]

Epoch 055 | LR 0.000057 | Train MSE 0.0440 | Val MSE 0.0456 | Val MAE 0.8596 | Val MSE 2.2352
Sample pred first 3 steps: [[0.0051758  0.012721  ]
 [0.00613823 0.01201835]
 [0.00637956 0.00965884]]
Sample target first 3 steps: [[-1.0263214e-05  5.1069769e-06]
 [-8.2099714e-06  2.6287930e-06]
 [ 1.4959690e-05 -1.9491419e-05]]


Epoch:  38%|███▊      | 57/150 [28:07<48:53, 31.55s/epoch]

Epoch 056 | LR 0.000054 | Train MSE 0.0439 | Val MSE 0.0462 | Val MAE 0.8665 | Val MSE 2.2643


Epoch:  39%|███▊      | 58/150 [28:36<47:27, 30.95s/epoch]

Epoch 057 | LR 0.000051 | Train MSE 0.0443 | Val MSE 0.0455 | Val MAE 0.8599 | Val MSE 2.2279


Epoch:  39%|███▉      | 59/150 [29:05<46:01, 30.35s/epoch]

Epoch 058 | LR 0.000048 | Train MSE 0.0428 | Val MSE 0.0455 | Val MAE 0.8588 | Val MSE 2.2305


Epoch:  40%|████      | 60/150 [29:34<44:45, 29.84s/epoch]

Epoch 059 | LR 0.000046 | Train MSE 0.0432 | Val MSE 0.0451 | Val MAE 0.8580 | Val MSE 2.2112


Epoch:  41%|████      | 61/150 [30:02<43:34, 29.38s/epoch]

Epoch 060 | LR 0.000044 | Train MSE 0.0433 | Val MSE 0.0453 | Val MAE 0.8564 | Val MSE 2.2211
Sample pred first 3 steps: [[0.00288346 0.00550511]
 [0.00445258 0.00484881]
 [0.00365209 0.00316707]]
Sample target first 3 steps: [[-1.0263214e-05  5.1069769e-06]
 [-8.2099714e-06  2.6287930e-06]
 [ 1.4959690e-05 -1.9491419e-05]]


Epoch:  41%|████▏     | 62/150 [30:31<42:53, 29.25s/epoch]

Epoch 061 | LR 0.000042 | Train MSE 0.0436 | Val MSE 0.0450 | Val MAE 0.8567 | Val MSE 2.2060


Epoch:  42%|████▏     | 63/150 [31:03<43:44, 30.17s/epoch]

Epoch 062 | LR 0.000039 | Train MSE 0.0442 | Val MSE 0.0445 | Val MAE 0.8555 | Val MSE 2.1817
Validation improved: 0.044634 -> 0.044525


Epoch:  43%|████▎     | 64/150 [31:36<44:22, 30.96s/epoch]

Epoch 063 | LR 0.000038 | Train MSE 0.0445 | Val MSE 0.0444 | Val MAE 0.8566 | Val MSE 2.1773
Validation improved: 0.044525 -> 0.044435


Epoch:  43%|████▎     | 65/150 [32:09<44:38, 31.51s/epoch]

Epoch 064 | LR 0.000036 | Train MSE 0.0457 | Val MSE 0.0447 | Val MAE 0.8537 | Val MSE 2.1887


Epoch:  44%|████▍     | 66/150 [32:42<44:40, 31.91s/epoch]

Epoch 065 | LR 0.000034 | Train MSE 0.0433 | Val MSE 0.0446 | Val MAE 0.8533 | Val MSE 2.1831
Sample pred first 3 steps: [[-0.00245391 -0.00049837]
 [-0.0022184   0.00185758]
 [ 0.0005749   0.00169089]]
Sample target first 3 steps: [[-1.0263214e-05  5.1069769e-06]
 [-8.2099714e-06  2.6287930e-06]
 [ 1.4959690e-05 -1.9491419e-05]]


Epoch:  45%|████▍     | 67/150 [33:15<44:47, 32.38s/epoch]

Epoch 066 | LR 0.000032 | Train MSE 0.0431 | Val MSE 0.0448 | Val MAE 0.8541 | Val MSE 2.1957


Epoch:  45%|████▌     | 68/150 [33:48<44:28, 32.54s/epoch]

Epoch 067 | LR 0.000031 | Train MSE 0.0425 | Val MSE 0.0448 | Val MAE 0.8550 | Val MSE 2.1966


Epoch:  46%|████▌     | 69/150 [34:20<43:50, 32.48s/epoch]

Epoch 068 | LR 0.000029 | Train MSE 0.0436 | Val MSE 0.0448 | Val MAE 0.8553 | Val MSE 2.1940


Epoch:  47%|████▋     | 70/150 [34:53<43:12, 32.40s/epoch]

Epoch 069 | LR 0.000028 | Train MSE 0.0428 | Val MSE 0.0445 | Val MAE 0.8525 | Val MSE 2.1815


Epoch:  47%|████▋     | 71/150 [35:25<42:32, 32.31s/epoch]

Epoch 070 | LR 0.000026 | Train MSE 0.0445 | Val MSE 0.0446 | Val MAE 0.8564 | Val MSE 2.1869
Sample pred first 3 steps: [[-0.00166019  0.00243511]
 [-0.00057432  0.00439369]
 [ 0.00103963  0.00336192]]
Sample target first 3 steps: [[-1.0263214e-05  5.1069769e-06]
 [-8.2099714e-06  2.6287930e-06]
 [ 1.4959690e-05 -1.9491419e-05]]


Epoch:  48%|████▊     | 72/150 [35:57<42:00, 32.31s/epoch]

Epoch 071 | LR 0.000025 | Train MSE 0.0436 | Val MSE 0.0445 | Val MAE 0.8555 | Val MSE 2.1783


Epoch:  49%|████▊     | 73/150 [36:30<41:31, 32.36s/epoch]

Epoch 072 | LR 0.000024 | Train MSE 0.0437 | Val MSE 0.0442 | Val MAE 0.8515 | Val MSE 2.1661
Validation improved: 0.044435 -> 0.044207


Epoch:  49%|████▉     | 74/150 [37:02<41:03, 32.41s/epoch]

Epoch 073 | LR 0.000022 | Train MSE 0.0432 | Val MSE 0.0445 | Val MAE 0.8517 | Val MSE 2.1800


Epoch:  50%|█████     | 75/150 [37:35<40:32, 32.43s/epoch]

Epoch 074 | LR 0.000021 | Train MSE 0.0428 | Val MSE 0.0444 | Val MAE 0.8519 | Val MSE 2.1773


Epoch:  51%|█████     | 76/150 [38:07<40:01, 32.45s/epoch]

Epoch 075 | LR 0.000020 | Train MSE 0.0435 | Val MSE 0.0445 | Val MAE 0.8501 | Val MSE 2.1818
Sample pred first 3 steps: [[0.00058047 0.00501341]
 [0.00072036 0.00713411]
 [0.00391781 0.00625786]]
Sample target first 3 steps: [[-1.0263214e-05  5.1069769e-06]
 [-8.2099714e-06  2.6287930e-06]
 [ 1.4959690e-05 -1.9491419e-05]]


Epoch:  51%|█████▏    | 77/150 [38:39<39:19, 32.33s/epoch]

Epoch 076 | LR 0.000019 | Train MSE 0.0429 | Val MSE 0.0444 | Val MAE 0.8510 | Val MSE 2.1774


Epoch:  52%|█████▏    | 78/150 [39:12<38:53, 32.41s/epoch]

Epoch 077 | LR 0.000018 | Train MSE 0.0432 | Val MSE 0.0441 | Val MAE 0.8491 | Val MSE 2.1630
Validation improved: 0.044207 -> 0.044143


Epoch:  53%|█████▎    | 79/150 [39:44<38:22, 32.43s/epoch]

Epoch 078 | LR 0.000017 | Train MSE 0.0431 | Val MSE 0.0444 | Val MAE 0.8505 | Val MSE 2.1767


Epoch:  53%|█████▎    | 80/150 [40:17<37:52, 32.46s/epoch]

Epoch 079 | LR 0.000017 | Train MSE 0.0429 | Val MSE 0.0446 | Val MAE 0.8551 | Val MSE 2.1878


Epoch:  54%|█████▍    | 81/150 [40:49<37:20, 32.47s/epoch]

Epoch 080 | LR 0.000016 | Train MSE 0.0423 | Val MSE 0.0446 | Val MAE 0.8537 | Val MSE 2.1847
Sample pred first 3 steps: [[-2.3291912e-05  1.6833981e-03]
 [ 3.6748010e-04  3.5775788e-03]
 [ 2.4321121e-03  2.4224645e-03]]
Sample target first 3 steps: [[-1.0263214e-05  5.1069769e-06]
 [-8.2099714e-06  2.6287930e-06]
 [ 1.4959690e-05 -1.9491419e-05]]


Epoch:  55%|█████▍    | 82/150 [41:22<36:52, 32.53s/epoch]

Epoch 081 | LR 0.000015 | Train MSE 0.0433 | Val MSE 0.0444 | Val MAE 0.8528 | Val MSE 2.1740


Epoch:  55%|█████▌    | 83/150 [41:55<36:23, 32.60s/epoch]

Epoch 082 | LR 0.000014 | Train MSE 0.0423 | Val MSE 0.0443 | Val MAE 0.8538 | Val MSE 2.1695


Epoch:  56%|█████▌    | 84/150 [42:28<36:00, 32.74s/epoch]

Epoch 083 | LR 0.000013 | Train MSE 0.0431 | Val MSE 0.0446 | Val MAE 0.8558 | Val MSE 2.1873


Epoch:  57%|█████▋    | 85/150 [43:01<35:29, 32.76s/epoch]

Epoch 084 | LR 0.000013 | Train MSE 0.0423 | Val MSE 0.0443 | Val MAE 0.8501 | Val MSE 2.1712


Epoch:  57%|█████▋    | 86/150 [43:34<35:10, 32.98s/epoch]

Epoch 085 | LR 0.000012 | Train MSE 0.0425 | Val MSE 0.0443 | Val MAE 0.8510 | Val MSE 2.1730
Sample pred first 3 steps: [[0.00137017 0.00110037]
 [0.00242188 0.00250947]
 [0.00379516 0.00209268]]
Sample target first 3 steps: [[-1.0263214e-05  5.1069769e-06]
 [-8.2099714e-06  2.6287930e-06]
 [ 1.4959690e-05 -1.9491419e-05]]


Epoch:  58%|█████▊    | 87/150 [44:07<34:41, 33.04s/epoch]

Epoch 086 | LR 0.000012 | Train MSE 0.0428 | Val MSE 0.0443 | Val MAE 0.8497 | Val MSE 2.1685


Epoch:  59%|█████▊    | 88/150 [45:11<43:32, 42.14s/epoch]

Epoch 087 | LR 0.000011 | Train MSE 0.0430 | Val MSE 0.0443 | Val MAE 0.8512 | Val MSE 2.1691


Epoch:  59%|█████▉    | 89/150 [45:42<39:26, 38.80s/epoch]

Epoch 088 | LR 0.000010 | Train MSE 0.0426 | Val MSE 0.0442 | Val MAE 0.8497 | Val MSE 2.1653


Epoch:  60%|██████    | 90/150 [56:50<3:47:35, 227.59s/epoch]

Epoch 089 | LR 0.000010 | Train MSE 0.0428 | Val MSE 0.0442 | Val MAE 0.8493 | Val MSE 2.1655


Epoch:  61%|██████    | 91/150 [57:21<2:45:50, 168.66s/epoch]

Epoch 090 | LR 0.000009 | Train MSE 0.0425 | Val MSE 0.0443 | Val MAE 0.8507 | Val MSE 2.1717
Sample pred first 3 steps: [[8.5605774e-05 1.0373497e-03]
 [1.0734636e-03 2.4974151e-03]
 [2.6500868e-03 1.8775389e-03]]
Sample target first 3 steps: [[-1.0263214e-05  5.1069769e-06]
 [-8.2099714e-06  2.6287930e-06]
 [ 1.4959690e-05 -1.9491419e-05]]


Epoch:  61%|██████▏   | 92/150 [57:52<2:03:17, 127.54s/epoch]

Epoch 091 | LR 0.000009 | Train MSE 0.0425 | Val MSE 0.0444 | Val MAE 0.8502 | Val MSE 2.1748


Epoch:  62%|██████▏   | 93/150 [58:24<1:33:52, 98.81s/epoch] 

Epoch 092 | LR 0.000008 | Train MSE 0.0418 | Val MSE 0.0443 | Val MAE 0.8500 | Val MSE 2.1719


Epoch:  63%|██████▎   | 94/150 [58:56<1:13:31, 78.77s/epoch]

Epoch 093 | LR 0.000008 | Train MSE 0.0416 | Val MSE 0.0443 | Val MAE 0.8485 | Val MSE 2.1686


Epoch:  63%|██████▎   | 95/150 [59:30<59:44, 65.18s/epoch]  

Epoch 094 | LR 0.000008 | Train MSE 0.0424 | Val MSE 0.0443 | Val MAE 0.8502 | Val MSE 2.1683


Epoch:  64%|██████▍   | 96/150 [1:00:02<49:46, 55.31s/epoch]

Epoch 095 | LR 0.000007 | Train MSE 0.0426 | Val MSE 0.0444 | Val MAE 0.8502 | Val MSE 2.1733
Sample pred first 3 steps: [[-0.00033303  0.00245556]
 [ 0.00051652  0.00374627]
 [ 0.00205963  0.00333297]]
Sample target first 3 steps: [[-1.0263214e-05  5.1069769e-06]
 [-8.2099714e-06  2.6287930e-06]
 [ 1.4959690e-05 -1.9491419e-05]]


Epoch:  65%|██████▍   | 97/150 [1:00:34<42:47, 48.45s/epoch]

Epoch 096 | LR 0.000007 | Train MSE 0.0431 | Val MSE 0.0444 | Val MAE 0.8507 | Val MSE 2.1732


Epoch:  65%|██████▍   | 97/150 [1:01:07<33:23, 37.81s/epoch]

Epoch 097 | LR 0.000007 | Train MSE 0.0426 | Val MSE 0.0442 | Val MAE 0.8497 | Val MSE 2.1670
Early stopping after 98 epochs without improvement





MoEModel(
  (expert_models): ModuleList(
    (0): SimpleLSTM(
      (lstm): LSTM(5, 512, batch_first=True)
      (fc1): Linear(in_features=512, out_features=512, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
      (relu): ReLU()
      (fc2): Linear(in_features=512, out_features=120, bias=True)
    )
    (1): VelocityModel(
      (ego_encoder): LSTM(5, 512, batch_first=True)
      (neighbor_encoder): LSTM(5, 512, batch_first=True)
      (fc1): Linear(in_features=1024, out_features=1024, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
      (relu): ReLU()
      (fc2): Linear(in_features=1024, out_features=120, bias=True)
    )
    (2): AngularAccelerationModel(
      (ego_encoder): LSTM(9, 512, batch_first=True)
      (neighbor_encoder): LSTM(9, 512, batch_first=True)
      (fc1): Linear(in_features=1024, out_features=1024, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
      (relu): ReLU()
      (fc2): Linear(in_features=1024, out_features=120, bias=True