In [1]:
!pip install torch_geometric

Collecting torch_geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/63.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m19.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch_geometric
Successfully installed torch_geometric-2.6.1


In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import random

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch_geometric.data import Data, Batch
import tqdm

In [4]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
train_npz = np.load('/content/drive/MyDrive/Colab Notebooks/train.npz')
train_data = train_npz['data']
test_npz  = np.load('/content/drive/MyDrive/Colab Notebooks/test_input.npz')
test_data  = test_npz['data']

In [3]:
print(train_data.shape, test_data.shape)

# Split once for later use
X_train = train_data[..., :50, :]
Y_train = train_data[:, 0, 50:, :2]

(10000, 50, 110, 6) (2100, 50, 50, 6)


In [4]:
class TrajectoryDatasetTrain(Dataset):
    def __init__(self, data, scale=10.0, augment=True):
        """
        data: Shape (N, 50, 110, 6) Training data
        scale: Scale for normalization (suggested to use 10.0 for Argoverse 2 data)
        augment: Whether to apply data augmentation (only for training)
        """
        self.data = data
        self.scale = scale
        self.augment = augment

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        scene = self.data[idx]
        # Getting 50 historical timestamps and 60 future timestamps
        hist = scene[:, :50, :].copy()    # (agents=50, time_seq=50, 6)
        future = torch.tensor(scene[0, 50:, :2].copy(), dtype=torch.float32)  # (60, 2)

        # Data augmentation(only for training)
        if self.augment:
            if np.random.rand() < 0.5:
                theta = np.random.uniform(-np.pi, np.pi)
                R = np.array([[np.cos(theta), -np.sin(theta)],
                              [np.sin(theta),  np.cos(theta)]], dtype=np.float32)
                # Rotate the historical trajectory and future trajectory
                hist[..., :2] = hist[..., :2] @ R
                hist[..., 2:4] = hist[..., 2:4] @ R
                future = future @ R
            if np.random.rand() < 0.5:
                hist[..., 0] *= -1
                hist[..., 2] *= -1
                future[:, 0] *= -1

        # Use the last timeframe of the historical trajectory as the origin
        origin = hist[0, 49, :2].copy()  # (2,)
        hist[..., :2] = hist[..., :2] - origin
        future = future - origin

        # Normalize the historical trajectory and future trajectory
        hist[..., :4] = hist[..., :4] / self.scale
        future = future / self.scale

        data_item = Data(
            x=torch.tensor(hist, dtype=torch.float32),
            y=future.type(torch.float32),
            origin=torch.tensor(origin, dtype=torch.float32).unsqueeze(0),
            scale=torch.tensor(self.scale, dtype=torch.float32),
        )

        return data_item


class TrajectoryDatasetTest(Dataset):
    def __init__(self, data, scale=10.0):
        """
        data: Shape (N, 50, 110, 6) Testing data
        scale: Scale for normalization (suggested to use 10.0 for Argoverse 2 data)
        """
        self.data = data
        self.scale = scale

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # Testing data only contains historical trajectory
        scene = self.data[idx]  # (50, 50, 6)
        hist = scene.copy()

        origin = hist[0, 49, :2].copy()
        hist[..., :2] = hist[..., :2] - origin
        hist[..., :4] = hist[..., :4] / self.scale

        data_item = Data(
            x=torch.tensor(hist, dtype=torch.float32),
            origin=torch.tensor(origin, dtype=torch.float32).unsqueeze(0),
            scale=torch.tensor(self.scale, dtype=torch.float32),
        )
        return data_item

In [36]:
# angel correct
# class TrajectoryDatasetTrain(Dataset):
#     def __init__(self, data, scale=10.0, augment=True):
#         self.data = data
#         self.scale = scale
#         self.augment = augment

#     def __len__(self):
#         return len(self.data)

#     def __getitem__(self, idx):
#         scene = self.data[idx]
#         hist = scene[:, :50, :].copy()    # (agents=50, time_seq=50, 6)
#         future = torch.tensor(scene[0, 50:, :2].copy(), dtype=torch.float32)  # (60, 2)

#         # Data augmentation(only for training)
#         if self.augment:
#             if np.random.rand() < 0.5:
#                 theta = np.random.uniform(-np.pi, np.pi)
#                 R = np.array([[np.cos(theta), -np.sin(theta)],
#                               [np.sin(theta),  np.cos(theta)]], dtype=np.float32)
#                 # Rotate the historical trajectory and future trajectory
#                 hist[..., :2] = hist[..., :2] @ R
#                 hist[..., 2:4] = hist[..., 2:4] @ R
#                 future = future @ R
#             if np.random.rand() < 0.5:
#                 hist[..., 0] *= -1
#                 hist[..., 2] *= -1
#                 future[:, 0] *= -1

#         # Use the last timeframe of the historical trajectory as the origin
#         origin = hist[0, 49, :2].copy()  # (2,)
#         hist[..., :2] = hist[..., :2] - origin
#         future = future - origin

#         # =========== 关键：将所有agent对齐ego agent的朝向 =============
#         # 使用ego最后一帧的速度方向作为朝向参考
#         ego_v = hist[0, 49, 2:4].copy()  # (vx, vy)
#         norm = np.linalg.norm(ego_v)
#         if norm > 1e-3:
#             theta = np.arctan2(ego_v[1], ego_v[0])   # ego当前朝向角（相对x轴）
#             R = np.array([
#                 [np.cos(-theta), -np.sin(-theta)],
#                 [np.sin(-theta),  np.cos(-theta)]
#             ], dtype=np.float32)
#             hist[..., :2] = hist[..., :2] @ R
#             hist[..., 2:4] = hist[..., 2:4] @ R
#             future = future @ R
#         # ============================================================

#         # Normalize the historical trajectory and future trajectory
#         hist[..., :4] = hist[..., :4] / self.scale
#         future = future / self.scale

#         data_item = Data(
#             x=torch.tensor(hist, dtype=torch.float32),
#             y=future.type(torch.float32),
#             origin=torch.tensor(origin, dtype=torch.float32).unsqueeze(0),
#             scale=torch.tensor(self.scale, dtype=torch.float32),
#         )

#         return data_item


# class TrajectoryDatasetTest(Dataset):
#     def __init__(self, data, scale=10.0):
#         """
#         data: shape (N, 50, 110, 6)
#         scale: normalization factor (default: 10.0)
#         """
#         self.data = data
#         self.scale = scale

#     def __len__(self):
#         return len(self.data)

#     def __getitem__(self, idx):
#         scene = self.data[idx]  # (50, 110, 6)
#         hist = scene[:, :50, :].copy()  # 只取前 50 帧历史，shape (50, 50, 6)

#         # 平移使 ego agent 的最后一帧作为原点
#         origin = hist[0, 49, :2].copy()
#         hist[..., :2] -= origin

#         # 对齐朝向：让 ego agent 的速度方向对齐 x 轴
#         theta = 0.0
#         ego_v = hist[0, 49, 2:4].copy()
#         norm = np.linalg.norm(ego_v)
#         if norm > 1e-3:
#             theta = np.arctan2(ego_v[1], ego_v[0])
#             R = np.array([
#                 [np.cos(-theta), -np.sin(-theta)],
#                 [np.sin(-theta),  np.cos(-theta)]
#             ], dtype=np.float32)
#             hist[..., :2] = hist[..., :2] @ R
#             hist[..., 2:4] = hist[..., 2:4] @ R

#         # 缩放归一化
#         hist[..., :4] /= self.scale

#         # 构造 PyG 图结构（单个样本 = 一整张图）
#         data_item = Data(
#             x=torch.tensor(hist, dtype=torch.float32),  # (50, 50, 6)
#             origin=torch.tensor(origin, dtype=torch.float32).unsqueeze(0),  # (1, 2)
#             scale=torch.tensor(self.scale, dtype=torch.float32),
#             theta=torch.tensor(theta, dtype=torch.float32).unsqueeze(0)  # (1,)
#         )
#         return data_item

In [5]:
torch.manual_seed(251)
np.random.seed(42)

scale = 1.0

N = len(train_data)
val_size = int(0.1 * N)
train_size = N - val_size

train_dataset = TrajectoryDatasetTrain(train_data[:train_size], scale=scale, augment=True)
val_dataset = TrajectoryDatasetTrain(train_data[train_size:], scale=scale, augment=False)

train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=lambda x: Batch.from_data_list(x))
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=lambda x: Batch.from_data_list(x))

# Set device for training speedup
if torch.backends.mps.is_available():
    device = torch.device('mps')
    print("Using Apple Silicon GPU")
elif torch.cuda.is_available():
    device = torch.device('cuda')
    print("Using CUDA GPU")
else:
    device = torch.device('cpu')

Using CUDA GPU


In [6]:
import torch

In [7]:
class LSTM(nn.Module):
    def __init__(self, input_dim=6, hidden_dim=128, output_dim=60 * 2):
        super(LSTM, self).__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, data):
        x = data.x
        x= x.reshape(-1, 50, 50, 6)  # (batch_size, num_agents, seq_len, input_dim)
        x = x[:, 0, :, :] # Only Consider ego agent index 0

        lstm_out, _ = self.lstm(x)
        # lstm_out is of shape (batch_size, seq_len, hidden_dim) and we want the last time step output
        out = self.fc(lstm_out[:, -1, :])
        return out.view(-1, 60, 2)

In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data, Batch
from torch_geometric.nn import GCNConv

class ImprovedLSTM_GCN(nn.Module):
    """
    升级版 LSTM+GCN 轨迹预测器
    """
    def __init__(self, input_dim=6, hidden_dim=128, pred_len=60, gcn_out_dim=32, dropout=0.1):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.pred_len = pred_len

        # LSTM用于ego历史序列
        self.lstm = nn.LSTM(
            input_dim,
            hidden_dim,
            num_layers=2,
            batch_first=True,
            dropout=dropout
        )

        # GCN用于全局最后一帧邻居关系
        self.gcn1 = GCNConv(input_dim, 32)
        self.gcn2 = GCNConv(32, gcn_out_dim)

        # 全连接
        self.fc = nn.Sequential(
            nn.Linear(hidden_dim + gcn_out_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, pred_len * 2)
        )

    def build_graph_batch(self, x_last, max_neighbor_dist=None):
        """
        x_last: (B, N, 6)
        Returns a PyG Batch object.
        """
        B, N, C = x_last.shape
        data_list = []
        for b in range(B):
            node_feat = x_last[b]  # (N,6)
            src, dst = [], []
            for i in range(N):
                for j in range(N):
                    if i == j:
                        continue
                    # 可选：只连距离小于阈值的邻居
                    if max_neighbor_dist is not None:
                        pos_i = node_feat[i, :2]
                        pos_j = node_feat[j, :2]
                        dist = torch.norm(pos_i - pos_j)
                        if dist > max_neighbor_dist:
                            continue
                    src.append(i)
                    dst.append(j)
            edge_index = torch.tensor([src, dst], dtype=torch.long, device=node_feat.device)
            data = Data(x=node_feat, edge_index=edge_index)
            data_list.append(data)
        batch = Batch.from_data_list(data_list)
        return batch

    def forward(self, data):
        """
        data.x: (B*A, T, D) or (B, A, T, D)
        data.num_graphs: B
        """
        x = data.x
        B = data.num_graphs

        # —— 1) 恢复成 (B, A=50, T, D) ——
        x = x.view(B, 50, 50, -1)  # (B,50,50,D)

        # —— 2) 提取 ego agent 的历史序列 ——
        ego_hist = x[:, 0, :, :]  # (B, 50, D)

        # —— 3) LSTM 编码（两层）——
        lstm_out, _ = self.lstm(ego_hist)       # (B,50,hidden_dim)
        last_hidden = lstm_out[:, -1, :]        # (B, hidden_dim)

        # —— 4) 用GCN建图对最后一帧所有agent聚合 ——
        x_last = x[:, :, -1, :6]  # (B,50,6)
        graph_batch = self.build_graph_batch(x_last)
        h = F.relu(self.gcn1(graph_batch.x, graph_batch.edge_index))
        h = self.gcn2(h, graph_batch.edge_index)
        h_nodes = h.view(B, 50, -1)  # (B, 50, gcn_out_dim)
        ego_gcn = h_nodes[:, 0, :]   # (B, gcn_out_dim)



        # —— 5) 拼接 LSTM 和 GCN ——

        fusion = torch.cat([last_hidden, ego_gcn], dim=-1)  # (B, hidden_dim+gcn_out_dim)


        # —— 6) 全连接预测 ——
        out = self.fc(fusion)                      # (B, pred_len*2)
        out = out.view(B, self.pred_len, 2)        # (B, pred_len, 2)
        return out
