# LSTM - vanilla

In [1]:
best_model = "best_model23.pt"


import os

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import random

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch_geometric.data import Data, Batch
import tqdm



train_file = np.load('data/train.npz')

train_data = train_file['data']
print("train_data's shape", train_data.shape)
test_file = np.load('data/test_input.npz')

test_data = test_file['data']
print("test_data's shape", test_data.shape)


class TrajectoryDatasetTrain(Dataset):
    def __init__(self, data, scale=10.0, augment=True):
        """
        data: Shape (N, 50, 110, 6) Training data
        scale: Scale for normalization (suggested to use 10.0 for Argoverse 2 data)
        augment: Whether to apply data augmentation (only for training)
        """
        self.data = data
        self.scale = scale
        self.augment = augment

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        scene = self.data[idx]
        # Getting 50 historical timestamps and 60 future timestamps
        hist = scene[:, :50, :].copy()    # (agents=50, time_seq=50, 6)
        future = torch.tensor(scene[0, 50:, 2:4].copy(), dtype=torch.float32)  # (60, 2)
        
        # Data augmentation(only for training)
        if self.augment:
            if np.random.rand() < 0.5:
                theta = np.random.uniform(-np.pi, np.pi)
                R = np.array([[np.cos(theta), -np.sin(theta)],
                              [np.sin(theta),  np.cos(theta)]], dtype=np.float32)
                # Rotate the historical trajectory and future trajectory
                hist[..., :2] = hist[..., :2] @ R
                hist[..., 2:4] = hist[..., 2:4] @ R
                future = future @ R
            if np.random.rand() < 0.5:
                hist[..., 0] *= -1
                hist[..., 2] *= -1
                future[:, 0] *= -1

        # Use the last timeframe of the historical trajectory as the origin
        origin = hist[0, 49, :2].copy()  # (2,)
        hist[..., :2] = hist[..., :2] - origin
        # future = future - origin

        # Normalize the historical trajectory and future trajectory
        hist[..., :4] = hist[..., :4] / self.scale
        future = future / self.scale

        data_item = Data(
            x=torch.tensor(hist, dtype=torch.float32),
            y=future.type(torch.float32),
            origin=torch.tensor(origin, dtype=torch.float32).unsqueeze(0), # (1,2)
            scale=torch.tensor(self.scale, dtype=torch.float32), # scalar e.g. 7.0
        )
        
        # print(f'x: {data_item.x.shape}')
        # print(f'y: {data_item.y.shape}')

        return data_item
    

class TrajectoryDatasetTest(Dataset):
    def __init__(self, data, scale=10.0):
        """
        data: Shape (N, 50, 110, 6) Testing data
        scale: Scale for normalization (suggested to use 10.0 for Argoverse 2 data)
        """
        self.data = data
        self.scale = scale

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # Testing data only contains historical trajectory
        scene = self.data[idx]  # (50, 50, 6)
        hist = scene.copy()
        
        origin = hist[0, 49, :2].copy()
        hist[..., :2] = hist[..., :2] - origin
        hist[..., :4] = hist[..., :4] / self.scale

        data_item = Data(
            x=torch.tensor(hist, dtype=torch.float32),
            origin=torch.tensor(origin, dtype=torch.float32).unsqueeze(0),
            scale=torch.tensor(self.scale, dtype=torch.float32),
        )
        return data_item


torch.manual_seed(251)
np.random.seed(42)

scale = 7.0

N = len(train_data)
val_size = int(0.1 * N)
train_size = N - val_size

train_dataset = TrajectoryDatasetTrain(train_data[:train_size], scale=scale, augment=True)
val_dataset = TrajectoryDatasetTrain(train_data[train_size:], scale=scale, augment=False)

train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=lambda x: Batch.from_data_list(x))
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=lambda x: Batch.from_data_list(x))

# Set device for training speedup
if torch.backends.mps.is_available():
    device = torch.device('mps')
    print("Using Apple Silicon GPU")
elif torch.cuda.is_available():
    device = torch.device('cuda')
    print("Using CUDA GPU")
else:
    device = torch.device('cpu')


