In [54]:
# # 1. G·ª° b·ªè phi√™n b·∫£n qu√° m·ªõi hi·ªán t·∫°i
# !pip uninstall torch torchvision torchaudio torch-scatter torch-sparse torch-geometric torch-geometric-temporal -y

# # 2. C√†i ƒë·∫∑t PyTorch 2.5.1 (B·∫£n ·ªïn ƒë·ªãnh) + CUDA 12.4
# !pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu124

# # 3. C√†i ƒë·∫∑t c√°c th∆∞ vi·ªán v·ªá tinh (Scatter/Sparse) d√†nh RI√äNG cho b·∫£n 2.5.1
# !pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-2.5.1+cu124.html

# # 4. C√†i th∆∞ vi·ªán ch√≠nh
# !pip install pytorch_lightning torch-geometric torch-geometric-temporal

# # # 5. Runtime > Restart session
# # # 6 Ignore this !pip section

In [55]:
# from google.colab import drive
# drive.mount('/content/drive')

In [56]:
import os
import random, math
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from torch_geometric_temporal.nn.recurrent import TGCN, EvolveGCNH
from torch.utils.data import DataLoader, TensorDataset, Sampler
from collections import defaultdict

# Set environment variables for reproducibility and safety
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
import warnings
warnings.filterwarnings('ignore')
from sklearn.metrics import precision_score, recall_score, accuracy_score

# 1. Configuration & Seeding
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

In [57]:
class TimeAwareBatchSampler(Sampler):
    """
    K√≠ch ho·∫°t vi·ªác l·∫•y m·∫´u theo t·ª´ng b∆∞·ªõc th·ªùi gian (Time Step).
    ƒê·∫£m b·∫£o t·∫•t c·∫£ c√°c m·∫´u trong m·ªôt batch ƒë·ªÅu thu·ªôc c√πng m·ªôt th·ªùi ƒëi·ªÉm 'time_idx'.
    ƒêi·ªÅu n√†y c·ª±c k·ª≥ quan tr·ªçng cho T-GCN ƒë·ªÉ x·ª≠ l√Ω ƒë√∫ng c·ª≠a s·ªï ƒë·ªì th·ªã (edge_index_window).
    """
    def __init__(self, data_source, batch_size, shuffle=True):
        self.data_source = data_source
        self.batch_size = batch_size
        self.shuffle = shuffle
        
        # data_source l√† TensorDataset(users, items, time_indices)
        # Ch√∫ng ta c·∫ßn gom nh√≥m index thay theo time_idx
        self.time_indices = data_source.tensors[2].numpy()
        
        # T·∫°o dictionary: time_idx -> list of dataset_indices
        self.time_groups = defaultdict(list)
        for idx, t in enumerate(self.time_indices):
            self.time_groups[t].append(idx)
            
        self.time_keys = sorted(list(self.time_groups.keys()))

    def __iter__(self):
        # Lu√¥n gi·ªØ th·ª© t·ª± th·ªùi gian tƒÉng d·∫ßn ƒë·ªÉ m√¥ h√¨nh h·ªçc theo di·ªÖn ti·∫øn l·ªãch s·ª≠
        # (Kh√¥ng shuffle keys n·ªØa)
        keys = self.time_keys[:] 
            
        for t in keys:
            indices = self.time_groups[t][:]
            
            # Ch·ªâ x√°o tr·ªôn d·ªØ li·ªáu B√äN TRONG m·ªôt th√°ng
            # ƒê·ªÉ c√°c batch trong c√πng 1 th√°ng c√≥ s·ª± ng·∫´u nhi√™n
            if self.shuffle:
                random.shuffle(indices)
            
            # T·∫°o c√°c batch t·ª´ indices c·ªßa th·ªùi ƒëi·ªÉm t
            for i in range(0, len(indices), self.batch_size):
                yield indices[i : i + self.batch_size]

    def __len__(self):
        # T·ªïng s·ªë batch
        count = 0
        for t in self.time_keys:
            indices = self.time_groups[t]
            count += (len(indices) + self.batch_size - 1) // self.batch_size
        return count