# ====== 测试用例 ======
if __name__ == "__main__":
    class DummyData:
        pass

    B, A, T, D = 4, 50, 50, 6
    pred_len = 60
    torch.manual_seed(123)
    dummy = DummyData()
    dummy.x = torch.randn(B*A, T, D)
    dummy.num_graphs = B

    model = ImprovedLSTM_GCN(input_dim=6, hidden_dim=128, pred_len=pred_len, gcn_out_dim=32, dropout=0.1)
    output = model(dummy)
    print(output.shape)  # (B, pred_len, 2)

torch.Size([4, 60, 2])


In [12]:
#without angel correct
model = ImprovedLSTM_GCN(
    input_dim=6,
    hidden_dim=256,
    pred_len=60,
    dropout=0.1,
).to(device)


optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.75) # You can try different schedulers
early_stopping_patience = 10
best_val_loss = float('inf')
no_improvement = 0
criterion = nn.MSELoss()

In [13]:
for epoch in tqdm.tqdm(range(100), desc="Epoch", unit="epoch"):
    # ---- Training ----
    model.train()
    train_loss = 0
    for batch in train_dataloader:
        batch = batch.to(device)
        pred = model(batch)
        y = batch.y.view(batch.num_graphs, 60, 2)
        loss = criterion(pred, y)
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
        optimizer.step()
        train_loss += loss.item()

    # ---- Validation ----
    model.eval()
    val_loss = 0
    val_mae = 0
    val_mse = 0
    with torch.no_grad():
        for batch in val_dataloader:
            batch = batch.to(device)
            pred = model(batch)
            y = batch.y.view(batch.num_graphs, 60, 2)
            val_loss += criterion(pred, y).item()

            # show MAE and MSE with unnormalized data
            pred = pred * batch.scale.view(-1, 1, 1) + batch.origin.unsqueeze(1)
            y = y * batch.scale.view(-1, 1, 1) + batch.origin.unsqueeze(1)
            val_mae += nn.L1Loss()(pred, y).item()
            val_mse += nn.MSELoss()(pred, y).item()

    train_loss /= len(train_dataloader)
    val_loss /= len(val_dataloader)
    val_mae /= len(val_dataloader)
    val_mse /= len(val_dataloader)
    scheduler.step()
    # scheduler.step(val_loss)

    tqdm.tqdm.write(f"Epoch {epoch:03d} | Learning rate {optimizer.param_groups[0]['lr']:.6f} | train normalized MSE {train_loss:8.4f} | val normalized MSE {val_loss:8.4f}, | val MAE {val_mae:8.4f} | val MSE {val_mse:8.4f}")
    if val_loss < best_val_loss - 1e-3:
        best_val_loss = val_loss
        no_improvement = 0
        torch.save(model.state_dict(), "best_model.pt")
    else:
        no_improvement += 1
        if no_improvement >= early_stopping_patience:
            print("Early stop!")
            break

  future = future @ R
  future = future - origin
Epoch:   1%|          | 1/100 [00:16<26:50, 16.27s/epoch]

Epoch 000 | Learning rate 0.001000 | train normalized MSE 170.9074 | val normalized MSE  60.5126, | val MAE   4.8404 | val MSE  60.5126


Epoch:   2%|▏         | 2/100 [00:31<25:15, 15.47s/epoch]

Epoch 001 | Learning rate 0.001000 | train normalized MSE  34.8927 | val normalized MSE  29.9572, | val MAE   3.2393 | val MSE  29.9572


Epoch:   3%|▎         | 3/100 [00:46<24:35, 15.21s/epoch]

Epoch 002 | Learning rate 0.001000 | train normalized MSE  23.4613 | val normalized MSE  16.9462, | val MAE   2.2581 | val MSE  16.9462


Epoch:   4%|▍         | 4/100 [01:01<24:11, 15.12s/epoch]

Epoch 003 | Learning rate 0.001000 | train normalized MSE  19.1280 | val normalized MSE  15.1718, | val MAE   2.0964 | val MSE  15.1718


Epoch:   5%|▌         | 5/100 [01:16<23:52, 15.08s/epoch]

Epoch 004 | Learning rate 0.001000 | train normalized MSE  18.0514 | val normalized MSE  16.4666, | val MAE   2.3713 | val MSE  16.4666


Epoch:   6%|▌         | 6/100 [01:31<23:39, 15.10s/epoch]

Epoch 005 | Learning rate 0.001000 | train normalized MSE  16.8153 | val normalized MSE  14.4466, | val MAE   2.1053 | val MSE  14.4466


Epoch:   7%|▋         | 7/100 [01:46<23:20, 15.06s/epoch]

Epoch 006 | Learning rate 0.001000 | train normalized MSE  16.1434 | val normalized MSE  13.8521, | val MAE   2.0102 | val MSE  13.8521


Epoch:   8%|▊         | 8/100 [02:01<23:02, 15.03s/epoch]

Epoch 007 | Learning rate 0.001000 | train normalized MSE  15.5910 | val normalized MSE  14.9744, | val MAE   2.1371 | val MSE  14.9744


Epoch:   9%|▉         | 9/100 [02:16<22:48, 15.03s/epoch]

Epoch 008 | Learning rate 0.001000 | train normalized MSE  15.5527 | val normalized MSE  13.0326, | val MAE   1.8985 | val MSE  13.0326


Epoch:  10%|█         | 10/100 [02:31<22:34, 15.05s/epoch]

Epoch 009 | Learning rate 0.001000 | train normalized MSE  14.8115 | val normalized MSE  13.1999, | val MAE   1.9755 | val MSE  13.1999


Epoch:  11%|█         | 11/100 [02:46<22:18, 15.04s/epoch]

Epoch 010 | Learning rate 0.001000 | train normalized MSE  14.6474 | val normalized MSE  12.6197, | val MAE   1.9484 | val MSE  12.6196


Epoch:  12%|█▏        | 12/100 [03:01<22:03, 15.04s/epoch]

Epoch 011 | Learning rate 0.001000 | train normalized MSE  14.4516 | val normalized MSE  12.1046, | val MAE   1.9231 | val MSE  12.1046


Epoch:  13%|█▎        | 13/100 [03:16<21:47, 15.02s/epoch]

Epoch 012 | Learning rate 0.001000 | train normalized MSE  14.1813 | val normalized MSE  12.0891, | val MAE   1.8531 | val MSE  12.0890


Epoch:  14%|█▍        | 14/100 [03:31<21:30, 15.01s/epoch]

Epoch 013 | Learning rate 0.001000 | train normalized MSE  14.2669 | val normalized MSE  12.4216, | val MAE   1.8634 | val MSE  12.4216


Epoch:  15%|█▌        | 15/100 [03:46<21:15, 15.01s/epoch]

Epoch 014 | Learning rate 0.001000 | train normalized MSE  13.6424 | val normalized MSE  11.4618, | val MAE   1.7979 | val MSE  11.4618


Epoch:  16%|█▌        | 16/100 [04:01<21:01, 15.02s/epoch]

Epoch 015 | Learning rate 0.001000 | train normalized MSE  13.4688 | val normalized MSE  12.2315, | val MAE   1.8633 | val MSE  12.2315


Epoch:  17%|█▋        | 17/100 [04:16<20:46, 15.02s/epoch]

Epoch 016 | Learning rate 0.001000 | train normalized MSE  13.3648 | val normalized MSE  11.6649, | val MAE   1.8798 | val MSE  11.6649


Epoch:  18%|█▊        | 18/100 [04:31<20:30, 15.01s/epoch]

Epoch 017 | Learning rate 0.001000 | train normalized MSE  13.3631 | val normalized MSE  10.6332, | val MAE   1.7143 | val MSE  10.6332


Epoch:  19%|█▉        | 19/100 [04:46<20:15, 15.00s/epoch]

Epoch 018 | Learning rate 0.001000 | train normalized MSE  13.0165 | val normalized MSE  12.4433, | val MAE   1.9511 | val MSE  12.4433


Epoch:  20%|██        | 20/100 [05:01<19:59, 15.00s/epoch]

Epoch 019 | Learning rate 0.000750 | train normalized MSE  13.0456 | val normalized MSE  10.8398, | val MAE   1.7412 | val MSE  10.8398


Epoch:  21%|██        | 21/100 [05:16<19:44, 14.99s/epoch]

Epoch 020 | Learning rate 0.000750 | train normalized MSE  12.4311 | val normalized MSE  10.6840, | val MAE   1.7316 | val MSE  10.6840


Epoch:  22%|██▏       | 22/100 [05:31<19:29, 15.00s/epoch]

Epoch 021 | Learning rate 0.000750 | train normalized MSE  12.0708 | val normalized MSE  10.2082, | val MAE   1.6575 | val MSE  10.2082


Epoch:  23%|██▎       | 23/100 [05:46<19:12, 14.97s/epoch]

Epoch 022 | Learning rate 0.000750 | train normalized MSE  11.9364 | val normalized MSE  10.1090, | val MAE   1.6624 | val MSE  10.1090


Epoch:  24%|██▍       | 24/100 [06:01<18:56, 14.95s/epoch]

Epoch 023 | Learning rate 0.000750 | train normalized MSE  11.7786 | val normalized MSE  10.5625, | val MAE   1.7262 | val MSE  10.5625


Epoch:  25%|██▌       | 25/100 [06:16<18:41, 14.95s/epoch]

Epoch 024 | Learning rate 0.000750 | train normalized MSE  11.5391 | val normalized MSE  10.3477, | val MAE   1.7224 | val MSE  10.3477


Epoch:  26%|██▌       | 26/100 [06:31<18:30, 15.00s/epoch]

Epoch 025 | Learning rate 0.000750 | train normalized MSE  11.4640 | val normalized MSE  10.1881, | val MAE   1.6306 | val MSE  10.1881


Epoch:  27%|██▋       | 27/100 [06:46<18:14, 14.99s/epoch]

Epoch 026 | Learning rate 0.000750 | train normalized MSE  11.4005 | val normalized MSE  10.3097, | val MAE   1.6454 | val MSE  10.3097


Epoch:  28%|██▊       | 28/100 [07:01<18:00, 15.01s/epoch]

Epoch 027 | Learning rate 0.000750 | train normalized MSE  11.2130 | val normalized MSE  10.4064, | val MAE   1.6578 | val MSE  10.4064


Epoch:  29%|██▉       | 29/100 [07:16<17:45, 15.00s/epoch]

Epoch 028 | Learning rate 0.000750 | train normalized MSE  11.2589 | val normalized MSE  10.1848, | val MAE   1.6811 | val MSE  10.1848


Epoch:  30%|███       | 30/100 [07:31<17:29, 14.99s/epoch]

Epoch 029 | Learning rate 0.000750 | train normalized MSE  11.2187 | val normalized MSE   9.8738, | val MAE   1.5446 | val MSE   9.8738


Epoch:  31%|███       | 31/100 [07:46<17:14, 14.99s/epoch]

Epoch 030 | Learning rate 0.000750 | train normalized MSE  11.2597 | val normalized MSE   9.4084, | val MAE   1.5248 | val MSE   9.4084


Epoch:  32%|███▏      | 32/100 [08:01<16:58, 14.98s/epoch]

Epoch 031 | Learning rate 0.000750 | train normalized MSE  11.0842 | val normalized MSE   9.7266, | val MAE   1.5723 | val MSE   9.7266


Epoch:  33%|███▎      | 33/100 [08:16<16:45, 15.01s/epoch]