class AutoRegressiveLSTM(nn.Module):
    def __init__(self, input_dim=5, hidden_dim=512, output_dim=2, num_layers=1, window_size=20, future_steps=60):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.window_size = window_size
        self.future_steps = future_steps

        # Separate encoder for ego and each neighbor agent
        self.encoder = nn.LSTM(input_size=input_dim, hidden_size=hidden_dim, num_layers=num_layers, batch_first=True)
        self.neighbor_encoder = nn.LSTM(input_size=input_dim, hidden_size=hidden_dim, num_layers=num_layers, batch_first=True)

        # Decoder
        self.decoder = nn.LSTM(input_size=2, hidden_size=hidden_dim * 2, num_layers=num_layers, batch_first=True)
        self.out = nn.Linear(hidden_dim * 2, output_dim)

    def encode_trajectory_window(self, traj_window):
        """Encode a trajectory window using the appropriate encoder"""
        return self.encoder(traj_window)

    def encode_neighbor_window(self, neighbor_window):
        """Encode a neighbor trajectory window"""
        return self.neighbor_encoder(neighbor_window)

    def get_closest_neighbor(self, x, current_step):
        """Get the closest neighbor at the current step"""
        batch_size = x.size(0)
        
        # Get positions at current step
        ego_pos = x[:, 0, current_step, :2].unsqueeze(1)  # (batch, 1, 2)
        agent_pos = x[:, :, current_step, :2]  # (batch, 50, 2)
        dists = torch.norm(agent_pos - ego_pos, dim=-1)  # (batch, 50)
        dists[:, 0] = float('inf')  # mask out ego

        # Select closest neighbor
        _, neighbor_ids = torch.topk(dists, k=1, dim=1, largest=False)  # (batch, 1)
        return neighbor_ids[:, 0]  # (batch,)

    def forward(self, data, forcing_ratio=0.5):
        x = data.x[..., :5]  # Only pos & vel
        x = x.reshape(-1, 50, 50, 5)  # (batch, agents=50, time=50, features=5)
        batch_size = x.size(0)
        device = x.device

        if self.training:
            # During training, we have access to ground truth future
            future = data.y.view(batch_size, 60, 2)
            
            # We'll predict from step 40 to step 110 (total 70 steps)
            # First 10 steps (40-49) use existing data + some predictions
            # Next 60 steps (50-109) are pure predictions
            all_outputs = []
            
            # Initialize with last window_size steps from existing data
            current_step = 50 - self.window_size  # Start from step 40
            
            for pred_step in range(70):  # Predict 70 steps total
                # Determine the window end step
                window_end = current_step + self.window_size
                
                if window_end <= 50:
                    # Use existing data for the entire window
                    ego_window = x[:, 0, current_step:window_end, :]  # (batch, window_size, 5)
                    neighbor_id = self.get_closest_neighbor(x, window_end - 1)
                    neighbor_window = torch.stack([x[b, neighbor_id[b], current_step:window_end] 
                                                 for b in range(batch_size)], dim=0)
                else:
                    # Mix existing data and predictions
                    existing_steps = max(0, 50 - current_step)
                    pred_steps = self.window_size - existing_steps
                    
                    if existing_steps > 0:
                        # Use existing data for the first part
                        ego_existing = x[:, 0, current_step:50, :]  # (batch, existing_steps, 5)
                        neighbor_id = self.get_closest_neighbor(x, 49)
                        neighbor_existing = torch.stack([x[b, neighbor_id[b], current_step:50] 
                                                       for b in range(batch_size)], dim=0)
                        
                        # For predicted part, we need to construct features from predictions
                        if pred_steps > 0 and len(all_outputs) >= pred_steps:
                            # Get recent predictions and convert to features
                            recent_preds = torch.stack(all_outputs[-pred_steps:], dim=1)  # (batch, pred_steps, 2)
                            
                            # Create velocity from position differences
                            if pred_steps > 1:
                                pred_vel = recent_preds[:, 1:] - recent_preds[:, :-1]
                                pred_vel = torch.cat([pred_vel[:, :1], pred_vel], dim=1)  # Duplicate first vel
                            else:
                                # Use last known velocity
                                last_vel = x[:, 0, 49, 2:4].unsqueeze(1)
                                pred_vel = last_vel.repeat(1, pred_steps, 1)
                            
                            # Combine position and velocity (ignore other features for now)
                            ego_pred_features = torch.cat([recent_preds, pred_vel, 
                                                         torch.zeros(batch_size, pred_steps, 1, device=device)], dim=2)
                            
                            # Assume neighbor follows similar pattern (simplified)
                            neighbor_pred_features = ego_pred_features.clone()
                            
                            # Concatenate existing and predicted features
                            ego_window = torch.cat([ego_existing, ego_pred_features], dim=1)
                            neighbor_window = torch.cat([neighbor_existing, neighbor_pred_features], dim=1)
                        else:
                            # Not enough predictions yet, pad with zeros or repeat last
                            pad_size = pred_steps
                            ego_pad = ego_existing[:, -1:].repeat(1, pad_size, 1)
                            neighbor_pad = neighbor_existing[:, -1:].repeat(1, pad_size, 1)
                            ego_window = torch.cat([ego_existing, ego_pad], dim=1)
                            neighbor_window = torch.cat([neighbor_existing, neighbor_pad], dim=1)
                    else:
                        # Pure prediction mode - use only recent predictions
                        if len(all_outputs) >= self.window_size:
                            recent_preds = torch.stack(all_outputs[-self.window_size:], dim=1)
                            
                            # Create features from predictions
                            if self.window_size > 1:
                                pred_vel = recent_preds[:, 1:] - recent_preds[:, :-1]
                                pred_vel = torch.cat([pred_vel[:, :1], pred_vel], dim=1)
                            else:
                                pred_vel = torch.zeros_like(recent_preds)
                            
                            ego_window = torch.cat([recent_preds, pred_vel, 
                                                  torch.zeros(batch_size, self.window_size, 1, device=device)], dim=2)
                            neighbor_window = ego_window.clone()
                        else:
                            # Fallback: use last available data
                            ego_window = x[:, 0, -self.window_size:, :]
                            neighbor_id = self.get_closest_neighbor(x, 49)
                            neighbor_window = torch.stack([x[b, neighbor_id[b], -self.window_size:] 
                                                         for b in range(batch_size)], dim=0)
                
                # Encode the windows
                _, (ego_hidden, ego_cell) = self.encode_trajectory_window(ego_window)
                _, (neighbor_hidden, neighbor_cell) = self.encode_neighbor_window(neighbor_window)
                
                # Concatenate hidden states
                hidden = torch.cat([ego_hidden, neighbor_hidden], dim=2)
                cell = torch.cat([ego_cell, neighbor_cell], dim=2)
                
                # Decode one step
                if pred_step == 0:
                    decoder_input = x[:, 0, 49, :2].unsqueeze(1)  # Last known position
                else:
                    decoder_input = all_outputs[-1].unsqueeze(1)
                
                output, (hidden, cell) = self.decoder(decoder_input, (hidden, cell))
                pred = self.out(output).squeeze(1)  # (batch, 2)
                
                # Teacher forcing for first 10 steps during training
                if pred_step < 10 and self.training and random.random() < forcing_ratio:
                    if pred_step < len(future):
                        pred = future[:, pred_step]
                
                all_outputs.append(pred)
                current_step += 1
            
            # Return only the 60 future steps we care about (steps 10-69 correspond to timesteps 50-109)
            outputs = torch.stack(all_outputs[10:], dim=1)  # (batch, 60, 2)
            
        else:
            # Inference mode: predict 60 steps into the future
            outputs = []
            
            # Start with the last window_size steps from input
            current_step = 50 - self.window_size
            
            for pred_step in range(self.future_steps):
                # Determine what data to use for the window
                if current_step + self.window_size <= 50:
                    # Use existing data
                    ego_window = x[:, 0, current_step:current_step + self.window_size, :]
                    neighbor_id = self.get_closest_neighbor(x, current_step + self.window_size - 1)
                    neighbor_window = torch.stack([x[b, neighbor_id[b], current_step:current_step + self.window_size] 
                                                 for b in range(batch_size)], dim=0)
                else:
                    # Mix existing data and predictions
                    existing_steps = max(0, 50 - current_step)
                    pred_steps = self.window_size - existing_steps
                    
                    if existing_steps > 0:
                        ego_existing = x[:, 0, current_step:50, :]
                        neighbor_id = self.get_closest_neighbor(x, 49)
                        neighbor_existing = torch.stack([x[b, neighbor_id[b], current_step:50] 
                                                       for b in range(batch_size)], dim=0)
                        
                        if pred_steps > 0 and len(outputs) >= pred_steps:
                            recent_preds = torch.stack(outputs[-pred_steps:], dim=1)
                            
                            if pred_steps > 1:
                                pred_vel = recent_preds[:, 1:] - recent_preds[:, :-1]
                                pred_vel = torch.cat([pred_vel[:, :1], pred_vel], dim=1)
                            else:
                                last_vel = x[:, 0, 49, 2:4].unsqueeze(1)
                                pred_vel = last_vel.repeat(1, pred_steps, 1)
                            
                            ego_pred_features = torch.cat([recent_preds, pred_vel,
                                                         torch.zeros(batch_size, pred_steps, 1, device=device)], dim=2)
                            neighbor_pred_features = ego_pred_features.clone()
                            
                            ego_window = torch.cat([ego_existing, ego_pred_features], dim=1)
                            neighbor_window = torch.cat([neighbor_existing, neighbor_pred_features], dim=1)
                        else:
                            pad_size = pred_steps
                            ego_pad = ego_existing[:, -1:].repeat(1, pad_size, 1)
                            neighbor_pad = neighbor_existing[:, -1:].repeat(1, pad_size, 1)
                            ego_window = torch.cat([ego_existing, ego_pad], dim=1)
                            neighbor_window = torch.cat([neighbor_existing, neighbor_pad], dim=1)
                    else:
                        # Pure prediction mode
                        if len(outputs) >= self.window_size:
                            recent_preds = torch.stack(outputs[-self.window_size:], dim=1)
                            
                            if self.window_size > 1:
                                pred_vel = recent_preds[:, 1:] - recent_preds[:, :-1]
                                pred_vel = torch.cat([pred_vel[:, :1], pred_vel], dim=1)
                            else:
                                pred_vel = torch.zeros_like(recent_preds)
                            
                            ego_window = torch.cat([recent_preds, pred_vel,
                                                  torch.zeros(batch_size, self.window_size, 1, device=device)], dim=2)
                            neighbor_window = ego_window.clone()
                        else:
                            ego_window = x[:, 0, -self.window_size:, :]
                            neighbor_id = self.get_closest_neighbor(x, 49)
                            neighbor_window = torch.stack([x[b, neighbor_id[b], -self.window_size:] 
                                                         for b in range(batch_size)], dim=0)
                
                # Encode
                _, (ego_hidden, ego_cell) = self.encode_trajectory_window(ego_window)
                _, (neighbor_hidden, neighbor_cell) = self.encode_neighbor_window(neighbor_window)
                
                hidden = torch.cat([ego_hidden, neighbor_hidden], dim=2)
                cell = torch.cat([ego_cell, neighbor_cell], dim=2)
                
                # Decode
                if pred_step == 0:
                    decoder_input = x[:, 0, 49, :2].unsqueeze(1)
                else:
                    decoder_input = outputs[-1].unsqueeze(1)
                
                output, (hidden, cell) = self.decoder(decoder_input, (hidden, cell))
                pred = self.out(output).squeeze(1)
                
                outputs.append(pred)
                current_step += 1
            
            outputs = torch.stack(outputs, dim=1)  # (batch, 60, 2)

        return outputs


