In [2]:
# block 1: imports (OGBN-Arxiv + GAT)
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# PyG / OGB
from ogb.nodeproppred import PygNodePropPredDataset, Evaluator
from torch_geometric.nn import GATConv
from torch_geometric.utils import add_self_loops, to_undirected

# 常用工具
import torch_geometric.transforms as T
from torch_geometric.data import Data

# 数据处理 / 可视化 / 其他
import pandas as pd
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE

# typing / debug
from typing import Optional, Tuple, List


In [3]:
# block 2 (original features): Load OGBN-Arxiv graph using ORIGINAL OGB embeddings
import torch
import pandas as pd
from torch_geometric.data import Data
from torch_geometric.utils import add_self_loops, to_undirected

# === 路径设置 ===
base_dir = "/home/xyx/gnn_demo/arxiv/ogbn_arxiv"
raw_dir = f"{base_dir}/raw"
split_dir = f"{base_dir}/split/time"

# === 1. 加载原始节点特征 ===
node_feat = pd.read_csv(f"{raw_dir}/node-feat.csv.gz", compression="gzip", header=None).values
x = torch.tensor(node_feat, dtype=torch.float)
print(f"✅ Loaded original OGB node features: {x.shape}")

# === 2. 加载边关系、标签、年份 ===
edge_index = pd.read_csv(f"{raw_dir}/edge.csv.gz", compression="gzip", header=None).values.T
edge_index = torch.tensor(edge_index, dtype=torch.long)

labels = pd.read_csv(f"{raw_dir}/node-label.csv.gz", compression="gzip", header=None).values.squeeze()
labels = torch.tensor(labels, dtype=torch.long)

years = pd.read_csv(f"{raw_dir}/node_year.csv.gz", compression="gzip", header=None).values.squeeze()
years = torch.tensor(years, dtype=torch.long)

# === 3. 无向化 + 自环 ===
edge_index = to_undirected(edge_index, num_nodes=x.shape[0])
edge_index, _ = add_self_loops(edge_index, num_nodes=x.shape[0])

# === 4. 构造 PyG 数据对象 ===
data = Data(x=x, edge_index=edge_index, y=labels)
data.node_year = years

# === 5. 加载训练/验证/测试划分 ===
train_idx = torch.tensor(pd.read_csv(f"{split_dir}/train.csv.gz", compression="gzip", header=None)[0].values)
valid_idx = torch.tensor(pd.read_csv(f"{split_dir}/valid.csv.gz", compression="gzip", header=None)[0].values)
test_idx = torch.tensor(pd.read_csv(f"{split_dir}/test.csv.gz", compression="gzip", header=None)[0].values)

# === 6. 打印信息 ===
print("==============================================================")
print(f"Nodes: {data.num_nodes}")
print(f"Edges: {data.num_edges}")
print(f"Feature dim (OGB original): {data.num_features}")
print(f"Classes: {int(labels.max()) + 1}")
print(f"Years: {years.min().item()} - {years.max().item()}")
print(f"Train/Valid/Test: {len(train_idx)}/{len(valid_idx)}/{len(test_idx)}")
print("==============================================================")


✅ Loaded original OGB node features: torch.Size([169343, 128])
Nodes: 169343
Edges: 2484941
Feature dim (OGB original): 128
Classes: 40
Years: 1971 - 2020
Train/Valid/Test: 90941/29799/48603


In [4]:
# block 3
import torch
import torch.nn.functional as F
from torch_geometric.nn import GATConv
from torch_geometric.utils import to_undirected
from ogb.nodeproppred import Evaluator

# === 构建更复杂的 GAT ===
class GAT(torch.nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, heads=4, dropout=0.3):
        super().__init__()
        self.gat1 = GATConv(in_dim, hidden_dim, heads=heads, dropout=dropout)
        self.gat2 = GATConv(hidden_dim * heads, hidden_dim, heads=heads, dropout=dropout)
        self.gat3 = GATConv(hidden_dim * heads, out_dim, heads=1, concat=False, dropout=dropout)
        self.dropout = dropout

    def forward(self, x, edge_index):
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.gat1(x, edge_index)
        x = F.elu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.gat2(x, edge_index)
        x = F.elu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.gat3(x, edge_index)
        return F.log_softmax(x, dim=1)

# === 初始化设备与数据 ===
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

data = data.to(device)
edge_index = to_undirected(data.edge_index, num_nodes=data.num_nodes)
data.edge_index = edge_index

# === 维度参数 ===
in_dim = data.num_features         # 768 (SciBERT)
hidden_dim = 128                   # 可调
out_dim = int(data.y.max()) + 1    # 40 类
print(f"in_dim={in_dim}, hidden_dim={hidden_dim}, out_dim={out_dim}")

gat_model = GAT(in_dim, hidden_dim, out_dim, heads=4, dropout=0.4).to(device)
print("\n--- GAT Model Architecture ---")
print(gat_model)

# === 优化器与损失 ===
optimizer = torch.optim.Adam(gat_model.parameters(), lr=0.002, weight_decay=5e-4)
criterion = torch.nn.NLLLoss()

# === 评估器 ===
evaluator = Evaluator(name="ogbn-arxiv")

# === 训练与验证循环 ===
def train():
    gat_model.train()
    optimizer.zero_grad()
    out = gat_model(data.x, data.edge_index)
    loss = criterion(out[train_idx], data.y[train_idx])
    loss.backward()
    optimizer.step()
    return loss.item()

@torch.no_grad()
def test():
    gat_model.eval()
    out = gat_model(data.x, data.edge_index)
    y_pred = out.argmax(dim=-1, keepdim=True)

    train_acc = evaluator.eval({
        'y_true': data.y[train_idx].unsqueeze(-1),
        'y_pred': y_pred[train_idx]
    })['acc']
    valid_acc = evaluator.eval({
        'y_true': data.y[valid_idx].unsqueeze(-1),
        'y_pred': y_pred[valid_idx]
    })['acc']
    test_acc = evaluator.eval({
        'y_true': data.y[test_idx].unsqueeze(-1),
        'y_pred': y_pred[test_idx]
    })['acc']

    return train_acc, valid_acc, test_acc

