# Initialization through shallow embeddings

This code is partially inspired from the paper De Nadai, M., et al, Personalized Audiobook Recommendations at Spotify Through Graph Neural Networks.
  WWW '24: Companion Proceedings of the ACM Web Conference 2024, ACM (2024), pp. 403--412, 2024.

In [None]:
# --- INSTALLATION AND IMPORTS ---
!pip install pandas -q
!pip install torch-geometric -f https://data.pyg.org/whl/torch-2.8.0+cu126.html
!pip install py-tgb -q
!pip install modules

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Linear
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
from torch_geometric.loader import TemporalDataLoader
from torch_geometric.nn.models.tgn import LastNeighborLoader
from tgb.nodeproppred.dataset_pyg import PyGNodePropPredDataset
from tgb.nodeproppred.evaluate import Evaluator
from tgb.utils.utils import set_random_seed
import numpy as np
import scipy.sparse as sp
from tqdm import tqdm
import timeit
import os
import pandas as pd
import matplotlib.pyplot as plt

# --- CONFIGURATION CONSTANTS ---
SEED = 1
EMBEDDING_DIM = 100
NUM_GENRES = 264
NB_NEIGHBORS = 10
LR_MAIN = 1e-3
LR_PRETRAIN = 0.01
EPOCHS_MAIN = 10
EPOCHS_PRETRAIN = 100
CO_OCCURRENCE_THRESHOLD = 0.3
FINETUNE_EMBEDDINGS = True # Set to False to freeze embeddings
PRETRAINED_EMB_PATH = "pretrained_genre_embeddings.pt" #the path to the created embeddings
# --- END CONFIGURATION ---

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(SEED)
set_random_seed(SEED)
print("Setting random seed to", SEED)
print("Device:", DEVICE)

# --- UTILITY CLASSES ---

class NodePredictor(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.lin_node = Linear(in_dim, in_dim)
        self.out = Linear(in_dim, out_dim)

    def forward(self, node_embed):
        h = self.lin_node(node_embed)
        h = F.relu(h)
        h = self.out(h)
        return h

class StaticEmbeddingMemory(nn.Module):
    """
    Core Embedding layer for the main model.
    It handles loading pretrained weights or initializing randomly.
    """
    def __init__(self, num_nodes, emb_dim, pretrained_path=None, freeze=False):
        super().__init__()

        # Default random initialization (Xavier uniform)
        self.emb = nn.Embedding(num_nodes, emb_dim)
        nn.init.xavier_uniform_(self.emb.weight)

        # Load and potentially freeze pretrained weights
        if pretrained_path and os.path.exists(pretrained_path):
            pretrained_weights = torch.load(pretrained_path).to(DEVICE)
            if pretrained_weights.shape == (num_nodes, emb_dim):
                print(f"Loaded pretrained embeddings from {pretrained_path}.")
                self.emb = nn.Embedding.from_pretrained(pretrained_weights, freeze=freeze)
            else:
                print(f"Warning: Pretrained weights shape {pretrained_weights.shape} mismatch. Initializing randomly.")
        else:
            print(f"No valid pretrained embeddings found. Initializing randomly.")

    def forward(self, n_id):
        return self.emb(n_id)

    def update_state(self, *args, **kwargs):
        return
    def reset_state(self):
        return
    def detach(self):
        return

class SimpleGCN(nn.Module):
    """2-layer GCN encoder for the main model."""
    def __init__(self, in_channels, hidden_channels, out_channels, dropout=0.1):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)
        self.dropout = dropout

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv2(x, edge_index)
        return x

def reset_all_label_pointers(ds):
    for name in ["_label_time_idx", "label_time_idx", "_current_label_idx"]:
        if hasattr(ds, name):
            try:
                setattr(ds, name, 0)
            except Exception:
                pass
    for split in ["train", "valid", "val", "test"]:
        for name in [f"_{split}_label_time_idx", f"{split}_label_time_idx"]:
            if hasattr(ds, name):
                try:
                    setattr(ds, name, 0)
                except Exception:
                    pass