def train_improved_model(model, train_dataloader, val_dataloader, 
                         device, criterion=nn.MSELoss(), 
                         lr=0.001, epochs=100, patience=15):
    """
    Improved training function with better debugging and early stopping
    """
    # Initialize optimizer with smaller learning rate
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    
    # Exponential decay scheduler
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95)
    
    early_stopping_patience = patience
    best_val_loss = float('inf')
    no_improvement = 0
    
    # Save initial state for comparison
    initial_state_dict = {k: v.clone() for k, v in model.state_dict().items()}
    
    for epoch in tqdm.tqdm(range(epochs), desc="Epoch", unit="epoch"):
        # ---- Training ----
        model.train()
        train_loss = 0
        num_train_batches = 0
        forcing_ratio = max(0.0, 1.0 - epoch / 50)
        
        for batch in train_dataloader:
            batch = batch.to(device)
            pred = model(batch, forcing_ratio=forcing_ratio)
            y = batch.y.view(batch.num_graphs, 60, 2)
            
            # Check for NaN predictions
            if torch.isnan(pred).any():
                print(f"WARNING: NaN detected in predictions during training")
                continue
                
            loss = criterion(pred, y)
            
            # Check if loss is valid
            if torch.isnan(loss) or torch.isinf(loss):
                print(f"WARNING: Invalid loss value: {loss.item()}")
                continue
                
            optimizer.zero_grad()
            loss.backward()
            
            # More conservative gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            
            optimizer.step()
            train_loss += loss.item()
            num_train_batches += 1
        
        # Skip epoch if no valid batches
        if num_train_batches == 0:
            print("WARNING: No valid training batches in this epoch")
            continue
            
        train_loss /= num_train_batches
        
        # ---- Validation ----
        model.eval()
        val_loss = 0
        val_mae = 0
        val_mse = 0
        num_val_batches = 0
        
        # Sample predictions for debugging
        sample_input = None
        sample_pred = None
        sample_target = None
        
        with torch.no_grad():
            for batch_idx, batch in enumerate(val_dataloader):
                batch = batch.to(device)
                pred = model(batch)
                y = batch.y.view(batch.num_graphs, 60, 2)
                
                # Store sample for debugging
                if batch_idx == 0 and sample_input is None:
                    sample_input = batch.x[0].cpu().numpy()
                    sample_pred = pred[0].cpu().numpy()
                    sample_target = y[0].cpu().numpy()
                
                # Skip invalid predictions
                if torch.isnan(pred).any():
                    print(f"WARNING: NaN detected in predictions during validation")
                    continue
                    
                batch_loss = criterion(pred, y).item()
                val_loss += batch_loss
                
                # Unnormalize for real-world metrics
                pred_unnorm = pred * batch.scale.view(-1, 1, 1)
                y_unnorm = y * batch.scale.view(-1, 1, 1)
                
                val_mae += nn.L1Loss()(pred_unnorm, y_unnorm).item()
                val_mse += nn.MSELoss()(pred_unnorm, y_unnorm).item()
                
                num_val_batches += 1
        
        # Skip epoch if no valid validation batches
        if num_val_batches == 0:
            print("WARNING: No valid validation batches in this epoch")
            continue
            
        val_loss /= num_val_batches
        val_mae /= num_val_batches
        val_mse /= num_val_batches
        
        # Update learning rate
        scheduler.step()
        
        # Print with more details
        tqdm.tqdm.write(
            f"Epoch {epoch:03d} | LR {optimizer.param_groups[0]['lr']:.6f} | "
            f"Train MSE {train_loss:.4f} | Val MSE (normalized) {val_loss:.4f} | "
            f"Val MAE (true) {val_mae:.4f} | Val MSE (true) {val_mse:.4f}"
        )
        
        # Debug output - first 3 predictions vs targets
        if epoch % 5 == 0:
            tqdm.tqdm.write(f"Sample pred first 3 steps: {sample_pred[:3]}")
            tqdm.tqdm.write(f"Sample target first 3 steps: {sample_target[:3]}")
            
            # Check if model weights are changing
            if epoch > 0:
                weight_change = False
                for name, param in model.named_parameters():
                    if param.requires_grad:
                        initial_param = initial_state_dict[name]
                        if not torch.allclose(param, initial_param, rtol=1e-4):
                            weight_change = True
                            break
                if not weight_change:
                    tqdm.tqdm.write("WARNING: Model weights barely changing!")
        
        # Relaxed improvement criterion - consider any improvement
        if val_loss < best_val_loss:
            tqdm.tqdm.write(f"Validation improved: {best_val_loss:.6f} -> {val_loss:.6f}")
            best_val_loss = val_loss
            no_improvement = 0
            torch.save(model.state_dict(), 'best_model.pth')  # Fixed variable name
        else:
            no_improvement += 1
            if no_improvement >= early_stopping_patience:
                print(f"Early stopping after {epoch+1} epochs without improvement")
                break
    
    # Load best model before returning
    model.load_state_dict(torch.load('best_model.pth'))
    return model