In [58]:
class DataModule(pl.LightningDataModule):
    def __init__(self, interaction_file, min_interactions= 100, batch_size=1024, train_size=0.7, val_size=0.15, test_size=0.15, built_dataset=None):
        super().__init__()
        self.interaction_file = interaction_file
        self.batch_size = batch_size
        self.train_size = train_size
        self.val_size = val_size
        self.test_size = test_size
        self.min_interactions = min_interactions

    def prepare_data(self):
        # --- 1. Load & Preprocess ---
        df = pd.read_csv(self.interaction_file)

        df['timestamp'] = pd.to_datetime(df['timestamp'])
        df = df.sort_values('timestamp')
        df['month'] = df['timestamp'].dt.to_period('M')

        # ƒê·∫øm s·ªë d√≤ng m·ªói th√°ng
        month_counts = df['month'].value_counts()
        valid_months = month_counts[month_counts >= self.min_interactions].index
        
        # Ch·ªâ gi·ªØ l·∫°i th√°ng h·ª£p l·ªá
        df = df[df['month'].isin(valid_months)].copy()
        
        if len(df) == 0:
            raise ValueError("D·ªØ li·ªáu sau khi l·ªçc b·ªã r·ªóng! H√£y gi·∫£m ng∆∞·ª°ng MIN_INTERACTIONS.")
            
        print(f"Removed months with < {self.min_interactions} interactions. Remaining months: {len(valid_months)}")

        # Mapping ID
        unique_users = df['user_id'].unique()
        unique_items = df['item_id'].unique()
        
        self.num_users = len(unique_users)
        self.num_items = len(unique_items)
        self.user_to_idx = {u: idx for idx, u in enumerate(unique_users)}
        self.item_to_idx = {i: idx for idx, i in enumerate(unique_items)}

        df['user_idx'] = df['user_id'].map(self.user_to_idx)
        df['item_idx'] = df['item_id'].map(self.item_to_idx)

        # T·∫°o Time Index (0, 1, 2, ...) cho c√°c th√°ng
        valid_months_sorted = sorted(valid_months)

        month_to_idx = {m: i for i, m in enumerate(valid_months_sorted)}
        df['time_idx'] = df['month'].map(month_to_idx)
        
        self.num_time_steps = len(valid_months_sorted)

        # --- 2. Temporal Split ---
        train_end = int(self.num_time_steps * self.train_size)
        val_end = train_end + int(self.num_time_steps * self.val_size)
        
        train_months = valid_months_sorted[:train_end]
        val_months = valid_months_sorted[train_end:val_end]
        test_months = valid_months_sorted[val_end:]
        
        # L·ªçc Time Index t∆∞∆°ng ·ª©ng
        self.train_df = df[df['month'].isin(train_months)].sort_values('timestamp')
        self.val_df = df[df['month'].isin(val_months)].sort_values('timestamp')
        self.test_df = df[df['month'].isin(test_months)].sort_values('timestamp')

        # --- 3. Build Graph cho TO√ÄN B·ªò th·ªùi gian ---
        self.edge_index_all = [None] * self.num_time_steps
        
        for month, group in df.groupby('month'):
            t_idx = month_to_idx[month]
            
            src = torch.tensor(group['user_idx'].values, dtype=torch.long)
            dst = torch.tensor(group['item_idx'].values, dtype=torch.long) + self.num_users
            
            # Undirected
            edge_index = torch.stack([torch.cat([src, dst]), torch.cat([dst, src])], dim=0)
            self.edge_index_all[t_idx] = edge_index

        # Fill c√°c th√°ng b·ªã thi·∫øu
        for t in range(self.num_time_steps):
            if self.edge_index_all[t] is None:
                self.edge_index_all[t] = torch.empty((2, 0), dtype=torch.long)

        # --- 4. User History ---
        self.train_user_pos_items = self._build_user_history(self.train_df)
        self.val_user_pos_items = self._build_user_history(self.val_df)
        self.test_user_pos_items = self._build_user_history(self.test_df)

    def _build_user_history(self, df_subset):
        user_pos_items = defaultdict(set)
        for u, i in zip(df_subset['user_idx'], df_subset['item_idx']):
            user_pos_items[u].add(i)
        return user_pos_items

    def _create_dataset(self, df_subset):
        if len(df_subset) == 0:
            return TensorDataset(torch.empty(0), torch.empty(0), torch.empty(0))
            
        users = torch.tensor(df_subset['user_idx'].values, dtype=torch.long)
        items = torch.tensor(df_subset['item_idx'].values, dtype=torch.long)
        times = torch.tensor(df_subset['time_idx'].values, dtype=torch.long)
        
        return TensorDataset(users, items, times)

    def train_dataloader(self):
        dataset = self._create_dataset(self.train_df)
        batch_sampler = TimeAwareBatchSampler(dataset, batch_size=self.batch_size, shuffle=True)
        return DataLoader(dataset, batch_sampler=batch_sampler)

    def val_dataloader(self):
        dataset = self._create_dataset(self.val_df)
        batch_sampler = TimeAwareBatchSampler(dataset, batch_size=self.batch_size, shuffle=False)
        return DataLoader(dataset, batch_sampler=batch_sampler)

    def test_dataloader(self):
        dataset = self._create_dataset(self.test_df)
        batch_sampler = TimeAwareBatchSampler(dataset, batch_size=self.batch_size, shuffle=False)
        return DataLoader(dataset, batch_sampler=batch_sampler)

