In [None]:
import torch, torch.nn as nn
from torch_geometric.nn.models.tgn import TGNMemory, IdentityMessage, LastAggregator

In [None]:
!pip install -q torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
!pip install -q torch-geometric torch-scatter torch-sparse torch-cluster torch-spline-conv -f https://data.pyg.org/whl/torch-2.4.0+cu121.html

import torch
print("Torch:", torch.__version__, "| CUDA:", torch.version.cuda, "| GPU:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")


Torch: 2.8.0+cu126 | CUDA: 12.6 | GPU: CPU


In [None]:
# --- PATCH 1: Build TemporalData with label propagation + int64 timestamps ---

import numpy as np
import pandas as pd
import torch
from torch_geometric.data import TemporalData

def load_temporal_from_csvs_LABELPROP(data_dir: str = "."):
    feats = pd.read_csv("/content/features.csv")
    labs  = pd.read_csv(f"/content/labels.csv")
    ed    = pd.read_csv(f"/content/edges.csv")

    # Map node_id -> 0..N-1
    node_ids = np.sort(feats["node_id"].unique())
    nid_map = {nid: i for i, nid in enumerate(node_ids)}
    num_nodes = len(node_ids)

    feats["nid"]   = feats["node_id"].map(nid_map)
    ed["src_nid"]  = ed["src"].map(nid_map)
    ed["dst_nid"]  = ed["dst"].map(nid_map)
    labs["nid"]    = labs["node_id"].map(nid_map)

    # --- Edges (timestamps as int64/Long) ---
    src = torch.tensor(ed["src_nid"].to_numpy(), dtype=torch.long)
    dst = torch.tensor(ed["dst_nid"].to_numpy(), dtype=torch.long)
    t_e = torch.tensor(ed["timestamp"].to_numpy(), dtype=torch.long)

    # Node features per (node, time)
    ignore = {"node_id","nid","timestamp"}
    feat_cols = [c for c in feats.columns if c not in ignore and np.issubdtype(feats[c].dtype, np.number)]

    x_list, n_list, t_list = [], [], []
    for t_val, frame in feats.groupby("timestamp"):
        frame = frame.sort_values("nid")
        x_list.append(torch.tensor(frame[feat_cols].to_numpy(np.float32), dtype=torch.float))
        n_list.append(torch.tensor(frame["nid"].to_numpy(np.int64), dtype=torch.long))
        t_list.append(torch.full((len(frame),), int(t_val), dtype=torch.long))  # int64 node times

    x    = torch.cat(x_list, dim=0)
    n_id = torch.cat(n_list, dim=0)
    t_n  = torch.cat(t_list, dim=0)

    data = TemporalData(src=src, dst=dst, t=t_e, x=x, n_id=n_id, t_n=t_n)

    # -------- Label propagation: use each node's final label for ALL its timestamps --------
    labs_sorted = labs.sort_values(["nid", "timestamp"])
    labs_last = labs_sorted.groupby("nid").tail(1)  # final label per node
    node_label = {int(r.nid): int(r.label) for r in labs_last.itertuples(index=False)}

    # Assign label to every (nid, t) occurrence
    y_n = torch.tensor([node_label.get(int(i), -1) for i in n_id.tolist()], dtype=torch.long)
    data.y_n = y_n

    in_channels = x.size(1)
    return data, num_nodes, in_channels


In [None]:
# --- PATCH 2: Reload using label propagation, split by EDGE times, and train (with memory.detach) ---

import numpy as np
import torch
import torch.nn as nn
from torch_geometric.loader import TemporalDataLoader

# Assumes:
# - load_temporal_from_csvs_LABELPROP(.) from Patch 1 is defined
# - TGNClassifier (version-robust) is defined

device = "cuda" if torch.cuda.is_available() else "cpu"

# 1) Reload data (labels exist at ALL timestamps thanks to Patch 1)
data, num_nodes, in_ch = load_temporal_from_csvs_LABELPROP(".")
loader = TemporalDataLoader(data, batch_size=4096, shuffle=True)

# 2) Build edge-time splits with guarantees for small T
edge_times = torch.unique(data.t).cpu().tolist()
edge_times = sorted(int(t) for t in edge_times)
n = len(edge_times)

if n == 1:
    times_train = [edge_times[0]]
    times_val   = [edge_times[0]]
    times_test  = [edge_times[0]]
elif n == 2:
    times_train = [edge_times[0]]
    times_val   = [edge_times[1]]
    times_test  = [edge_times[1]]
elif n == 3:
    times_train = [edge_times[0]]
    times_val   = [edge_times[1]]
    times_test  = [edge_times[2]]