# Example usage
def train_and_evaluate_model(train_dataloader, val_dataloader, device):
    # Create model with sliding window
    model = AutoRegressiveLSTM(window_size=10, future_steps=60)
    model = model.to(device)
    
    # Train with improved function
    train_improved_model(
        model=model,
        train_dataloader=train_dataloader,
        val_dataloader=val_dataloader,
        device=device,
        lr=0.005,  # Lower learning rate
        patience=20,  # More patience
        epochs=150
    )
    
    # Evaluate
    model.eval()
    test_mse = 0
    with torch.no_grad():
        for batch in val_dataloader:
            batch = batch.to(device)
            pred = model(batch)
            y = batch.y.view(batch.num_graphs, 60, 2)
            
            # Unnormalize
            pred = pred * batch.scale.view(-1, 1, 1)
            y = y * batch.scale.view(-1, 1, 1)
            
            test_mse += nn.MSELoss()(pred, y).item()
    
    test_mse /= len(val_dataloader)
    print(f"Val MSE: {test_mse:.4f}")
    
    return model

train_data's shape (10000, 50, 110, 6)
test_data's shape (2100, 50, 50, 6)
Using CUDA GPU


In [None]:
train_and_evaluate_model(train_dataloader, val_dataloader, device)

Epoch:   1%|          | 1/150 [01:12<2:59:27, 72.26s/epoch]