Epoch 032 | Learning rate 0.000750 | train normalized MSE  10.7069 | val normalized MSE   9.3960, | val MAE   1.4877 | val MSE   9.3959


Epoch:  34%|███▍      | 34/100 [08:31<16:29, 15.00s/epoch]

Epoch 033 | Learning rate 0.000750 | train normalized MSE  10.9230 | val normalized MSE  10.4352, | val MAE   1.6280 | val MSE  10.4352


Epoch:  35%|███▌      | 35/100 [08:46<16:14, 15.00s/epoch]

Epoch 034 | Learning rate 0.000750 | train normalized MSE  10.8443 | val normalized MSE   9.2376, | val MAE   1.5971 | val MSE   9.2376


Epoch:  36%|███▌      | 36/100 [09:01<16:00, 15.01s/epoch]

Epoch 035 | Learning rate 0.000750 | train normalized MSE  10.8814 | val normalized MSE   9.5558, | val MAE   1.6005 | val MSE   9.5558


Epoch:  37%|███▋      | 37/100 [09:16<15:46, 15.02s/epoch]

Epoch 036 | Learning rate 0.000750 | train normalized MSE  10.8347 | val normalized MSE   9.2593, | val MAE   1.5209 | val MSE   9.2593


Epoch:  38%|███▊      | 38/100 [09:31<15:30, 15.00s/epoch]

Epoch 037 | Learning rate 0.000750 | train normalized MSE  10.7054 | val normalized MSE   9.4171, | val MAE   1.5361 | val MSE   9.4171


Epoch:  39%|███▉      | 39/100 [09:46<15:14, 14.99s/epoch]

Epoch 038 | Learning rate 0.000750 | train normalized MSE  10.5242 | val normalized MSE  10.4179, | val MAE   1.6984 | val MSE  10.4179


Epoch:  40%|████      | 40/100 [10:01<14:58, 14.98s/epoch]

Epoch 039 | Learning rate 0.000563 | train normalized MSE  10.6550 | val normalized MSE   9.7334, | val MAE   1.5870 | val MSE   9.7334


Epoch:  41%|████      | 41/100 [10:16<14:44, 14.99s/epoch]

Epoch 040 | Learning rate 0.000563 | train normalized MSE  10.1452 | val normalized MSE   9.2418, | val MAE   1.5144 | val MSE   9.2418


Epoch:  42%|████▏     | 42/100 [10:31<14:30, 15.01s/epoch]

Epoch 041 | Learning rate 0.000563 | train normalized MSE  10.0213 | val normalized MSE   9.4892, | val MAE   1.5183 | val MSE   9.4892


Epoch:  43%|████▎     | 43/100 [10:46<14:15, 15.01s/epoch]

Epoch 042 | Learning rate 0.000563 | train normalized MSE  10.0847 | val normalized MSE   9.1077, | val MAE   1.4978 | val MSE   9.1077


Epoch:  44%|████▍     | 44/100 [11:01<14:00, 15.01s/epoch]

Epoch 043 | Learning rate 0.000563 | train normalized MSE   9.9737 | val normalized MSE   9.3561, | val MAE   1.5053 | val MSE   9.3561


Epoch:  45%|████▌     | 45/100 [11:16<13:45, 15.01s/epoch]

Epoch 044 | Learning rate 0.000563 | train normalized MSE   9.8341 | val normalized MSE   9.1630, | val MAE   1.4927 | val MSE   9.1630


Epoch:  46%|████▌     | 46/100 [11:31<13:31, 15.02s/epoch]

Epoch 045 | Learning rate 0.000563 | train normalized MSE   9.8983 | val normalized MSE   9.2921, | val MAE   1.5119 | val MSE   9.2921


Epoch:  47%|████▋     | 47/100 [11:46<13:16, 15.03s/epoch]

Epoch 046 | Learning rate 0.000563 | train normalized MSE   9.8338 | val normalized MSE   8.9558, | val MAE   1.4975 | val MSE   8.9558


Epoch:  48%|████▊     | 48/100 [12:01<13:00, 15.01s/epoch]

Epoch 047 | Learning rate 0.000563 | train normalized MSE   9.7293 | val normalized MSE   8.5051, | val MAE   1.4246 | val MSE   8.5051


Epoch:  49%|████▉     | 49/100 [12:16<12:45, 15.00s/epoch]

Epoch 048 | Learning rate 0.000563 | train normalized MSE   9.8451 | val normalized MSE   9.5873, | val MAE   1.5271 | val MSE   9.5873


Epoch:  50%|█████     | 50/100 [12:31<12:29, 14.99s/epoch]

Epoch 049 | Learning rate 0.000563 | train normalized MSE   9.7873 | val normalized MSE   9.3437, | val MAE   1.5185 | val MSE   9.3436


Epoch:  51%|█████     | 51/100 [12:46<12:15, 15.01s/epoch]

Epoch 050 | Learning rate 0.000563 | train normalized MSE   9.5801 | val normalized MSE   8.9445, | val MAE   1.4424 | val MSE   8.9445


Epoch:  52%|█████▏    | 52/100 [13:01<12:06, 15.13s/epoch]

Epoch 051 | Learning rate 0.000563 | train normalized MSE   9.7773 | val normalized MSE   8.4136, | val MAE   1.4553 | val MSE   8.4136


Epoch:  53%|█████▎    | 53/100 [13:17<11:55, 15.22s/epoch]

Epoch 052 | Learning rate 0.000563 | train normalized MSE   9.5489 | val normalized MSE   8.9405, | val MAE   1.5242 | val MSE   8.9405


Epoch:  54%|█████▍    | 54/100 [13:32<11:42, 15.27s/epoch]

Epoch 053 | Learning rate 0.000563 | train normalized MSE   9.6607 | val normalized MSE   8.3759, | val MAE   1.4144 | val MSE   8.3759


Epoch:  55%|█████▌    | 55/100 [13:47<11:29, 15.32s/epoch]

Epoch 054 | Learning rate 0.000563 | train normalized MSE   9.4085 | val normalized MSE   8.5277, | val MAE   1.3821 | val MSE   8.5277


Epoch:  56%|█████▌    | 56/100 [14:03<11:15, 15.35s/epoch]

Epoch 055 | Learning rate 0.000563 | train normalized MSE   9.5519 | val normalized MSE   8.7604, | val MAE   1.4882 | val MSE   8.7604


Epoch:  57%|█████▋    | 57/100 [14:18<11:00, 15.36s/epoch]

Epoch 056 | Learning rate 0.000563 | train normalized MSE   9.4420 | val normalized MSE   8.6859, | val MAE   1.4193 | val MSE   8.6859


Epoch:  58%|█████▊    | 58/100 [14:34<10:47, 15.42s/epoch]

Epoch 057 | Learning rate 0.000563 | train normalized MSE   9.5504 | val normalized MSE   8.7400, | val MAE   1.4047 | val MSE   8.7400


Epoch:  59%|█████▉    | 59/100 [14:49<10:32, 15.44s/epoch]

Epoch 058 | Learning rate 0.000563 | train normalized MSE   9.5636 | val normalized MSE   8.8364, | val MAE   1.4556 | val MSE   8.8364


Epoch:  60%|██████    | 60/100 [15:05<10:17, 15.45s/epoch]

Epoch 059 | Learning rate 0.000422 | train normalized MSE   9.4348 | val normalized MSE   8.6086, | val MAE   1.4622 | val MSE   8.6086


Epoch:  61%|██████    | 61/100 [15:20<10:03, 15.47s/epoch]

Epoch 060 | Learning rate 0.000422 | train normalized MSE   9.2674 | val normalized MSE   8.7434, | val MAE   1.4147 | val MSE   8.7434


Epoch:  62%|██████▏   | 62/100 [15:36<09:48, 15.48s/epoch]

Epoch 061 | Learning rate 0.000422 | train normalized MSE   8.9580 | val normalized MSE   8.9165, | val MAE   1.4721 | val MSE   8.9165


Epoch:  63%|██████▎   | 63/100 [15:51<09:32, 15.46s/epoch]

Epoch 062 | Learning rate 0.000422 | train normalized MSE   9.0880 | val normalized MSE   8.5480, | val MAE   1.3855 | val MSE   8.5480


Epoch:  63%|██████▎   | 63/100 [16:07<09:28, 15.35s/epoch]

Epoch 063 | Learning rate 0.000422 | train normalized MSE   9.1467 | val normalized MSE   8.7322, | val MAE   1.4184 | val MSE   8.7322
Early stop!





In [14]:
test_dataset = TrajectoryDatasetTest(test_data, scale=scale)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False,
                         collate_fn=lambda xs: Batch.from_data_list(xs))

best_model = torch.load("best_model.pt")
model = ImprovedLSTM_GCN(
    input_dim=6,
    hidden_dim=256,
    pred_len=60,
    dropout=0.1,
).to(device)

model.load_state_dict(best_model)
model.eval()

pred_list = []
with torch.no_grad():
    for batch in test_loader:
        batch = batch.to(device)
        pred_norm = model(batch)

        # Reshape the prediction to (N, 60, 2)
        pred = pred_norm * batch.scale.view(-1,1,1) + batch.origin.unsqueeze(1)
        pred_list.append(pred.cpu().numpy())
pred_list = np.concatenate(pred_list, axis=0)  # (N,60,2)
pred_output = pred_list.reshape(-1, 2)  # (N*60, 2)
output_df = pd.DataFrame(pred_output, columns=['x', 'y'])
output_df.index.name = 'index'
output_df.to_csv('submission.csv', index=True)

In [24]:
#with angel correct
model = ImprovedLSTM_GCN(
    input_dim=6,
    hidden_dim=256,
    pred_len=60,
    dropout=0.1,
).to(device)


optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.75) # You can try different schedulers
early_stopping_patience = 10
best_val_loss = float('inf')
no_improvement = 0
criterion = nn.MSELoss()

In [25]:
for epoch in tqdm.tqdm(range(100), desc="Epoch", unit="epoch"):
    # ---- Training ----
    model.train()
    train_loss = 0
    for batch in train_dataloader:
        batch = batch.to(device)
        pred = model(batch)
        y = batch.y.view(batch.num_graphs, 60, 2)
        loss = criterion(pred, y)
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
        optimizer.step()
        train_loss += loss.item()

    # ---- Validation ----
    model.eval()
    val_loss = 0
    val_mae = 0
    val_mse = 0
    with torch.no_grad():
        for batch in val_dataloader:
            batch = batch.to(device)
            pred = model(batch)
            y = batch.y.view(batch.num_graphs, 60, 2)
            val_loss += criterion(pred, y).item()

            # show MAE and MSE with unnormalized data
            pred = pred * batch.scale.view(-1, 1, 1) + batch.origin.unsqueeze(1)
            y = y * batch.scale.view(-1, 1, 1) + batch.origin.unsqueeze(1)
            val_mae += nn.L1Loss()(pred, y).item()
            val_mse += nn.MSELoss()(pred, y).item()

    train_loss /= len(train_dataloader)
    val_loss /= len(val_dataloader)
    val_mae /= len(val_dataloader)
    val_mse /= len(val_dataloader)
    scheduler.step()
    # scheduler.step(val_loss)

    tqdm.tqdm.write(f"Epoch {epoch:03d} | Learning rate {optimizer.param_groups[0]['lr']:.6f} | train normalized MSE {train_loss:8.4f} | val normalized MSE {val_loss:8.4f}, | val MAE {val_mae:8.4f} | val MSE {val_mse:8.4f}")
    if val_loss < best_val_loss - 1e-3:
        best_val_loss = val_loss
        no_improvement = 0
        torch.save(model.state_dict(), "best_model.pt")
    else:
        no_improvement += 1
        if no_improvement >= early_stopping_patience:
            print("Early stop!")
            break

  future = future @ R
  future = future - origin
  future = future @ R
Epoch:   1%|          | 1/100 [00:17<28:40, 17.37s/epoch]

Epoch 000 | Learning rate 0.001000 | train normalized MSE 182.8710 | val normalized MSE  38.1784, | val MAE   3.7449 | val MSE  38.1784


