In [10]:
import random, os, pickle, time
import pandas as pd
import numpy as np
import networkx as nx

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader

from tqdm import tqdm

from collections import defaultdict


# Set environment variables for reproducibility and safety
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
import warnings
warnings.filterwarnings('ignore')
from sklearn.metrics import precision_score, recall_score, accuracy_score

# 1. Configuration & Seeding
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

In [11]:
name = 'book'

In [None]:
class TCKG:
    def __init__(self, tckg_csv_path):
        self.adj_list = defaultdict(list)
        
        print(f"Loading TCKG from {tckg_csv_path}...")
        df_tckg = pd.read_csv(tckg_csv_path, usecols=['head_id', 'relation_id', 'tail_id'])
        
        offset = df_tckg['relation_id'].max()
            
        # self.num_relations = offset * 2 + 1# Total relation_id (bidirection)
        # print(f"Inverse Relation Offset: {offset}. Total Relation Space: {self.num_relations}")

        data = df_tckg[['head_id', 'relation_id', 'tail_id']].to_numpy() # Using numpy to speedup
        
        for h, r, t in data:
            h, r, t = int(h), int(r), int(t)

            self.adj_list[h].append((r, t))
            self.adj_list[t].append((r + offset, h)) # Ex: r=5 (watched) -> r_inv=105 (watched_by)

        print(f"TCKG Loaded successfully. Graph construction complete.")

    def get_neighbors(self, node_id):
        return self.adj_list[node_id]

    def get_all_nodes(self):
        return list(self.adj_list.keys())

In [13]:
class TimeAwareRewardFunction(nn.Module):
    def __init__(self, user_embs, entity_embs, relation_embs, interaction_cluster_ids, bias_embs=None, temperature= None):
        """
        Args:
            user_embs (nn.Embedding): Embedding c·ªßa User (e_u)
            entity_embs (nn.Embedding): Embedding c·ªßa Item/Entity (e_v)
            relation_embs (nn.Embedding): Embedding c·ªßa Relation (d√πng ƒë·ªÉ l·∫•y V_U)
            interaction_cluster_ids (list or tensor): Danh s√°ch c√°c Relation ID ƒë·∫°i di·ªán cho Time Clusters.
                                                      V√≠ d·ª•: [20, 21, 22] ·ª©ng v·ªõi interacted_0, interacted_1...
                                                      ƒê√¢y ch√≠nh l√† t·∫≠p {V_U^1, ..., V_U^L}
            bias_embs (nn.Embedding, optional): Bias c·ªßa entity (b_v). N·∫øu None s·∫Ω t·ª± kh·ªüi t·∫°o.
        """
        super(TimeAwareRewardFunction, self).__init__()
        
        self.user_embs = user_embs
        self.entity_embs = entity_embs
        self.relation_embs = relation_embs
        
        # Danh s√°ch ID c·ªßa c√°c cluster t∆∞∆°ng t√°c theo th·ªùi gian (V_U)
        # Chuy·ªÉn th√†nh tensor ƒë·ªÉ t√≠nh to√°n song song
        self.register_buffer('cluster_ids', torch.tensor(interaction_cluster_ids, dtype=torch.long))
        
        # Entity Bias (b_v) - Eq (11)
        if bias_embs is None:
            num_entities = entity_embs.num_embeddings
            self.bias_embs = nn.Embedding(num_entities, 1)
            nn.init.zeros_(self.bias_embs.weight) # Kh·ªüi t·∫°o bias b·∫±ng 0
        else:
            self.bias_embs = bias_embs

        if temperature is None:
            self.temperature = self.entity_embs.embedding_dim ** 0.5
        else:
            self.temperature = temperature

    def calculate_weights(self, history_relation_ids):
        """
        Th·ª±c hi·ªán Eq (13): T√≠nh tr·ªçng s·ªë w_h d·ª±a tr√™n t·∫ßn su·∫•t xu·∫•t hi·ªán trong l·ªãch s·ª≠.
        
        Args:
            history_relation_ids: (Batch, Max_History_Len) - Ch·ª©a relation ID trong qu√° kh·ª© c·ªßa user.
            
        Returns:
            weights: (Batch, Num_Clusters) - Vector W_hu
        """
        # 1. So kh·ªõp History v·ªõi Cluster IDs
        # history: (B, H, 1)
        # clusters: (1, 1, L)
        # matches: (B, H, L) -> True n·∫øu relation t·∫°i history kh·ªõp v·ªõi cluster ID
        hist_expanded = history_relation_ids.unsqueeze(-1)
        clusters_expanded = self.cluster_ids.view(1, 1, -1)
        
        matches = (hist_expanded == clusters_expanded).float()
        
        # 2. ƒê·∫øm s·ªë l·∫ßn xu·∫•t hi·ªán (T·ª≠ s·ªë Eq 13)
        # Sum theo chi·ªÅu History (dim=1) -> (B, L)
        counts = matches.sum(dim=1) 
        
        # 3. T√≠nh ƒë·ªô d√†i q th·ª±c t·∫ø (M·∫´u s·ªë Eq 13)
        # q = t·ªïng s·ªë l·∫ßn xu·∫•t hi·ªán c·ªßa b·∫•t k·ª≥ cluster n√†o trong history (tr√°nh t√≠nh padding 0)
        q = counts.sum(dim=1, keepdim=True)
        
        # 4. Normalize ƒë·ªÉ ra tr·ªçng s·ªë (tr√°nh chia cho 0)
        weights = counts / (q + 1e-9)
        
        return weights

    def forward(self, user_ids, item_ids, history_relation_ids):
        """
        T√≠nh Reward Score g_R(v | u)
        
        Args:
            user_ids: (Batch,)
            item_ids: (Batch,) - Item ƒë√≠ch (v_hat) m√† Agent d·ª± ƒëo√°n/d·ª´ng l·∫°i
            history_relation_ids: (Batch, History_Len) - L·ªãch s·ª≠ relation c·ªßa user
            
        Returns:
            scores: (Batch,) - ƒêi·ªÉm reward
        """
        # --- B∆Ø·ªöC 1: L·∫•y Embeddings c∆° b·∫£n ---
        u_e = self.user_embs(user_ids)       # (B, Dim) -> e_u
        v_e = self.entity_embs(item_ids)     # (B, Dim) -> e_v
        v_b = self.bias_embs(item_ids).squeeze(-1) # (B,) -> b_v
        
        # --- B∆Ø·ªöC 2: T√≠nh Personalized Interaction Relation (Eq 12) ---
        # r_vu^T = W_hu * V_U
        
        # a. T√≠nh weights (B, L)
        weights = self.calculate_weights(history_relation_ids)
        
        # b. L·∫•y embedding c·ªßa c√°c cluster V_U^1...L
        # cluster_embs shape: (L, Dim)
        cluster_embs = self.relation_embs(self.cluster_ids)
        
        # c. T√≠nh t·ªïng c√≥ tr·ªçng s·ªë
        # (B, L) x (L, Dim) -> (B, Dim)
        r_interaction = torch.matmul(weights, cluster_embs)
        
        # --- B∆Ø·ªöC 3: T√≠nh Score (Eq 11 & Final Eq) ---
        # g = (e_u + r_interaction) . e_v + b_v
        
        # Dot product: (e_u + r) * e_v
        query_vector = u_e + r_interaction # (B, Dim)
        dot_product = torch.sum(query_vector * v_e, dim=1) # (B,)
        
        scores = dot_product + v_b # C·ªông bias
        
        # 1. Scale Score (Chia cho nhi·ªát ƒë·ªô)
        scaled_score = scores / self.temperature
        
        # 2. √Åp d·ª•ng Sigmoid
        rewards = torch.sigmoid(scaled_score)
        
        return rewards