Epoch 000 | LR 0.004750 | Train MSE 0.4824 | Val MSE (normalized) 0.7825 | Val MAE (true) 4.4064 | Val MSE (true) 38.3420
Sample pred first 3 steps: [[-0.19138068  0.14863092]
 [-0.19138564  0.14862876]
 [-0.19138607  0.14862972]]
Sample target first 3 steps: [[-1.0263214e-05  5.1069769e-06]
 [-8.2099714e-06  2.6287930e-06]
 [ 1.4959690e-05 -1.9491419e-05]]
Validation improved: inf -> 0.782490


Epoch:   1%|▏         | 2/150 [02:23<2:56:51, 71.70s/epoch]

Epoch 001 | LR 0.004513 | Train MSE 0.4165 | Val MSE (normalized) 0.6330 | Val MAE (true) 3.8867 | Val MSE (true) 31.0162
Validation improved: 0.782490 -> 0.632983


Epoch:   2%|▏         | 3/150 [03:34<2:55:15, 71.54s/epoch]

Epoch 002 | LR 0.004287 | Train MSE 0.3612 | Val MSE (normalized) 1.2159 | Val MAE (true) 6.1735 | Val MSE (true) 59.5776


Epoch:   3%|▎         | 4/150 [04:46<2:54:11, 71.59s/epoch]

