In [68]:
# # 1. G·ª° b·ªè phi√™n b·∫£n qu√° m·ªõi hi·ªán t·∫°i
# !pip uninstall torch torchvision torchaudio torch-scatter torch-sparse torch-geometric torch-geometric-temporal -y

# # 2. C√†i ƒë·∫∑t PyTorch 2.5.1 (B·∫£n ·ªïn ƒë·ªãnh) + CUDA 12.4
# !pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu124

# # 3. C√†i ƒë·∫∑t c√°c th∆∞ vi·ªán v·ªá tinh (Scatter/Sparse) d√†nh RI√äNG cho b·∫£n 2.5.1
# !pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-2.5.1+cu124.html

# # 4. C√†i th∆∞ vi·ªán ch√≠nh
# !pip install pytorch_lightning torch-geometric torch-geometric-temporal

# # # 5. Runtime > Restart session
# # # 6 Ignore this !pip section

In [69]:
# from google.colab import drive
# drive.mount('/content/drive')

In [70]:
import os
import random
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from torch_geometric_temporal.nn.recurrent import TGCN, EvolveGCNH
from torch.utils.data import DataLoader, TensorDataset, Sampler
from collections import defaultdict

# Set environment variables for reproducibility and safety
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
import warnings
warnings.filterwarnings('ignore')
from sklearn.metrics import precision_score, recall_score, accuracy_score

# 1. Configuration & Seeding
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

In [71]:
class TimeAwareBatchSampler(Sampler):
    """
    K√≠ch ho·∫°t vi·ªác l·∫•y m·∫´u theo t·ª´ng b∆∞·ªõc th·ªùi gian (Time Step).
    ƒê·∫£m b·∫£o t·∫•t c·∫£ c√°c m·∫´u trong m·ªôt batch ƒë·ªÅu thu·ªôc c√πng m·ªôt th·ªùi ƒëi·ªÉm 'time_idx'.
    ƒêi·ªÅu n√†y c·ª±c k·ª≥ quan tr·ªçng cho T-GCN ƒë·ªÉ x·ª≠ l√Ω ƒë√∫ng c·ª≠a s·ªï ƒë·ªì th·ªã (edge_index_window).
    """
    def __init__(self, data_source, batch_size, shuffle=True):
        self.data_source = data_source
        self.batch_size = batch_size
        self.shuffle = shuffle
        
        # data_source l√† TensorDataset(users, items, time_indices)
        # Ch√∫ng ta c·∫ßn gom nh√≥m index thay theo time_idx
        self.time_indices = data_source.tensors[2].numpy()
        
        # T·∫°o dictionary: time_idx -> list of dataset_indices
        self.time_groups = defaultdict(list)
        for idx, t in enumerate(self.time_indices):
            self.time_groups[t].append(idx)
            
        self.time_keys = sorted(list(self.time_groups.keys()))

    def __iter__(self):
        # Lu√¥n gi·ªØ th·ª© t·ª± th·ªùi gian tƒÉng d·∫ßn ƒë·ªÉ m√¥ h√¨nh h·ªçc theo di·ªÖn ti·∫øn l·ªãch s·ª≠
        # (Kh√¥ng shuffle keys n·ªØa)
        keys = self.time_keys[:] 
            
        for t in keys:
            indices = self.time_groups[t][:]
            
            # Ch·ªâ x√°o tr·ªôn d·ªØ li·ªáu B√äN TRONG m·ªôt th√°ng
            # ƒê·ªÉ c√°c batch trong c√πng 1 th√°ng c√≥ s·ª± ng·∫´u nhi√™n
            if self.shuffle:
                random.shuffle(indices)
            
            # T·∫°o c√°c batch t·ª´ indices c·ªßa th·ªùi ƒëi·ªÉm t
            for i in range(0, len(indices), self.batch_size):
                yield indices[i : i + self.batch_size]

    def __len__(self):
        # T·ªïng s·ªë batch
        count = 0
        for t in self.time_keys:
            indices = self.time_groups[t]
            count += (len(indices) + self.batch_size - 1) // self.batch_size
        return count