else:
    n_train = max(1, int(round(0.6 * n)))
    n_val   = max(1, int(round(0.2 * n)))
    if n_train + n_val >= n:
        n_val = max(1, n - n_train - 1)
    n_test = n - n_train - n_val
    if n_test < 1:
        steal = 1 - n_test
        n_train = max(1, n_train - steal)
        n_test = 1
    times_train = edge_times[:n_train]
    times_val   = edge_times[n_train:n_train+n_val]
    times_test  = edge_times[n_train+n_val:]

times_train_set = set(times_train)
times_val_set   = set(times_val)
times_test_set  = set(times_test)

def _mask_for_phase(batch_t: torch.Tensor, phase: str) -> torch.Tensor:
    """Return a boolean mask selecting events in this batch_t that belong to the given phase."""
    if phase == "train":
        pool = times_train_set
    elif phase == "val":
        pool = times_val_set
    else:
        pool = times_test_set
    pool_t = torch.tensor(list(pool), dtype=batch_t.dtype, device=batch_t.device)
    # Prefer torch.isin, fallback to OR of equalities
    try:
        return torch.isin(batch_t, pool_t)
    except AttributeError:
        if pool_t.numel() == 0:
            return torch.zeros_like(batch_t, dtype=torch.bool)
        m = (batch_t == pool_t[0])
        for k in range(1, pool_t.numel()):
            m = m | (batch_t == pool_t[k])
        return m

# 3) Model / optimizer / loss
model = TGNClassifier(num_nodes=num_nodes, in_channels=in_ch).to(device)
opt   = torch.optim.Adam(model.parameters(), lr=1e-3)
crit  = nn.CrossEntropyLoss()

def run_epoch(train: bool = True, phase: str = "train"):
    """
    phase âˆˆ {"train","val","test"}.
    Uses per-event masks from edge TIMES to pick which events contribute to loss/metrics.
    Detaches TGN memory after each batch to avoid backprop-through-history errors.
    """
    (model.train() if train else model.eval())
    model.reset_memory()

    tot_loss = 0.0
    total = 0
    correct = 0

    for batch in loader:
        batch = batch.to(device)

        event_mask = _mask_for_phase(batch.t, phase)
        if not bool(event_mask.any()):
            try: model.memory.detach()
            except AttributeError: model.reset_memory()
            continue

        # Messages: zeros (alignment-safe). Replace with per-event features if desired:
        # msg = data.x[batch.src].to(device)
        msg = torch.zeros((batch.src.size(0), in_ch), dtype=torch.float, device=device)

        logits = model(batch.src, batch.dst, batch.t, msg)

        # Labels have been propagated to all timestamps; just apply event mask
        y_full = data.y_n[batch.dst].to(device)
        y_sel  = torch.where(event_mask, y_full, torch.full_like(y_full, -1))
        sel    = (y_sel >= 0)
        if not bool(sel.any()):
            try: model.memory.detach()
            except AttributeError: model.reset_memory()
            continue

        if train and phase == "train":
            opt.zero_grad()
            loss = crit(logits[sel], y_sel[sel])
            loss.backward()
            opt.step()
            tot_loss += float(loss.item())
        else:
            with torch.no_grad():
                loss = crit(logits[sel], y_sel[sel])
                tot_loss += float(loss.item())

        # ðŸ”‘ Critical for TGN: cut the graph between batches
        try: model.memory.detach()
        except AttributeError: model.reset_memory()

        pred = logits.argmax(dim=-1)
        correct += int((pred[sel] == y_sel[sel]).sum().item())
        total   += int(sel.sum().item())

    acc = (correct / total) if total else 0.0
    return tot_loss, acc

# 4) Train / validate / test
for ep in range(1, 11):
    tr_loss, tr_acc = run_epoch(train=True,  phase="train")
    _,      val_acc = run_epoch(train=False, phase="val")
    print(f"Epoch {ep:02d} | train_loss={tr_loss:.3f} | train_acc={tr_acc:.3f} | val_acc={val_acc:.3f}")

_, test_acc = run_epoch(train=False, phase="test")
print("Test acc =", round(test_acc, 3))
print("Times used -> train:", times_train, "| val:", times_val, "| test:", times_test)


Epoch 01 | train_loss=5.259 | train_acc=0.859 | val_acc=0.857
Epoch 02 | train_loss=4.281 | train_acc=0.859 | val_acc=0.857
Epoch 03 | train_loss=4.234 | train_acc=0.859 | val_acc=0.857
Epoch 04 | train_loss=4.205 | train_acc=0.859 | val_acc=0.857
Epoch 05 | train_loss=4.152 | train_acc=0.859 | val_acc=0.857
Epoch 06 | train_loss=4.166 | train_acc=0.859 | val_acc=0.857
Epoch 07 | train_loss=4.140 | train_acc=0.859 | val_acc=0.857
Epoch 08 | train_loss=4.151 | train_acc=0.859 | val_acc=0.857
Epoch 09 | train_loss=4.123 | train_acc=0.859 | val_acc=0.857