Epoch 003 | LR 0.004073 | Train MSE 0.5277 | Val MSE (normalized) 0.6483 | Val MAE (true) 4.1758 | Val MSE (true) 31.7689


Epoch:   3%|▎         | 5/150 [05:58<2:53:12, 71.67s/epoch]

Epoch 004 | LR 0.003869 | Train MSE 0.4206 | Val MSE (normalized) 0.9633 | Val MAE (true) 5.4157 | Val MSE (true) 47.2020


Epoch:   4%|▍         | 6/150 [07:10<2:52:15, 71.77s/epoch]

Epoch 005 | LR 0.003675 | Train MSE 0.2690 | Val MSE (normalized) 1.0011 | Val MAE (true) 5.4034 | Val MSE (true) 49.0552
Sample pred first 3 steps: [[0.01332489 0.23051494]
 [0.01339083 0.23090136]
 [0.01337924 0.2314953 ]]
Sample target first 3 steps: [[-1.0263214e-05  5.1069769e-06]
 [-8.2099714e-06  2.6287930e-06]
 [ 1.4959690e-05 -1.9491419e-05]]


Epoch:   5%|▍         | 7/150 [08:22<2:51:29, 71.95s/epoch]

Epoch 006 | LR 0.003492 | Train MSE 0.1723 | Val MSE (normalized) 0.6304 | Val MAE (true) 3.8997 | Val MSE (true) 30.8903
Validation improved: 0.632983 -> 0.630414