def rebuild_dataset_and_loaders(dataset, data, train_mask, val_mask, test_mask, batch_size, device):
    _ds = PyGNodePropPredDataset(name=dataset.name, root="datasets")
    _data = _ds.get_TemporalData().to(device)

    _train_data = _data[train_mask].to(device)
    _val_data   = _data[val_mask].to(device)
    _test_data  = _data[test_mask].to(device)

    _train_loader = TemporalDataLoader(_train_data, batch_size=batch_size, shuffle=False)
    _val_loader   = TemporalDataLoader(_val_data,   batch_size=batch_size, shuffle=False)
    _test_loader  = TemporalDataLoader(_test_data,  batch_size=batch_size, shuffle=False)

    return _ds, _data, _train_data, _val_data, _test_data, _train_loader, _val_loader, _test_loader

Looking in links: https://data.pyg.org/whl/torch-2.8.0+cu126.html
Setting random seed to 1
Device: cuda


In [None]:
# 1. GRAPH BUILDING (STREAMING LABELS)
print("\n--- 1. Building Static Genre Co-occurrence Graph via Streaming ---")

if 'dataset' not in locals():
    name = "tgbn-genre"
    dataset = PyGNodePropPredDataset(name=name, root="datasets")

if 'full_data' not in locals():
    full_data = dataset.get_TemporalData()

batch_size = 10

# Data Loaders
train_loader_full = TemporalDataLoader(train_data, batch_size=batch_size, shuffle=False)
val_loader_full = TemporalDataLoader(val_data, batch_size=batch_size, shuffle=False)
test_loader_full = TemporalDataLoader(test_data, batch_size=batch_size, shuffle=False)

def stream_and_collect_labels(loader, split_name):
    """Streams data using TGB's causal mechanism and collects all resulting labels."""
    print(f"Streaming {split_name} split...")

    # reset internal label pointer
    dataset.reset_label_time()
    label_t = dataset.get_label_time()

    all_labels_list = []

    for batch in tqdm(loader):
        query_t = batch.t[-1]

        if query_t > label_t:
            label_tuple = dataset.get_node_label(query_t)

            if label_tuple is None:
                continue

            label_ts, label_srcs, labels = label_tuple
            label_t = dataset.get_label_time() # Get next label timestamp

            all_labels_list.append(labels.cpu())

    # Return concatenated 2D tensor (N_labeled_events, N_genres)
    return torch.cat(all_labels_list, dim=0) if all_labels_list else torch.empty((0, NUM_GENRES))

# Accumulate labels from all splits
all_labels_list = []
all_labels_list.append(stream_and_collect_labels(train_loader_full, "Train"))
all_labels_list.append(stream_and_collect_labels(val_loader_full, "Validation"))
all_labels_list.append(stream_and_collect_labels(test_loader_full, "Test"))

all_labels = torch.cat(all_labels_list, dim=0)

if all_labels.shape[1] > NUM_GENRES:
     print(f"WARNING: Slicing label matrix from {all_labels.shape[1]} columns down to {NUM_GENRES} columns.")
     all_labels = all_labels[:, :NUM_GENRES]

if all_labels.shape[1] != NUM_GENRES:
    raise RuntimeError(f"Label matrix must have {NUM_GENRES} columns. Found {all_labels.shape[1]}. Cannot proceed.")

num_genres = all_labels.shape[1]
print(f"Total Labeled Events extracted via streaming: {all_labels.size(0)}")


# Filter out empty label vectors, if any, if needed
valid_labels = all_labels[torch.sum(all_labels, dim=1) > 0]
if valid_labels.size(0) == 0:
    raise RuntimeError("The dataset appears to contain no valid labeled events (data.y is all zeros or empty). Cannot build co-occurrence graph.")

# Build the Co-occurrence Matrix to filter then
label_tensor = valid_labels.float()
co_occurrence_matrix = torch.matmul(label_tensor.T, label_tensor).cpu().numpy()

# Build Adjacency Matrix using the threshold
N = np.diag(co_occurrence_matrix)
epsilon = 1e-8
N_i = N.reshape(-1, 1) + epsilon
N_j = N.reshape(1, -1) + epsilon
ratio_i = co_occurrence_matrix / N_i
ratio_j = co_occurrence_matrix / N_j
adjacency_matrix = np.logical_or(ratio_i >= CO_OCCURRENCE_THRESHOLD, ratio_j >= CO_OCCURRENCE_THRESHOLD).astype(int)
np.fill_diagonal(adjacency_matrix, 0)

