In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np

In [None]:
class TrajectoryDataset(Dataset):
    def __init__(self, npz_file):
        data = np.load(npz_file)
        self.past = torch.tensor(data['past'], dtype=torch.float32)
        self.future = torch.tensor(data['future'], dtype=torch.float32)

    def __len__(self):
        return len(self.past)

    def __getitem__(self, idx):
        return self.past[idx], self.future[idx]

train_dataset = TrajectoryDataset("/content/drive/MyDrive/ECE271B project/val/processed_val_pit.npz")
# test_dataset = TrajectoryDataset("/content/drive/MyDrive/ECE271B project/test_obs/processed_test_pit.npz")

BATCH_SIZE = 64
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
# test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=50):
        super().__init__()
        self.pe = nn.Parameter(torch.randn(1, max_len, d_model))

    def forward(self, x):
        return x + self.pe[:, :x.shape[1], :]

In [None]:
class ProbabilisticAttentionTransformer(nn.Module):
    def __init__(self, input_dim=3, output_dim=2, hidden_dim=256, num_layers=6, num_heads=8, num_samples=10, dropout=0.2):
        super().__init__()

        self.num_samples = num_samples

        self.input_proj = nn.Linear(input_dim, hidden_dim)
        self.pos_encoding = PositionalEncoding(hidden_dim)

        self.encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=num_heads,
                                                        dim_feedforward=hidden_dim * 4, dropout=dropout, activation="gelu", batch_first=True)
        self.encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers)

        self.decoder_layer = nn.TransformerDecoderLayer(d_model=hidden_dim, nhead=num_heads,
                                                        dim_feedforward=hidden_dim * 4, dropout=dropout, activation="gelu", batch_first=True)
        self.decoder = nn.TransformerDecoder(self.decoder_layer, num_layers=num_layers)

        self.prob_attention = nn.Linear(hidden_dim, num_samples)
        self.output_proj = nn.Linear(hidden_dim, output_dim)

        self.layer_norm = nn.LayerNorm(hidden_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, past_trajectory):
        batch_size, seq_len, _ = past_trajectory.shape
        past_trajectory = self.input_proj(past_trajectory)
        past_trajectory = self.pos_encoding(past_trajectory)
        past_trajectory = self.layer_norm(past_trajectory)

        memory = self.encoder(past_trajectory)
        attn_weights = torch.softmax(self.prob_attention(memory), dim=-1)  # (batch, seq_len, num_samples)
        weighted_memory = memory.unsqueeze(2) * attn_weights.unsqueeze(-1)  # (batch, seq_len, num_samples, hidden_dim)

        future_preds = []
        for i in range(self.num_samples):
            decoded_future = self.decoder(weighted_memory[:, :, i, :], memory)
            future_preds.append(self.output_proj(decoded_future))

        return torch.stack(future_preds, dim=1), attn_weights  # (batch, num_samples, seq_len, output_dim), (batch, seq_len, num_samples)

In [None]:
def probability_matching_loss(predicted_futures, future, attn_weights):
    distances = torch.norm(predicted_futures - future.unsqueeze(1), dim=-1)  # (batch, num_samples, seq_len)
    min_distances, indices = torch.min(distances.mean(dim=2), dim=1)  # (batch,)

    true_probs = torch.zeros_like(attn_weights[:, 0, :]).scatter_(1, indices.unsqueeze(1), 1.0)  # (batch, num_samples)
    pred_probs = attn_weights.mean(dim=1)  # Average over seq_len: (batch, num_samples)

    prob_loss = torch.nn.functional.kl_div(pred_probs.log(), true_probs, reduction="batchmean")

    return torch.mean(min_distances), prob_loss

def train_model(model, train_loader, num_epochs=50, lr=0.0005, alpha=1.0, beta=1.0, gamma=0.1, delta=0.1):
    model.train()
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)

    for epoch in range(num_epochs):
        total_loss, total_minADE, total_minFDE, total_prob_loss, total_div_loss = 0, 0, 0, 0, 0

        for past, future in train_loader:
            past, future = past.to(device), future.to(device)
            optimizer.zero_grad()
            predicted_futures, attn_weights = model(past)

            # --- minADE ---
            distances = torch.norm(predicted_futures - future.unsqueeze(1), dim=-1)  # (batch, num_samples, seq_len)
            minADE_loss = torch.mean(torch.min(distances.mean(dim=2), dim=1)[0])

            # --- minFDE ---
            final_displacement = torch.norm(predicted_futures[:, :, -1, :] - future[:, -1, :].unsqueeze(1), dim=-1)
            minFDE_loss = torch.mean(torch.min(final_displacement, dim=1)[0])

            # --- Probability matching ---
            prob_match_loss, prob_loss = probability_matching_loss(predicted_futures, future, attn_weights)

            # --- Diversity ---
            diversity_loss = torch.mean(torch.norm(predicted_futures[:, 1:, :, :] - predicted_futures[:, :-1, :, :], dim=-1))

            # --- Total loss ---
            loss = alpha * minADE_loss + beta * minFDE_loss + gamma * prob_loss + delta * diversity_loss
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            total_minADE += minADE_loss.item()
            total_minFDE += minFDE_loss.item()
            total_prob_loss += prob_loss.item()
            total_div_loss += diversity_loss.item()

        print(f"Epoch [{epoch+1}/{num_epochs}] | Total Loss: {total_loss/len(train_loader):.4f} "
              f"| minADE: {total_minADE/len(train_loader):.4f} | minFDE: {total_minFDE/len(train_loader):.4f} "
              f"| Prob Loss: {total_prob_loss/len(train_loader):.4f} | Diversity: {total_div_loss/len(train_loader):.4f}")

In [None]:
def evaluate_model(model, test_loader):
    model.eval()
    total_minADE, total_minFDE = 0, 0

    with torch.no_grad():
        for past, future in test_loader:
            past, future = past.to(device), future.to(device)
            predicted_futures, attn_weights = model(past)

            distances = torch.norm(predicted_futures - future.unsqueeze(1), dim=-1)
            minADE = torch.mean(torch.min(distances.mean(dim=2), dim=1)[0])
            total_minADE += minADE.item()

            final_displacement = torch.norm(predicted_futures[:, :, -1, :] - future[:, -1, :].unsqueeze(1), dim=-1)
            minFDE = torch.mean(torch.min(final_displacement, dim=1)[0])
            total_minFDE += minFDE.item()

    print(f"Evaluation -> minADE: {total_minADE / len(test_loader):.4f}, minFDE: {total_minFDE / len(test_loader):.4f}")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ProbabilisticAttentionTransformer().to(device)

train_model(model, train_loader, num_epochs=50)
evaluate_model(model, test_loader)