Epoch:   5%|▌         | 8/150 [09:35<2:51:05, 72.29s/epoch]

Epoch 007 | LR 0.003317 | Train MSE 0.1069 | Val MSE (normalized) 0.5166 | Val MAE (true) 3.5111 | Val MSE (true) 25.3138
Validation improved: 0.630414 -> 0.516609


Epoch:   6%|▌         | 9/150 [10:48<2:50:29, 72.55s/epoch]

Epoch 008 | LR 0.003151 | Train MSE 0.1030 | Val MSE (normalized) 0.4496 | Val MAE (true) 3.3337 | Val MSE (true) 22.0297
Validation improved: 0.516609 -> 0.449586


Epoch:   7%|▋         | 10/150 [12:02<2:49:47, 72.77s/epoch]

Epoch 009 | LR 0.002994 | Train MSE 0.0925 | Val MSE (normalized) 0.4197 | Val MAE (true) 3.1299 | Val MSE (true) 20.5650
Validation improved: 0.449586 -> 0.419695


Epoch:   7%|▋         | 11/150 [13:15<2:48:53, 72.90s/epoch]

Epoch 010 | LR 0.002844 | Train MSE 0.0902 | Val MSE (normalized) 0.4267 | Val MAE (true) 3.1892 | Val MSE (true) 20.9105
Sample pred first 3 steps: [[-0.5314613  -0.06865697]
 [-0.53145885 -0.068639  ]
 [-0.5314308  -0.06864282]]
Sample target first 3 steps: [[-1.0263214e-05  5.1069769e-06]
 [-8.2099714e-06  2.6287930e-06]
 [ 1.4959690e-05 -1.9491419e-05]]


Epoch:   8%|▊         | 12/150 [14:29<2:48:19, 73.19s/epoch]

Epoch 011 | LR 0.002702 | Train MSE 0.0922 | Val MSE (normalized) 0.4187 | Val MAE (true) 3.1401 | Val MSE (true) 20.5163
Validation improved: 0.419695 -> 0.418701


Epoch:   9%|▊         | 13/150 [15:42<2:47:27, 73.34s/epoch]

Epoch 012 | LR 0.002567 | Train MSE 0.0924 | Val MSE (normalized) 0.4282 | Val MAE (true) 3.2894 | Val MSE (true) 20.9830


Epoch:   9%|▉         | 14/150 [16:56<2:46:40, 73.54s/epoch]

Epoch 013 | LR 0.002438 | Train MSE 0.0904 | Val MSE (normalized) 0.4040 | Val MAE (true) 3.1107 | Val MSE (true) 19.7956
Validation improved: 0.418701 -> 0.403993


Epoch:  10%|█         | 15/150 [18:10<2:45:53, 73.73s/epoch]

Epoch 014 | LR 0.002316 | Train MSE 0.0915 | Val MSE (normalized) 0.3725 | Val MAE (true) 2.9856 | Val MSE (true) 18.2536
Validation improved: 0.403993 -> 0.372523


Epoch:  11%|█         | 16/150 [19:25<2:44:53, 73.83s/epoch]

Epoch 015 | LR 0.002201 | Train MSE 0.0837 | Val MSE (normalized) 0.3680 | Val MAE (true) 3.0143 | Val MSE (true) 18.0342
Sample pred first 3 steps: [[0.10497974 0.02569452]
 [0.1047133  0.02604919]
 [0.10468215 0.02613487]]
Sample target first 3 steps: [[-1.0263214e-05  5.1069769e-06]
 [-8.2099714e-06  2.6287930e-06]
 [ 1.4959690e-05 -1.9491419e-05]]
Validation improved: 0.372523 -> 0.368044


Epoch:  11%|█▏        | 17/150 [20:38<2:43:42, 73.86s/epoch]

Epoch 016 | LR 0.002091 | Train MSE 0.0887 | Val MSE (normalized) 0.3566 | Val MAE (true) 2.9581 | Val MSE (true) 17.4742
Validation improved: 0.368044 -> 0.356616