In [14]:
class TPRecEnvironment(nn.Module):
    def __init__(self, kg, entity_embeddings, relation_embeddings, reward_function, max_path_len=3, history_len=3):
        """
        Th√™m tham s·ªë: reward_function
        """
        super(TPRecEnvironment, self).__init__()
        self.kg = kg
        self.entity_embs = entity_embeddings
        self.relation_embs = relation_embeddings
        self.max_path_len = max_path_len
        self.history_len = history_len
        
        # L∆ØU REWARD FUNCTION ƒê∆Ø·ª¢C TRUY·ªÄN V√ÄO
        self.reward_function = reward_function 

        # State tracking
        self.current_entities = None
        self.current_users = None
        self.path_history = None
        self.step_counter = 0

    def reset(self, user_ids):
        """
        Initiate status s_0 = (u, u, ‚àÖ)
        """
        batch_size = user_ids.size(0)

        self.current_users = user_ids       # Multiple values because of batch_size
        self.current_entities = user_ids    # Initiate status s_0 = (u, u, ‚àÖ)
        
        # History h_k: store (relation, entity)
        # B·ªï sung device=user_ids.device ƒë·ªÉ tr√°nh l·ªói khi d√πng GPU
        self.path_history = torch.zeros((batch_size, self.history_len * 2), 
                                        dtype=torch.long, 
                                        device=user_ids.device)

        self.step_counter = 0        
        return self._get_state_embedding()

    def _get_state_embedding(self):
        """
        State = [User_Emb, Flattened_History_Emb, Current_Entity_Emb]
        Input cho Policy Network s·∫Ω l√† vector n·ªëi d√†i c·ªßa 3 th√†nh ph·∫ßn n√†y.
        """
        u_emb = self.entity_embs(self.current_users)
        e_emb = self.entity_embs(self.current_entities)
        
        # 2a: T√°ch index c·ªßa Relation v√† Entity b·∫±ng k·ªπ thu·∫≠t Slicing
        r_indices = self.path_history[:, 0::2] # (Batch, history_len)
        e_indices = self.path_history[:, 1::2] # (Batch, history_len)
        
        # 2b: Lookup Embedding
        r_vecs = self.relation_embs(r_indices) 
        e_vecs = self.entity_embs(e_indices)
        
        # 2c: K·∫øt h·ª£p Relation v√† Entity t·∫°i m·ªói b∆∞·ªõc
        step_vecs = torch.cat([r_vecs, e_vecs], dim=2)
        
        # 2d: L√†m ph·∫≥ng (Flatten) to√†n b·ªô l·ªãch s·ª≠
        batch_size = step_vecs.size(0)
        h_emb_flat = step_vecs.view(batch_size, -1)

        # 3. K·∫æT H·ª¢P T·∫§T C·∫¢ (Concatenate)
        # K·∫øt qu·∫£ c√≥ c·∫•u tr√∫c logic t∆∞∆°ng ƒë∆∞∆°ng: [User_Emb, r1, e1, r2, e2, r3, e3, Current_Entity_Emb]
        state_vector = torch.cat([u_emb, h_emb_flat, e_emb], dim=1) 
        
        return state_vector

    def get_pruned_actions(self, epsilon=10):
        """
        Th·ª±c hi·ªán Eq (8): Pruning function g_k((r, e_{k+1}) | u)
        C·∫Øt t·ªâa kh√¥ng gian h√†nh ƒë·ªông, ch·ªâ gi·ªØ l·∫°i top-epsilon neighbors.
        """
        batch_size = self.current_users.size(0)
        device = self.current_users.device # L·∫•y device hi·ªán t·∫°i
        valid_actions = []
        
        u_emb = self.entity_embs(self.current_users) # (B, dim)
        
        for i in range(batch_size):
            u_id = self.current_users[i].item()
            curr_node = self.current_entities[i].item()
            
            neighbors = self.kg.get_neighbors(curr_node) 
            
            if not neighbors:
                valid_actions.append([]) # Dead end
                continue
                
            # ƒê√£ b·ªï sung `device=device` ƒë·ªÉ tensor t·∫°o ra n·∫±m c√πng thi·∫øt b·ªã v·ªõi model
            rels = torch.tensor([n[0] for n in neighbors], device=device)
            next_nodes = torch.tensor([n[1] for n in neighbors], device=device)
            
            r_emb = self.relation_embs(rels)
            next_node_emb = self.entity_embs(next_nodes)
            
            query = u_emb[i] + r_emb 
            scores = torch.sum(query * next_node_emb, dim=1) 
            
            k = min(epsilon, len(scores))
            top_scores, top_indices = torch.topk(scores, k)
            
            actions_i = []
            for idx in top_indices:
                actions_i.append((rels[idx].item(), next_nodes[idx].item())) # L·∫•y .item() l∆∞u v√†o list
            valid_actions.append(actions_i)
            
        return valid_actions

    def get_action_space_batch(self):
        """
        [H√ÄM M·ªöI] L·∫•y kh√¥ng gian h√†nh ƒë·ªông cho c·∫£ Batch v√† Padding th√†nh Tensor.
        """
        raw_actions_list = self.get_pruned_actions() 
        batch_size = len(raw_actions_list)
        
        lengths = [len(acts) for acts in raw_actions_list]
        max_len = max(lengths) if lengths else 0
        if max_len == 0:
            max_len = 1
            
        device = self.current_entities.device
        
        r_indices = torch.zeros((batch_size, max_len), dtype=torch.long, device=device)
        e_indices = torch.zeros((batch_size, max_len), dtype=torch.long, device=device)
        action_mask = torch.zeros((batch_size, max_len), dtype=torch.float, device=device)
        
        for i, actions in enumerate(raw_actions_list):
            num_acts = len(actions)
            if num_acts > 0:
                rs = [a[0] for a in actions]
                es = [a[1] for a in actions]
                
                r_indices[i, :num_acts] = torch.tensor(rs, device=device)
                e_indices[i, :num_acts] = torch.tensor(es, device=device)
                action_mask[i, :num_acts] = 1.0
                
        r_emb = self.relation_embs(r_indices)
        e_emb = self.entity_embs(e_indices)
        
        action_embs = torch.cat([r_emb, e_emb], dim=-1)
        
        return action_embs, action_mask, raw_actions_list

    def step(self, actions):
        """
        Transition function (Eq 9): Chuy·ªÉn tr·∫°ng th√°i sang b∆∞·ªõc k+1
        """
        device = self.current_entities.device
        batch_size = len(actions)
        
        rels_list = [a[0] for a in actions]
        ents_list = [a[1] for a in actions]
        
        next_relations = torch.tensor(rels_list, dtype=torch.long, device=device)
        next_entities = torch.tensor(ents_list, dtype=torch.long, device=device)
        
        new_r = next_relations.unsqueeze(1)
        new_e = next_entities.unsqueeze(1)
        new_entry = torch.cat([new_r, new_e], dim=1)
        
        history_shifted = self.path_history[:, 2:]
        self.path_history = torch.cat([history_shifted, new_entry], dim=1)
        
        self.current_entities = next_entities
        self.step_counter += 1
        done = (self.step_counter >= self.max_path_len)
        
        return self._get_state_embedding(), done

    def step_with_indices(self, action_indices, raw_actions_list):
        """
        [H√ÄM M·ªöI] Th·ª±c hi·ªán b∆∞·ªõc ƒëi d·ª±a tr√™n Index (0, 1, 2...) m√† Agent ch·ªçn.
        """
        selected_real_actions = []
        batch_size = len(action_indices)
        
        for i in range(batch_size):
            idx = action_indices[i].item()
            user_acts = raw_actions_list[i]
            
            if len(user_acts) > 0:
                idx = min(idx, len(user_acts) - 1) # B·∫£o v·ªá index out of bounds
                real_action = user_acts[idx] 
            else:
                curr_node = self.current_entities[i].item()
                real_action = (0, curr_node) # Dead-end: T·ª± tr·ªè v·ªÅ ch√≠nh n√≥
                
            selected_real_actions.append(real_action)
            
        return self.step(selected_real_actions)

    def get_reward(self):
        """
        G·ªçi khi done=True. T√≠nh Time-aware Reward g_R(v|u).
        """
        user_ids = self.current_users
        item_ids = self.current_entities
        
        # Ch·ªâ l·∫•y c·ªôt relation t·ª´ l·ªãch s·ª≠
        history_relation_ids = self.path_history[:, 0::2] 
        
        rewards = self.reward_function(user_ids, item_ids, history_relation_ids)
        return rewards

