In [None]:
# Install PyTorch (CUDA 지원 버전)
!pip install torch==2.4.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124

# Install PyG dependencies
!pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.4.0+cu124.html
!pip install torch-geometric


Looking in indexes: https://download.pytorch.org/whl/cu124
Collecting torch==2.4.0
  Downloading https://download.pytorch.org/whl/cu124/torch-2.4.0%2Bcu124-cp312-cp312-linux_x86_64.whl (797.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m797.2/797.2 MB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12==12.4.99 (from torch==2.4.0)
  Downloading https://download.pytorch.org/whl/cu124/nvidia_cuda_nvrtc_cu12-12.4.99-py3-none-manylinux2014_x86_64.whl (24.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.7/24.7 MB[0m [31m80.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting nvidia-cuda-runtime-cu12==12.4.99 (from torch==2.4.0)
  Downloading https://download.pytorch.org/whl/cu124/nvidia_cuda_runtime_cu12-12.4.99-py3-none-manylinux2014_x86_64.whl (883 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.4/883.4 kB[0m [31m61.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting nvidia-cuda-cupti-cu12==12.4.99 (

In [1]:
import torch
from sklearn.metrics import average_precision_score, roc_auc_score
from torch.nn import Linear
import os.path as osp

from torch_geometric.datasets import JODIEDataset
from torch_geometric.loader import TemporalDataLoader
from torch_geometric.nn import TGNMemory, TransformerConv
from torch_geometric.nn.models.tgn import IdentityMessage, LastAggregator, LastNeighborLoader

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# FIXED for Colab: the dataset path must NOT use __file__
root = '/content/data/JODIE'

dataset = JODIEDataset(root, name='wikipedia')
data = dataset[0]
data = data.to(device)


KeyboardInterrupt: 

In [None]:
data

TemporalData(src=[157474], dst=[157474], t=[157474], msg=[157474, 172], y=[157474])

In [None]:
# 데이터 크기가 너무 커서 자름.
data_split = data[0:10000]

In [None]:
data_split.num_nodes

8750

In [None]:
train_data, val_data, test_data = data.train_val_test_split(
    val_ratio=0.15, test_ratio=0.15)

train_loader = TemporalDataLoader(train_data, batch_size=200, neg_sampling_ratio=1.0)
val_loader = TemporalDataLoader(val_data, batch_size=200, neg_sampling_ratio=1.0)
test_loader = TemporalDataLoader(test_data, batch_size=200, neg_sampling_ratio=1.0)

neighbor_loader = LastNeighborLoader(data.num_nodes, size=10, device=device)

# Global index → Local index 매핑용 텐서
assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device)


In [None]:
class GraphAttentionEmbedding(torch.nn.Module):
    def __init__(self, in_channels, out_channels, msg_dim, time_enc):
        super().__init__()
        self.time_enc = time_enc
        edge_dim = msg_dim + time_enc.out_channels
        self.conv = TransformerConv(
            in_channels, out_channels // 2, heads=2,
            dropout=0.1, edge_dim=edge_dim
        )

    def forward(self, x, last_update, edge_index, t, msg):
        rel_t = last_update[edge_index[0]] - t
        rel_t_enc = self.time_enc(rel_t.to(x.dtype))
        edge_attr = torch.cat([rel_t_enc, msg], dim=-1)
        return self.conv(x, edge_index, edge_attr)