In [72]:
class DataModule(pl.LightningDataModule):
    def __init__(self, interaction_file, min_interactions= 100, batch_size=1024, train_size=0.7, val_size=0.15, test_size=0.15, built_dataset=None):
        super().__init__()
        self.interaction_file = interaction_file
        self.batch_size = batch_size
        self.train_size = train_size
        self.val_size = val_size
        self.test_size = test_size
        self.min_interactions = min_interactions

    def prepare_data(self):
        # --- 1. Load & Preprocess ---
        df = pd.read_csv(self.interaction_file)

        df['timestamp'] = pd.to_datetime(df['timestamp'])
        df = df.sort_values('timestamp')
        df['month'] = df['timestamp'].dt.to_period('M')

        # ƒê·∫øm s·ªë d√≤ng m·ªói th√°ng
        month_counts = df['month'].value_counts()
        valid_months = month_counts[month_counts >= self.min_interactions].index
        
        # Ch·ªâ gi·ªØ l·∫°i th√°ng h·ª£p l·ªá
        df = df[df['month'].isin(valid_months)].copy()
        
        if len(df) == 0:
            raise ValueError("D·ªØ li·ªáu sau khi l·ªçc b·ªã r·ªóng! H√£y gi·∫£m ng∆∞·ª°ng MIN_INTERACTIONS.")
            
        print(f"Removed months with < {self.min_interactions} interactions. Remaining months: {len(valid_months)}")





        # Mapping ID
        unique_users = df['user_id'].unique()
        unique_items = df['item_id'].unique()
        
        self.num_users = len(unique_users)
        self.num_items = len(unique_items)
        self.user_to_idx = {u: idx for idx, u in enumerate(unique_users)}
        self.item_to_idx = {i: idx for idx, i in enumerate(unique_items)}

        df['user_idx'] = df['user_id'].map(self.user_to_idx)
        df['item_idx'] = df['item_id'].map(self.item_to_idx)

        # T·∫°o Time Index (0, 1, 2, ...) cho c√°c th√°ng
        valid_months_sorted = sorted(valid_months)

        month_to_idx = {m: i for i, m in enumerate(valid_months_sorted)}
        df['time_idx'] = df['month'].map(month_to_idx)
        
        self.num_time_steps = len(valid_months_sorted)

        # --- 2. Temporal Split ---
        train_end = int(self.num_time_steps * self.train_size)
        val_end = train_end + int(self.num_time_steps * self.val_size)
        
        train_months = valid_months_sorted[:train_end]
        val_months = valid_months_sorted[train_end:val_end]
        test_months = valid_months_sorted[val_end:]
        
        # L·ªçc Time Index t∆∞∆°ng ·ª©ng
        self.train_df = df[df['month'].isin(train_months)].sort_values('timestamp')
        self.val_df = df[df['month'].isin(val_months)].sort_values('timestamp')
        self.test_df = df[df['month'].isin(test_months)].sort_values('timestamp')

        # print(self.train_df.head(1))
        # user_id     item_id  timestamp  rating    month  user_idx  item_idx  time_idx  


        # --- 3. Build Graph cho TO√ÄN B·ªò th·ªùi gian (ho·∫∑c √≠t nh·∫•t l√† Train) ---
        # T-GCN c·∫ßn danh s√°ch c√°c edge_index theo th·ªùi gian
        # Ta x√¢y d·ª±ng self.edge_index_all l√† list ƒë·ªô d√†i num_time_steps
        self.edge_index_all = [None] * self.num_time_steps
        
        for month, group in df.groupby('month'):
            t_idx = month_to_idx[month]
            
            src = torch.tensor(group['user_idx'].values, dtype=torch.long)
            dst = torch.tensor(group['item_idx'].values, dtype=torch.long) + self.num_users
            
            # Undirected
            edge_index = torch.stack([torch.cat([src, dst]), torch.cat([dst, src])], dim=0)
            self.edge_index_all[t_idx] = edge_index

        # Fill c√°c th√°ng b·ªã thi·∫øu (n·∫øu c√≥) b·∫±ng edge_index r·ªóng ho·∫∑c c·ªßa th√°ng tr∆∞·ªõc
        for t in range(self.num_time_steps):
            if self.edge_index_all[t] is None:
                # T·∫°o r·ªóng
                self.edge_index_all[t] = torch.empty((2, 0), dtype=torch.long)

        # --- 4. User History ---
        self.train_user_pos_items = self._build_user_history(self.train_df)
        self.val_user_pos_items = self._build_user_history(self.val_df)
        self.test_user_pos_items = self._build_user_history(self.test_df)

    def _build_user_history(self, df_subset):
        user_pos_items = defaultdict(set)
        for u, i in zip(df_subset['user_idx'], df_subset['item_idx']):
            user_pos_items[u].add(i)
        return user_pos_items

    def _create_dataset(self, df_subset):
        if len(df_subset) == 0:
            return TensorDataset(torch.empty(0), torch.empty(0), torch.empty(0))
            
        users = torch.tensor(df_subset['user_idx'].values, dtype=torch.long)
        items = torch.tensor(df_subset['item_idx'].values, dtype=torch.long)
        # TH√äM: Time Index
        times = torch.tensor(df_subset['time_idx'].values, dtype=torch.long)
        
        return TensorDataset(users, items, times)

    def train_dataloader(self):
        dataset = self._create_dataset(self.train_df)
        # S·ª≠ d·ª•ng Custom Sampler ƒë·ªÉ ƒë·∫£m b·∫£o Batch ph√π h·ª£p v·ªõi Time-Step
        # Batch tr·∫£ v·ªÅ s·∫Ω c√≥ d·∫°ng: (users, items, time_indices)
        # Trong ƒë√≥ t·∫•t c·∫£ time_indices trong 1 batch ƒê·ªÄU GI·ªêNG NHAU (n·∫øu d√πng TimeAwareBatchSampler)
       
        batch_sampler = TimeAwareBatchSampler(dataset, batch_size=self.batch_size, shuffle=True)
        return DataLoader(dataset, batch_sampler=batch_sampler)

    def val_dataloader(self):
        dataset = self._create_dataset(self.val_df)
        # Val ƒë∆°n gi·∫£n c√≥ th·ªÉ d√πng default sampler
        return DataLoader(dataset, batch_size=self.batch_size, shuffle=False)

    def test_dataloader(self):
        dataset = self._create_dataset(self.test_df)
        return DataLoader(dataset, batch_size=self.batch_size, shuffle=False)