In [15]:
class TPRecPolicy(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=128, dropout=0.1):
        """
        M·∫°ng Policy (Actor) s·ª≠ d·ª•ng MLP ƒë·ªÉ ch·∫•m ƒëi·ªÉm h√†nh ƒë·ªông.
        
        Args:
            state_dim (int): K√≠ch th∆∞·ªõc vector tr·∫°ng th√°i s_k.
                             = dim(u) + dim(flattened_history) + dim(e_k)
            action_dim (int): K√≠ch th∆∞·ªõc vector h√†nh ƒë·ªông a_k.
                              = dim(relation) + dim(next_node)
            hidden_dim (int): S·ªë neuron l·ªõp ·∫©n.
            dropout (float): T·ª∑ l·ªá dropout ƒë·ªÉ tr√°nh overfitting.
        """
        super(TPRecPolicy, self).__init__()
        
        # Layer 1: K·∫øt h·ª£p State v√† Action
        # Input size = State + Action
        self.fc1 = nn.Linear(state_dim + action_dim, hidden_dim)
        
        # Activation function (ELU ho·∫∑c ReLU th∆∞·ªùng d√πng trong Graph RL)
        self.act = nn.ELU()
        
        # Dropout
        self.dropout = nn.Dropout(dropout)
        
        # Layer 2: Output ra 1 ƒëi·ªÉm s·ªë (Scalar Score) cho c·∫∑p (State, Action)
        self.fc2 = nn.Linear(hidden_dim, 1)

    def forward(self, state_emb, action_embs, action_mask=None):
        """
        Forward pass ƒë·ªÉ t√≠nh x√°c su·∫•t h√†nh ƒë·ªông.
        
        Args:
            state_emb: (Batch, State_Dim) - Vector tr·∫°ng th√°i s_k
            action_embs: (Batch, Max_Actions, Action_Dim) - C√°c h√†nh ƒë·ªông ·ª©ng vi√™n
            action_mask: (Batch, Max_Actions) - Mask (1=H·ª£p l·ªá, 0=Padding)
            
        Returns:
            probs: (Batch, Max_Actions) - X√°c su·∫•t (ƒë√£ qua Softmax)
            log_probs: (Batch, Max_Actions) - Log x√°c su·∫•t (ƒë·ªÉ t√≠nh Loss)
        """
        batch_size, num_actions, _ = action_embs.size()
        
        # 1. M·ªü r·ªông State ƒë·ªÉ kh·ªõp v·ªõi s·ªë l∆∞·ª£ng Action (Broadcasting)
        # state: (B, S_Dim) -> (B, 1, S_Dim) -> (B, K, S_Dim)
        state_expanded = state_emb.unsqueeze(1).expand(-1, num_actions, -1)
        
        # 2. Gh√©p (Concatenate) State v√† Action
        # input: (B, K, S_Dim + A_Dim)
        inputs = torch.cat([state_expanded, action_embs], dim=2)
        
        # 3. Qua m·∫°ng MLP
        x = self.fc1(inputs)
        x = self.act(x)
        x = self.dropout(x)
        scores = self.fc2(x).squeeze(-1) # (B, K, 1) -> (B, K)
        
        # 4. Masking (C·ª±c k·ª≥ quan tr·ªçng)
        # G√°n ƒëi·ªÉm s·ªë r·∫•t th·∫•p (-1e9) cho c√°c h√†nh ƒë·ªông padding ƒë·ªÉ Softmax = 0
        if action_mask is not None:
            scores = scores.masked_fill(action_mask == 0, -1e9)
        
        # 5. T√≠nh x√°c su·∫•t (Softmax)
        probs = F.softmax(scores, dim=1)
        
        # Tr·∫£ v·ªÅ c·∫£ log_probs ƒë·ªÉ ti·ªán t√≠nh REINFORCE Loss sau n√†y
        log_probs = F.log_softmax(scores, dim=1)
        
        return probs, log_probs