Epoch:   2%|▏         | 2/100 [00:34<28:14, 17.29s/epoch]

Epoch 001 | Learning rate 0.001000 | train normalized MSE  34.9320 | val normalized MSE  24.8856, | val MAE   2.9198 | val MSE  24.8856


Epoch:   3%|▎         | 3/100 [00:51<27:46, 17.18s/epoch]

Epoch 002 | Learning rate 0.001000 | train normalized MSE  23.2602 | val normalized MSE  16.6670, | val MAE   2.2919 | val MSE  16.6670


Epoch:   4%|▍         | 4/100 [01:08<27:25, 17.14s/epoch]

Epoch 003 | Learning rate 0.001000 | train normalized MSE  19.3751 | val normalized MSE  19.4913, | val MAE   2.5829 | val MSE  19.4913


Epoch:   5%|▌         | 5/100 [01:25<27:05, 17.11s/epoch]

Epoch 004 | Learning rate 0.001000 | train normalized MSE  17.9221 | val normalized MSE  15.2310, | val MAE   2.1510 | val MSE  15.2310


Epoch:   6%|▌         | 6/100 [01:42<26:49, 17.12s/epoch]

Epoch 005 | Learning rate 0.001000 | train normalized MSE  16.5660 | val normalized MSE  14.7380, | val MAE   2.1193 | val MSE  14.7380


Epoch:   7%|▋         | 7/100 [02:00<26:34, 17.14s/epoch]

Epoch 006 | Learning rate 0.001000 | train normalized MSE  15.7167 | val normalized MSE  12.2206, | val MAE   1.8995 | val MSE  12.2206


Epoch:   8%|▊         | 8/100 [02:17<26:18, 17.16s/epoch]

Epoch 007 | Learning rate 0.001000 | train normalized MSE  15.5106 | val normalized MSE  13.4359, | val MAE   2.0332 | val MSE  13.4359


Epoch:   9%|▉         | 9/100 [02:34<26:04, 17.19s/epoch]

Epoch 008 | Learning rate 0.001000 | train normalized MSE  14.7068 | val normalized MSE  12.8026, | val MAE   1.8905 | val MSE  12.8027


Epoch:  10%|█         | 10/100 [02:51<25:45, 17.17s/epoch]

Epoch 009 | Learning rate 0.001000 | train normalized MSE  14.2900 | val normalized MSE  13.4104, | val MAE   2.0179 | val MSE  13.4104


Epoch:  11%|█         | 11/100 [03:08<25:27, 17.17s/epoch]

Epoch 010 | Learning rate 0.001000 | train normalized MSE  14.0720 | val normalized MSE  11.8562, | val MAE   1.8229 | val MSE  11.8562


Epoch:  12%|█▏        | 12/100 [03:26<25:17, 17.24s/epoch]

Epoch 011 | Learning rate 0.001000 | train normalized MSE  14.0107 | val normalized MSE  11.9436, | val MAE   1.8596 | val MSE  11.9436


Epoch:  13%|█▎        | 13/100 [03:43<24:56, 17.20s/epoch]

Epoch 012 | Learning rate 0.001000 | train normalized MSE  13.2205 | val normalized MSE  11.9203, | val MAE   1.8244 | val MSE  11.9203


Epoch:  14%|█▍        | 14/100 [04:00<24:36, 17.17s/epoch]

Epoch 013 | Learning rate 0.001000 | train normalized MSE  13.5133 | val normalized MSE  12.2319, | val MAE   1.8617 | val MSE  12.2319


Epoch:  15%|█▌        | 15/100 [04:17<24:18, 17.16s/epoch]

Epoch 014 | Learning rate 0.001000 | train normalized MSE  13.4192 | val normalized MSE  11.5753, | val MAE   1.8411 | val MSE  11.5753


Epoch:  16%|█▌        | 16/100 [04:34<23:59, 17.14s/epoch]

Epoch 015 | Learning rate 0.001000 | train normalized MSE  13.2000 | val normalized MSE  11.3134, | val MAE   1.7597 | val MSE  11.3134


Epoch:  17%|█▋        | 17/100 [04:51<23:40, 17.11s/epoch]

Epoch 016 | Learning rate 0.001000 | train normalized MSE  13.2409 | val normalized MSE  12.0366, | val MAE   1.8395 | val MSE  12.0366


Epoch:  18%|█▊        | 18/100 [05:08<23:20, 17.08s/epoch]

Epoch 017 | Learning rate 0.001000 | train normalized MSE  12.9820 | val normalized MSE  11.5420, | val MAE   1.8211 | val MSE  11.5420


Epoch:  19%|█▉        | 19/100 [05:25<22:59, 17.03s/epoch]

Epoch 018 | Learning rate 0.001000 | train normalized MSE  12.8233 | val normalized MSE  11.1044, | val MAE   1.7774 | val MSE  11.1044


Epoch:  20%|██        | 20/100 [05:42<22:43, 17.04s/epoch]

Epoch 019 | Learning rate 0.000750 | train normalized MSE  12.7748 | val normalized MSE  10.8319, | val MAE   1.7220 | val MSE  10.8319


Epoch:  21%|██        | 21/100 [05:59<22:27, 17.05s/epoch]

Epoch 020 | Learning rate 0.000750 | train normalized MSE  12.0260 | val normalized MSE  10.5695, | val MAE   1.6728 | val MSE  10.5695


Epoch:  22%|██▏       | 22/100 [06:16<22:09, 17.05s/epoch]

Epoch 021 | Learning rate 0.000750 | train normalized MSE  11.7242 | val normalized MSE  11.1606, | val MAE   1.7460 | val MSE  11.1606


Epoch:  23%|██▎       | 23/100 [06:33<21:52, 17.04s/epoch]

Epoch 022 | Learning rate 0.000750 | train normalized MSE  11.5792 | val normalized MSE   9.9480, | val MAE   1.6285 | val MSE   9.9480


Epoch:  24%|██▍       | 24/100 [06:50<21:32, 17.00s/epoch]

Epoch 023 | Learning rate 0.000750 | train normalized MSE  11.4791 | val normalized MSE   9.9005, | val MAE   1.6317 | val MSE   9.9005


Epoch:  25%|██▌       | 25/100 [07:07<21:13, 16.98s/epoch]

Epoch 024 | Learning rate 0.000750 | train normalized MSE  11.4751 | val normalized MSE  10.1124, | val MAE   1.6898 | val MSE  10.1124


Epoch:  26%|██▌       | 26/100 [07:24<20:57, 17.00s/epoch]

Epoch 025 | Learning rate 0.000750 | train normalized MSE  11.2980 | val normalized MSE  10.3370, | val MAE   1.7260 | val MSE  10.3370


Epoch:  27%|██▋       | 27/100 [07:41<20:41, 17.01s/epoch]

Epoch 026 | Learning rate 0.000750 | train normalized MSE  11.2651 | val normalized MSE   9.5312, | val MAE   1.5285 | val MSE   9.5312


Epoch:  28%|██▊       | 28/100 [07:59<20:29, 17.08s/epoch]

Epoch 027 | Learning rate 0.000750 | train normalized MSE  10.9755 | val normalized MSE  10.2373, | val MAE   1.6582 | val MSE  10.2373


Epoch:  29%|██▉       | 29/100 [08:16<20:19, 17.18s/epoch]

Epoch 028 | Learning rate 0.000750 | train normalized MSE  11.0271 | val normalized MSE   9.7339, | val MAE   1.5614 | val MSE   9.7339


Epoch:  30%|███       | 30/100 [08:33<20:07, 17.25s/epoch]

Epoch 029 | Learning rate 0.000750 | train normalized MSE  11.0009 | val normalized MSE   9.9107, | val MAE   1.6125 | val MSE   9.9107


Epoch:  31%|███       | 31/100 [08:51<19:49, 17.24s/epoch]

Epoch 030 | Learning rate 0.000750 | train normalized MSE  11.0054 | val normalized MSE   9.4817, | val MAE   1.5501 | val MSE   9.4817


Epoch:  32%|███▏      | 32/100 [09:08<19:32, 17.25s/epoch]

Epoch 031 | Learning rate 0.000750 | train normalized MSE  10.9216 | val normalized MSE   9.9067, | val MAE   1.6234 | val MSE   9.9067


Epoch:  33%|███▎      | 33/100 [09:25<19:14, 17.23s/epoch]

Epoch 032 | Learning rate 0.000750 | train normalized MSE  10.6247 | val normalized MSE   9.2281, | val MAE   1.5258 | val MSE   9.2281


Epoch:  34%|███▍      | 34/100 [09:42<18:56, 17.22s/epoch]

Epoch 033 | Learning rate 0.000750 | train normalized MSE  10.6109 | val normalized MSE   9.6692, | val MAE   1.5476 | val MSE   9.6692


Epoch:  35%|███▌      | 35/100 [09:59<18:36, 17.18s/epoch]

Epoch 034 | Learning rate 0.000750 | train normalized MSE  10.5302 | val normalized MSE   9.0754, | val MAE   1.5479 | val MSE   9.0754


Epoch:  36%|███▌      | 36/100 [10:16<18:16, 17.14s/epoch]

Epoch 035 | Learning rate 0.000750 | train normalized MSE  10.7197 | val normalized MSE   9.7244, | val MAE   1.5958 | val MSE   9.7244


Epoch:  37%|███▋      | 37/100 [10:33<17:55, 17.08s/epoch]

Epoch 036 | Learning rate 0.000750 | train normalized MSE  10.5275 | val normalized MSE   9.2475, | val MAE   1.5523 | val MSE   9.2475


Epoch:  38%|███▊      | 38/100 [10:50<17:37, 17.05s/epoch]

Epoch 037 | Learning rate 0.000750 | train normalized MSE  10.5837 | val normalized MSE   9.3655, | val MAE   1.5177 | val MSE   9.3655


Epoch:  39%|███▉      | 39/100 [11:07<17:21, 17.08s/epoch]

Epoch 038 | Learning rate 0.000750 | train normalized MSE  10.5441 | val normalized MSE   9.2623, | val MAE   1.4977 | val MSE   9.2623


Epoch:  40%|████      | 40/100 [11:25<17:05, 17.10s/epoch]

Epoch 039 | Learning rate 0.000563 | train normalized MSE  10.4636 | val normalized MSE   9.2318, | val MAE   1.5028 | val MSE   9.2318


Epoch:  41%|████      | 41/100 [11:42<16:49, 17.10s/epoch]

Epoch 040 | Learning rate 0.000563 | train normalized MSE  10.0927 | val normalized MSE   9.0070, | val MAE   1.5038 | val MSE   9.0070


Epoch:  42%|████▏     | 42/100 [11:59<16:34, 17.15s/epoch]

Epoch 041 | Learning rate 0.000563 | train normalized MSE   9.8332 | val normalized MSE   8.8749, | val MAE   1.4341 | val MSE   8.8749


Epoch:  43%|████▎     | 43/100 [12:16<16:18, 17.17s/epoch]

Epoch 042 | Learning rate 0.000563 | train normalized MSE   9.9898 | val normalized MSE   8.9119, | val MAE   1.4717 | val MSE   8.9119


Epoch:  44%|████▍     | 44/100 [12:33<15:59, 17.14s/epoch]

Epoch 043 | Learning rate 0.000563 | train normalized MSE   9.9548 | val normalized MSE   8.9315, | val MAE   1.4566 | val MSE   8.9315


Epoch:  45%|████▌     | 45/100 [12:50<15:37, 17.04s/epoch]

Epoch 044 | Learning rate 0.000563 | train normalized MSE   9.8419 | val normalized MSE   9.0210, | val MAE   1.4823 | val MSE   9.0210


Epoch:  46%|████▌     | 46/100 [13:07<15:20, 17.05s/epoch]

Epoch 045 | Learning rate 0.000563 | train normalized MSE   9.7509 | val normalized MSE   8.8278, | val MAE   1.4343 | val MSE   8.8278


Epoch:  47%|████▋     | 47/100 [13:24<15:03, 17.06s/epoch]

