In [None]:
# # 1. Gỡ bỏ phiên bản quá mới hiện tại
# !pip uninstall torch torchvision torchaudio torch-scatter torch-sparse torch-geometric torch-geometric-temporal -y

# # 2. Cài đặt PyTorch 2.5.1 (Bản ổn định) + CUDA 12.4
# !pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu124

# # 3. Cài đặt các thư viện vệ tinh (Scatter/Sparse) dành RIÊNG cho bản 2.5.1
# !pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-2.5.1+cu124.html

# # 4. Cài thư viện chính
# !pip install pytorch_lightning torch-geometric torch-geometric-temporal

# # # 5. Runtime > Restart session
# # # 6 Ignore this !pip section

In [None]:
import os
import random
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from torch_geometric_temporal.nn.recurrent import TGCN
from torch.utils.data import DataLoader
from sklearn.preprocessing import LabelEncoder

# Set environment variables for reproducibility and safety
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
import warnings
warnings.filterwarnings('ignore')
from sklearn.metrics import precision_score, recall_score, accuracy_score

# 1. Configuration & Seeding
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

In [None]:
def load_tgcn_data(data_dir='data'):
    print("Loading interactions...")
    # Load and process interaction data
    # Only using book_interaction.csv
    file_path = os.path.join(data_dir, 'book_interaction.csv')
    if not os.path.exists(file_path):
        raise FileNotFoundError(f"File not found: {file_path}")

    df_inter = pd.read_csv(file_path)
    # Clean column names (strip type suffix if present, e.g. user_id:token -> user_id)
    df_inter.columns = [c.split(':')[0] for c in df_inter.columns]

    # Ensure timestamp is datetime
    df_inter['timestamp'] = pd.to_datetime(df_inter['timestamp'])

    print("Mapping IDs...")
    user_encoder = LabelEncoder()
    item_encoder = LabelEncoder()

    # Encode Users and Items
    df_inter['user_idx'] = user_encoder.fit_transform(df_inter['user_id'].astype(str))
    df_inter['item_idx'] = item_encoder.fit_transform(df_inter['item_id'].astype(str))

    num_users = len(user_encoder.classes_)
    num_items = len(item_encoder.classes_)

    print(f"Total Users: {num_users}, Total Items: {num_items}")

    print("Creating temporal snapshots...")
    df_inter['month'] = df_inter['timestamp'].dt.to_period('M')
    # Sort by month to ensure temporal order
    months = sorted(df_inter['month'].unique())

    # 1. Determine Training Split (70%) to build the static graph
    train_len_months = int(len(months) * 0.7)
    train_months = months[:train_len_months]

    # 2. Get all interactions in the training set
    df_train = df_inter[df_inter['month'].isin(train_months)]

    # 3. Build Static Edge Index from Train Set
    train_u_idx = df_train['user_idx'].values
    train_i_idx = df_train['item_idx'].values

    train_u_node_idx = torch.tensor(train_u_idx, dtype=torch.long)
    train_i_node_idx = torch.tensor(train_i_idx + num_users, dtype=torch.long)

    # Create undirected graph from unique training interactions
    # Note: torch.unique might be needed if multiple interactions exist, but edge_index usually works fine with multis.
    # We'll use the raw list; duplicates increase weight in message passing or are redundant.
    # For efficiency and cleanliness, let's keep unique edges.
    train_edges_df = df_train[['user_idx', 'item_idx']].drop_duplicates()
    unique_u_idx = torch.tensor(train_edges_df['user_idx'].values, dtype=torch.long)
    unique_i_idx = torch.tensor(train_edges_df['item_idx'].values + num_users, dtype=torch.long)

    train_edge_index = torch.stack([
        torch.cat([unique_u_idx, unique_i_idx]),
        torch.cat([unique_i_idx, unique_u_idx])
    ], dim=0)

    print(f"Static Training Graph created with {train_edges_df.shape[0]} edges.")

    dataset = []

    for m in months:
        snapshot_df = df_inter[df_inter['month'] == m]
        if snapshot_df.empty:
            continue

        u_idx_raw = snapshot_df['user_idx'].values
        i_idx_raw = snapshot_df['item_idx'].values

        # Node indices for validataion/testing targets
        u_node_idx = torch.tensor(u_idx_raw, dtype=torch.long)
        i_node_idx = torch.tensor(i_idx_raw + num_users, dtype=torch.long)

        # Use the STATIC train_edge_index for Graph Structure
        dataset.append({
            'edge_index': train_edge_index,
            'y': torch.ones(len(snapshot_df), dtype=torch.float), # All interactions are positive (likes)
            'target_u': u_node_idx,
            'target_i': i_node_idx
        })

    print(f"Loaded {len(dataset)} snapshots.")

    # Split into Train (First 70%) and Remainder (30%)
    # Requirement: First 70% for training (temporal). Remaining 30% randomly split into Val and Test.
    total_len = len(dataset)
    train_len = int(total_len * 0.7)

    train_dataset = dataset[:train_len]
    remainder_dataset = dataset[train_len:]

    # Shuffle the remainder to randomize validation/test split
    import random
    random.shuffle(remainder_dataset)

    # Allocate roughly 10% of total (1/3 of remainder) to val, and 20% (2/3 of remainder) to test
    val_len = int(total_len * 0.1)

    val_dataset = remainder_dataset[:val_len]
    test_dataset = remainder_dataset[val_len:]

    print(f"Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")

    return train_dataset, val_dataset, test_dataset, num_users, num_items