### Class Config 

In [16]:
class Config:
    # ƒê∆∞·ªùng d·∫´n d·ªØ li·ªáu
    TCKG_PATH = f"./data/{name}/{name}_TCKG.csv"  # File CSV ch·ª©a ƒë·ªì th·ªã (nh∆∞ b·∫°n ƒë√£ g·ª≠i)
    # TRAIN_USERS_PATH = "data/train_users.csv" # File ch·ª©a danh s√°ch user ID d√πng ƒë·ªÉ train
    SAVE_DIR = "./checkpoints"
    
    # Si√™u tham s·ªë Model
    EMBED_DIM = 64
    HIDDEN_DIM = 256
    HISTORY_LEN = 3   # k' (ƒë·ªô d√†i l·ªãch s·ª≠)
    MAX_PATH_LEN = 3  # K (s·ªë b∆∞·ªõc ƒëi)
    
    # Si√™u tham s·ªë Training
    BATCH_SIZE = 512 # Batch l·ªõn gi√∫p RL ·ªïn ƒë·ªãnh h∆°n
    NUM_EPOCHS = 500
    LEARNING_RATE = 5e-4
    BETA_ENTROPY = 0.01
    
    # Thi·∫øt b·ªã
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # IDs c·ªßa c√°c Relation Interaction (Quan tr·ªçng cho Reward)
    # V√≠ d·ª•: 20=interacted_0, 21=interacted_1, 22=interacted_2
    INTERACTION_CLUSTER_IDS = [20, 21, 22, 23] 