Epoch:  12%|█▏        | 18/150 [21:52<2:42:34, 73.90s/epoch]

Epoch 017 | LR 0.001986 | Train MSE 0.0865 | Val MSE (normalized) 0.3608 | Val MAE (true) 2.8972 | Val MSE (true) 17.6813


Epoch:  13%|█▎        | 19/150 [23:07<2:41:42, 74.07s/epoch]

Epoch 018 | LR 0.001887 | Train MSE 0.0874 | Val MSE (normalized) 0.3679 | Val MAE (true) 2.9411 | Val MSE (true) 18.0269


Epoch:  13%|█▎        | 20/150 [24:22<2:41:04, 74.34s/epoch]

Epoch 019 | LR 0.001792 | Train MSE 0.0895 | Val MSE (normalized) 0.3568 | Val MAE (true) 2.7832 | Val MSE (true) 17.4846


Epoch:  14%|█▍        | 21/150 [25:37<2:40:18, 74.56s/epoch]

Epoch 020 | LR 0.001703 | Train MSE 0.0882 | Val MSE (normalized) 0.3045 | Val MAE (true) 2.5961 | Val MSE (true) 14.9206
Sample pred first 3 steps: [[0.17706004 0.06006136]
 [0.1769722  0.06004719]
 [0.17703271 0.06004643]]
Sample target first 3 steps: [[-1.0263214e-05  5.1069769e-06]
 [-8.2099714e-06  2.6287930e-06]
 [ 1.4959690e-05 -1.9491419e-05]]
Validation improved: 0.356616 -> 0.304501


Epoch:  15%|█▍        | 22/150 [26:52<2:39:24, 74.72s/epoch]

Epoch 021 | LR 0.001618 | Train MSE 0.0838 | Val MSE (normalized) 0.2628 | Val MAE (true) 2.5034 | Val MSE (true) 12.8791
Validation improved: 0.304501 -> 0.262839


Epoch:  15%|█▌        | 23/150 [28:08<2:38:51, 75.05s/epoch]

Epoch 022 | LR 0.001537 | Train MSE 0.0833 | Val MSE (normalized) 0.2558 | Val MAE (true) 2.5225 | Val MSE (true) 12.5352
Validation improved: 0.262839 -> 0.255819


# Final Pred

In [13]:
test_dataset = TrajectoryDatasetTest(test_data, scale=scale)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False,
                         collate_fn=lambda xs: Batch.from_data_list(xs))

best_model2 = torch.load(best_model)
model = AutoRegressiveLSTM().to(device)
# optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)
# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.25) # You can try different schedulers
# criterion = nn.MSELoss()

model.load_state_dict(best_model2)
model.eval()

pred_list = []
with torch.no_grad():
    for batch in test_loader:
        batch = batch.to(device)
        pred_vel_norm = model(batch)
        
        pred_vel = pred_vel_norm * batch.scale.view(-1,1,1) # (B, 60, 2)
        
        # Get origin in meters (position at t=49 for ego)
        # origin = batch.origin  # (B, 1, 2)
        origin = batch.origin.unsqueeze(1)  # Ensure shape is (B, 1, 2)
        
        # Integrate velocity to get position over 60 steps
        dt = 0.1  # seconds per step
        pred_pos = [origin]  # list of (B, 1, 2)
        
        for t in range(60):
            next_pos = pred_pos[-1] + pred_vel[:, t:t+1, :] * dt  # (B, 1, 2)
            pred_pos.append(next_pos)
        
        # Concatenate positions across time steps
        pred_xy = torch.cat(pred_pos[1:], dim=1)  # skip initial origin, get (B, 60, 2)

        pred_list.append(pred_xy.cpu().numpy())
        
        # pred_list.append(pred.cpu().numpy())
pred_list = np.concatenate(pred_list, axis=0)  # (N,60,2)
pred_output = pred_list.reshape(-1, 2)  # (N*60, 2)
output_df = pd.DataFrame(pred_output, columns=['x', 'y'])
output_df.index.name = 'index'
output_df.to_csv('submission_lstm_simple_auto23.csv', index=True)