# === 主训练过程 ===
best_val_acc = 0
best_test_acc = 0

# ✅ 新增：仅保存在内存中（用于后续 block 4）
epochs_list, train_acc_list, val_acc_list, test_acc_list = [], [], [], []

for epoch in range(1, 101):
    loss = train()
    train_acc, val_acc, test_acc = test()

    # 记录每个 epoch 的数据
    epochs_list.append(epoch)
    train_acc_list.append(train_acc)
    val_acc_list.append(val_acc)
    test_acc_list.append(test_acc)

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_test_acc = test_acc

    if epoch % 10 == 0:
        print(f"Epoch {epoch:03d} | Loss: {loss:.4f} | "
              f"Train: {train_acc:.4f} | Val: {val_acc:.4f} | Test: {test_acc:.4f}")

print(f"\n✅ Training finished! Best Val ACC: {best_val_acc:.4f}, "
      f"Corresponding Test ACC: {best_test_acc:.4f}")

print("\n✅ Training logs (epochs_list, train_acc_list, val_acc_list, test_acc_list) are available in memory.")


Using device: cuda
in_dim=128, hidden_dim=128, out_dim=40

--- GAT Model Architecture ---
GAT(
  (gat1): GATConv(128, 128, heads=4)
  (gat2): GATConv(512, 128, heads=4)
  (gat3): GATConv(512, 40, heads=1)
)
Epoch 010 | Loss: 2.4985 | Train: 0.4005 | Val: 0.4211 | Test: 0.4204
Epoch 020 | Loss: 1.9429 | Train: 0.5496 | Val: 0.5672 | Test: 0.5528
Epoch 030 | Loss: 1.7017 | Train: 0.6070 | Val: 0.6232 | Test: 0.6166
Epoch 040 | Loss: 1.5717 | Train: 0.6323 | Val: 0.6410 | Test: 0.6341
Epoch 050 | Loss: 1.5013 | Train: 0.6434 | Val: 0.6487 | Test: 0.6397
Epoch 060 | Loss: 1.4601 | Train: 0.6542 | Val: 0.6600 | Test: 0.6540
Epoch 070 | Loss: 1.4332 | Train: 0.6607 | Val: 0.6652 | Test: 0.6627
Epoch 080 | Loss: 1.4089 | Train: 0.6675 | Val: 0.6695 | Test: 0.6633
Epoch 090 | Loss: 1.3960 | Train: 0.6704 | Val: 0.6721 | Test: 0.6681
Epoch 100 | Loss: 1.3916 | Train: 0.6738 | Val: 0.6733 | Test: 0.6659

✅ Training finished! Best Val ACC: 0.6733, Corresponding Test ACC: 0.6672

