# TGN with MA features, cross entropy loss, on the subgraph

The code was adapted from https://github.com/shenyangHuang/TGB

In [None]:
!pip install pandas -q
!pip install torch-geometric -f https://data.pyg.org/whl/torch-2.8.0+cu126.html
!pip install py-tgb -q
!pip install modules


Looking in links: https://data.pyg.org/whl/torch-2.8.0+cu126.html
Collecting torch-geometric
  Downloading torch_geometric-2.7.0-py3-none-any.whl.metadata (63 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.7/63.7 kB[0m [31m1.4 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.7.0-py3-none-any.whl (1.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m38.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch-geometric
Successfully installed torch-geometric-2.7.0
  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m91.2/91.2 kB[0m [31m4.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m154.4/154.4 kB[0m [31m9.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.4/12.4 MB[0m [31m95.4 M

In [None]:
import torch
from torch.nn import Linear
import torch.nn.functional as F
from torch.nn.parameter import Parameter
import math
from torch_geometric.nn import TransformerConv
import numpy as np

class NodePredictor(torch.nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.lin_node = Linear(in_dim, in_dim)
        self.out = Linear(in_dim, out_dim)

    def forward(self, node_embed):
        h = self.lin_node(node_embed)
        h = h.relu()
        h = self.out(h)
        # h = F.log_softmax(h, dim=-1)
        return h

class GraphAttentionEmbedding(torch.nn.Module):
    """
    Reference:
    - https://github.com/pyg-team/pytorch_geometric/blob/master/examples/tgn.py
    """

    def __init__(self, in_channels, out_channels, msg_dim, time_enc):
        super().__init__()
        self.time_enc = time_enc
        edge_dim = msg_dim + time_enc.out_channels
        self.conv = TransformerConv(
            in_channels, out_channels // 2, heads=2, dropout=0.1, edge_dim=edge_dim
        )

    def forward(self, x, last_update, edge_index, t, msg):
        rel_t = last_update[edge_index[0]] - t
        rel_t_enc = self.time_enc(rel_t.to(x.dtype))
        edge_attr = torch.cat([rel_t_enc, msg], dim=-1)
        return self.conv(x, edge_index, edge_attr)


class MAFeatures:
    def __init__(self, num_class, window=7):
        self.num_class = num_class
        self.window = window
        self.dict = {}

    def reset(self):
        self.dict = {}

    def update_dict(self, node_id, label_vec):
        if node_id in self.dict:
            total = self.dict[node_id] * (self.window - 1) + label_vec
            self.dict[node_id] = total / self.window
        else:
            self.dict[node_id] = label_vec

    def query_dict(self, node_id):
        if node_id in self.dict:
            return self.dict[node_id]
        else:
            return np.zeros(self.num_class, dtype=np.float32)

    def batch_query(self, node_ids):
        feats = [self.query_dict(int(n)) for n in node_ids]
        return np.stack(feats, axis=0).astype(np.float32)


'def to_one_hot_if_needed(labels_tensor, num_classes):\n    if labels_tensor.dim() == 1:\n        return F.one_hot(labels_tensor.long(), num_classes=num_classes).float()\n    return labels_tensor.float()'

In [None]:
# Choose hyperparameters
lr = 0.0001
batch_size = 200
global_hidden_dim = 100
nb_neighbors = 10
window_ma = 7

epochs = 20

In [None]:
from tqdm import tqdm
import torch
import timeit
import matplotlib.pyplot as plt

from torch_geometric.loader import TemporalDataLoader
from torch_geometric.nn import TGNMemory
from torch_geometric.nn.models.tgn import (
    IdentityMessage,
    LastAggregator,
    LastNeighborLoader,
)

from tgb.nodeproppred.dataset_pyg import PyGNodePropPredDataset
from tgb.nodeproppred.evaluate import Evaluator
from tgb.utils.utils import set_random_seed

seed = 1
print ("setting random seed to be", seed)
torch.manual_seed(seed)
set_random_seed(seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
name = "tgbn-genre"
dataset = PyGNodePropPredDataset(name=name, root="datasets")
train_mask = dataset.train_mask.to(device)
val_mask = dataset.val_mask.to(device)
test_mask = dataset.test_mask.to(device)

eval_metric = dataset.eval_metric
num_classes = dataset.num_classes
data = dataset.get_TemporalData()
data = data.to(device)

# Keep only 10% randomly selected unique source nodes (users)
unique_src_nodes = torch.unique(data.src)
num_src_nodes_to_select = int(len(unique_src_nodes) * 0.1)
selected_src_nodes = unique_src_nodes[torch.randperm(len(unique_src_nodes))[:num_src_nodes_to_select]].to(device)

# Filter data by timestamp
timestamp_threshold = 1135778006
time_mask = data.t > timestamp_threshold


# Combine masks
combined_mask = torch.logical_and(torch.isin(data.src, selected_src_nodes), time_mask).to(device)


train_data = data[train_mask & combined_mask].to(device)
val_data = data[val_mask & combined_mask].to(device)
test_data = data[test_mask & combined_mask].to(device)

evaluator = Evaluator(name=name)
ma_tracker = MAFeatures(num_class=num_classes, window=window_ma)

# Ensure to only sample actual destination nodes as negatives.
min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())


train_loader = TemporalDataLoader(train_data, batch_size=batch_size)
val_loader = TemporalDataLoader(val_data, batch_size=batch_size)
test_loader = TemporalDataLoader(test_data, batch_size=batch_size)

neighbor_loader = LastNeighborLoader(data.num_nodes, size=nb_neighbors, device=device)

memory_dim = time_dim = embedding_dim = global_hidden_dim

memory = TGNMemory(
    data.num_nodes,
    data.msg.size(-1),
    memory_dim,
    time_dim,
    message_module=IdentityMessage(data.msg.size(-1), memory_dim, time_dim),
    aggregator_module=LastAggregator(),
).to(device)

gnn = (
    GraphAttentionEmbedding(
        in_channels=memory_dim,
        out_channels=embedding_dim,
        msg_dim=data.msg.size(-1),
        time_enc=memory.time_enc,
    )
    .to(device)
    .float()
)

node_pred = NodePredictor(in_dim=embedding_dim + num_classes, out_dim=num_classes).to(device)

optimizer = torch.optim.Adam(
    set(memory.parameters()) | set(gnn.parameters()) | set(node_pred.parameters()),
    lr=lr,
)

criterion = torch.nn.CrossEntropyLoss()
# Helper vector to map global node indices to local ones.
assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device)


def plot_curve(scores, out_name):
    plt.plot(scores, color="#e34a33")
    plt.ylabel("score")
    plt.savefig(out_name + ".pdf")
    plt.close()


def process_edges(src, dst, t, msg):
    if src.nelement() > 0:
        # msg = msg.to(torch.float32)
        memory.update_state(src, dst, t, msg)
        neighbor_loader.insert(src, dst)


def train():
    memory.train()
    gnn.train()
    node_pred.train()

    memory.reset_state()  # Start with a fresh memory.
    neighbor_loader.reset_state()  # Start with an empty graph.
    ma_tracker.reset()

    total_loss = 0
    label_t = dataset.get_label_time()  # check when does the first label start
    total_score = 0
    num_label_ts = 0

    for batch in tqdm(train_loader):
        batch = batch.to(device)
        optimizer.zero_grad()
        src, dst, t, msg = batch.src, batch.dst, batch.t, batch.msg

        query_t = batch.t[-1]
        # check if this batch moves to the next day
        if query_t > label_t:
            # find the node labels from the past day
            label_tuple = dataset.get_node_label(query_t)
            label_ts, label_srcs, labels = (
                label_tuple[0],
                label_tuple[1],
                label_tuple[2],
            )
            label_t = dataset.get_label_time()
            label_srcs = label_srcs.to(device)

            # Process all edges that are still in the past day
            previous_day_mask = batch.t < label_t
            process_edges(
                src[previous_day_mask],
                dst[previous_day_mask],
                t[previous_day_mask],
                msg[previous_day_mask],
            )
            # Reset edges to be the edges from tomorrow so they can be used later
            src, dst, t, msg = (
                src[~previous_day_mask],
                dst[~previous_day_mask],
                t[~previous_day_mask],
                msg[~previous_day_mask],
            )

            """
            modified for node property prediction
            1. sample neighbors from the neighbor loader for all nodes to be predicted
            2. extract memory from the sampled neighbors and the nodes
            3. run gnn with the extracted memory embeddings and the corresponding time and message
            """
            n_id = label_srcs
            n_id_neighbors, mem_edge_index, e_id = neighbor_loader(n_id)
            assoc[n_id_neighbors] = torch.arange(n_id_neighbors.size(0), device=device)

            z, last_update = memory(n_id_neighbors)

            z = gnn(
                z,
                last_update,
                mem_edge_index,
                data.t[e_id].to(device),
                data.msg[e_id].to(device),
            )
            z = z[assoc[n_id]]

            # moving-average features (computed BEFORE updating with current labels)
            with torch.no_grad():
                ma_feats_np = ma_tracker.batch_query(label_srcs.detach().cpu().numpy())
            ma_feats = torch.from_numpy(ma_feats_np).to(device)
            z = torch.cat([z, ma_feats], dim=1)

            # loss and metric computation
            pred = node_pred(z)

            loss = criterion(pred, labels.to(device))
            np_pred = pred.cpu().detach().numpy()
            np_true = labels.cpu().detach().numpy()

            input_dict = {
                "y_true": np_true,
                "y_pred": np_pred,
                "eval_metric": [eval_metric],
            }
            result_dict = evaluator.eval(input_dict)
            score = result_dict[eval_metric]
            total_score += score
            num_label_ts += 1

            loss.backward()
            optimizer.step()
            total_loss += float(loss.detach())

            # AFTER using current labels, update moving-average tracker
            with torch.no_grad():
                lbl_vec = labels.to(device)
                for nid, vec in zip(label_srcs, lbl_vec):
                    ma_tracker.update_dict(int(nid.item()), vec.detach().cpu().numpy())

        # Update memory and neighbor loader with ground-truth state.
        process_edges(src, dst, t, msg)
        memory.detach()

    metric_dict = {
        "ce": total_loss / num_label_ts,
    }
    metric_dict[eval_metric] = total_score / num_label_ts
    return metric_dict


@torch.no_grad()
def test(loader):
    memory.eval()
    gnn.eval()
    node_pred.eval()

    label_t = dataset.get_label_time()  # check when does the first label start
    num_label_ts = 0
    total_score = 0

    for batch in tqdm(loader):
        batch = batch.to(device)
        src, dst, t, msg = batch.src, batch.dst, batch.t, batch.msg

        query_t = batch.t[-1]
        if query_t > label_t:
            label_tuple = dataset.get_node_label(query_t)
            if label_tuple is None:
                break
            label_ts, label_srcs, labels = (
                label_tuple[0],
                label_tuple[1],
                label_tuple[2],
            )
            label_t = dataset.get_label_time()
            label_srcs = label_srcs.to(device)

            # Process all edges that are still in the past day
            previous_day_mask = batch.t < label_t
            process_edges(
                src[previous_day_mask],
                dst[previous_day_mask],
                t[previous_day_mask],
                msg[previous_day_mask],
            )
            # Reset edges to be the edges from tomorrow so they can be used later
            src, dst, t, msg = (
                src[~previous_day_mask],
                dst[~previous_day_mask],
                t[~previous_day_mask],
                msg[~previous_day_mask],
            )

            """
            modified for node property prediction
            1. sample neighbors from the neighbor loader for all nodes to be predicted
            2. extract memory from the sampled neighbors and the nodes
            3. run gnn with the extracted memory embeddings and the corresponding time and message
            """
            n_id = label_srcs
            n_id_neighbors, mem_edge_index, e_id = neighbor_loader(n_id)
            assoc[n_id_neighbors] = torch.arange(n_id_neighbors.size(0), device=device)

            z, last_update = memory(n_id_neighbors)
            z = gnn(
                z,
                last_update,
                mem_edge_index,
                data.t[e_id].to(device),
                data.msg[e_id].to(device),
            )
            z = z[assoc[n_id]]

            with torch.no_grad():
                ma_feats_np = ma_tracker.batch_query(label_srcs.detach().cpu().numpy())
            ma_feats = torch.from_numpy(ma_feats_np).to(device)
            z = torch.cat([z, ma_feats], dim=1)

            # loss and metric computation
            pred = node_pred(z)
            np_pred = pred.cpu().detach().numpy()
            np_true = labels.cpu().detach().numpy()

            input_dict = {
                "y_true": np_true,
                "y_pred": np_pred,
                "eval_metric": [eval_metric],
            }
            result_dict = evaluator.eval(input_dict)
            score = result_dict[eval_metric]
            total_score += score
            num_label_ts += 1

            # Update MA tracker with current labels for subsequent timesteps
            with torch.no_grad():
                lbl_vec = labels.to(device)
                for nid, vec in zip(label_srcs, lbl_vec):
                    ma_tracker.update_dict(int(nid.item()), vec.detach().cpu().numpy())

        process_edges(src, dst, t, msg)

    metric_dict = {}
    metric_dict[eval_metric] = total_score / num_label_ts
    return metric_dict


train_curve = []
val_curve = []
test_curve = []
max_val_score = 0  #find the best test score based on validation score
best_test_idx = 0
for epoch in range(1, epochs + 1):
    start_time = timeit.default_timer()
    train_dict = train()
    print("------------------------------------")
    print(f"training Epoch: {epoch:02d}")
    print(train_dict)
    train_curve.append(train_dict[eval_metric])
    print("Training takes--- %s seconds ---" % (timeit.default_timer() - start_time))

    start_time = timeit.default_timer()
    val_dict = test(val_loader)
    print(val_dict)
    val_curve.append(val_dict[eval_metric])
    if (val_dict[eval_metric] > max_val_score):
        max_val_score = val_dict[eval_metric]
        best_test_idx = epoch - 1
    print("Validation takes--- %s seconds ---" % (timeit.default_timer() - start_time))

    start_time = timeit.default_timer()
    test_dict = test(test_loader)
    print(test_dict)
    test_curve.append(test_dict[eval_metric])
    print("Test takes--- %s seconds ---" % (timeit.default_timer() - start_time))
    print("------------------------------------")
    dataset.reset_label_time()


# # code for plotting
# plot_curve(train_curve, "train_curve")
# plot_curve(val_curve, "val_curve")
# plot_curve(test_curve, "test_curve")

max_test_score = test_curve[best_test_idx]
print("------------------------------------")
print("------------------------------------")
print ("best val score: ", max_val_score)
print ("best validation epoch   : ", best_test_idx + 1)
print ("best test score: ", max_test_score)

setting random seed to be 1


17858396it [01:09, 257708.69it/s]
2741936it [00:06, 411634.62it/s]
100%|██████████| 6205/6205 [01:38<00:00, 62.77it/s]


------------------------------------
training Epoch: 01
{'ce': 4.609626301574707, 'ndcg': np.float64(0.3683632440557754)}
Training takes--- 98.86132255900003 seconds ---


100%|██████████| 1311/1311 [00:17<00:00, 76.27it/s]


{'ndcg': np.float64(0.380812865408687)}
Validation takes--- 17.208037972 seconds ---


 99%|█████████▉| 1434/1446 [00:16<00:00, 86.75it/s]


{'ndcg': np.float64(0.3750938587548371)}
Test takes--- 16.534290844000054 seconds ---
------------------------------------


100%|██████████| 6205/6205 [01:36<00:00, 64.42it/s]


------------------------------------
training Epoch: 02
{'ce': 4.29920734577179, 'ndcg': np.float64(0.42863096199170103)}
Training takes--- 96.33452035200003 seconds ---


100%|██████████| 1311/1311 [00:17<00:00, 72.96it/s]


{'ndcg': np.float64(0.4350711261908297)}
Validation takes--- 17.992023750000044 seconds ---


 99%|█████████▉| 1434/1446 [00:16<00:00, 86.26it/s]


{'ndcg': np.float64(0.438029887201799)}
Test takes--- 16.629847084000062 seconds ---
------------------------------------


  6%|▌         | 375/6205 [00:10<02:49, 34.33it/s]


KeyboardInterrupt: 