In [73]:
class TGCNRecommender(pl.LightningModule):
    def __init__(self, num_users, num_items, sequence_length, embedding_dim, lr):
        super().__init__()
        self.save_hyperparameters()
        self.processed = False # Flag ki·ªÉm tra setup

        self.num_users = num_users
        self.num_items = num_items
        self.num_nodes = num_users + num_items
        self.seq_len = sequence_length
        self.embedding_dim = embedding_dim
        self.lr = lr

        self.node_emb = nn.Embedding(self.num_nodes, embedding_dim)
        nn.init.xavier_uniform_(self.node_emb.weight)
        
        self.tgcn = TGCN(in_channels=embedding_dim, out_channels=embedding_dim) 

    def setup(self, stage=None):
        # L·∫•y tham chi·∫øu d·ªØ li·ªáu t·ª´ DataModule
        if self.trainer.datamodule is not None:
            self.edge_index_all = self.trainer.datamodule.edge_index_all
            self.train_user_pos_items = self.trainer.datamodule.train_user_pos_items
            self.processed = True

    def forward(self, time_idx):
        # Forward gi·ªù ƒë√¢y nh·∫≠n time_idx v√† ch·∫°y T-GCN tr√™n c·ª≠a s·ªï th·ªùi gian t∆∞∆°ng ·ª©ng
        
        # X√°c ƒë·ªãnh c·ª≠a s·ªï: [t - seq_len, t)
        # N·∫øu time_idx < seq_len, ta l·∫•y t·ª´ 0 -> time_idx (ho·∫∑c pad)
        start = max(0, time_idx - self.seq_len)
        end = time_idx
        
        window_edges = self.edge_index_all[start : end + 1] # L·∫•y c·∫£ ƒë·ªì th·ªã hi·ªán t·∫°i ƒë·ªÉ t√≠ch h·ª£p th√¥ng tin
        
        x = self.node_emb.weight
            
        # Reset h m·ªói khi b·∫Øt ƒë·∫ßu m·ªôt chu·ªói m·ªõi? 
        # V·ªõi T-GCN training theo batch ng·∫´u nhi√™n v·ªÅ th·ªùi gian, ta th∆∞·ªùng ph·∫£i t√°i t·∫°o h t·ª´ ƒë·∫ßu c·ª≠a s·ªï.
        # N·∫øu duy tr√¨ h li√™n t·ª•c (BPTT), ta c·∫ßn sampler tu·∫ßn t·ª± ch·∫∑t ch·∫Ω h∆°n.
        # ·ªû ƒë√¢y ta gi·∫£ s·ª≠ t√°i t·∫°o local context:
        h = torch.zeros(self.num_nodes, self.embedding_dim).to(x.device) # Reset local
        
        for edge_index in window_edges:
            edge_index = edge_index.to(x.device)          
            h = self.tgcn(X=x, edge_index=edge_index,H=h) # ph·∫£i ƒë·ªÉn X=x, H=h ƒë·ªÉ tr√°nh sai th·ª©c t·ª± trong h√†m forward

        return h # (Num_nodes, Emb_dim)

    def training_step(self, batch, batch_idx):
        user_ids, item_ids, time_indices = batch
        
        # V√¨ d√πng TimeAwareBatchSampler, t·∫•t c·∫£ t trong batch l√† gi·ªëng nhau
        current_t = time_indices[0].item()
        
        # Ch·∫°y model ƒë·ªÉ l·∫•y embedding t·∫°i th·ªùi ƒëi·ªÉm t
        # Output shape: (Num_Nodes, Feature)
        node_embs = self(current_t) 
        
        user_embs = node_embs[:self.num_users]
        item_embs = node_embs[self.num_users:]
        
        # T√≠nh Loss (Negative Sampling...)
        batch_user_emb = user_embs[user_ids]
        batch_pos_item_emb = item_embs[item_ids]
        
        pos_scores = torch.sum(batch_user_emb * batch_pos_item_emb, dim=1)
        pos_loss = -torch.log(torch.sigmoid(pos_scores) + 1e-10).mean()
        
        # Negative Sampling ƒë∆°n gi·∫£n
        neg_item_ids = torch.randint(0, self.num_items, (len(user_ids),), device=self.device)
        batch_neg_item_emb = item_embs[neg_item_ids]
        neg_scores = torch.sum(batch_user_emb * batch_neg_item_emb, dim=1)
        neg_loss = -torch.log(1 - torch.sigmoid(neg_scores) + 1e-10).mean()
        
        loss = pos_loss + neg_loss
        self.log("train_loss", loss, prog_bar=True)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)