### Load Dataset

In [17]:
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader

# 1. T·∫°o class Dataset l√†m c·∫ßu n·ªëi
class InteractionDataset(Dataset):
    def __init__(self, df):
        """Nh·∫≠n v√†o m·ªôt DataFrame v√† chuy·ªÉn ƒë·ªïi c√°c c·ªôt c·∫ßn thi·∫øt th√†nh m·∫£ng NumPy"""
        self.users = df['user_id'].values
        self.entities = df['entity_id'].values

    def __len__(self):
        """Tr·∫£ v·ªÅ t·ªïng s·ªë d√≤ng d·ªØ li·ªáu"""
        return len(self.users)

    def __getitem__(self, idx):
        """L·∫•y d·ªØ li·ªáu t·∫°i v·ªã tr√≠ idx v√† bi·∫øn th√†nh PyTorch Tensor"""
        user_tensor = torch.tensor(self.users[idx], dtype=torch.long)
        entity_tensor = torch.tensor(self.entities[idx], dtype=torch.long)
        
        return user_tensor

# 2. ƒê·ªçc file CSV th√†nh Pandas DataFrame
# (L∆∞u √Ω nh·ªè: M√¨nh gi·ªØ nguy√™n t√™n file c·ªßa b·∫°n, nh∆∞ng ch·ªØ 'interacions' h√¨nh nh∆∞ ƒëang thi·∫øu ch·ªØ 't', b·∫°n nh·ªõ ki·ªÉm tra l·∫°i t√™n file th·∫≠t nh√©)
train_df = pd.read_csv(f'./data/{name}/{name}_train_interactions.csv')
val_df = pd.read_csv(f'./data/{name}/{name}_val_interactions.csv')
test_df = pd.read_csv(f'./data/{name}/{name}_test_interactions.csv')

# 3. B·ªçc DataFrame v√†o class Dataset v·ª´a t·∫°o
train_dataset = InteractionDataset(train_df)
val_dataset = InteractionDataset(val_df)
test_dataset = InteractionDataset(test_df)

# 4. ƒê∆∞a Dataset v√†o DataLoader
train_loader = DataLoader(train_dataset, batch_size=1024, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=1024, shuffle=False, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=1024, shuffle=False, num_workers=0)

### Main function