class LinkPredictor(torch.nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.lin_src = Linear(in_channels, in_channels)
        self.lin_dst = Linear(in_channels, in_channels)
        self.lin_final = Linear(in_channels, 1)

    def forward(self, z_src, z_dst):
        h = self.lin_src(z_src) + self.lin_dst(z_dst)
        h = torch.relu(h)
        return self.lin_final(h)


In [None]:
memory_dim = time_dim = embedding_dim = 100

memory = TGNMemory(
    data.num_nodes,
    data.msg.size(-1),
    memory_dim,
    time_dim,
    message_module=IdentityMessage(data.msg.size(-1), memory_dim, time_dim),
    aggregator_module=LastAggregator(),
).to(device)

gnn = GraphAttentionEmbedding(
    in_channels=memory_dim,
    out_channels=embedding_dim,
    msg_dim=data.msg.size(-1),
    time_enc=memory.time_enc,
).to(device)

link_pred = LinkPredictor(in_channels=embedding_dim).to(device)

optimizer = torch.optim.Adam(
    list(memory.parameters()) +
    list(gnn.parameters()) +
    list(link_pred.parameters()),
    lr=0.0001
)

criterion = torch.nn.BCEWithLogitsLoss()


  return disable_fn(*args, **kwargs)


In [None]:
def train():
    memory.train()
    gnn.train()
    link_pred.train()

    memory.reset_state()
    neighbor_loader.reset_state()

    total_loss = 0
    for batch in train_loader:
        optimizer.zero_grad()
        batch = batch.to(device)

        n_id, edge_index, e_id = neighbor_loader(batch.n_id)
        assoc[n_id] = torch.arange(n_id.size(0), device=device)

        z, last_update = memory(n_id)
        z = gnn(z, last_update, edge_index, data.t[e_id], data.msg[e_id])

        pos_out = link_pred(z[assoc[batch.src]], z[assoc[batch.dst]])
        neg_out = link_pred(z[assoc[batch.src]], z[assoc[batch.neg_dst]])

        loss = criterion(pos_out, torch.ones_like(pos_out))
        loss += criterion(neg_out, torch.zeros_like(neg_out))

        memory.update_state(batch.src, batch.dst, batch.t.long(), batch.msg)
        neighbor_loader.insert(batch.src, batch.dst)

        loss.backward()
        optimizer.step()
        memory.detach()

        total_loss += float(loss) * batch.num_events

    return total_loss / train_data.num_events


In [None]:
for batch in train_loader:
  print(batch.src.long())
  break

tensor([ 0,  1,  1,  2,  1,  2,  3,  1,  4,  4,  5,  1,  1,  2,  6,  7,  3,  8,
         1,  2,  2,  9,  1,  6,  4,  1,  3,  2,  1,  2, 10, 11,  3,  3, 12,  1,
         4, 13, 14, 12,  1, 14, 14, 15, 12, 16, 17,  4,  4,  3, 18, 11, 19, 20,
        21, 22,  3, 17, 11,  4, 23,  4, 24, 12, 25,  3,  4,  2, 26,  4,  4, 27,
        11, 12, 28, 18, 29, 11,  2,  8,  4,  4, 30, 31, 28, 28, 16, 32,  4,  3,
        33,  8,  4, 16, 24,  4,  3,  1,  3, 34, 16, 35,  3, 36,  3,  4, 16,  3,
        27, 16, 12, 24, 28,  4, 16,  1, 14, 31, 34,  3,  4, 19, 31, 27, 31, 14,
        37,  1, 14,  4, 37,  4, 37,  3,  4, 38,  3, 14, 16, 16,  4, 14,  8, 12,
         1, 16,  3, 14, 14, 16,  3, 39,  3,  1, 40, 16, 41, 42, 43,  1, 28, 16,
        44, 12, 42, 16, 34,  3, 45, 16,  3,  3, 16, 16,  3,  6,  3, 46, 16,  3,
        16, 16,  3,  1,  6,  3,  3,  1, 12,  3,  1,  3, 47, 48,  3, 48, 16,  3,
        49, 50], device='cuda:0')


In [None]:
@torch.no_grad()
def test(loader):
    memory.eval()
    gnn.eval()
    link_pred.eval()

    torch.manual_seed(12345)

    aps, aucs = [], []
    for batch in loader:
        batch = batch.to(device)

        n_id, edge_index, e_id = neighbor_loader(batch.n_id)
        assoc[n_id] = torch.arange(n_id.size(0), device=device)

        z, last_update = memory(n_id)
        z = gnn(z, last_update, edge_index, data.t[e_id], data.msg[e_id])

        pos_out = link_pred(z[assoc[batch.src]], z[assoc[batch.dst]])
        neg_out = link_pred(z[assoc[batch.src]], z[assoc[batch.neg_dst]])

        y_pred = torch.cat([pos_out, neg_out], dim=0).sigmoid().cpu()
        y_true = torch.cat([
            torch.ones(pos_out.size(0)),
            torch.zeros(neg_out.size(0))
        ], dim=0)

        aps.append(average_precision_score(y_true, y_pred))
        aucs.append(roc_auc_score(y_true, y_pred))

        memory.update_state(batch.src, batch.dst, batch.t, batch.msg)
        neighbor_loader.insert(batch.src, batch.dst)

    return float(torch.tensor(aps).mean()), float(torch.tensor(aucs).mean())


In [None]:
for epoch in range(1, 5):
    loss = train()
    print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}')

    val_ap, val_auc = test(val_loader)
    test_ap, test_auc = test(test_loader)

    print(f'Val AP:  {val_ap:.4f}, Val AUC:  {val_auc:.4f}')
    print(f'Test AP: {test_ap:.4f}, Test AUC: {test_auc:.4f}')


Epoch: 01, Loss: 0.6086
Val AP:  0.9557, Val AUC:  0.9521
Test AP: 0.9502, Test AUC: 0.9446
Epoch: 02, Loss: 0.5951
Val AP:  0.9557, Val AUC:  0.9522
Test AP: 0.9494, Test AUC: 0.9457
Epoch: 03, Loss: 0.5804
Val AP:  0.9559, Val AUC:  0.9534
Test AP: 0.9478, Test AUC: 0.9459
Epoch: 04, Loss: 0.5672
Val AP:  0.9600, Val AUC:  0.9577
Test AP: 0.9542, Test AUC: 0.9517