In [74]:
if __name__ == "__main__":
    file = 'data/book_interaction.csv'
    # file = "/content/drive/MyDrive/Computer Science Master/01 Luan Van/data/book_interaction.csv"
    data_module = DataModule(file)

    data_module.prepare_data()

    model = TGCNRecommender(
        num_users=data_module.num_users,
        num_items=data_module.num_items,
        sequence_length = 5,
        embedding_dim= 32,
        lr = 0.001,
    )

    trainer = pl.Trainer(
        max_epochs=10,
        accelerator="auto",
        devices=1,
        enable_progress_bar=True,
        log_every_n_steps=1
    )

    trainer.fit(model, data_module)
    print("Completed")


üí° Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores


Removed months with < 100 interactions. Remaining months: 176



  | Name     | Type      | Params | Mode  | FLOPs
-------------------------------------------------------
0 | node_emb | Embedding | 588 K  | train | 0    
1 | tgcn     | TGCN      | 9.4 K  | train | 0    
-------------------------------------------------------
597 K     Trainable params
0         Non-trainable params
597 K     Total params
2.392     Total estimated model params size (MB)
14        Modules in train mode
0         Modules in eval mode
0         Total Flops


Removed months with < 100 interactions. Remaining months: 176
Epoch 0:  22%|‚ñà‚ñà‚ñè       | 27/123 [09:46<34:45,  0.05it/s, v_num=48, train_loss=1.380] 
Epoch 9: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 123/123 [00:23<00:00,  5.18it/s, v_num=50, train_loss=0.325]

`Trainer.fit` stopped: `max_epochs=10` reached.


Epoch 9: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 123/123 [00:23<00:00,  5.17it/s, v_num=50, train_loss=0.325]
Completed