✅ Training logs (

In [5]:
# # block 5 (OGB original – NO_LEAK + optional LOW_DATA)
# import torch
# from torch_geometric.data import Data
# import torch_geometric.utils as utils
# import time
# import torch.nn.functional as F

# # === 参数区 ===
# LOW_DATA_MODE = True   # ← 改为 False 可使用全部训练边
# LOW_DATA_RATIO = 0.05  # 仅当 LOW_DATA_MODE=True 时生效 (可改 0.01)

# print("\n--- [No-Leak + Low-Data] OGBN-Arxiv Link Prediction Prep (OGB Original) ---")
# start_time = time.time()

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# gat_model.eval()
# data = data.to(device)

# # === Step 1: 边划分 ===
# print("[1/6] Splitting edges (10% val / 10% test)...")
# edge_index_full = data.edge_index.cpu()
# lp_data = Data(x=data.x.cpu(), edge_index=edge_index_full)
# split_data_arxiv = utils.train_test_split_edges(lp_data, val_ratio=0.1, test_ratio=0.1)
# print("✅ Edge split complete.")

# # === Step 2: 可选 Low-Data 下采样 ===
# if LOW_DATA_MODE:
#     ratio = LOW_DATA_RATIO
#     num_total = split_data_arxiv.train_pos_edge_index.size(1)
#     num_keep = int(num_total * ratio)
#     perm = torch.randperm(num_total)[:num_keep]
#     split_data_arxiv.train_pos_edge_index = split_data_arxiv.train_pos_edge_index[:, perm]
#     print(f"✅ Low-Data mode ON: Kept {num_keep} training edges ({ratio*100:.1f}%)")
# else:
#     print("✅ Low-Data mode OFF: Using all training edges.")

# # === Step 3: 构建训练子图 ===
# train_subgraph = Data(
#     x=data.x.cpu(),
#     edge_index=split_data_arxiv.train_pos_edge_index.cpu()
# )
# print(f"✅ Train subgraph built with {train_subgraph.edge_index.shape[1]} edges.")

# # === Step 4: 重算 GAT 嵌入（仅训练边，避免信息泄漏） ===
# print("[4/6] Recomputing GAT embeddings on train-only subgraph (OGB features)...")
# train_subgraph = train_subgraph.to(device)
# with torch.no_grad():
#     h = F.dropout(train_subgraph.x, p=gat_model.dropout, training=False)
#     h = gat_model.gat1(h, train_subgraph.edge_index); h = F.elu(h)
#     h = gat_model.gat2(h, train_subgraph.edge_index); h = F.elu(h)
#     node_embeddings_arxiv = h.detach().cpu()
# print(f"✅ New embeddings computed: {node_embeddings_arxiv.shape}")

# # === Step 5: 负采样（1:5） ===
# print("[5/6] Sampling negative edges (1:5 ratio)...")
# num_nodes = data.num_nodes
# num_train_pos = split_data_arxiv.train_pos_edge_index.size(1)
# num_train_neg = num_train_pos * 5
# train_neg_edge_index = utils.negative_sampling(
#     edge_index=split_data_arxiv.train_pos_edge_index,
#     num_nodes=num_nodes,
#     num_neg_samples=num_train_neg,
#     method='sparse'
# )
# split_data_arxiv.train_neg_edge_index = train_neg_edge_index
# print(f"✅ Negative sampling done. Train +: {num_train_pos}, Train -: {num_train_neg}")

# # === Step 6: 移到 GPU ===
# split_data_arxiv = split_data_arxiv.to(device)
# node_embeddings_arxiv = node_embeddings_arxiv.to(device)

# elapsed = time.time() - start_time
# print("\n--- Summary ---")
# print(f"Nodes: {num_nodes}")
# print(f"Train pos: {num_train_pos}")
# print(f"Train neg: {num_train_neg}")
# print(f"Val pos: {split_data_arxiv.val_pos_edge_index.shape[1]}")
# print(f"Test pos: {split_data_arxiv.test_pos_edge_index.shape[1]}")
# print(f"Runtime: {elapsed:.2f}s (~{elapsed/60:.2f} min)")
# print("✅ Block 5 (No-Leak + Low-Data, OGB original) finished successfully.")


In [6]:
# # block 5.1 (save – OGB original)
# torch.save({
#     'node_embeddings_arxiv': node_embeddings_arxiv.cpu(),
#     'split_data_arxiv': {
#         'train_pos_edge_index': split_data_arxiv.train_pos_edge_index.cpu(),
#         'train_neg_edge_index': split_data_arxiv.train_neg_edge_index.cpu(),
#         'val_pos_edge_index': split_data_arxiv.val_pos_edge_index.cpu(),
#         'val_neg_edge_index': split_data_arxiv.val_neg_edge_index.cpu(),
#         'test_pos_edge_index': split_data_arxiv.test_pos_edge_index.cpu(),
#         'test_neg_edge_index': split_data_arxiv.test_neg_edge_index.cpu(),
#     }
# }, "ogbn_arxiv_link_data_ogb.pt")

# print("✅ Saved OGB-original embeddings and splits to ogbn_arxiv_link_data_ogb.pt")


In [7]:
# block 5.2 (OGB original): Load Pre-Saved ogbn_arxiv_link_data_ogb.pt
import torch
from torch_geometric.data import Data

print("\n--- [Block 5.2] Loading Pre-Saved ogbn_arxiv_link_data_ogb.pt ---")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# === Step 1: 加载保存文件 ===
checkpoint = torch.load("ogbn_arxiv_link_data_ogb.pt", map_location="cpu")

# === Step 2: 恢复节点嵌入 ===
node_embeddings_arxiv = checkpoint["node_embeddings_arxiv"]
print(f"✅ Loaded node embeddings: {tuple(node_embeddings_arxiv.shape)}")

# === Step 3: 恢复边划分数据 ===
split_dict = checkpoint["split_data_arxiv"]

# 构造 Data 对象（与 Block 5 输出结构保持一致）
split_data_arxiv = Data()
split_data_arxiv.train_pos_edge_index = split_dict["train_pos_edge_index"]
split_data_arxiv.train_neg_edge_index = split_dict["train_neg_edge_index"]
split_data_arxiv.val_pos_edge_index   = split_dict["val_pos_edge_index"]
split_data_arxiv.val_neg_edge_index   = split_dict["val_neg_edge_index"]
split_data_arxiv.test_pos_edge_index  = split_dict["test_pos_edge_index"]
split_data_arxiv.test_neg_edge_index  = split_dict["test_neg_edge_index"]

# === Step 4: 移动到 GPU ===
node_embeddings_arxiv = node_embeddings_arxiv.to(device)
split_data_arxiv = split_data_arxiv.to(device)

# === Step 5: 打印验证信息 ===
print("✅ All components loaded and moved to device.")
print(f"Train pos edges: {split_data_arxiv.train_pos_edge_index.shape[1]}")
print(f"Train neg edges: {split_data_arxiv.train_neg_edge_index.shape[1]}")
print(f"Val pos edges:   {split_data_arxiv.val_pos_edge_index.shape[1]}")
print(f"Test pos edges:  {split_data_arxiv.test_pos_edge_index.shape[1]}")
print(f"Device: {device}")

print("\n--- Block 5.2 ready. You can proceed to Block 6 (LinkPredictor definition). ---")



--- [Block 5.2] Loading Pre-Saved ogbn_arxiv_link_data_ogb.pt ---
✅ Loaded node embeddings: (169343, 512)
✅ All components loaded and moved to device.
Train pos edges: 92624
Train neg edges: 463120
Val pos edges:   115779
Test pos edges:  115779
Device: cuda

--- Block 5.2 ready. You can proceed to Block 6 (LinkPredictor definition). ---


In [8]:
# block 5.3 (OGB original): Recompute GAT embeddings (Train Edges Only, No Leakage)
import torch
import torch.nn.functional as F
from torch_geometric.utils import to_undirected
import time

print("\n--- [Block 5.3] Recomputing GAT embeddings using TRAIN edges only (OGB original) ---")
start_time = time.time()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
gat_model = gat_model.to(device)
gat_model.eval()

# === Step 1: 构造仅包含训练正样本边的子图 ===
train_edge_index = split_data_arxiv.train_pos_edge_index
train_edge_index = to_undirected(train_edge_index)  # 确保为无向图
train_edge_index = train_edge_index.to(device)

print(f"Using only train edges: {train_edge_index.shape[1]} edges for GAT propagation.")

# === Step 2: 基于训练子图计算节点嵌入（防止信息泄漏） ===
with torch.no_grad():
    x = data.x.to(device)

    h = F.dropout(x, p=gat_model.dropout, training=False)
    h = gat_model.gat1(h, train_edge_index); h = F.elu(h)
    h = gat_model.gat2(h, train_edge_index); h = F.elu(h)
    node_embeddings_arxiv = h.detach()

print(f"✅ Done. New train-only embeddings computed: {node_embeddings_arxiv.shape}")
print(f"Runtime: {time.time() - start_time:.2f}s")

# === Step 3: 移动到 GPU（用于后续 Block 6–9） ===
node_embeddings_arxiv = node_embeddings_arxiv.to(device)
print("✅ node_embeddings_arxiv replaced with train-only embeddings.")
print("\n--- Block 5.3 ready. You can proceed to Block 6 (LinkPredictor definition). ---")



--- [Block 5.3] Recomputing GAT embeddings using TRAIN edges only (OGB original) ---
Using only train edges: 180516 edges for GAT propagation.
✅ Done. New train-only embeddings computed: torch.Size([169343, 512])
Runtime: 0.02s
✅ node_embeddings_arxiv replaced with train-only embeddings.

--- Block 5.3 ready. You can proceed to Block 6 (LinkPredictor definition). ---


In [9]:
# block 5.4 (OGB original): Apply Low-Data Sampling to Loaded Split
import torch
import time
from torch_geometric.utils import negative_sampling

print("\n--- [Block 5.4] Applying Optional Low-Data Sampling (OGB Original) ---")
start_time = time.time()

# === Step 0: 参数开关 ===
LOW_DATA_MODE = True      # ← 改为 False 关闭下采样
LOW_DATA_RATIO = 0.01     # 仅当 LOW_DATA_MODE=True 时生效 (可调 0.01, 0.05 等)

# === Step 1: 确认 split_data_arxiv 是否存在 ===
assert 'split_data_arxiv' in locals(), "❌ split_data_arxiv not found. Run Block 5.2 or 5.3 first."

if LOW_DATA_MODE:
    num_total = split_data_arxiv.train_pos_edge_index.size(1)
    num_keep = int(num_total * LOW_DATA_RATIO)
    perm = torch.randperm(num_total)[:num_keep]
    split_data_arxiv.train_pos_edge_index = split_data_arxiv.train_pos_edge_index[:, perm]
    print(f"✅ Low-Data mode ON: kept {num_keep}/{num_total} training edges ({LOW_DATA_RATIO*100:.1f}%)")

    # --- Step 2: 同步重新采样负边（保持 1:5 比例） ---
    print("[Resampling negative edges to match new ratio]...")
    num_nodes = data.num_nodes
    num_train_pos = split_data_arxiv.train_pos_edge_index.size(1)
    num_train_neg = num_train_pos * 5

    split_data_arxiv.train_neg_edge_index = negative_sampling(
        edge_index=split_data_arxiv.train_pos_edge_index,
        num_nodes=num_nodes,
        num_neg_samples=num_train_neg,
        method='sparse'
    )
    print(f"✅ Negative edges resampled: Train + {num_train_pos}, Train - {num_train_neg}")
else:
    print("✅ Low-Data mode OFF: keeping full training edge set.")

elapsed = time.time() - start_time
print(f"Runtime: {elapsed:.2f}s")
print("✅ Block 5.4 finished successfully. You can proceed to Block 6 (LinkPredictor definition).")



--- [Block 5.4] Applying Optional Low-Data Sampling (OGB Original) ---
✅ Low-Data mode ON: kept 926/92624 training edges (1.0%)
[Resampling negative edges to match new ratio]...
✅ Negative edges resampled: Train + 926, Train - 4630
Runtime: 0.04s
✅ Block 5.4 finished successfully. You can proceed to Block 6 (LinkPredictor definition).


In [10]:
# block 6 (OGB original): Enhanced LinkPredictor definition
import torch
import torch.nn as nn
import torch.nn.functional as F

class LinkPredictor(nn.Module):
    def __init__(self, embedding_dim, hidden_dim=64, dropout=0.2):
        super().__init__()
        self.fc1 = nn.Linear(embedding_dim * 2, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, 1)
        self.dropout = dropout

    def forward(self, all_node_embeddings, edge_label_index):
        # edge_label_index shape: [2, num_edges]
        src = all_node_embeddings[edge_label_index[0]]
        dst = all_node_embeddings[edge_label_index[1]]
        h = torch.cat([src, dst], dim=1)
        h = F.relu(self.fc1(h))
        h = F.dropout(h, p=self.dropout, training=self.training)
        h = F.relu(self.fc2(h))
        h = F.dropout(h, p=self.dropout, training=self.training)
        return self.fc3(h).squeeze(-1)

print("✅ Enhanced LinkPredictor defined (OGB original version: 3-layer MLP with dropout).")


✅ Enhanced LinkPredictor defined (OGB original version: 3-layer MLP with dropout).


In [11]:
# block 7 (OGB original): Initialize LinkPredictor, optimizer, and loss
import torch
import torch.optim as optim

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# === 自动推断嵌入维度 ===
embedding_dim_arxiv = node_embeddings_arxiv.shape[1]

# === 初始化链路预测器 ===
link_predictor_arxiv = LinkPredictor(
    embedding_dim=embedding_dim_arxiv,
    hidden_dim=128,     # 可调，例如 128 / 256
    dropout=0.3         # 稍微强一点的正则化
).to(device)

# === 优化器 ===
optimizer_link_arxiv = optim.AdamW(
    link_predictor_arxiv.parameters(),
    lr=1e-3,
    weight_decay=1e-4
)

# === 损失函数 ===
criterion_link_arxiv = torch.nn.BCEWithLogitsLoss()

print(f"\n✅ LinkPredictor model for OGBN-Arxiv (OGB Original features) initialized on {device}.")
print(f"Embedding dimension: {embedding_dim_arxiv}")
print(link_predictor_arxiv)



✅ LinkPredictor model for OGBN-Arxiv (OGB Original features) initialized on cuda.
Embedding dimension: 512
LinkPredictor(
  (fc1): Linear(in_features=1024, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=128, bias=True)
  (fc3): Linear(in_features=128, out_features=1, bias=True)
)


In [12]:
# block 8 (OGB original): Train and test functions for link prediction
import torch
from sklearn.metrics import roc_auc_score

@torch.no_grad()
def test_link_pred_arxiv(pos_edge_index, neg_edge_index, batch_size=1024):
    """
    评估链路预测模型在给定边集上的 AUC。
    使用 batch 计算避免显存溢出。
    """
    link_predictor_arxiv.eval()
    pos_scores, neg_scores = [], []
    total_pos = pos_edge_index.size(1)
    total_neg = neg_edge_index.size(1)

    # === 正样本批次计算 ===
    for start in range(0, total_pos, batch_size):
        end = min(start + batch_size, total_pos)
        edges = pos_edge_index[:, start:end]
        logits = link_predictor_arxiv(node_embeddings_arxiv, edges)
        pos_scores.append(logits.sigmoid().cpu())
        torch.cuda.empty_cache()

    # === 负样本批次计算 ===
    for start in range(0, total_neg, batch_size):
        end = min(start + batch_size, total_neg)
        edges = neg_edge_index[:, start:end]
        logits = link_predictor_arxiv(node_embeddings_arxiv, edges)
        neg_scores.append(logits.sigmoid().cpu())
        torch.cuda.empty_cache()

    # === 拼接结果 ===
    pos_scores = torch.cat(pos_scores)
    neg_scores = torch.cat(neg_scores)
    scores = torch.cat([pos_scores, neg_scores])
    labels = torch.cat([
        torch.ones(pos_scores.size(0)),
        torch.zeros(neg_scores.size(0))
    ])

    auc = roc_auc_score(labels.numpy(), scores.numpy())
    return auc


def train_link_pred_arxiv(batch_size=128):
    """
    在训练集上优化链路预测器。
    随机混合正负样本并以 batch 训练。
    """
    link_predictor_arxiv.train()
    optimizer_link_arxiv.zero_grad()

    pos_edges = split_data_arxiv.train_pos_edge_index
    neg_edges = split_data_arxiv.train_neg_edge_index
    num_pos = pos_edges.size(1)
    num_neg = neg_edges.size(1)

    # === 合并正负样本 ===
    all_edges = torch.cat([pos_edges, neg_edges], dim=1)
    all_labels = torch.cat([
        torch.ones(num_pos, device=device),
        torch.zeros(num_neg, device=device)
    ])
    perm = torch.randperm(num_pos + num_neg, device=device)
    all_edges, all_labels = all_edges[:, perm], all_labels[perm]

    total_loss, steps = 0.0, 0
    for start in range(0, all_edges.size(1), batch_size):
        end = min(start + batch_size, all_edges.size(1))
        e = all_edges[:, start:end]
        lbl = all_labels[start:end]

        logits = link_predictor_arxiv(node_embeddings_arxiv, e)
        loss = criterion_link_arxiv(logits, lbl)

        loss.backward()
        optimizer_link_arxiv.step()
        optimizer_link_arxiv.zero_grad(set_to_none=True)

        total_loss += loss.item() * (end - start)
        steps += 1

        if steps % 1000 == 0:
            print(f"Batch {steps}: avg_loss={total_loss/(steps*batch_size):.4f}")
            torch.cuda.empty_cache()

    return total_loss / all_edges.size(1)

print("✅ train_link_pred_arxiv() and test_link_pred_arxiv() defined (OGB Original version).")


✅ train_link_pred_arxiv() and test_link_pred_arxiv() defined (OGB Original version).


In [13]:
# block 9 (OGB original): Train loop for OGBN-Arxiv link predictor
import copy, time

best_val_auc_arxiv = 0.0
best_model_state_link_arxiv = None
num_epochs_link_arxiv = 10  # 可根据需要调大，例如 30 或 50
start_time = time.time()

print("\n--- Starting OGBN-Arxiv Link Predictor Training (OGB Original features) ---")

for epoch in range(1, num_epochs_link_arxiv + 1):
    # === 训练一个 epoch ===
    loss_arxiv = train_link_pred_arxiv()
    
    # === 验证集 AUC ===
    val_auc_arxiv = test_link_pred_arxiv(
        split_data_arxiv.val_pos_edge_index,
        split_data_arxiv.val_neg_edge_index
    )

    # === 保存最佳模型 ===
    if val_auc_arxiv > best_val_auc_arxiv:
        best_val_auc_arxiv = val_auc_arxiv
        best_model_state_link_arxiv = copy.deepcopy(link_predictor_arxiv.state_dict())

    # === 每隔若干 epoch 打印训练 / 测试结果 ===
    if epoch % 5 == 0 or epoch == 1:
        train_auc = test_link_pred_arxiv(
            split_data_arxiv.train_pos_edge_index,
            split_data_arxiv.train_neg_edge_index
        )
        test_auc = test_link_pred_arxiv(
            split_data_arxiv.test_pos_edge_index,
            split_data_arxiv.test_neg_edge_index
        )
        elapsed = time.time() - start_time
        print(f"Epoch {epoch:03d} | "
              f"Loss: {loss_arxiv:.4f} | "
              f"Train AUC: {train_auc:.4f} | "
              f"Val AUC: {val_auc_arxiv:.4f} | "
              f"Test AUC: {test_auc:.4f} | "
              f"Time: {elapsed/60:.2f} min")

print("\n--- Training Finished (OGB Original features) ---")
print(f"✅ Best Validation AUC: {best_val_auc_arxiv:.4f}")

# === 加载最佳模型 ===
if best_model_state_link_arxiv:
    link_predictor_arxiv.load_state_dict(best_model_state_link_arxiv)
    print("Best model state reloaded.")

# === 最终测试集评估 ===
final_test_auc_arxiv = test_link_pred_arxiv(
    split_data_arxiv.test_pos_edge_index,
    split_data_arxiv.test_neg_edge_index
)
print(f"Final Test AUC (OGB Original): {final_test_auc_arxiv:.4f}")



--- Starting OGBN-Arxiv Link Predictor Training (OGB Original features) ---
Epoch 001 | Loss: 0.3896 | Train AUC: 0.9617 | Val AUC: 0.8689 | Test AUC: 0.8699 | Time: 0.01 min
Epoch 005 | Loss: 0.0822 | Train AUC: 0.9966 | Val AUC: 0.9091 | Test AUC: 0.9093 | Time: 0.03 min
Epoch 010 | Loss: 0.0550 | Train AUC: 0.9981 | Val AUC: 0.9081 | Test AUC: 0.9077 | Time: 0.04 min

--- Training Finished (OGB Original features) ---
✅ Best Validation AUC: 0.9104
Best model state reloaded.
Final Test AUC (OGB Original): 0.9097


In [14]:
# block A1-OGB (Ablation: No GAT, Same Head, Original OGB Embedding)
import torch
import torch.nn.functional as F
from ogb.nodeproppred import Evaluator

print("\n--- [Ablation A1-OGB] No-GAT (Original Embedding) ---")

# === 模型定义 ===
class NoGAT(torch.nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, dropout=0.4):
        super().__init__()
        self.fc1 = torch.nn.Linear(in_dim, hidden_dim * 4)
        self.fc2 = torch.nn.Linear(hidden_dim * 4, hidden_dim * 4)
        self.fc3 = torch.nn.Linear(hidden_dim * 4, out_dim)
        self.dropout = dropout

    def forward(self, x, edge_index=None):  # 保留接口，但不使用图结构
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = F.elu(self.fc1(x))
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = F.elu(self.fc2(x))
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.fc3(x)
        return F.log_softmax(x, dim=1)

# === 初始化 ===
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
data = data.to(device)   # 此时 data.x 是 OGB 原生 embedding（dim=128）

in_dim = data.num_features       # 128
hidden_dim = 128
out_dim = int(data.y.max()) + 1

nogat_model_ogb = NoGAT(in_dim, hidden_dim, out_dim, dropout=0.4).to(device)
optimizer = torch.optim.Adam(nogat_model_ogb.parameters(), lr=0.002, weight_decay=5e-4)
criterion = torch.nn.NLLLoss()
evaluator = Evaluator(name="ogbn-arxiv")

# === 训练函数 ===
def train():
    nogat_model_ogb.train()
    optimizer.zero_grad()
    out = nogat_model_ogb(data.x)  # 不使用边
    loss = criterion(out[train_idx], data.y[train_idx])
    loss.backward()
    optimizer.step()
    return loss.item()

# === 测试函数 ===
@torch.no_grad()
def test():
    nogat_model_ogb.eval()
    out = nogat_model_ogb(data.x)
    y_pred = out.argmax(dim=-1, keepdim=True)
    train_acc = evaluator.eval({'y_true': data.y[train_idx].unsqueeze(-1), 'y_pred': y_pred[train_idx]})['acc']
    val_acc = evaluator.eval({'y_true': data.y[valid_idx].unsqueeze(-1), 'y_pred': y_pred[valid_idx]})['acc']
    test_acc = evaluator.eval({'y_true': data.y[test_idx].unsqueeze(-1), 'y_pred': y_pred[test_idx]})['acc']
    return train_acc, val_acc, test_acc

# === 主训练循环 ===
best_val_acc = 0
best_test_acc = 0
for epoch in range(1, 101):
    loss = train()
    train_acc, val_acc, test_acc = test()
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_test_acc = test_acc
    if epoch % 10 == 0:
        print(f"[A1-OGB No-GAT] Epoch {epoch:03d} | Loss: {loss:.4f} | "
              f"Train: {train_acc:.4f} | Val: {val_acc:.4f} | Test: {test_acc:.4f}")

print(f"\n✅ [A1-OGB No-GAT] Finished. Best Val ACC: {best_val_acc:.4f}, "
      f"Corresponding Test ACC: {best_test_acc:.4f}")



--- [Ablation A1-OGB] No-GAT (Original Embedding) ---


[A1-OGB No-GAT] Epoch 010 | Loss: 2.8105 | Train: 0.2688 | Val: 0.2728 | Test: 0.2413
[A1-OGB No-GAT] Epoch 020 | Loss: 2.4163 | Train: 0.3901 | Val: 0.4068 | Test: 0.3850
[A1-OGB No-GAT] Epoch 030 | Loss: 2.2498 | Train: 0.4358 | Val: 0.4588 | Test: 0.4431
[A1-OGB No-GAT] Epoch 040 | Loss: 2.1555 | Train: 0.4671 | Val: 0.4807 | Test: 0.4606
[A1-OGB No-GAT] Epoch 050 | Loss: 2.1026 | Train: 0.4786 | Val: 0.4926 | Test: 0.4752
[A1-OGB No-GAT] Epoch 060 | Loss: 2.0715 | Train: 0.4847 | Val: 0.4985 | Test: 0.4801
[A1-OGB No-GAT] Epoch 070 | Loss: 2.0577 | Train: 0.4889 | Val: 0.4999 | Test: 0.4822
[A1-OGB No-GAT] Epoch 080 | Loss: 2.0447 | Train: 0.4916 | Val: 0.5021 | Test: 0.4837
[A1-OGB No-GAT] Epoch 090 | Loss: 2.0381 | Train: 0.4940 | Val: 0.5048 | Test: 0.4854
[A1-OGB No-GAT] Epoch 100 | Loss: 2.0282 | Train: 0.4962 | Val: 0.5073 | Test: 0.4889

✅ [A1-OGB No-GAT] Finished. Best Val ACC: 0.5073, Corresponding Test ACC: 0.4889


In [15]:
# block A2-OGB (Ablation: Param-Matched No-GAT, Original OGB Embedding)
import torch
import torch.nn.functional as F
from ogb.nodeproppred import Evaluator

print("\n--- [A2-OGB] Param-Matched No-GAT (Original OGB Embedding) ---")

# === 模型定义 ===
class NoGAT_ParamMatched(torch.nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, dropout=0.4):
        super().__init__()
        # 这里 hidden_dim 仍然调大，用于匹配 GAT 的参数规模
        self.fc1 = torch.nn.Linear(in_dim, hidden_dim)
        self.fc2 = torch.nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = torch.nn.Linear(hidden_dim, out_dim)
        self.dropout = dropout

    def forward(self, x, edge_index=None):  # 不使用 edge_index
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = F.elu(self.fc1(x))
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = F.elu(self.fc2(x))
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.fc3(x)
        return F.log_softmax(x, dim=1)

# === 初始化 ===
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
data = data.to(device)  # 此时 data.x 是 OGB 原始 embedding (128维)

in_dim = data.num_features       # 128
hidden_dim = 384                 # 调大，匹配 GAT 的多头参数数量
out_dim = int(data.y.max()) + 1

nogat_matched_ogb = NoGAT_ParamMatched(in_dim, hidden_dim, out_dim, dropout=0.4).to(device)
optimizer = torch.optim.Adam(nogat_matched_ogb.parameters(), lr=0.002, weight_decay=5e-4)
criterion = torch.nn.NLLLoss()
evaluator = Evaluator(name="ogbn-arxiv")

# === 训练函数 ===
def train():
    nogat_matched_ogb.train()
    optimizer.zero_grad()
    out = nogat_matched_ogb(data.x)
    loss = criterion(out[train_idx], data.y[train_idx])
    loss.backward()
    optimizer.step()
    return loss.item()

# === 测试函数 ===
@torch.no_grad()
def test():
    nogat_matched_ogb.eval()
    out = nogat_matched_ogb(data.x)
    y_pred = out.argmax(dim=-1, keepdim=True)
    train_acc = evaluator.eval({'y_true': data.y[train_idx].unsqueeze(-1), 'y_pred': y_pred[train_idx]})['acc']
    val_acc = evaluator.eval({'y_true': data.y[valid_idx].unsqueeze(-1), 'y_pred': y_pred[valid_idx]})['acc']
    test_acc = evaluator.eval({'y_true': data.y[test_idx].unsqueeze(-1), 'y_pred': y_pred[test_idx]})['acc']
    return train_acc, val_acc, test_acc

# === 主训练循环 ===
best_val_acc, best_test_acc = 0, 0
for epoch in range(1, 101):
    loss = train()
    train_acc, val_acc, test_acc = test()
    if val_acc > best_val_acc:
        best_val_acc, best_test_acc = val_acc, test_acc
    if epoch % 10 == 0:
        print(f"[A2-OGB No-GAT ParamMatched] Epoch {epoch:03d} | "
              f"Loss: {loss:.4f} | Train: {train_acc:.4f} | "
              f"Val: {val_acc:.4f} | Test: {test_acc:.4f}")

print(f"\n✅ [A2-OGB No-GAT ParamMatched] Finished. "
      f"Best Val ACC: {best_val_acc:.4f}, Test ACC: {best_test_acc:.4f}")



--- [A2-OGB] Param-Matched No-GAT (Original OGB Embedding) ---
[A2-OGB No-GAT ParamMatched] Epoch 010 | Loss: 2.9076 | Train: 0.2422 | Val: 0.2126 | Test: 0.1854
[A2-OGB No-GAT ParamMatched] Epoch 020 | Loss: 2.5290 | Train: 0.3601 | Val: 0.3794 | Test: 0.3583
[A2-OGB No-GAT ParamMatched] Epoch 030 | Loss: 2.3498 | Train: 0.4065 | Val: 0.4294 | Test: 0.4125
[A2-OGB No-GAT ParamMatched] Epoch 040 | Loss: 2.2247 | Train: 0.4430 | Val: 0.4591 | Test: 0.4448
[A2-OGB No-GAT ParamMatched] Epoch 050 | Loss: 2.1582 | Train: 0.4637 | Val: 0.4787 | Test: 0.4601
[A2-OGB No-GAT ParamMatched] Epoch 060 | Loss: 2.1153 | Train: 0.4752 | Val: 0.4893 | Test: 0.4709
[A2-OGB No-GAT ParamMatched] Epoch 070 | Loss: 2.0904 | Train: 0.4805 | Val: 0.4938 | Test: 0.4754
[A2-OGB No-GAT ParamMatched] Epoch 080 | Loss: 2.0755 | Train: 0.4851 | Val: 0.4969 | Test: 0.4785
[A2-OGB No-GAT ParamMatched] Epoch 090 | Loss: 2.0621 | Train: 0.4888 | Val: 0.5010 | Test: 0.4802
[A2-OGB No-GAT ParamMatched] Epoch 100 | Loss

In [16]:
# block A3-OGB (Ablation: OGB-only Link Prediction, no GAT)
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# === 0) 检查依赖 ===
assert 'data' in locals(), "❌ data 未定义（应包含 OGB 原始特征 data.x）"
assert 'split_data_arxiv' in locals(), "❌ 缺少 split_data_arxiv（请先运行 block 5.2 或 5.4）"

# === 1) 使用 OGB 原生节点特征（不经过 GAT） ===
node_embeddings_ogb = data.x.to(device)   # [N, 128]
embed_dim_ogb = node_embeddings_ogb.shape[1]
print(f"[A3-OGB] Using original OGB node embeddings: {tuple(node_embeddings_ogb.shape)}")

# === 2) 定义轻量链路预测器 ===
class LinkPredictorOGB(nn.Module):
    def __init__(self, embedding_dim, hidden_dim=64):
        super().__init__()
        self.fc1 = nn.Linear(embedding_dim * 2, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, 1)

    def forward(self, all_node_embeddings, edge_label_index):
        src = all_node_embeddings[edge_label_index[0]]
        dst = all_node_embeddings[edge_label_index[1]]
        h = torch.cat([src, dst], dim=1)
        h = F.relu(self.fc1(h))
        return self.fc2(h).squeeze(-1)

link_predictor_ogb = LinkPredictorOGB(embed_dim_ogb, hidden_dim=64).to(device)
optimizer_ogb = torch.optim.AdamW(link_predictor_ogb.parameters(), lr=1e-3, weight_decay=1e-4)
criterion_ogb = nn.BCEWithLogitsLoss()

# === 3) 训练函数 ===
def train_link_pred_ogb(batch_size=512):
    link_predictor_ogb.train()
    optimizer_ogb.zero_grad()

    pos_edges = split_data_arxiv.train_pos_edge_index
    neg_edges = split_data_arxiv.train_neg_edge_index
    num_pos, num_neg = pos_edges.size(1), neg_edges.size(1)

    perm_pos = torch.randperm(num_pos, device=device)
    perm_neg = torch.randperm(num_neg, device=device)

    total_loss, steps = 0.0, 0
    half = max(1, batch_size // 2)
    for start in range(0, num_pos, half):
        end = min(start + half, num_pos)
        p = pos_edges[:, perm_pos[start:end]]
        n = neg_edges[:, perm_neg[start:end]]  # 与正样本等量

        edges = torch.cat([p, n], dim=1)
        labels = torch.cat([
            torch.ones(p.size(1), device=device),
            torch.zeros(n.size(1), device=device)
        ])

        logits = link_predictor_ogb(node_embeddings_ogb, edges)
        loss = criterion_ogb(logits, labels)

        loss.backward()
        optimizer_ogb.step()
        optimizer_ogb.zero_grad(set_to_none=True)

        total_loss += loss.item() * edges.size(1)
        steps += 1

    return total_loss / (steps * batch_size)

# === 4) 测试函数 ===
@torch.no_grad()
def test_link_pred_ogb(pos_edge_index, neg_edge_index, batch_size=2048):
    link_predictor_ogb.eval()
    pos_scores, neg_scores = [], []
    P, N = pos_edge_index.size(1), neg_edge_index.size(1)

    for s in range(0, P, batch_size):
        e = pos_edge_index[:, s:min(s + batch_size, P)]
        pos_scores.append(link_predictor_ogb(node_embeddings_ogb, e).sigmoid().cpu())
    for s in range(0, N, batch_size):
        e = neg_edge_index[:, s:min(s + batch_size, N)]
        neg_scores.append(link_predictor_ogb(node_embeddings_ogb, e).sigmoid().cpu())

    pos_scores = torch.cat(pos_scores); neg_scores = torch.cat(neg_scores)
    scores = torch.cat([pos_scores, neg_scores]).numpy()
    labels = torch.cat([torch.ones_like(pos_scores), torch.zeros_like(neg_scores)]).numpy()
    return roc_auc_score(labels, scores)

# === 5) 训练主循环 ===
best_val_auc_ogb, best_state_ogb = 0.0, None
EPOCHS = 10
for epoch in range(1, EPOCHS + 1):
    loss = train_link_pred_ogb(batch_size=512)
    val_auc = test_link_pred_ogb(split_data_arxiv.val_pos_edge_index, split_data_arxiv.val_neg_edge_index)

    if val_auc > best_val_auc_ogb:
        best_val_auc_ogb = val_auc
        best_state_ogb = {k: v.detach().cpu() for k, v in link_predictor_ogb.state_dict().items()}

    if epoch % 5 == 0 or epoch == 1:
        train_auc = test_link_pred_ogb(split_data_arxiv.train_pos_edge_index, split_data_arxiv.train_neg_edge_index)
        test_auc = test_link_pred_ogb(split_data_arxiv.test_pos_edge_index, split_data_arxiv.test_neg_edge_index)
        print(f"[A3-OGB][Epoch {epoch:02d}] loss={loss:.4f} | trainAUC={train_auc:.4f} | "
              f"valAUC={val_auc:.4f} | testAUC={test_auc:.4f}")

# === 6) 加载最佳模型并输出最终结果 ===
if best_state_ogb is not None:
    link_predictor_ogb.load_state_dict({k: v.to(device) for k, v in best_state_ogb.items()})

final_test_auc_ogb = test_link_pred_ogb(
    split_data_arxiv.test_pos_edge_index,
    split_data_arxiv.test_neg_edge_index
)
print(f"\n✅ [A3-OGB Ablation] Final Test AUC (OGB-only): {final_test_auc_ogb:.4f}")


[A3-OGB] Using original OGB node embeddings: (169343, 128)
[A3-OGB][Epoch 01] loss=0.6272 | trainAUC=0.6146 | valAUC=0.6066 | testAUC=0.6084
[A3-OGB][Epoch 05] loss=0.6046 | trainAUC=0.6667 | valAUC=0.6512 | testAUC=0.6529
[A3-OGB][Epoch 10] loss=0.5800 | trainAUC=0.6782 | valAUC=0.6578 | testAUC=0.6594

✅ [A3-OGB Ablation] Final Test AUC (OGB-only): 0.6594