In [18]:
def train(env, policy_net, train_loader, num_epochs=50, 
          learning_rate=1e-3, beta_entropy=0.01, 
          save_path="./checkpoints", device="cuda"):
    """
    H√†m hu·∫•n luy·ªán g·ªôp (kh√¥ng d√πng h√†m ph·ª• tr·ª£ train_one_epoch).
    
    C·∫•u tr√∫c l·ªìng nhau:
    1. Loop Epochs (H√†ng trƒÉm l·∫ßn)
       2. Loop Batches (Duy·ªát qua to√†n b·ªô dataset)
          3. Loop Steps (Agent ƒëi k b∆∞·ªõc t√¨m ƒë∆∞·ªùng)
             -> T√≠nh Reward & Loss
             -> Backpropagation
    """
    
    # --- 1. SETUP ---
    if not os.path.exists(save_path):
        os.makedirs(save_path)
        
    # ƒê∆∞a model v√† env v√†o thi·∫øt b·ªã (GPU/CPU)
    policy_net.to(device)
    # env.to(device) # N·∫øu class Env c·ªßa b·∫°n c√≥ h·ªó tr·ª£ .to()
    
    # Optimizer (Adam)
    optimizer = optim.Adam(policy_net.parameters(), lr=learning_rate)
    
    # Learning Rate Scheduler (Gi·∫£m LR sau m·ªói 50 epoch)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.9)
    
    print(f"üöÄ B·∫Øt ƒë·∫ßu hu·∫•n luy·ªán {num_epochs} epochs...")
    start_time = time.time()

    print(f"T·ªïng s·ªë m·∫´u trong Dataset: {len(train_loader.dataset)}")
    print(f"T·ªïng s·ªë Batch s·∫Ω ch·∫°y: {len(train_loader)}")
    
    # --- 2. V√íNG L·∫∂P EPOCH (V√≤ng ngo√†i c√πng) ---
    for epoch in range(1, num_epochs + 1):
        
        policy_net.train() # B·∫≠t ch·∫ø ƒë·ªô train (Dropout, Batchnorm...)
        
        total_epoch_loss = 0
        total_epoch_reward = 0
        num_batches = 0
        
        # Thanh ti·∫øn tr√¨nh cho t·ª´ng Epoch
        pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{num_epochs}")
        
        # --- 3. V√íNG L·∫∂P BATCH (Duy·ªát Dataset) ---
        for batch_users in train_loader:
            batch_users = batch_users.to(device)
            
            # A. Reset m√¥i tr∆∞·ªùng cho batch m·ªõi
            # state_emb: (Batch, State_Dim)
            state_emb = env.reset(batch_users)
            
            # Danh s√°ch l∆∞u l·∫°i th√¥ng tin ƒë·ªÉ t√≠nh Loss sau khi ƒëi xong
            saved_log_probs = [] # Log œÄ(a|s)
            saved_entropies = [] # H(œÄ)
            
            # --- 4. V√íNG L·∫∂P STEP (Agent ƒëi t√¨m ƒë∆∞·ªùng) ---
            # ƒêi max_path_len b∆∞·ªõc (v√≠ d·ª•: 3 b∆∞·ªõc)
            for t in range(env.max_path_len):
                
                # B. L·∫•y h√†nh ƒë·ªông kh·∫£ thi (Pruned Action Space)
                # action_embs: (Batch, K, Dim)
                # mask: (Batch, K) - ƒê·ªÉ che padding
                # raw_actions: List ƒë·ªÉ map l·∫°i ID th·ª±c t·∫ø
                action_embs, action_mask, raw_actions = env.get_action_space_batch()
                
                # C. Policy Forward (Ra quy·∫øt ƒë·ªãnh)
                probs, log_probs = policy_net(state_emb, action_embs, action_mask)
                
                # D. Ch·ªçn h√†nh ƒë·ªông (Sampling d·ª±a tr√™n x√°c su·∫•t)
                m = torch.distributions.Categorical(probs)
                action_indices = m.sample() # (Batch,)
                
                # E. L∆∞u Log Prob v√† Entropy
                saved_log_probs.append(m.log_prob(action_indices))
                saved_entropies.append(m.entropy())
                
                # F. Th·ª±c hi·ªán b∆∞·ªõc ƒëi (Transition)
                # H√†m n√†y c·∫≠p nh·∫≠t v·ªã tr√≠ agent v√† l·ªãch s·ª≠
                state_emb, done = env.step_with_indices(action_indices, raw_actions)
            
            # --- 5. T√çNH TO√ÅN LOSS & UPDATE (K·∫øt th√∫c 1 trajectory) ---
            
            # G. T√≠nh Reward cu·ªëi c√πng (Terminal Reward)
            # rewards shape: (Batch,)
            rewards = env.get_reward() 
            
            # H. T√≠nh REINFORCE Loss
            # Loss = - sum(log_prob * reward) - beta * entropy
            batch_loss = 0
            
            # C·ªông d·ªìn loss qua c√°c b∆∞·ªõc th·ªùi gian (t=1..K)
            # Gi·∫£ s·ª≠ Reward ·ªü b∆∞·ªõc cu·ªëi √°p d·ª•ng cho to√†n b·ªô chu·ªói h√†nh ƒë·ªông
            for log_prob in saved_log_probs:
                # D·∫•u tr·ª´ (-) ƒë·ªÉ chuy·ªÉn b√†i to√°n Maximize Reward th√†nh Minimize Loss
                batch_loss += -log_prob * rewards
                
            # Tr·ª´ ƒëi Entropy (Khuy·∫øn kh√≠ch kh√°m ph√°)
            for entropy in saved_entropies:
                batch_loss -= beta_entropy * entropy
                
            # L·∫•y trung b√¨nh loss tr√™n to√†n b·ªô Batch users
            loss = batch_loss.mean()
            
            # I. Backpropagation (Lan truy·ªÅn ng∆∞·ª£c)
            optimizer.zero_grad() # X√≥a gradient c≈©
            loss.backward()       # T√≠nh gradient m·ªõi
            
            # K·ªπ thu·∫≠t Gradient Clipping (Tr√°nh b√πng n·ªï gradient - R·∫•t quan tr·ªçng trong RL)
            torch.nn.utils.clip_grad_norm_(policy_net.parameters(), max_norm=1.0)
            
            optimizer.step()      # C·∫≠p nh·∫≠t tr·ªçng s·ªë
            
            # --- 6. LOGGING ---
            total_epoch_loss += loss.item()
            total_epoch_reward += rewards.mean().item()
            num_batches += 1
            
            # C·∫≠p nh·∫≠t thanh ti·∫øn tr√¨nh (Hi·ªÉn th·ªã loss/reward hi·ªán t·∫°i)
            pbar.set_postfix({'Loss': f"{loss.item():.4f}", 'R': f"{rewards.mean().item():.4f}"})
            
        # --- K·∫æT TH√öC 1 EPOCH ---
        scheduler.step() # C·∫≠p nh·∫≠t Learning Rate
        
        avg_loss = total_epoch_loss / num_batches
        avg_reward = total_epoch_reward / num_batches
        
        # In k·∫øt qu·∫£ Epoch
        # print(f"\n[K·∫øt th√∫c Epoch {epoch}] Avg Loss: {avg_loss:.4f} | Avg Reward: {avg_reward:.4f}")
        
        # L∆∞u Checkpoint m·ªói 10 Epoch
        if epoch % 10 == 0:
            ckpt_path = os.path.join(save_path, f"model_epoch_{epoch}.pt")
            torch.save({
                'epoch': epoch,
                'model_state_dict': policy_net.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': avg_loss,
            }, ckpt_path)
            # print(f"üíæ ƒê√£ l∆∞u model t·∫°i {ckpt_path}")

    print(f"‚úÖ Ho√†n th√†nh training sau {(time.time() - start_time):.0f} gi√¢y.")