Epoch 046 | Learning rate 0.000563 | train normalized MSE   9.7780 | val normalized MSE   8.9280, | val MAE   1.5071 | val MSE   8.9280


Epoch:  48%|████▊     | 48/100 [13:41<14:47, 17.07s/epoch]

Epoch 047 | Learning rate 0.000563 | train normalized MSE   9.7128 | val normalized MSE   8.9028, | val MAE   1.4519 | val MSE   8.9028


Epoch:  49%|████▉     | 49/100 [13:58<14:32, 17.10s/epoch]

Epoch 048 | Learning rate 0.000563 | train normalized MSE   9.6625 | val normalized MSE   9.2245, | val MAE   1.4704 | val MSE   9.2245


Epoch:  50%|█████     | 50/100 [14:16<14:14, 17.10s/epoch]

Epoch 049 | Learning rate 0.000563 | train normalized MSE   9.6933 | val normalized MSE   8.8616, | val MAE   1.4548 | val MSE   8.8616


Epoch:  51%|█████     | 51/100 [14:33<13:57, 17.09s/epoch]

Epoch 050 | Learning rate 0.000563 | train normalized MSE   9.6765 | val normalized MSE   8.7384, | val MAE   1.4424 | val MSE   8.7384


Epoch:  52%|█████▏    | 52/100 [14:50<13:39, 17.08s/epoch]

Epoch 051 | Learning rate 0.000563 | train normalized MSE   9.5544 | val normalized MSE   8.6926, | val MAE   1.4610 | val MSE   8.6926


Epoch:  53%|█████▎    | 53/100 [15:07<13:20, 17.03s/epoch]

Epoch 052 | Learning rate 0.000563 | train normalized MSE   9.4537 | val normalized MSE   9.0789, | val MAE   1.4746 | val MSE   9.0789


Epoch:  54%|█████▍    | 54/100 [15:24<13:02, 17.01s/epoch]

Epoch 053 | Learning rate 0.000563 | train normalized MSE   9.5579 | val normalized MSE   8.2184, | val MAE   1.3843 | val MSE   8.2184


Epoch:  55%|█████▌    | 55/100 [15:41<12:45, 17.02s/epoch]

Epoch 054 | Learning rate 0.000563 | train normalized MSE   9.3889 | val normalized MSE   9.0796, | val MAE   1.4824 | val MSE   9.0796


Epoch:  56%|█████▌    | 56/100 [15:58<12:29, 17.04s/epoch]

Epoch 055 | Learning rate 0.000563 | train normalized MSE   9.5236 | val normalized MSE   8.8042, | val MAE   1.4626 | val MSE   8.8042


Epoch:  57%|█████▋    | 57/100 [16:15<12:13, 17.06s/epoch]

Epoch 056 | Learning rate 0.000563 | train normalized MSE   9.5550 | val normalized MSE   8.5579, | val MAE   1.4120 | val MSE   8.5579


Epoch:  58%|█████▊    | 58/100 [16:32<11:58, 17.11s/epoch]

Epoch 057 | Learning rate 0.000563 | train normalized MSE   9.4738 | val normalized MSE   8.8210, | val MAE   1.4499 | val MSE   8.8210


Epoch:  59%|█████▉    | 59/100 [16:49<11:40, 17.09s/epoch]

Epoch 058 | Learning rate 0.000563 | train normalized MSE   9.4398 | val normalized MSE   8.8190, | val MAE   1.4699 | val MSE   8.8190


Epoch:  60%|██████    | 60/100 [17:06<11:21, 17.04s/epoch]

Epoch 059 | Learning rate 0.000422 | train normalized MSE   9.3919 | val normalized MSE   8.9310, | val MAE   1.4615 | val MSE   8.9310


Epoch:  61%|██████    | 61/100 [17:23<11:04, 17.03s/epoch]

Epoch 060 | Learning rate 0.000422 | train normalized MSE   9.1375 | val normalized MSE   8.9389, | val MAE   1.4572 | val MSE   8.9389


Epoch:  62%|██████▏   | 62/100 [17:40<10:48, 17.05s/epoch]

Epoch 061 | Learning rate 0.000422 | train normalized MSE   9.1337 | val normalized MSE   8.7871, | val MAE   1.4576 | val MSE   8.7871


Epoch:  63%|██████▎   | 63/100 [17:57<10:29, 17.01s/epoch]

Epoch 062 | Learning rate 0.000422 | train normalized MSE   9.1242 | val normalized MSE   8.7058, | val MAE   1.4191 | val MSE   8.7058


Epoch:  63%|██████▎   | 63/100 [18:14<10:42, 17.38s/epoch]

Epoch 063 | Learning rate 0.000422 | train normalized MSE   9.0986 | val normalized MSE   8.6675, | val MAE   1.4539 | val MSE   8.6675
Early stop!





In [45]:

class SocialLSTM(nn.Module):
    """
    简化版 Social-LSTM：所有 agent 使用共享 LSTM，
    每个时间步聚合邻居的 hidden state，作为社交信息输入。
    """
    def __init__(self, input_dim=6, hidden_dim=128, pred_len=60, dropout=0.1):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.pred_len = pred_len

        # 嵌入原始输入特征
        self.embed = nn.Linear(input_dim, hidden_dim)

        # 社交交互 + LSTM 单元
        self.lstm_cell = nn.LSTMCell(hidden_dim * 2, hidden_dim)

        # 最终预测头
        self.fc = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, pred_len * 2)
        )

    def forward(self, data):
        """
        data.x: shape (B*A, T, D)
        data.num_graphs = B
        """
        x = data.x  # (B*A, T, D)
        B = data.num_graphs
        A = x.size(0) // B
        T = x.size(1)

        # 重新组织形状
        x = x.view(B, A, T, self.input_dim)  # (B, A, T, D)

        # 嵌入输入
        x_embed = self.embed(x)  # (B, A, T, hidden_dim)

        # 初始化 hidden 和 cell
        h = torch.zeros(B, A, self.hidden_dim, device=x.device)
        c = torch.zeros(B, A, self.hidden_dim, device=x.device)

        # 时序迭代
        for t in range(T):
            x_t = x_embed[:, :, t, :]  # (B, A, hidden_dim)

            # 社交池化：对所有 agent 的 h 做平均
            social_context = h.mean(dim=1, keepdim=True).expand(-1, A, -1)  # (B, A, hidden_dim)

            # 拼接个体状态 + 社交特征
            lstm_input = torch.cat([x_t, social_context], dim=-1)  # (B, A, 2*hidden_dim)

            # LSTMCell 只能处理 2D，所以展平后逐 agent 处理
            h = h.view(B * A, self.hidden_dim)
            c = c.view(B * A, self.hidden_dim)
            lstm_input = lstm_input.view(B * A, -1)

            h, c = self.lstm_cell(lstm_input, (h, c))  # 单步更新
            h = h.view(B, A, self.hidden_dim)
            c = c.view(B, A, self.hidden_dim)

        # 最终只取 ego agent 的 hidden state
        ego_h = h[:, 0, :]  # (B, hidden_dim)

        out = self.fc(ego_h)              # (B, pred_len*2)
        out = out.view(B, self.pred_len, 2)  # (B, pred_len, 2)

        return out


# ====== 测试用例 ======
if __name__ == "__main__":
    class DummyData:
        pass

    B, A, T, D = 4, 50, 50, 6
    pred_len = 60
    torch.manual_seed(123)
    dummy = DummyData()
    dummy.x = torch.randn(B * A, T, D)
    dummy.num_graphs = B

    model = SocialLSTM(input_dim=D, hidden_dim=128, pred_len=pred_len, dropout=0.1)
    output = model(dummy)
    print(output.shape)  # 应输出: torch.Size([4, 60, 2])

torch.Size([4, 60, 2])


In [46]:
model = SocialLSTM(
    input_dim=6,
    hidden_dim=256,
    pred_len=60,
    dropout=0.1,
).to(device)


optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.75) # You can try different schedulers
early_stopping_patience = 10
best_val_loss = float('inf')
no_improvement = 0
criterion = nn.MSELoss()

In [47]:
for epoch in tqdm.tqdm(range(100), desc="Epoch", unit="epoch"):
    # ---- Training ----
    model.train()
    train_loss = 0
    for batch in train_dataloader:
        batch = batch.to(device)
        pred = model(batch)
        y = batch.y.view(batch.num_graphs, 60, 2)
        loss = criterion(pred, y)
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
        optimizer.step()
        train_loss += loss.item()

    # ---- Validation ----
    model.eval()
    val_loss = 0
    val_mae = 0
    val_mse = 0
    with torch.no_grad():
        for batch in val_dataloader:
            batch = batch.to(device)
            pred = model(batch)
            y = batch.y.view(batch.num_graphs, 60, 2)
            val_loss += criterion(pred, y).item()

            # show MAE and MSE with unnormalized data
            pred = pred * batch.scale.view(-1, 1, 1) + batch.origin.unsqueeze(1)
            y = y * batch.scale.view(-1, 1, 1) + batch.origin.unsqueeze(1)
            val_mae += nn.L1Loss()(pred, y).item()
            val_mse += nn.MSELoss()(pred, y).item()

    train_loss /= len(train_dataloader)
    val_loss /= len(val_dataloader)
    val_mae /= len(val_dataloader)
    val_mse /= len(val_dataloader)
    scheduler.step()
    # scheduler.step(val_loss)

    tqdm.tqdm.write(f"Epoch {epoch:03d} | Learning rate {optimizer.param_groups[0]['lr']:.6f} | train normalized MSE {train_loss:8.4f} | val normalized MSE {val_loss:8.4f}, | val MAE {val_mae:8.4f} | val MSE {val_mse:8.4f}")
    if val_loss < best_val_loss - 1e-3:
        best_val_loss = val_loss
        no_improvement = 0
        torch.save(model.state_dict(), "best_model.pt")
    else:
        no_improvement += 1
        if no_improvement >= early_stopping_patience:
            print("Early stop!")
            break

  future = future - origin
  future = future @ R
  future = future @ R
Epoch:   1%|          | 1/100 [00:26<43:56, 26.63s/epoch]

Epoch 000 | Learning rate 0.001000 | train normalized MSE  64.3009 | val normalized MSE  23.5604, | val MAE   2.8736 | val MSE  23.5604


Epoch:   2%|▏         | 2/100 [00:53<43:43, 26.77s/epoch]

Epoch 001 | Learning rate 0.001000 | train normalized MSE  24.4762 | val normalized MSE  22.3684, | val MAE   2.6976 | val MSE  22.3684


Epoch:   3%|▎         | 3/100 [01:20<43:18, 26.79s/epoch]

Epoch 002 | Learning rate 0.001000 | train normalized MSE  22.3988 | val normalized MSE  19.2778, | val MAE   2.3538 | val MSE  19.2778


Epoch:   4%|▍         | 4/100 [01:47<42:55, 26.82s/epoch]

Epoch 003 | Learning rate 0.001000 | train normalized MSE  21.2690 | val normalized MSE  20.8845, | val MAE   2.6684 | val MSE  20.8845


Epoch:   5%|▌         | 5/100 [02:13<42:27, 26.82s/epoch]

Epoch 004 | Learning rate 0.001000 | train normalized MSE  20.2979 | val normalized MSE  18.3050, | val MAE   2.3161 | val MSE  18.3050


Epoch:   6%|▌         | 6/100 [02:40<42:00, 26.81s/epoch]

Epoch 005 | Learning rate 0.001000 | train normalized MSE  19.8324 | val normalized MSE  18.7722, | val MAE   2.5013 | val MSE  18.7722


Epoch:   7%|▋         | 7/100 [03:07<41:36, 26.84s/epoch]

Epoch 006 | Learning rate 0.001000 | train normalized MSE  19.0700 | val normalized MSE  17.9383, | val MAE   2.3352 | val MSE  17.9383


Epoch:   8%|▊         | 8/100 [03:34<41:12, 26.88s/epoch]

Epoch 007 | Learning rate 0.001000 | train normalized MSE  19.1668 | val normalized MSE  18.2268, | val MAE   2.2454 | val MSE  18.2268