In [None]:
class TGCNRecommender(pl.LightningModule):
    def __init__(self, num_cells, num_users, num_items, embedding_dim=64, lr=0.01, interaction_batch_size=1024):
        super().__init__()
        self.save_hyperparameters()

        # Important: We use manual optimization to handle mini-batches of interactions
        self.automatic_optimization = False

        self.num_cells = num_cells

        self.num_nodes = num_users + num_items
        self.embedding_dim = embedding_dim

        # Learnable Node Embeddings
        self.node_emb = nn.Embedding(self.num_nodes, embedding_dim)
        nn.init.xavier_uniform_(self.node_emb.weight)

        # T-GCN Layer
        self.tgcns = nn.ModuleList([TGCN(in_channels=embedding_dim, 
                                        out_channels=embedding_dim) for _ in range(num_cells)])

        self.lr = lr
        self.h0 = None

    def on_train_epoch_start(self):
        self.h0 = None

    def on_test_epoch_start(self):
        self.h0 = None

    def forward(self, edge_index, target_u, target_i, h):
        # 1. Get current node embeddings
        x = self.node_emb.weight

        # 2. Update Embeddings with T-GCN
        h_out = self.h0
        for tgcn in self.tgcns:
            h_out = tgcn(x, self.edge_index, h_out) #h_out shape: [num_nodes, embedding_dim]

        user_embs = h_out[:self.num_users]
        item_embs = h_out[self.num_users:]

        return user_embs, item_embs

    def compute_loss(self, batch, user_embs, item_embs):
        user_ids, item_ids = batch
        pos_item_ids = item_ids - self.hparams.num_users

        # Get embeddings
        user_emb = full_user_embs[user_ids]
        pos_emb = full_item_embs[pos_item_ids]

        # Compute positive scores
        pos_scores = torch.exp(-torch.abs(user_emb - pos_emb).sum(dim=1))

        ####################### Hard negative Sampling #######################
        distances = torch.cdist(user_emb, full_item_embs, p=1)
        scores = torch.exp(-distances)

        ######## Mask all pos_item_ids of the user in train_dataset ########
        ### Basically, the  model should only see the information in the train_dataset.
        ### Therefore, only mask the pos_item_ids of the user in train_dataset
        ### All cell (user, item) in val_dataset should be treated as blank hence don't mask the val_dataset

        for i, u in enumerate(user_ids.tolist()):
            pos_item_ids = [item - self.num_users for item in self.train_user_pos_items[u]]
            scores[i, pos_item_ids] = float('-inf')
        ######## Mask all pos_item_ids of the user in train_dataset ########

        k = 10 # Select top-K most negatives for each user
        neg_item_ids = torch.topk(scores, k=k, dim=1).indices

        # Get embeddings for these negatives
        neg_emb = full_item_embs[neg_item_ids]

        neg_scores = torch.exp(-torch.abs(user_emb.unsqueeze(1) - neg_emb).sum(dim=2))
        neg_scores = neg_scores.mean(dim=1)
        ####################### Hard negative Sampling #######################


        ####################### Compute Loss #######################
        scores = torch.cat([pos_scores, neg_scores], dim=0)
        labels = torch.cat([torch.ones_like(pos_scores), torch.zeros_like(neg_scores)], dim=0)

        loss = F.binary_cross_entropy(scores, labels)
        ####################### Compute Loss #######################
        return loss

    def training_step(self, batch, batch_idx):
        user_embs, item_embs = self()
        loss = self.compute_loss(batch, full_user_embs, full_item_embs)

        self.log('train_loss', loss, prog_bar=True, logger=True)
        return loss


    def validation_step(self, batch, batch_idx):
        return self._evaluate_step(batch, batch_idx, "val")

    def test_step(self, batch, batch_idx):
        return self._evaluate_step(batch, batch_idx, "test")

    def _evaluate_step(self, batch, batch_idx, stage):
        edge_index, y = batch['edge_index'], batch['y']
        target_u, target_i = batch['target_u'], batch['target_i']

        if self.h is None:
             self.h = torch.zeros(self.num_nodes, self.embedding_dim, device=self.device)
        else:
             self.h = self.h.to(self.device).detach()

        # Forward pass (Full Batch inference for embeddings is usually fine,
        # but we batch the scoring to save memory)
        x = self.node_emb.weight
        h_new = self.tgcn(x, edge_index, None, self.h)
        self.h = h_new.detach()

        # Evaluation in chunks
        batch_size = self.hparams.interaction_batch_size
        num_interactions = len(target_u)

        total_tp, total_fp, total_fn, total_tn = 0, 0, 0, 0

        for start_idx in range(0, num_interactions, batch_size):
            end_idx = min(start_idx + batch_size, num_interactions)
            batch_u = target_u[start_idx:end_idx]
            batch_i = target_i[start_idx:end_idx]

            # Positive Predictions
            u_emb = h_new[batch_u]
            i_emb = h_new[batch_i]
            pos_preds = self.predictor(torch.cat([u_emb, i_emb], dim=1)).view(-1)
            pos_probs = torch.sigmoid(pos_preds)

            # Negative Sampling (1:1)
            # We generate negatives on the fly for evaluation
            neg_i = torch.randint(
                self.hparams.num_users,
                self.hparams.num_users + self.hparams.num_items,
                (len(batch_u),),
                device=self.device
            )
            neg_i_emb = h_new[neg_i]
            neg_preds = self.predictor(torch.cat([u_emb, neg_i_emb], dim=1)).view(-1)
            neg_probs = torch.sigmoid(neg_preds)

            # Metrics
            all_probs = torch.cat([pos_probs, neg_probs])
            all_labels = torch.cat([torch.ones_like(pos_probs), torch.zeros_like(neg_probs)])
            preds = (all_probs > 0.5).float()

            total_tp += ((preds == 1) & (all_labels == 1)).sum().item()
            total_fp += ((preds == 1) & (all_labels == 0)).sum().item()
            total_fn += ((preds == 0) & (all_labels == 1)).sum().item()
            total_tn += ((preds == 0) & (all_labels == 0)).sum().item()

        precision = total_tp / (total_tp + total_fp + 1e-8)
        recall = total_tp / (total_tp + total_fn + 1e-8)
        accuracy = (total_tp + total_tn) / (total_tp + total_tn + total_fp + total_fn + 1e-8)

        self.log_dict({
            f"{stage}_precision": precision,
            f"{stage}_recall": recall,
            f"{stage}_accuracy": accuracy
        }, prog_bar=True)

        return precision

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)