In [59]:
class TGCNRecommender(pl.LightningModule):
    def __init__(self, num_users, num_items, sequence_length, embedding_dim, lr, dropout=0.2, weight_decay=1e-5):
        super().__init__()
        self.save_hyperparameters()

        self.num_users = num_users
        self.num_items = num_items
        self.num_nodes = num_users + num_items
        self.seq_len = sequence_length
        self.embedding_dim = embedding_dim
        self.lr = lr
        self.dropout_rate = dropout
        self.weight_decay = weight_decay

        self.node_emb = nn.Embedding(self.num_nodes, embedding_dim)
        nn.init.xavier_uniform_(self.node_emb.weight)
        
        self.tgcn = TGCN(in_channels=embedding_dim, out_channels=embedding_dim) 
        self.dropout = nn.Dropout(self.dropout_rate)

    def setup(self, stage=None):
        self.edge_index_all = self.trainer.datamodule.edge_index_all
        self.train_user_pos_items = self.trainer.datamodule.train_user_pos_items
        self.val_user_pos_items = self.trainer.datamodule.val_user_pos_items
        self.test_user_pos_items = self.trainer.datamodule.test_user_pos_items
        
    @staticmethod
    def hit_at_k(pred_items, true_items, k):
        hits = 0
        for pred, true in zip(pred_items, true_items):
            if len(set(pred[:k]) & set(true)) > 0:
                hits += 1
        return hits / len(true_items)

    @staticmethod
    def ndcg_at_k(pred_items, true_items, k):
        ndcg = 0.0
        for pred, true in zip(pred_items, true_items):
            gains = []
            for idx, item in enumerate(pred[:k]):
                gains.append(1 if item in true else 0)
            ideal_gains = [1] * min(len(true), k)
            dcg = sum(g / math.log2(i+2) for i, g in enumerate(gains))
            idcg = sum(g / math.log2(i+2) for i, g in enumerate(ideal_gains))
            ndcg += dcg / idcg if idcg > 0 else 0
        return ndcg / len(true_items)

    @staticmethod
    def recall_at_k(pred_items, true_items, k):
        recall = 0.0
        for pred, true in zip(pred_items, true_items):
            recall += len(set(pred[:k]) & set(true)) / len(true)
        return recall / len(true_items)

    @staticmethod
    def precision_at_k(pred_items, true_items, k):
        precision = 0.0
        for pred, true in zip(pred_items, true_items):
            precision += len(set(pred[:k]) & set(true)) / k
        return precision / len(true_items)

    def forward(self, time_idx):
        if time_idx == 0:
            return self.node_emb.weight
        
        start = max(0, time_idx - self.seq_len)
        window_edges = self.edge_index_all[start : time_idx] 
        
        x = self.node_emb.weight
        # Apply Dropout to input features
        x = self.dropout(x)
        
        h = torch.zeros(self.num_nodes, self.embedding_dim).to(x.device)
        
        for edge_index in window_edges:
            edge_index = edge_index.to(x.device)          
            h = self.tgcn(X=x, edge_index=edge_index, H=h) 

        return h

    def compute_loss(self, batch, user_embs, item_embs):
        user_ids, pos_item_ids, time_indices = batch

        # Get embeddings
        user_emb = user_embs[user_ids]
        pos_emb = item_embs[pos_item_ids]

        # Compute positive scores
        pos_scores = torch.exp(-torch.abs(user_emb - pos_emb).sum(dim=1))

        ####################### Hard negative Sampling #######################
        distances = torch.cdist(user_emb, item_embs, p=1)
        scores = torch.exp(-distances)

        ######## Mask all pos_item_ids of the user in train_dataset ########
        ### Basically, the  model should only see the information in the train_dataset.
        ### Therefore, only mask the pos_item_ids of the user in train_dataset
        ### All cell (user, item) in val_dataset should be treated as blank hence don't mask the val_dataset

        for i, u in enumerate(user_ids.tolist()):
            pos_item_ids = [item for item in self.train_user_pos_items[u]]
            scores[i, pos_item_ids] = float('-inf')
        ######## Mask all pos_item_ids of the user in train_dataset ########

        k = 10 # Select top-K most negatives for each user
        neg_item_ids = torch.topk(scores, k=k, dim=1).indices

        # Get embeddings for these negatives
        neg_emb = item_embs[neg_item_ids]

        neg_scores = torch.exp(-torch.abs(user_emb.unsqueeze(1) - neg_emb).sum(dim=2))
        neg_scores = neg_scores.mean(dim=1)
        ####################### Hard negative Sampling #######################


        ####################### Compute Loss #######################
        scores = torch.cat([pos_scores, neg_scores], dim=0)
        labels = torch.cat([torch.ones_like(pos_scores), torch.zeros_like(neg_scores)], dim=0)

        loss = F.binary_cross_entropy(scores, labels)
        ####################### Compute Loss #######################
        return loss

    def training_step(self, batch, batch_idx):
        user_ids, item_ids, time_indices = batch
        current_t = time_indices[0].item()
        
        node_embs = self(current_t) 
        user_embs = node_embs[:self.num_users]
        item_embs = node_embs[self.num_users:]

        loss = self.compute_loss(batch, user_embs, item_embs)
        self.log("train_loss", loss, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        user_ids, item_ids, time_indices = batch

        current_t = time_indices[0].item()
        
        node_embs = self(current_t) 

        user_embs = node_embs[:self.num_users]
        item_embs = node_embs[self.num_users:]

        batch_user_emb = user_embs[user_ids]

        distances = torch.cdist(batch_user_emb, item_embs, p=1)
        scores = torch.exp(-distances)  # it is the score between the ith user in batch_size and ALL items

        ########## Mask those user-item pair that already in training set so that it won't suggest again
        mask = torch.zeros_like(scores, dtype=torch.bool)
        for i, u in enumerate(user_ids.tolist()):
            trained_items = [item for item in self.train_user_pos_items[u]]
            mask[i, trained_items] = True

        scores = scores.masked_fill(mask, float('-inf'))    #### Make them to -inf so that TopK won't pick again
        ########## Mask those user-item pair that already in training set so that it won't suggest again

        #################### Calculate metrics
        # k_values = [5, 10, 15, 20]  # Example: you can add more values as needed
        k_values = [10]

        for k in k_values:
            # Get top-k items for this k
            topk_items = torch.topk(scores, k=k, dim=1).indices.tolist() # (1024, K=5)

            true_items = []  # each user may have multiple positive items
            for u in user_ids.tolist():
                adjusted_val_items = [item - self.num_users for item in self.val_user_pos_items[u]]
                true_items.append(adjusted_val_items)

            # Compute metrics for this k
            hit = self.hit_at_k(topk_items, true_items, k)
            ndcg = self.ndcg_at_k(topk_items, true_items, k)
            recall = self.recall_at_k(topk_items, true_items, k)
            precision = self.precision_at_k(topk_items, true_items, k)

            # Log metrics dynamically
            self.log(f"val_hit@{k:02d}", hit, prog_bar=True)
            self.log(f"val_recall@{k:02d}", recall, prog_bar=True)
            self.log(f"val_precision@{k:02d}", precision, prog_bar=True)
            self.log(f"val_ndcg@{k:02d}", ndcg, prog_bar=True)

    def test_step(self, batch, batch_idx):
        loss, recall_10, precision_10 = self._common_step(batch, evaluate=True, k=10)
        self.log("test_loss", loss, prog_bar=True)
        self.log("test_recall_10", recall_10, prog_bar=True)
        self.log("test_precision_10", precision_10, prog_bar=True)
        return loss
        
    def _common_step(self, batch, evaluate=False, k=10):
        user_ids, item_ids, time_indices = batch
        current_t = time_indices[0].item()
        
        node_embs = self(current_t) 
        user_embs = node_embs[:self.num_users]
        item_embs = node_embs[self.num_users:]
        
        batch_user_emb = user_embs[user_ids]
        batch_pos_item_emb = item_embs[item_ids]
        pos_scores = (batch_user_emb * batch_pos_item_emb).sum(dim=1)
        
        neg_item_ids = torch.randint(0, self.num_items, (len(user_ids),), device=self.device)
        neg_scores = (batch_user_emb * item_embs[neg_item_ids]).sum(dim=1)
        
        loss = -torch.log(torch.sigmoid(pos_scores - neg_scores) + 1e-10).mean()
        
        recall, precision = 0.0, 0.0
        if evaluate:
             all_scores = torch.matmul(batch_user_emb, item_embs.T)
             _, topk_indices = torch.topk(all_scores, k, dim=1)
             hits = (topk_indices == item_ids.view(-1, 1)).any(dim=1).float()
             recall = hits.mean()
             precision = hits.mean() / k
             
        return loss, recall, precision

    def configure_optimizers(self):
        # Added weight_decay
        return torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)