Epoch:   9%|▉         | 9/100 [04:01<40:46, 26.89s/epoch]

Epoch 008 | Learning rate 0.001000 | train normalized MSE  18.5292 | val normalized MSE  16.6081, | val MAE   2.2136 | val MSE  16.6081


Epoch:  10%|█         | 10/100 [04:28<40:13, 26.82s/epoch]

Epoch 009 | Learning rate 0.001000 | train normalized MSE  18.2237 | val normalized MSE  19.6836, | val MAE   2.4774 | val MSE  19.6836


Epoch:  11%|█         | 11/100 [04:55<39:48, 26.84s/epoch]

Epoch 010 | Learning rate 0.001000 | train normalized MSE  18.0951 | val normalized MSE  18.9679, | val MAE   2.3672 | val MSE  18.9679


Epoch:  12%|█▏        | 12/100 [05:21<39:21, 26.83s/epoch]

Epoch 011 | Learning rate 0.001000 | train normalized MSE  17.6471 | val normalized MSE  15.6870, | val MAE   2.0764 | val MSE  15.6870


Epoch:  13%|█▎        | 13/100 [05:48<38:52, 26.81s/epoch]

Epoch 012 | Learning rate 0.001000 | train normalized MSE  16.9384 | val normalized MSE  15.3893, | val MAE   2.1109 | val MSE  15.3893


Epoch:  14%|█▍        | 14/100 [06:15<38:24, 26.80s/epoch]

Epoch 013 | Learning rate 0.001000 | train normalized MSE  16.7040 | val normalized MSE  15.0326, | val MAE   2.0354 | val MSE  15.0326


Epoch:  15%|█▌        | 15/100 [06:42<37:55, 26.78s/epoch]

Epoch 014 | Learning rate 0.001000 | train normalized MSE  16.8419 | val normalized MSE  15.1114, | val MAE   2.0373 | val MSE  15.1114


Epoch:  16%|█▌        | 16/100 [07:09<37:32, 26.82s/epoch]

Epoch 015 | Learning rate 0.001000 | train normalized MSE  16.3469 | val normalized MSE  14.4963, | val MAE   1.9514 | val MSE  14.4963


Epoch:  17%|█▋        | 17/100 [07:35<37:06, 26.83s/epoch]

Epoch 016 | Learning rate 0.001000 | train normalized MSE  16.3260 | val normalized MSE  15.9908, | val MAE   2.2047 | val MSE  15.9908


Epoch:  18%|█▊        | 18/100 [08:02<36:42, 26.86s/epoch]

Epoch 017 | Learning rate 0.001000 | train normalized MSE  16.0569 | val normalized MSE  14.6305, | val MAE   2.1361 | val MSE  14.6305


Epoch:  19%|█▉        | 19/100 [08:29<36:12, 26.82s/epoch]

Epoch 018 | Learning rate 0.001000 | train normalized MSE  15.9675 | val normalized MSE  14.0568, | val MAE   1.9319 | val MSE  14.0568


Epoch:  20%|██        | 20/100 [08:56<35:43, 26.80s/epoch]

Epoch 019 | Learning rate 0.000750 | train normalized MSE  15.8549 | val normalized MSE  14.8430, | val MAE   2.0306 | val MSE  14.8430


Epoch:  21%|██        | 21/100 [09:23<35:19, 26.83s/epoch]

Epoch 020 | Learning rate 0.000750 | train normalized MSE  14.9760 | val normalized MSE  13.1223, | val MAE   1.8347 | val MSE  13.1223


Epoch:  22%|██▏       | 22/100 [09:50<34:53, 26.84s/epoch]

Epoch 021 | Learning rate 0.000750 | train normalized MSE  14.8338 | val normalized MSE  14.3009, | val MAE   1.9314 | val MSE  14.3010


Epoch:  23%|██▎       | 23/100 [10:16<34:23, 26.80s/epoch]

Epoch 022 | Learning rate 0.000750 | train normalized MSE  14.5235 | val normalized MSE  13.4135, | val MAE   1.8839 | val MSE  13.4135


Epoch:  24%|██▍       | 24/100 [10:43<33:49, 26.71s/epoch]

Epoch 023 | Learning rate 0.000750 | train normalized MSE  14.5236 | val normalized MSE  12.9558, | val MAE   1.8845 | val MSE  12.9558


Epoch:  25%|██▌       | 25/100 [11:09<33:18, 26.64s/epoch]

Epoch 024 | Learning rate 0.000750 | train normalized MSE  14.5388 | val normalized MSE  13.5388, | val MAE   1.9121 | val MSE  13.5388


Epoch:  26%|██▌       | 26/100 [11:36<32:54, 26.68s/epoch]

Epoch 025 | Learning rate 0.000750 | train normalized MSE  14.4978 | val normalized MSE  13.4251, | val MAE   1.8950 | val MSE  13.4251


Epoch:  27%|██▋       | 27/100 [12:03<32:29, 26.70s/epoch]

Epoch 026 | Learning rate 0.000750 | train normalized MSE  14.4102 | val normalized MSE  13.2509, | val MAE   1.9301 | val MSE  13.2509


Epoch:  28%|██▊       | 28/100 [12:30<32:04, 26.74s/epoch]

Epoch 027 | Learning rate 0.000750 | train normalized MSE  14.2200 | val normalized MSE  13.7422, | val MAE   1.9619 | val MSE  13.7422


Epoch:  29%|██▉       | 29/100 [12:56<31:40, 26.77s/epoch]

Epoch 028 | Learning rate 0.000750 | train normalized MSE  14.1752 | val normalized MSE  13.9077, | val MAE   1.9773 | val MSE  13.9077


Epoch:  30%|███       | 30/100 [13:23<31:12, 26.75s/epoch]

Epoch 029 | Learning rate 0.000750 | train normalized MSE  14.0934 | val normalized MSE  13.0892, | val MAE   1.8236 | val MSE  13.0892


Epoch:  31%|███       | 31/100 [13:50<30:46, 26.76s/epoch]

Epoch 030 | Learning rate 0.000750 | train normalized MSE  14.2700 | val normalized MSE  12.8904, | val MAE   1.8912 | val MSE  12.8904


Epoch:  32%|███▏      | 32/100 [14:17<30:21, 26.78s/epoch]

Epoch 031 | Learning rate 0.000750 | train normalized MSE  14.2485 | val normalized MSE  12.7603, | val MAE   1.8155 | val MSE  12.7603


Epoch:  33%|███▎      | 33/100 [14:44<29:55, 26.80s/epoch]

Epoch 032 | Learning rate 0.000750 | train normalized MSE  14.0378 | val normalized MSE  13.0273, | val MAE   1.8561 | val MSE  13.0273


Epoch:  34%|███▍      | 34/100 [15:10<29:29, 26.81s/epoch]

Epoch 033 | Learning rate 0.000750 | train normalized MSE  13.9209 | val normalized MSE  12.7406, | val MAE   1.8892 | val MSE  12.7406


Epoch:  35%|███▌      | 35/100 [15:37<29:00, 26.78s/epoch]

Epoch 034 | Learning rate 0.000750 | train normalized MSE  13.8881 | val normalized MSE  13.6669, | val MAE   1.8937 | val MSE  13.6669


Epoch:  36%|███▌      | 36/100 [16:04<28:31, 26.74s/epoch]

Epoch 035 | Learning rate 0.000750 | train normalized MSE  13.9760 | val normalized MSE  13.2591, | val MAE   1.9283 | val MSE  13.2592


Epoch:  37%|███▋      | 37/100 [16:31<28:07, 26.79s/epoch]

Epoch 036 | Learning rate 0.000750 | train normalized MSE  14.0418 | val normalized MSE  12.6121, | val MAE   1.8402 | val MSE  12.6121


Epoch:  38%|███▊      | 38/100 [16:58<27:41, 26.80s/epoch]

Epoch 037 | Learning rate 0.000750 | train normalized MSE  13.8089 | val normalized MSE  13.0942, | val MAE   1.8895 | val MSE  13.0942


Epoch:  39%|███▉      | 39/100 [17:24<27:14, 26.79s/epoch]

Epoch 038 | Learning rate 0.000750 | train normalized MSE  13.7812 | val normalized MSE  13.8312, | val MAE   1.9305 | val MSE  13.8312


Epoch:  40%|████      | 40/100 [17:51<26:48, 26.81s/epoch]

Epoch 039 | Learning rate 0.000563 | train normalized MSE  13.8017 | val normalized MSE  13.1377, | val MAE   1.8963 | val MSE  13.1377


Epoch:  41%|████      | 41/100 [18:18<26:17, 26.73s/epoch]

Epoch 040 | Learning rate 0.000563 | train normalized MSE  13.2092 | val normalized MSE  12.0462, | val MAE   1.7380 | val MSE  12.0462


Epoch:  42%|████▏     | 42/100 [18:45<25:51, 26.75s/epoch]

Epoch 041 | Learning rate 0.000563 | train normalized MSE  13.3435 | val normalized MSE  12.8252, | val MAE   1.8333 | val MSE  12.8252


Epoch:  43%|████▎     | 43/100 [19:11<25:24, 26.75s/epoch]

Epoch 042 | Learning rate 0.000563 | train normalized MSE  13.1449 | val normalized MSE  12.1733, | val MAE   1.7331 | val MSE  12.1733


Epoch:  44%|████▍     | 44/100 [19:38<24:59, 26.77s/epoch]

Epoch 043 | Learning rate 0.000563 | train normalized MSE  12.9530 | val normalized MSE  11.9996, | val MAE   1.7423 | val MSE  11.9996


Epoch:  45%|████▌     | 45/100 [20:05<24:32, 26.77s/epoch]

Epoch 044 | Learning rate 0.000563 | train normalized MSE  13.0258 | val normalized MSE  12.5727, | val MAE   1.7522 | val MSE  12.5727


Epoch:  46%|████▌     | 46/100 [20:31<24:01, 26.70s/epoch]

Epoch 045 | Learning rate 0.000563 | train normalized MSE  12.8591 | val normalized MSE  12.0237, | val MAE   1.7559 | val MSE  12.0237


Epoch:  47%|████▋     | 47/100 [20:58<23:30, 26.62s/epoch]

Epoch 046 | Learning rate 0.000563 | train normalized MSE  12.7244 | val normalized MSE  11.8543, | val MAE   1.7030 | val MSE  11.8543


Epoch:  48%|████▊     | 48/100 [21:24<23:00, 26.56s/epoch]

Epoch 047 | Learning rate 0.000563 | train normalized MSE  12.7487 | val normalized MSE  11.6994, | val MAE   1.7271 | val MSE  11.6994


Epoch:  49%|████▉     | 49/100 [21:51<22:37, 26.61s/epoch]

Epoch 048 | Learning rate 0.000563 | train normalized MSE  12.7981 | val normalized MSE  12.2135, | val MAE   1.8074 | val MSE  12.2135


Epoch:  50%|█████     | 50/100 [22:18<22:12, 26.65s/epoch]

Epoch 049 | Learning rate 0.000563 | train normalized MSE  12.8571 | val normalized MSE  11.3482, | val MAE   1.6792 | val MSE  11.3482


Epoch:  51%|█████     | 51/100 [22:44<21:47, 26.68s/epoch]

Epoch 050 | Learning rate 0.000563 | train normalized MSE  12.7215 | val normalized MSE  12.2553, | val MAE   1.7799 | val MSE  12.2553


Epoch:  52%|█████▏    | 52/100 [23:11<21:23, 26.73s/epoch]

Epoch 051 | Learning rate 0.000563 | train normalized MSE  12.7897 | val normalized MSE  11.8608, | val MAE   1.8030 | val MSE  11.8608


Epoch:  53%|█████▎    | 53/100 [23:38<20:56, 26.73s/epoch]

Epoch 052 | Learning rate 0.000563 | train normalized MSE  12.9315 | val normalized MSE  11.6881, | val MAE   1.7178 | val MSE  11.6881


Epoch:  54%|█████▍    | 54/100 [24:05<20:27, 26.68s/epoch]