# Convert to PyTorch Geometric Graph Object
coo_adj = sp.coo_matrix(adjacency_matrix)
edge_index = torch.tensor(np.vstack([coo_adj.row, coo_adj.col]), dtype=torch.long)
initial_x = nn.Embedding(NUM_GENRES, EMBEDDING_DIM).weight.data
genre_graph = Data(x=initial_x, edge_index=edge_index, num_nodes=NUM_GENRES).to(DEVICE)

print(f"Genre Nodes: {NUM_GENRES}, Edges (Undirected): {edge_index.size(1) // 2}")
print(f"Co-occurrence Threshold: {CO_OCCURRENCE_THRESHOLD}")

# 2. DEFINING THE SHALLOW EMBEDDING MODEL

class ShallowGCNEncoder(nn.Module):
    def __init__(self, num_nodes, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.embedding = nn.Embedding(num_nodes, in_channels)
        nn.init.xavier_uniform_(self.embedding.weight)
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)
    def forward(self, node_ids, edge_index):
        x = self.embedding(node_ids)
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=0.1, training=self.training)
        x = self.conv2(x, edge_index)
        return x

def dot_product_loss(z, pos_edge_index, neg_edge_index):
    pos_score = torch.sum(z[pos_edge_index[0]] * z[pos_edge_index[1]], dim=1)
    neg_score = torch.sum(z[neg_edge_index[0]] * z[neg_edge_index[1]], dim=1)
    scores = torch.cat([pos_score, neg_score])
    targets = torch.cat([torch.ones_like(pos_score), torch.zeros_like(neg_score)])
    return F.binary_cross_entropy_with_logits(scores, targets)

def negative_sample_edges(num_nodes, pos_edge_index, num_neg_samples):
    neg_edges = torch.randint(0, num_nodes, (2, num_neg_samples), dtype=torch.long, device=DEVICE)
    return neg_edges

# 3. EMBEDDING TRAINING AND SAVING

genre_encoder = ShallowGCNEncoder(
    num_nodes=NUM_GENRES,
    in_channels=EMBEDDING_DIM,
    hidden_channels=EMBEDDING_DIM,
    out_channels=EMBEDDING_DIM
).to(DEVICE)

optimizer_pretrain = torch.optim.Adam(genre_encoder.parameters(), lr=LR_PRETRAIN)
pos_edge_index = genre_graph.edge_index.to(DEVICE)
all_node_ids = torch.arange(NUM_GENRES, device=DEVICE)
num_neg_samples = pos_edge_index.size(1)

print(f"\n--- 3. Starting Shallow GNN Pre-training ({EPOCHS_PRETRAIN} Epochs) ---")

for epoch in range(1, EPOCHS_PRETRAIN + 1):
    genre_encoder.train()
    optimizer_pretrain.zero_grad()

    z = genre_encoder(all_node_ids, pos_edge_index)
    neg_edge_index = negative_sample_edges(NUM_GENRES, pos_edge_index, num_neg_samples)
    loss = dot_product_loss(z, pos_edge_index, neg_edge_index)

    loss.backward()
    optimizer_pretrain.step()

    if epoch % 20 == 0 or epoch == EPOCHS_PRETRAIN:
        print(f"Epoch {epoch:03d}/{EPOCHS_PRETRAIN}: Loss = {loss.item():.6f}")

with torch.no_grad():
    genre_encoder.eval()
    pretrained_embeddings = genre_encoder.embedding.weight.data.clone().cpu()
    torch.save(pretrained_embeddings, PRETRAINED_EMB_PATH)

print(f"\nPre-training Complete. Embeddings saved to: {PRETRAINED_EMB_PATH}")


--- 1. Building Static Genre Co-occurrence Graph via Streaming ---
Streaming Train split...


100%|██████████| 1250088/1250088 [06:52<00:00, 3029.19it/s]


Streaming Validation split...


100%|██████████| 267876/267876 [01:27<00:00, 3054.72it/s]


Streaming Test split...


100%|██████████| 267876/267876 [01:27<00:00, 3046.00it/s]


Total Labeled Events extracted via streaming: 668557
Genre Nodes: 264, Edges (Undirected): 341
Co-occurrence Threshold: 0.3

--- 3. Starting Shallow GNN Pre-training (100 Epochs) ---
Epoch 020/100: Loss = 0.432593
Epoch 040/100: Loss = 0.418416
Epoch 060/100: Loss = 0.394377
Epoch 080/100: Loss = 0.395304
Epoch 100/100: Loss = 0.408549

Pre-training Complete. Embeddings saved to: pretrained_genre_embeddings.pt