In [60]:
if __name__ == "__main__":
    file = 'data/book_interaction.csv'
    # file = "/content/drive/MyDrive/Computer Science Master/01 Luan Van/data/book_interaction.csv"
    data_module = DataModule(file)

    data_module.prepare_data()

    model = TGCNRecommender(
        num_users=data_module.num_users,
        num_items=data_module.num_items,
        sequence_length = 6,
        embedding_dim= 64,
        lr = 0.001,
    )

    trainer = pl.Trainer(
        max_epochs=10,
        accelerator="auto",
        devices=1,
        enable_progress_bar=True,
        log_every_n_steps=1
    )

    trainer.fit(model, data_module)
    print("Completed")


üí° Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores


Removed months with < 100 interactions. Remaining months: 176
Removed months with < 100 interactions. Remaining months: 176



  | Name     | Type      | Params | Mode  | FLOPs
-------------------------------------------------------
0 | node_emb | Embedding | 1.2 M  | train | 0    
1 | tgcn     | TGCN      | 37.2 K | train | 0    
2 | dropout  | Dropout   | 0      | train | 0    
-------------------------------------------------------
1.2 M     Trainable params
0         Non-trainable params
1.2 M     Total params
4.858     Total estimated model params size (MB)
15        Modules in train mode
0         Modules in eval mode
0         Total Flops


Epoch 3:  72%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñè  | 88/123 [00:56<00:22,  1.56it/s, v_num=11, train_loss=0.749, val_hit@10=0.000, val_recall@10=0.000, val_precision@10=0.000, val_ndcg@10=0.000] 


Detected KeyboardInterrupt, attempting graceful shutdown ...


SystemExit: 1