Epoch 053 | Learning rate 0.000563 | train normalized MSE  12.7650 | val normalized MSE  11.5555, | val MAE   1.7353 | val MSE  11.5555


Epoch:  55%|█████▌    | 55/100 [24:31<20:00, 26.69s/epoch]

Epoch 054 | Learning rate 0.000563 | train normalized MSE  12.8535 | val normalized MSE  11.5700, | val MAE   1.7097 | val MSE  11.5700


Epoch:  56%|█████▌    | 56/100 [24:58<19:35, 26.72s/epoch]

Epoch 055 | Learning rate 0.000563 | train normalized MSE  12.6456 | val normalized MSE  12.1141, | val MAE   1.7565 | val MSE  12.1141


Epoch:  57%|█████▋    | 57/100 [25:25<19:09, 26.74s/epoch]

Epoch 056 | Learning rate 0.000563 | train normalized MSE  12.7301 | val normalized MSE  12.1890, | val MAE   1.7764 | val MSE  12.1890


Epoch:  58%|█████▊    | 58/100 [25:52<18:43, 26.76s/epoch]

Epoch 057 | Learning rate 0.000563 | train normalized MSE  12.5996 | val normalized MSE  11.7170, | val MAE   1.7394 | val MSE  11.7170


Epoch:  59%|█████▉    | 59/100 [26:19<18:17, 26.78s/epoch]

Epoch 058 | Learning rate 0.000563 | train normalized MSE  12.7313 | val normalized MSE  13.0366, | val MAE   1.8004 | val MSE  13.0366


Epoch:  59%|█████▉    | 59/100 [26:45<18:35, 27.21s/epoch]

Epoch 059 | Learning rate 0.000422 | train normalized MSE  12.8282 | val normalized MSE  11.7551, | val MAE   1.7402 | val MSE  11.7551
Early stop!





In [49]:
test_dataset = TrajectoryDatasetTest(test_data, scale=scale)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False,
                         collate_fn=lambda xs: Batch.from_data_list(xs))

best_model = torch.load("best_model.pt")
model = SocialLSTM(
    input_dim=6,
    hidden_dim=256,
    pred_len=60,
    dropout=0.1,
).to(device)

model.load_state_dict(best_model)
model.eval()

pred_list = []
with torch.no_grad():
    for batch in test_loader:
        batch = batch.to(device)
        pred_norm = model(batch)

        # Reshape the prediction to (N, 60, 2)
        pred = pred_norm * batch.scale.view(-1,1,1) + batch.origin.unsqueeze(1)
        pred_list.append(pred.cpu().numpy())
pred_list = np.concatenate(pred_list, axis=0)  # (N,60,2)
pred_output = pred_list.reshape(-1, 2)  # (N*60, 2)
output_df = pd.DataFrame(pred_output, columns=['x', 'y'])
output_df.index.name = 'index'
output_df.to_csv('submission.csv', index=True)

In [9]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR

# -------------------
# 1) Positional Encoding 同上，不变
# -------------------
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 500):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2, dtype=torch.float32) * (-math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        seq_len = x.size(1)
        return x + self.pe[:, :seq_len, :]


# -------------------
# 2) Social Attention with Top‐K neighbors + 双层
# -------------------
class SocialTransformerAttention(nn.Module):
    def __init__(self, hidden_dim: int, num_heads: int = 8, dropout: float = 0.1, topk: int = 10):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.topk = topk

        self.attn = nn.MultiheadAttention(
            embed_dim=hidden_dim,
            num_heads=num_heads,
            batch_first=True,
            dropout=dropout
        )
        self.pos_proj = nn.Linear(3, hidden_dim)
        self.ln1 = nn.LayerNorm(hidden_dim)
        self.ln2 = nn.LayerNorm(hidden_dim)
        self.ffn = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim * 4),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim * 4, hidden_dim),
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, ego_feat, other_feats, ego_pos, other_pos, mask=None):
        """
        ego_feat  : (B, H)
        other_feats: (B, N, H)
        ego_pos   : (B, 2)
        other_pos : (B, N, 2)
        mask      : (B, N) - True 表示屏蔽
        """
        B, N, H = other_feats.shape

        # 1) 先用 Top‐K 的方式动态挑最近的 neighbors
        #    如果 mask 已经标了很远/全零的也在 True，那只管还剩多少？我们尽量从剩余的里再取 topK
        with torch.no_grad():
            dist_all = torch.norm(other_pos - ego_pos.unsqueeze(1), dim=-1)  # (B, N)
            if mask is not None:
                dist_all = dist_all.masked_fill(mask, float('inf'))  # 把被 mask 的设为 +∞
            # 取最近 topk 个
            topk_vals, topk_idx = torch.topk(dist_all, k=min(self.topk, N), largest=False)  # (B, topk)

            # 构造一个 new_mask，使得 **不** 在 topk_idx 内的都变 True
            new_mask = torch.ones((B, N), device=other_feats.device, dtype=torch.bool)
            arange_b = torch.arange(B, device=other_feats.device).unsqueeze(1)  # (B,1)
            new_mask[arange_b, topk_idx] = False
            # 合并之前的 mask
            if mask is not None:
                mask = mask | new_mask
            else:
                mask = new_mask

        # 2) Query / Key / Value
        q = ego_feat.unsqueeze(1)  # (B,1,H)
        k = other_feats            # (B,N,H)
        v = other_feats            # (B,N,H)

        # 3) 相对位置编码
        rel_pos = other_pos - ego_pos.unsqueeze(1)       # (B,N,2)
        dist = torch.norm(rel_pos, dim=-1, keepdim=True) # (B,N,1)
        pos_embed = torch.cat([rel_pos, dist], dim=-1)   # (B,N,3)
        pos_bias = self.pos_proj(pos_embed)              # (B,N,H)
        k = k + pos_bias
        v = v + pos_bias

        # 4) 防止全 mask 导致 nan
        all_masked = mask.all(dim=1)
        if all_masked.any():
            mask[all_masked, 0] = False

        # 5) MultiHeadAttention
        attn_out, _ = self.attn(q, k, v, key_padding_mask=mask)  # (B,1,H)
        out = attn_out.squeeze(1)                                 # (B,H)
        out = self.ln1(ego_feat + self.dropout(out))

        # 6) FFN + 残差 + LayerNorm
        ff = self.ffn(out)
        out = self.ln2(out + self.dropout(ff))

        return out  # (B,H)


# -------------------
# 3) 主干网络：更深的 TransformerEncoder + 2 层 Social Attention + 融合 FFN
# -------------------
class SocialTransformerPredictor(nn.Module):
    def __init__(self,
                 input_dim: int = 5,      # 现在你只用 (x,y,vx,vy,ax) 共5维
                 hidden_dim: int = 512,   # 加大到 512
                 pred_len: int = 60,
                 num_agents: int = 50,
                 seq_len: int = 50,
                 num_heads: int = 4,
                 enc_layers: int = 4,     # 深度改成 4
                 dropout: float = 0.1,
                 neighbor_threshold: float = 10.0,
                 scale: float = 1.0,
                 topk: int = 10):         # 只选最近 10 个邻居
        super().__init__()
        self.hidden_dim = hidden_dim
        self.pred_len = pred_len
        self.num_agents = num_agents
        self.seq_len = seq_len
        self.neighbor_threshold = neighbor_threshold
        self.scale = scale

        # 3.1) 时序编码：Linear→PosEnc→TransformerEncoder
        self.input_proj = nn.Linear(input_dim, hidden_dim)
        self.pos_enc = PositionalEncoding(d_model=hidden_dim, max_len=seq_len)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=num_heads,
            dim_feedforward=hidden_dim * 4,
            dropout=dropout,
            batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layer,
            num_layers=enc_layers
        )
        self.temporal_ln = nn.LayerNorm(hidden_dim)

        # 3.2) Social Attention：两层串联
        self.social_attn1 = SocialTransformerAttention(
            hidden_dim=hidden_dim,
            num_heads=num_heads,
            dropout=dropout,
            topk=topk
        )
        self.social_attn2 = SocialTransformerAttention(
            hidden_dim=hidden_dim,
            num_heads=num_heads,
            dropout=dropout,
            topk=topk
        )

        # 3.3) 融合后再过一个小 FFN
        self.fusion_ln = nn.LayerNorm(hidden_dim * 2)
        self.fusion_dropout = nn.Dropout(dropout)
        self.fusion_ffn = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim * 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim * 2, hidden_dim * 2)
        )

        # 3.4) 最终 MLP Head
        self.head = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim * 4),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim * 4, pred_len * 2)
        )

    def forward(self, data):
        """
        data.x          : (B, A, T, 5)
        data.num_graphs:  B
        data.current_pos: (B, A, 2)
        """
        x = data.x
        B = data.num_graphs

        # 1) reshape → (B, A, T, 5)
        x = x.view(B, self.num_agents, self.seq_len, -1)

        # 2) 提取当前帧位置 current_pos
        if hasattr(data, 'current_pos'):
            current_pos = data.current_pos       # (B,50,2)
        else:
            current_pos = x[:, :, -1, :2].contiguous()  # (B,50,2)

        # 3) 时序编码：Linear→PosEnc→TransformerEncoder
        x_flat    = x.view(B * self.num_agents, self.seq_len, -1)  # (B*50,50,5)
        x_proj    = self.input_proj(x_flat)                        # (B*50,50,512)
        x_enc     = self.pos_enc(x_proj)                            # (B*50,50,512)
        x_encoded = self.transformer_encoder(x_enc)                 # (B*50,50,512)

        # 4) 取最后一帧特征（也可以替换为 max pooling，或 concat 多种时序池化）
        last_feat = x_encoded[:, -1, :]             # (B*50,512)
        feats     = last_feat.view(B, self.num_agents, self.hidden_dim)  # (B,50,512)
        feats     = self.temporal_ln(feats)         # LayerNorm

        # 5) 分离 ego 与 others
        ego_feat    = feats[:, 0, :]    # (B,512)
        other_feats = feats[:, 1:, :]   # (B,49,512)

        ego_pos     = current_pos[:, 0, :]   # (B,2)
        other_pos   = current_pos[:, 1:, :]  # (B,49,2)

        # 6) 构造距离 dist & TopK 掩码
        dist_all = torch.norm(other_pos - ego_pos.unsqueeze(1), dim=-1)  # (B,49)
        threshold = self.neighbor_threshold / self.scale
        hard_mask = dist_all > threshold                 # (B,49)
        base_mask = hard_mask.to(feats.device)           # 只用距离mask

        # 7) 第一层 Social Attention
        updated_1 = self.social_attn1(
            ego_feat, other_feats, ego_pos, other_pos, mask=base_mask
        )  # (B,512)

        # 8) 第二层 Social Attention （用刚刚 updated_1 做新 ego_feat）
        updated_2 = self.social_attn2(
            updated_1, other_feats, ego_pos, other_pos, mask=base_mask
        )  # (B,512)

        # 9) 融合：把原始 ego_feat 与 updated_2 拼接
        fusion = torch.cat([ego_feat, updated_2], dim=-1)  # (B,1024)
        fusion = self.fusion_ln(fusion)
        fusion = self.fusion_dropout(fusion)
        ff_out = self.fusion_ffn(fusion)
        fusion = F.relu(fusion + ff_out)  # (B,1024)

        # 10) Head → (B, pred_len*2) → reshape → (B, pred_len, 2)
        out = self.head(fusion)
        out = out.view(B, self.pred_len, 2)
        return out


In [12]:
model = SocialTransformerPredictor(
    input_dim=6,
    hidden_dim=256,
    pred_len=60,
    num_agents=50,
    seq_len=50,
    num_heads=2,
    enc_layers=2,
    dropout=0.1,
    neighbor_threshold=10.0,
    scale=1.0,
    topk=10
).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.75) # You can try different schedulers
early_stopping_patience = 10
best_val_loss = float('inf')
no_improvement = 0
criterion = nn.MSELoss()