def main():
    cfg = Config()
    print(f"üî• Starting TPRec Training on {cfg.DEVICE}...")
    
    # -------------------------------------------------
    # B∆Ø·ªöC 1: X√¢y d·ª±ng ƒê·ªì th·ªã Tri th·ª©c (TCKG)
    # -------------------------------------------------
    # L∆∞u √Ω: C·∫ßn ƒë·∫£m b·∫£o file CSV ƒë√£ ƒë∆∞·ª£c map ID v·ªÅ d·∫°ng s·ªë nguy√™n li√™n t·ª•c (0, 1, 2...)
    print("Loading Knowledge Graph...")
    # T√≠nh offset t·ª± ƒë·ªông (nh∆∞ class TCKG t·ªëi ∆∞u t√¥i ƒë√£ vi·∫øt)
    tckg = TCKG(cfg.TCKG_PATH) 

    # -------------------------------------------------
    # B∆Ø·ªöC 2: Kh·ªüi t·∫°o Embeddings t·ª´ file Pickle
    # -------------------------------------------------
    print("Loading Pre-trained TransE Embeddings...")
    pickle_file_path = './pickle/book_transE_embeddings_2026-02-19_14-30-57.pkl'

    with open(pickle_file_path, 'rb') as f:
        saved_data = pickle.load(f)
        
    pretrained_ent = saved_data['entity_embeddings']
    pretrained_rel = saved_data['relation_embeddings']
    
    # 2. Chuy·ªÉn ƒë·ªïi sang PyTorch Tensor (√©p ki·ªÉu Float32 ƒë·ªÉ t√≠nh to√°n neural network)
    # N·∫øu data ƒëang l√† Numpy array:
    if isinstance(pretrained_ent, np.ndarray):
        ent_tensor = torch.tensor(pretrained_ent, dtype=torch.float32)
        rel_tensor = torch.tensor(pretrained_rel, dtype=torch.float32)
    else:
        # N·∫øu data ƒë√£ l√† Tensor s·∫µn:
        ent_tensor = pretrained_ent.clone().detach().float()
        rel_tensor = pretrained_rel.clone().detach().float()
        
    # 3. N·∫°p v√†o nn.Embedding
    # freeze=False: Cho ph√©p RL Agent ti·∫øp t·ª•c c·∫≠p nh·∫≠t (fine-tune) vector trong l√∫c t√¨m ƒë∆∞·ªùng
    # freeze=True: Kh√≥a c·ª©ng vector, RL Agent ch·ªâ h·ªçc Policy Network (H·ªçc nhanh h∆°n, ch·ªëng overfit)
    entity_embs = nn.Embedding.from_pretrained(ent_tensor, freeze=False)
    relation_embs = nn.Embedding.from_pretrained(rel_tensor, freeze=False)

    
    # -------------------------------------------------
    # B∆Ø·ªöC 3: Kh·ªüi t·∫°o Reward Function
    # -------------------------------------------------
    print("Setting up Reward Function...")
    reward_func = TimeAwareRewardFunction(
        user_embs=entity_embs,    # Chia s·∫ª tr·ªçng s·ªë v·ªõi Env
        entity_embs=entity_embs,
        relation_embs=relation_embs,
        interaction_cluster_ids=cfg.INTERACTION_CLUSTER_IDS,
        bias_embs=None, # T·ª± t·∫°o bias m·ªõi
        temperature= 2.0 
    )
    
    # -------------------------------------------------
    # B∆Ø·ªöC 4: Kh·ªüi t·∫°o M√¥i tr∆∞·ªùng (Environment)
    # -------------------------------------------------
    print("Setting up Environment...")
    env = TPRecEnvironment(
        kg=tckg,
        entity_embeddings=entity_embs,
        relation_embeddings=relation_embs,
        reward_function=reward_func, # Inject reward v√†o env
        max_path_len=cfg.MAX_PATH_LEN,
        history_len=cfg.HISTORY_LEN
    )
    
    # -------------------------------------------------
    # B∆Ø·ªöC 5: Kh·ªüi t·∫°o Policy Network (Agent)
    # -------------------------------------------------
    print("Building Policy Network...")
    
    # T√≠nh to√°n k√≠ch th∆∞·ªõc State v√† Action theo c√¥ng th·ª©c chu·∫©n
    # Action = Relation + Entity
    action_dim = cfg.EMBED_DIM + cfg.EMBED_DIM 
    
    # State = User + Flattened_History + Current_Entity
    # Flattened_History = k' * (Relation + Entity)
    history_flat_dim = cfg.HISTORY_LEN * (cfg.EMBED_DIM + cfg.EMBED_DIM)
    state_dim = cfg.EMBED_DIM + history_flat_dim + cfg.EMBED_DIM
    
    print(f"--> State Dim: {state_dim} | Action Dim: {action_dim}")
    
    policy_net = TPRecPolicy(
        state_dim=state_dim,
        action_dim=action_dim,
        hidden_dim=cfg.HIDDEN_DIM
    )
    
    
    # -------------------------------------------------
    # B∆Ø·ªöC 7: B·∫ÆT ƒê·∫¶U TRAINING
    # -------------------------------------------------
    # G·ªçi h√†m train monolithic (t·∫•t c·∫£ trong m·ªôt) b·∫°n ƒë√£ ch·ªçn

    train(
        env=env,
        policy_net=policy_net,
        train_loader=train_loader,
        num_epochs=cfg.NUM_EPOCHS,
        learning_rate=cfg.LEARNING_RATE,
        beta_entropy=cfg.BETA_ENTROPY,
        save_path=cfg.SAVE_DIR,
        device=cfg.DEVICE
    )