In [None]:
if __name__ == "__main__":
    # Reload data to ensure clean state
    train_dataset, val_dataset, test_dataset, num_users, num_items = load_tgcn_data("data")

    # Model Init with Interaction Batch Size
    embedding_dim = 64
    # We set interaction_batch_size to process ratings in chunks
    model = TGCNRecommender(
        num_users=num_users,
        num_items=num_items,
        embedding_dim=embedding_dim,
        interaction_batch_size=1024 # Adjust this based on GPU memory
    )

    # DataLoaders - Keep batch_size=1 to load one temporal snapshot at a time
    train_loader = DataLoader(train_dataset, batch_size=1, collate_fn=lambda x: x[0], shuffle=False)
    val_loader = DataLoader(val_dataset, batch_size=1, collate_fn=lambda x: x[0], shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=1, collate_fn=lambda x: x[0], shuffle=False)

    trainer = pl.Trainer(
        max_epochs=10,
        accelerator="auto",
        devices=1,
        enable_progress_bar=True,
        # log_every_n_steps=1 might be too frequent for inner loops, but okay here
        log_every_n_steps=1
    )

    print("Starting Training...")
    trainer.fit(model, train_loader, val_loader)
    print("Training Complete!")

    print("Starting Testing...")
    trainer.test(model, test_loader)
    print("Testing Complete!")