In [13]:
for epoch in tqdm.tqdm(range(100), desc="Epoch", unit="epoch"):
    # ---- Training ----
    model.train()
    train_loss = 0
    for batch in train_dataloader:
        batch = batch.to(device)
        pred = model(batch)
        y = batch.y.view(batch.num_graphs, 60, 2)
        loss = criterion(pred, y)
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
        optimizer.step()
        train_loss += loss.item()

    # ---- Validation ----
    model.eval()
    val_loss = 0
    val_mae = 0
    val_mse = 0
    with torch.no_grad():
        for batch in val_dataloader:
            batch = batch.to(device)
            pred = model(batch)
            y = batch.y.view(batch.num_graphs, 60, 2)
            val_loss += criterion(pred, y).item()

            # show MAE and MSE with unnormalized data
            pred = pred * batch.scale.view(-1, 1, 1) + batch.origin.unsqueeze(1)
            y = y * batch.scale.view(-1, 1, 1) + batch.origin.unsqueeze(1)
            val_mae += nn.L1Loss()(pred, y).item()
            val_mse += nn.MSELoss()(pred, y).item()

    train_loss /= len(train_dataloader)
    val_loss /= len(val_dataloader)
    val_mae /= len(val_dataloader)
    val_mse /= len(val_dataloader)
    scheduler.step()
    # scheduler.step(val_loss)

    tqdm.tqdm.write(f"Epoch {epoch:03d} | Learning rate {optimizer.param_groups[0]['lr']:.6f} | train normalized MSE {train_loss:8.4f} | val normalized MSE {val_loss:8.4f}, | val MAE {val_mae:8.4f} | val MSE {val_mse:8.4f}")
    if val_loss < best_val_loss - 1e-3:
        best_val_loss = val_loss
        no_improvement = 0
        torch.save(model.state_dict(), "best_model.pt")
    else:
        no_improvement += 1
        if no_improvement >= early_stopping_patience:
            print("Early stop!")
            break

  future = future - origin
  future = future @ R
Epoch:   1%|          | 1/100 [00:28<47:31, 28.80s/epoch]

Epoch 000 | Learning rate 0.000100 | train normalized MSE  92.9601 | val normalized MSE  18.3528, | val MAE   2.4280 | val MSE  18.3528


Epoch:   2%|▏         | 2/100 [00:57<46:51, 28.69s/epoch]

Epoch 001 | Learning rate 0.000100 | train normalized MSE  22.0859 | val normalized MSE  14.2802, | val MAE   2.1195 | val MSE  14.2802


Epoch:   3%|▎         | 3/100 [01:26<46:21, 28.67s/epoch]

Epoch 002 | Learning rate 0.000100 | train normalized MSE  18.2310 | val normalized MSE  13.2214, | val MAE   2.0725 | val MSE  13.2214


Epoch:   4%|▍         | 4/100 [01:54<45:54, 28.69s/epoch]

Epoch 003 | Learning rate 0.000100 | train normalized MSE  16.8202 | val normalized MSE  12.8164, | val MAE   1.9310 | val MSE  12.8164


Epoch:   5%|▌         | 5/100 [02:23<45:25, 28.68s/epoch]

Epoch 004 | Learning rate 0.000100 | train normalized MSE  15.6575 | val normalized MSE  13.4001, | val MAE   1.9481 | val MSE  13.4001


Epoch:   6%|▌         | 6/100 [02:52<44:57, 28.70s/epoch]

Epoch 005 | Learning rate 0.000100 | train normalized MSE  14.9806 | val normalized MSE  12.0607, | val MAE   1.9088 | val MSE  12.0607


Epoch:   7%|▋         | 7/100 [03:20<44:29, 28.70s/epoch]

Epoch 006 | Learning rate 0.000100 | train normalized MSE  14.7507 | val normalized MSE  12.0278, | val MAE   1.8395 | val MSE  12.0278


Epoch:   8%|▊         | 8/100 [03:49<43:58, 28.68s/epoch]

Epoch 007 | Learning rate 0.000100 | train normalized MSE  14.2540 | val normalized MSE  12.5609, | val MAE   1.8627 | val MSE  12.5609


Epoch:   9%|▉         | 9/100 [04:18<43:26, 28.64s/epoch]

Epoch 008 | Learning rate 0.000100 | train normalized MSE  14.2262 | val normalized MSE  11.5090, | val MAE   1.7533 | val MSE  11.5090


Epoch:  10%|█         | 10/100 [04:46<42:50, 28.57s/epoch]

Epoch 009 | Learning rate 0.000100 | train normalized MSE  13.9480 | val normalized MSE  11.5718, | val MAE   1.8940 | val MSE  11.5718


Epoch:  11%|█         | 11/100 [05:14<42:16, 28.50s/epoch]

Epoch 010 | Learning rate 0.000100 | train normalized MSE  13.8736 | val normalized MSE  12.1408, | val MAE   1.8350 | val MSE  12.1408


Epoch:  12%|█▏        | 12/100 [05:43<41:46, 28.48s/epoch]

Epoch 011 | Learning rate 0.000100 | train normalized MSE  13.6599 | val normalized MSE  11.3237, | val MAE   1.8002 | val MSE  11.3237


Epoch:  13%|█▎        | 13/100 [06:11<41:22, 28.54s/epoch]

Epoch 012 | Learning rate 0.000100 | train normalized MSE  13.6087 | val normalized MSE  11.0801, | val MAE   1.7486 | val MSE  11.0801


Epoch:  14%|█▍        | 14/100 [06:40<40:54, 28.54s/epoch]

Epoch 013 | Learning rate 0.000100 | train normalized MSE  13.5623 | val normalized MSE  11.9749, | val MAE   1.8263 | val MSE  11.9749


Epoch:  15%|█▌        | 15/100 [07:09<40:27, 28.56s/epoch]

Epoch 014 | Learning rate 0.000100 | train normalized MSE  13.2337 | val normalized MSE  11.5110, | val MAE   1.7691 | val MSE  11.5110


Epoch:  16%|█▌        | 16/100 [07:37<40:01, 28.59s/epoch]

Epoch 015 | Learning rate 0.000100 | train normalized MSE  13.3696 | val normalized MSE  11.2487, | val MAE   1.7288 | val MSE  11.2487


Epoch:  17%|█▋        | 17/100 [08:06<39:33, 28.60s/epoch]

Epoch 016 | Learning rate 0.000100 | train normalized MSE  12.9897 | val normalized MSE  11.8130, | val MAE   1.8057 | val MSE  11.8130


Epoch:  18%|█▊        | 18/100 [08:34<39:05, 28.60s/epoch]

Epoch 017 | Learning rate 0.000100 | train normalized MSE  13.0623 | val normalized MSE  10.3892, | val MAE   1.6400 | val MSE  10.3892


Epoch:  19%|█▉        | 19/100 [09:03<38:37, 28.62s/epoch]

Epoch 018 | Learning rate 0.000100 | train normalized MSE  12.6951 | val normalized MSE  11.3871, | val MAE   1.7832 | val MSE  11.3870


Epoch:  20%|██        | 20/100 [09:32<38:07, 28.59s/epoch]

Epoch 019 | Learning rate 0.000075 | train normalized MSE  12.7588 | val normalized MSE  11.3960, | val MAE   1.7510 | val MSE  11.3960


Epoch:  21%|██        | 21/100 [10:00<37:34, 28.54s/epoch]

Epoch 020 | Learning rate 0.000075 | train normalized MSE  12.2918 | val normalized MSE  11.1547, | val MAE   1.7519 | val MSE  11.1547


Epoch:  22%|██▏       | 22/100 [10:29<37:09, 28.58s/epoch]

Epoch 021 | Learning rate 0.000075 | train normalized MSE  11.9847 | val normalized MSE  10.9299, | val MAE   1.6299 | val MSE  10.9299


Epoch:  23%|██▎       | 23/100 [10:58<36:45, 28.64s/epoch]

Epoch 022 | Learning rate 0.000075 | train normalized MSE  11.9799 | val normalized MSE  11.5974, | val MAE   1.6886 | val MSE  11.5974


Epoch:  24%|██▍       | 24/100 [11:26<36:18, 28.67s/epoch]

Epoch 023 | Learning rate 0.000075 | train normalized MSE  11.9921 | val normalized MSE  10.3258, | val MAE   1.5935 | val MSE  10.3258


Epoch:  25%|██▌       | 25/100 [11:55<35:49, 28.66s/epoch]

Epoch 024 | Learning rate 0.000075 | train normalized MSE  12.1646 | val normalized MSE  11.7939, | val MAE   1.7338 | val MSE  11.7939


Epoch:  26%|██▌       | 26/100 [12:24<35:20, 28.65s/epoch]

Epoch 025 | Learning rate 0.000075 | train normalized MSE  11.7831 | val normalized MSE  11.6560, | val MAE   1.7852 | val MSE  11.6560


Epoch:  27%|██▋       | 27/100 [12:52<34:50, 28.64s/epoch]

Epoch 026 | Learning rate 0.000075 | train normalized MSE  11.8254 | val normalized MSE  11.4626, | val MAE   1.6836 | val MSE  11.4625


Epoch:  28%|██▊       | 28/100 [13:21<34:20, 28.62s/epoch]

Epoch 027 | Learning rate 0.000075 | train normalized MSE  11.7348 | val normalized MSE  11.0086, | val MAE   1.6269 | val MSE  11.0086


Epoch:  29%|██▉       | 29/100 [13:49<33:52, 28.62s/epoch]

Epoch 028 | Learning rate 0.000075 | train normalized MSE  11.6387 | val normalized MSE  10.6128, | val MAE   1.6722 | val MSE  10.6128


Epoch:  30%|███       | 30/100 [14:18<33:22, 28.60s/epoch]

Epoch 029 | Learning rate 0.000075 | train normalized MSE  11.8519 | val normalized MSE  10.3835, | val MAE   1.5556 | val MSE  10.3835


Epoch:  31%|███       | 31/100 [14:46<32:50, 28.56s/epoch]

Epoch 030 | Learning rate 0.000075 | train normalized MSE  11.6261 | val normalized MSE  10.5626, | val MAE   1.5778 | val MSE  10.5626


Epoch:  32%|███▏      | 32/100 [15:15<32:24, 28.60s/epoch]

Epoch 031 | Learning rate 0.000075 | train normalized MSE  11.8292 | val normalized MSE  12.1899, | val MAE   1.8424 | val MSE  12.1899


Epoch:  33%|███▎      | 33/100 [15:44<31:55, 28.59s/epoch]

Epoch 032 | Learning rate 0.000075 | train normalized MSE  11.7976 | val normalized MSE  10.7061, | val MAE   1.6482 | val MSE  10.7061


Epoch:  33%|███▎      | 33/100 [16:12<32:54, 29.48s/epoch]

Epoch 033 | Learning rate 0.000075 | train normalized MSE  11.6355 | val normalized MSE  10.9036, | val MAE   1.6445 | val MSE  10.9036
Early stop!





In [14]:
test_dataset = TrajectoryDatasetTest(test_data, scale=scale)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False,
                         collate_fn=lambda xs: Batch.from_data_list(xs))

best_model = torch.load("best_model.pt")
model = SocialTransformerPredictor(
    input_dim=6,
    hidden_dim=256,
    pred_len=60,
    num_agents=50,
    seq_len=50,
    num_heads=2,
    enc_layers=2,
    dropout=0.1,
    neighbor_threshold=10.0,
    scale=1.0,
    topk=10
).to(device)

model.load_state_dict(best_model)
model.eval()

pred_list = []
with torch.no_grad():
    for batch in test_loader:
        batch = batch.to(device)
        pred_norm = model(batch)

        # Reshape the prediction to (N, 60, 2)
        pred = pred_norm * batch.scale.view(-1,1,1) + batch.origin.unsqueeze(1)
        pred_list.append(pred.cpu().numpy())
pred_list = np.concatenate(pred_list, axis=0)  # (N,60,2)
pred_output = pred_list.reshape(-1, 2)  # (N*60, 2)
output_df = pd.DataFrame(pred_output, columns=['x', 'y'])
output_df.index.name = 'index'
output_df.to_csv('submission.csv', index=True)