if __name__ == "__main__":
    # T·∫°o th∆∞ m·ª•c data gi·∫£ l·∫≠p n·∫øu ch∆∞a c√≥ ƒë·ªÉ test logic (T√πy ch·ªçn)
    if not os.path.exists("checkpoints"):
        os.makedirs("checkpoints")
        
    main()

üî• Starting TPRec Training on cpu...
Loading Knowledge Graph...
Loading TCKG from ./data/book/book_TCKG.csv...
Inverse Relation Offset: 24. Total Relation Space: 49
TCKG Loaded successfully. Graph construction complete.
Loading Pre-trained TransE Embeddings...
Setting up Reward Function...
Setting up Environment...
Building Policy Network...
--> State Dim: 512 | Action Dim: 128
üöÄ B·∫Øt ƒë·∫ßu hu·∫•n luy·ªán 500 epochs...
T·ªïng s·ªë m·∫´u trong Dataset: 14683
T·ªïng s·ªë Batch s·∫Ω ch·∫°y: 15


Epoch 1/500:   0%|          | 0/15 [00:10<?, ?it/s, Loss=3.3544, R=0.5030]
Epoch 1/500:   0%|          | 0/15 [00:10<?, ?it/s, Loss=3.3544, R=0.5030]

[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
Epoch 2/500:   0%|          | 0/15 [00:09<?, ?it/s, Loss=3.3780, R=0.5029]
Epoch 3/500:   0%|          | 0/15 [00:08<?, ?it/s, Loss=3.3806, R=0.5032]

KeyboardInterrupt: 