In [17]:
# !pip install pytorch_lightning

In [18]:
import random, os, pickle, time
import pandas as pd
import numpy as np
import networkx as nx

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader

from tqdm import tqdm

from collections import defaultdict


# Set environment variables for reproducibility and safety
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
import warnings
warnings.filterwarnings('ignore')
from sklearn.metrics import precision_score, recall_score, accuracy_score

# 1. Configuration & Seeding
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

In [19]:
name = 'book'

#### TCKG

In [None]:
class TCKG:
    def __init__(self, tckg_csv_path):
        self.adj_list = defaultdict(list)

        print(f"Loading TCKG from {tckg_csv_path}...")
        df_tckg = pd.read_csv(tckg_csv_path, usecols=['head_id', 'relation_id', 'tail_id'])

        offset = df_tckg['relation_id'].max()

        data = df_tckg[['head_id', 'relation_id', 'tail_id']].to_numpy() # Using numpy to speedup
        for h, r, t in data:
            h, r, t = int(h), int(r), int(t)

            self.adj_list[h].append((r, t))
            self.adj_list[t].append((r + offset, h)) # Ex: r=5 (watched) -> r_inv=105 (watched_by)

        self._prepare_tensors()

        print(f"TCKG Loaded successfully. Graph construction complete.")

    def _prepare_tensors(self):
        num_nodes = max(self.adj_list.keys())

        self.max_edges = 0

        for node, edges in self.adj_list.items():
            if self.max_edge < len(edges):
                self.max_edge = len(edges)

        self.rel_matrix = torch.zeros((num_nodes, self.max_neighbors), dtype=torch.long)
        self.tail_matrix = torch.zeros((num_nodes, self.max_neighbors), dtype=torch.long)
        self.mask_matrix = torch.zeros((num_nodes, self.max_neighbors), dtype=torch.bool)

        for node, edges in self.adj_list.items():
            # Kh√¥ng c√≤n l·ªánh if c·∫Øt b·ªè (random.sample) ·ªü ƒë√¢y n·ªØa
            
            rels = [e[0] for e in edges]
            tails = [e[1] for e in edges]
            length = len(edges)

            # G√°n v√†o ma tr·∫≠n
            self.rel_matrix[node, :length] = torch.tensor(rels, dtype=torch.long)
            self.tail_matrix[node, :length] = torch.tensor(tails, dtype=torch.long)
            self.mask_matrix[node, :length] = True

    def get_neighbors(self, node_id):
        a = self.adj_list[node_id]
        return self.adj_list[node_id]

    def get_all_nodes(self):
        return list(self.adj_list.keys())

    def get_batched_neighbors(self, node_ids_tensor):
        rels = self.rel_matrix[node_ids_tensor]
        tails = self.tail_matrix[node_ids_tensor]
        masks = self.mask_matrix[node_ids_tensor]
        return rels, tails, masks

#### TimeAwareRewardFunction

In [21]:
class TimeAwareRewardFunction(nn.Module):
    def __init__(self, user_embs, entity_embs, relation_embs, interaction_cluster_ids, bias_embs=None, temperature= None):
        """
        Args:
            user_embs (nn.Embedding): Embedding c·ªßa User (e_u)
            entity_embs (nn.Embedding): Embedding c·ªßa Item/Entity (e_v)
            relation_embs (nn.Embedding): Embedding c·ªßa Relation (d√πng ƒë·ªÉ l·∫•y V_U)
            interaction_cluster_ids (list or tensor): Danh s√°ch c√°c Relation ID ƒë·∫°i di·ªán cho Time Clusters.
                                                      V√≠ d·ª•: [20, 21, 22] ·ª©ng v·ªõi interacted_0, interacted_1...
                                                      ƒê√¢y ch√≠nh l√† t·∫≠p {V_U^1, ..., V_U^L}
            bias_embs (nn.Embedding, optional): Bias c·ªßa entity (b_v). N·∫øu None s·∫Ω t·ª± kh·ªüi t·∫°o.
        """
        super(TimeAwareRewardFunction, self).__init__()

        self.user_embs = user_embs
        self.entity_embs = entity_embs
        self.relation_embs = relation_embs

        # Danh s√°ch ID c·ªßa c√°c cluster t∆∞∆°ng t√°c theo th·ªùi gian (V_U)
        # Chuy·ªÉn th√†nh tensor ƒë·ªÉ t√≠nh to√°n song song
        self.register_buffer('cluster_ids', torch.tensor(interaction_cluster_ids, dtype=torch.long))

        # Entity Bias (b_v) - Eq (11)
        if bias_embs is None:
            num_entities = entity_embs.num_embeddings - 1
            self.bias_embs = nn.Embedding(num_entities + 1, 1, padding_idx = 0)
            nn.init.zeros_(self.bias_embs.weight) # Kh·ªüi t·∫°o bias b·∫±ng 0
        else:
            self.bias_embs = bias_embs

        if temperature is None:
            self.temperature = self.entity_embs.embedding_dim ** 0.5
        else:
            self.temperature = temperature

    def calculate_weights(self, history_relation_ids):
        """
        Th·ª±c hi·ªán Eq (13): T√≠nh tr·ªçng s·ªë w_h d·ª±a tr√™n t·∫ßn su·∫•t xu·∫•t hi·ªán trong l·ªãch s·ª≠.

        Args:
            history_relation_ids: (Batch, Max_History_Len) - Ch·ª©a relation ID trong qu√° kh·ª© c·ªßa user.

        Returns:
            weights: (Batch, Num_Clusters) - Vector W_hu
        """
        # 1. So kh·ªõp History v·ªõi Cluster IDs
        # history: (B, H, 1)
        # clusters: (1, 1, L)
        # matches: (B, H, L) -> True n·∫øu relation t·∫°i history kh·ªõp v·ªõi cluster ID
        hist_expanded = history_relation_ids.unsqueeze(-1)
        clusters_expanded = self.cluster_ids.view(1, 1, -1)

        matches = (hist_expanded == clusters_expanded).float()

        # 2. ƒê·∫øm s·ªë l·∫ßn xu·∫•t hi·ªán (T·ª≠ s·ªë Eq 13)
        # Sum theo chi·ªÅu History (dim=1) -> (B, L)
        counts = matches.sum(dim=1)

        # 3. T√≠nh ƒë·ªô d√†i q th·ª±c t·∫ø (M·∫´u s·ªë Eq 13)
        # q = t·ªïng s·ªë l·∫ßn xu·∫•t hi·ªán c·ªßa b·∫•t k·ª≥ cluster n√†o trong history (tr√°nh t√≠nh padding 0)
        q = counts.sum(dim=1, keepdim=True)

        # 4. Normalize ƒë·ªÉ ra tr·ªçng s·ªë (tr√°nh chia cho 0)
        weights = counts / (q + 1e-9)

        return weights

    def forward(self, user_ids, item_ids, history_relation_ids):
        """
        T√≠nh Reward Score g_R(v | u)

        Args:
            user_ids: (Batch,)
            item_ids: (Batch,) - Item ƒë√≠ch (v_hat) m√† Agent d·ª± ƒëo√°n/d·ª´ng l·∫°i
            history_relation_ids: (Batch, History_Len) - L·ªãch s·ª≠ relation c·ªßa user

        Returns:
            scores: (Batch,) - ƒêi·ªÉm reward
        """
        # --- B∆Ø·ªöC 1: L·∫•y Embeddings c∆° b·∫£n ---
        u_e = self.user_embs(user_ids)       # (B, Dim) -> e_u
        
        v_e = self.entity_embs(item_ids)     # (B, Dim) -> e_v
        v_b = self.bias_embs(item_ids).squeeze(-1) # (B,) -> b_v

        # --- B∆Ø·ªöC 2: T√≠nh Personalized Interaction Relation (Eq 12) ---
        # r_vu^T = W_hu * V_U

        # a. T√≠nh weights (B, L)
        weights = self.calculate_weights(history_relation_ids)

        # b. L·∫•y embedding c·ªßa c√°c cluster V_U^1...L
        # cluster_embs shape: (L, Dim)
        cluster_embs = self.relation_embs(self.cluster_ids)

        # c. T√≠nh t·ªïng c√≥ tr·ªçng s·ªë
        # (B, L) x (L, Dim) -> (B, Dim)
        r_interaction = torch.matmul(weights, cluster_embs)

        # --- B∆Ø·ªöC 3: T√≠nh Score (Eq 11 & Final Eq) ---
        # g = (e_u + r_interaction) . e_v + b_v

        # Dot product: (e_u + r) * e_v
        query_vector = u_e + r_interaction # (B, Dim)
        dot_product = torch.sum(query_vector * v_e, dim=1) # (B,)

        scores = dot_product + v_b # C·ªông bias

        # 1. Scale Score (Chia cho nhi·ªát ƒë·ªô)
        scaled_score = scores / self.temperature

        # 2. √Åp d·ª•ng Sigmoid
        rewards = torch.sigmoid(scaled_score)

        return rewards


#### TPRecEnvironment

In [22]:
import torch
import torch.nn as nn

class TPRecEnvironment(nn.Module):
    def __init__(self, tckg, entity_embeddings, relation_embeddings, reward_function, max_path_len=3, history_len=3):
        """
        M√¥i tr∆∞·ªùng TPRec ƒë∆∞·ª£c thi·∫øt k·∫ø chu·∫©n x√°c theo ki·∫øn tr√∫c Sequence State
        """
        super(TPRecEnvironment, self).__init__()
        self.tckg = tckg
        self.entity_embs = entity_embeddings
        self.relation_embs = relation_embeddings
        self.max_path_len = max_path_len
        self.history_len = history_len

        # L∆ØU REWARD FUNCTION ƒê∆Ø·ª¢C TRUY·ªÄN V√ÄO
        self.reward_function = reward_function

        # State tracking
        self.current_entities = None
        self.current_users = None
        self.path_history = None
        self.step_counter = 0

        # Bi·∫øn l∆∞u tr·ªØ ƒë√°p √°n m·ª•c ti√™u ƒë·ªÉ t√≠nh Terminal Bonus
        self.target_items = None

    def reset(self, user_ids, target_items=None):
        """
        Kh·ªüi t·∫°o tr·∫°ng th√°i s_0 = (u, u, ‚àÖ)
        C·∫≠p nh·∫≠t: N·∫°p th√™m target_items ƒë·ªÉ Tr·ªçng t√†i (M√¥i tr∆∞·ªùng) c·∫ßm s·∫µn ƒë√°p √°n
        """
        batch_size = user_ids.size(0)

        self.current_users = user_ids
        self.current_entities = user_ids

        # N·∫†P ƒê√ÅP √ÅN: L∆∞u l·∫°i Target Items cho b∆∞·ªõc t√≠nh Reward cu·ªëi c√πng
        self.target_items = target_items

        # History h_k: store (entity, relation) theo chu·∫©n paper
        self.path_history = torch.zeros((batch_size, self.history_len * 2),
                                        dtype=torch.long,
                                        device=user_ids.device)

        self.step_counter = 0
        return self._get_state_embedding()

    def _get_state_embedding(self):
        """
        K·∫øt h·ª£p u, h_k, e_k th√†nh m·ªôt chu·ªói duy nh·∫•t ƒë∆∞a v√†o BLSTM
        """
        # 1. Th√™m chi·ªÅu th·ªùi gian (unsqueeze) cho u v√† e_k
        u_emb = self.entity_embs(self.current_users).unsqueeze(1)    # Shape: (B, 1, d)
        e_emb = self.entity_embs(self.current_entities).unsqueeze(1) # Shape: (B, 1, d)

        # 2. L·∫•y l·ªãch s·ª≠ e v√† r
        e_indices = self.path_history[:, 0::2]
        r_indices = self.path_history[:, 1::2]

        e_vecs = self.entity_embs(e_indices)   # Shape: (B, L, d)
        r_vecs = self.relation_embs(r_indices) # Shape: (B, L, d)

        # 3. Tr·ªôn xen k·∫Ω e v√† r th√†nh chu·ªói l·ªãch s·ª≠ [e1, r1, e2, r2...]
        B, L, d = e_vecs.shape
        history_seq = torch.zeros((B, L * 2, d), device=e_vecs.device)
        history_seq[:, 0::2, :] = e_vecs
        history_seq[:, 1::2, :] = r_vecs

        # 4. K·∫æT N·ªêI TO√ÄN B·ªò THEO ƒê√öNG C√îNG TH·ª®C S_K TRONG PAPER
        full_state_seq = torch.cat([u_emb, history_seq, e_emb], dim=1) # Shape: (B, Seq_Len, d)

        # Tr·∫£ v·ªÅ ƒê√öNG 1 BI·∫æN DUY NH·∫§T ƒë·∫°i di·ªán cho S_k
        return full_state_seq

    def get_pruned_actions(self, epsilon=15):
        """
        Phi√™n b·∫£n T·ªëi ∆∞u h√≥a Vectorization (GPU Accelerated)
        """
        batch_size = self.current_users.size(0)
        device = self.current_users.device

        # 1. K√âO V·ªÄ CPU M·ªòT L·∫¶N DUY NH·∫§T: Tr√°nh g·ªçi .item() N l·∫ßn
        curr_nodes_cpu = self.current_entities.tolist()

        # Tra c·ª©u l√°ng gi·ªÅng c·ª±c nhanh tr√™n CPU b·∫±ng List Comprehension
        batch_neighbors = [self.tckg.get_neighbors(node) for node in curr_nodes_cpu]

        # =====================================================================
        # üõë NG·∫ÆT LU√îN LI√äN K·∫æT: CH·ªêNG R√í R·ªà D·ªÆ LI·ªÜU (TARGET LEAKAGE)
        # =====================================================================
        if self.target_items is not None:
            targets_cpu = self.target_items.tolist()
            users_cpu = self.current_users.tolist()
            filtered_batch_neighbors = []

            for i, neighbors in enumerate(batch_neighbors):
                # N·∫æU Agent ƒëang ƒë·ª©ng t·∫°i ƒë√∫ng User g·ªëc c·ªßa n√≥
                if curr_nodes_cpu[i] == users_cpu[i]:
                    target_node = targets_cpu[i]
                    # CH√âM B·ªé ƒê∆Ø·ªúNG T·∫ÆT: X√≥a Target Item kh·ªèi danh s√°ch l√°ng gi·ªÅng
                    valid_neighbors = [n for n in neighbors if n[1] != target_node]
                    filtered_batch_neighbors.append(valid_neighbors)
                else:
                    # N·∫øu Agent ƒëang ·ªü Node kh√°c (Item, T√°c gi·∫£...), cho ph√©p ƒëi b√¨nh th∆∞·ªùng
                    filtered_batch_neighbors.append(neighbors)

            batch_neighbors = filtered_batch_neighbors
        # =====================================================================

        # T√¨m node c√≥ s·ªë l∆∞·ª£ng l√°ng gi·ªÅng l·ªõn nh·∫•t trong batch n√†y
        lengths = [len(n) for n in batch_neighbors]
        max_len = max(lengths) if lengths else 0

        valid_actions = []

        # N·∫øu to√†n b·ªô batch ƒë·ªÅu r∆°i v√†o ng√µ c·ª•t
        if max_len == 0:
            return [[] for _ in range(batch_size)]

        # 2. T·∫†O MA TR·∫¨N PADDING TR√äN CPU
        batch_rels = torch.zeros((batch_size, max_len), dtype=torch.long)
        batch_next_nodes = torch.zeros((batch_size, max_len), dtype=torch.long)
        mask = torch.zeros((batch_size, max_len), dtype=torch.bool)

        # ƒêi·ªÅn d·ªØ li·ªáu v√†o ma tr·∫≠n
        for i, neighbors in enumerate(batch_neighbors):
            num_n = lengths[i]
            if num_n > 0:
                batch_rels[i, :num_n] = torch.tensor([n[0] for n in neighbors])
                batch_next_nodes[i, :num_n] = torch.tensor([n[1] for n in neighbors])
                mask[i, :num_n] = True

        # 3. ƒê·∫®Y TO√ÄN B·ªò MA TR·∫¨N L√äN GPU M·ªòT L·∫¶N DUY NH·∫§T
        batch_rels = batch_rels.to(device)
        batch_next_nodes = batch_next_nodes.to(device)
        mask = mask.to(device)

        # 4. T√çNH TO√ÅN SONG SONG B·∫∞NG MA TR·∫¨N TR√äN GPU
        curr_emb = self.entity_embs(self.current_entities) # L·∫•y node HI·ªÜN T·∫†I (ƒê√£ s·ª≠a)
        r_emb = self.relation_embs(batch_rels)
        n_emb = self.entity_embs(batch_next_nodes)

        # T√≠nh Query: C·∫ßn unsqueeze curr_emb ƒë·ªÉ c·ªông broadcast v·ªõi r_emb
        query = curr_emb.unsqueeze(1) + r_emb

        # L·ªñI C≈® C·∫¶N X√ìA: scores = torch.sum(query * n_emb, dim=-1)

        # C√ÅCH CHU·∫®N M·ªöI D√ÄNH CHO TRANSE:
        # T√≠nh kho·∫£ng c√°ch L2 (Euclidean distance) gi·ªØa (h+r) v√† t
        # Kho·∫£ng c√°ch c√†ng nh·ªè c√†ng t·ªët -> Th√™m d·∫•u tr·ª´ (-) ƒë·ªÉ h√†m topk ch·ªçn gi√° tr·ªã g·∫ßn 0 nh·∫•t
        distances = torch.norm(query - n_emb, p=2, dim=-1)
        scores = -distances

        # CHE PADDING (Masking)
        scores = scores.masked_fill(~mask, float('-inf'))

        # 5. CH·ªåN TOP-K CHO C·∫¢ BATCH
        k = min(epsilon, max_len)
        top_scores, top_indices = torch.topk(scores, k, dim=1) # (B, k)

        # 6. K√âO K·∫æT QU·∫¢ V·ªÄ L·∫†I CPU
        top_indices_cpu = top_indices.tolist()
        batch_rels_cpu = batch_rels.tolist()
        batch_nodes_cpu = batch_next_nodes.tolist()

        # Gi·∫£i n√©n th√†nh d·∫°ng List g·ªëc
        for i in range(batch_size):
            actions_i = []
            num_real = lengths[i]

            if num_real > 0:
                valid_k = min(k, num_real)
                for j in range(valid_k):
                    idx = top_indices_cpu[i][j]
                    actions_i.append((batch_rels_cpu[i][idx], batch_nodes_cpu[i][idx]))

            valid_actions.append(actions_i)

        return valid_actions

    def get_action_space_batch(self):
        """
        L·∫•y kh√¥ng gian h√†nh ƒë·ªông cho c·∫£ Batch v√† Padding th√†nh Tensor.
        """
        raw_actions_list = self.get_pruned_actions()
        batch_size = len(raw_actions_list)

        lengths = [len(acts) for acts in raw_actions_list]
        max_len = max(lengths) if lengths else 0
        if max_len == 0:
            max_len = 1

        device = self.current_entities.device

        r_indices = torch.zeros((batch_size, max_len), dtype=torch.long, device=device)
        e_indices = torch.zeros((batch_size, max_len), dtype=torch.long, device=device)
        action_mask = torch.zeros((batch_size, max_len), dtype=torch.float, device=device)

        for i, actions in enumerate(raw_actions_list):
            num_acts = len(actions)
            if num_acts > 0:
                rs = [a[0] for a in actions]
                es = [a[1] for a in actions]

                r_indices[i, :num_acts] = torch.tensor(rs, device=device)
                e_indices[i, :num_acts] = torch.tensor(es, device=device)
                action_mask[i, :num_acts] = 1.0

        r_emb = self.relation_embs(r_indices)
        e_emb = self.entity_embs(e_indices)

        action_embs = torch.cat([r_emb, e_emb], dim=-1)

        return action_embs, action_mask, raw_actions_list

    def step(self, actions):
        """
        Transition function (Eq 9): Chuy·ªÉn tr·∫°ng th√°i sang b∆∞·ªõc k+1
        """
        device = self.current_entities.device
        batch_size = len(actions)

        rels_list = [a[0] for a in actions]
        ents_list = [a[1] for a in actions]

        next_relations = torch.tensor(rels_list, dtype=torch.long, device=device)
        next_entities = torch.tensor(ents_list, dtype=torch.long, device=device)

        # L·∫•y Node hi·ªán t·∫°i l√†m ƒëi·ªÉm xu·∫•t ph√°t (e_{k-1})
        curr_e = self.current_entities.unsqueeze(1)
        new_r = next_relations.unsqueeze(1)

        # N·ªëi l·∫°i theo chu·∫©n Paper: [e_{k-1}, r_k]
        new_entry = torch.cat([curr_e, new_r], dim=1)

        history_shifted = self.path_history[:, 2:]
        self.path_history = torch.cat([history_shifted, new_entry], dim=1)

        self.current_entities = next_entities #tensor([8673, 8614])
        self.step_counter += 1
        done = (self.step_counter >= self.max_path_len)

        return self._get_state_embedding(), done

    def step_with_indices(self, action_indices, raw_actions_list):
        """
        Th·ª±c hi·ªán b∆∞·ªõc ƒëi d·ª±a tr√™n Index (0, 1, 2...) m√† Agent ch·ªçn.
        """
        selected_real_actions = []
        batch_size = len(action_indices)

        for i in range(batch_size):
            idx = action_indices[i].item()
            user_acts = raw_actions_list[i]

            if len(user_acts) > 0:
                idx = min(idx, len(user_acts) - 1)
                real_action = user_acts[idx]
            else:
                curr_node = self.current_entities[i].item()
                real_action = (0, curr_node) # Dead-end: T·ª± tr·ªè v·ªÅ ch√≠nh n√≥

            selected_real_actions.append(real_action)

        return self.step(selected_real_actions)

    def get_reward(self):
        """
        G·ªçi khi done=True. T√≠nh Time-aware Reward g_R(v|u) + Terminal Bonus
        """
        user_ids = self.current_users
        item_ids = self.current_entities

        # history_relation_ids = []   
        # for user_id in user_ids.tolist():   # .tolist() to transfer from gpu to cpu
        #     batch_neighbors = self.tckg.get_neighbors(user_id)            
        #     for (rel, tail) in batch_neighbors:
                # history_relation_ids.append(rel)

        batch_rels, batch_tails, batch_masks = self.tckg.get_batched_neighbors(user_ids.cpu())
        
        # 2. L·∫•y ra m·ªôt tensor ph·∫≥ng 1 chi·ªÅu (flattened) y h·ªát nh∆∞ code c≈© c·ªßa b·∫°n
        # C√∫ ph√°p `[batch_masks]` s·∫Ω lo·∫°i b·ªè t·∫•t c·∫£ c√°c √¥ padding tr·ªëng, 
        # ch·ªâ gi·ªØ l·∫°i ƒë√∫ng c√°c relation h·ª£p l·ªá.
        history_relation_tensor = batch_rels[batch_masks]

        # 1. SOFT REWARD (ƒêi·ªÉm d·∫´n ƒë∆∞·ªùng)
        rewards = self.reward_function(user_ids, item_ids, history_relation_ids)

        return rewards

#### TPRecPolicy

In [23]:
import torch
import torch.nn as nn

class TPRecPolicy(nn.Module):
    def __init__(self, embed_dim, hidden_dim=128, dropout=0.1):
        super(TPRecPolicy, self).__init__()

        # BLSTM gi·ªù ƒë√¢y ch·ªâ c·∫ßn nh·∫≠n input_size = embed_dim (kh√¥ng c·∫ßn nh√¢n 2 n·ªØa)
        self.blstm = nn.LSTM(input_size=embed_dim,
                             hidden_size=hidden_dim // 2,
                             num_layers=1,
                             batch_first=True,
                             bidirectional=True)

        # W1 gi·ªù ƒë√¢y ch·ªâ c·∫ßn nh·∫≠n ƒë·∫ßu ra c·ªßa BLSTM
        self.W1 = nn.Linear(hidden_dim, hidden_dim)
        self.dropout = nn.Dropout(dropout)

        self.Wa = nn.Linear(hidden_dim, embed_dim * 2)
        self.Wc = nn.Linear(hidden_dim, 1)

    def forward(self, full_state_seq, action_embs, action_mask):
        # 1. ƒê∆∞a to√†n b·ªô s_k v√†o BLSTM nh∆∞ c√¥ng th·ª©c (15)
        lstm_out, _ = self.blstm(full_state_seq)

        # # L·∫•y tr·∫°ng th√°i ·ªü b∆∞·ªõc th·ªùi gian cu·ªëi c√πng ƒë·∫°i di·ªán cho to√†n b·ªô chu·ªói
        # lstm_last = lstm_out[:, -1, :]

        # V√¨ bidirectional=True, hidden_dim n√†y th·ª±c ch·∫•t l√† 2 ph·∫ßn gh√©p l·∫°i
        half_dim = self.blstm.hidden_size

        # L·∫•y tr·∫°ng th√°i T√≥m t·∫Øt Chi·ªÅu ƒëi t·ªõi (·ªû index cu·ªëi c√πng: -1)
        forward_last = lstm_out[:, -1, :half_dim]

        # L·∫•y tr·∫°ng th√°i T√≥m t·∫Øt Chi·ªÅu ƒëi l√πi (·ªû index ƒë·∫ßu ti√™n: 0)
        backward_last = lstm_out[:, 0, half_dim:]

        # Gh√©p 2 t√≥m t·∫Øt n√†y l·∫°i th√†nh m·ªôt Context Vector ho√†n h·∫£o
        lstm_last = torch.cat([forward_last, backward_last], dim=-1)

        # 2. T√≠nh x_k theo ƒë√∫ng ph∆∞∆°ng tr√¨nh
        x_k = self.dropout(torch.relu(self.W1(lstm_last)))

        # --- (Ph·∫ßn t√≠nh Actor v√† Critic gi·ªØ nguy√™n) ---
        query = self.Wa(x_k).unsqueeze(1)
        scores = torch.sum(query * action_embs, dim=-1)
        scores = scores.masked_fill(action_mask.bool() == False, float('-inf'))
        probs = torch.softmax(scores, dim=-1)

        value_baseline = self.Wc(x_k).squeeze(-1)

        return probs, value_baseline

#### TPRecLightningModel

In [24]:
import pytorch_lightning as pl
import torch
import torch.optim as optim

class TPRecLightningModel(pl.LightningModule):
    def __init__(self, env, policy_net, learning_rate=1e-3, beta_entropy=0.01):
        super().__init__()

        # L∆∞u l·∫°i hyperparameter (t·ª± ƒë·ªông log v√†o tensorboard n·∫øu d√πng)
        self.save_hyperparameters(ignore=['env', 'policy_net'])

        self.env = env
        self.policy_net = policy_net
        self.learning_rate = learning_rate
        self.beta_entropy = beta_entropy

    def forward(self, full_state_seq, action_embs, action_mask):
        """
        ƒê·∫°o di·ªÖn Lightning chuy·ªÉn ti·∫øp ƒê√öNG 3 tham s·ªë cho Di·ªÖn vi√™n Policy Network
        """
        return self.policy_net(full_state_seq, action_embs, action_mask)

    def training_step(self, batch, batch_idx):
        """
        N∆°i di·ªÖn ra to√†n b·ªô chuy·∫øn ƒëi c·ªßa Agent trong 1 Batch
        B·ªï sung: T√≠nh to√°n HR@10 v√† NDCG@10 tr√™n t·∫≠p Train
        """
        # 1. H·ª©ng c·∫£ User v√† Target Item t·ª´ DataLoader
        if isinstance(batch, (list, tuple)) and len(batch) >= 2:
            batch_users, target_items = batch[0], batch[1]
        else:
            batch_users = batch[0] if isinstance(batch, (list, tuple)) else batch
            target_items = None

        batch_size = batch_users.size(0)

        # 2. G·ªçi reset v√† truy·ªÅn c·∫£ target_items ƒë·ªÉ Env t√≠nh Bonus cu·ªëi c√πng
        full_state_seq = self.env.reset(batch_users, target_items)

        saved_log_probs = []
        saved_values = []
        saved_entropies = []

        for t in range(self.env.max_path_len):
            action_embs, action_mask, raw_actions = self.env.get_action_space_batch()

            # 3. Truy·ªÅn 1 bi·∫øn state duy nh·∫•t v√†o Policy
            probs, value_baseline = self(full_state_seq, action_embs, action_mask)

            # --- [ƒêO·∫†N M·ªöI TH√äM] T√çNH TRAIN METRICS ·ªû B∆Ø·ªöC CU·ªêI C√ôNG ---
            if t == self.env.max_path_len - 1 and target_items is not None:
                k = min(10, probs.size(1))
                _, topk_indices = torch.topk(probs, k=k, dim=1) # (Batch, k)

                top10_list = []
                for i in range(batch_size):
                    user_acts = raw_actions[i]
                    items_i = []
                    for idx in topk_indices[i].tolist():
                        if idx < len(user_acts):
                            items_i.append(user_acts[idx][1])
                    while len(items_i) < 10:
                        items_i.append(0)
                    top10_list.append(items_i)

                final_items_top10 = torch.tensor(top10_list, device=self.device)
                target_expanded = target_items.unsqueeze(1)

                # Ch·∫•m ƒëi·ªÉm Hit Ratio
                hits_matrix = (final_items_top10 == target_expanded).float()
                hr_at_10 = hits_matrix.sum(dim=1).clamp(max=1.0).mean()

                # Ch·∫•m ƒëi·ªÉm NDCG
                ranks = torch.arange(1, 11, device=self.device).float()
                discount = 1.0 / torch.log2(ranks + 1)
                ndcg_at_10 = (hits_matrix * discount).sum(dim=1).mean()

                # Ghi nh·∫≠n v√†o h·ªá th·ªëng Log
                self.log('train_hr@10', hr_at_10, prog_bar=True, on_step=False, on_epoch=True)
                self.log('train_ndcg@10', ndcg_at_10, prog_bar=True, on_step=False, on_epoch=True)
            # -----------------------------------------------------------

            # 4. Agent v·∫´n l·∫•y m·∫´u ng·∫´u nhi√™n (sample) ƒë·ªÉ ph·ª•c v·ª• cho RL
            m = torch.distributions.Categorical(probs)
            action_indices = m.sample()

            saved_log_probs.append(m.log_prob(action_indices))
            saved_values.append(value_baseline)
            saved_entropies.append(m.entropy())

            # 5. H·ª©ng State m·ªõi v√† b∆∞·ªõc sang ch·∫∑ng ti·∫øp theo
            full_state_seq, done = self.env.step_with_indices(action_indices, raw_actions)

        # =====================================================================
        # T√çNH TO√ÅN LOSS THEO C√îNG TH·ª®C (17) V√Ä (18) C·ª¶A B√ÄI B√ÅO
        # =====================================================================
        final_rewards = self.env.get_reward().detach() # Shape: (Batch,)
        
        # 1. C√îNG TH·ª®C 17: T√≠nh Expected Reward (G) c√≥ chi·∫øt kh·∫•u \gamma
        # \sum_{t=0}^{K-1} \gamma^t R_{k+1}
        gamma = 0.99 # H·ªá s·ªë chi·∫øt kh·∫•u chu·∫©n trong RL
        returns = []
        R = final_rewards # Ph·∫ßn th∆∞·ªüng nh·∫≠n ƒë∆∞·ª£c ·ªü b∆∞·ªõc cu·ªëi c√πng
        
        # T√≠nh ng∆∞·ª£c t·ª´ b∆∞·ªõc cu·ªëi v·ªÅ b∆∞·ªõc ƒë·∫ßu
        for step in reversed(range(self.env.max_path_len)):
            returns.insert(0, R)
            R = R * gamma # Chi·∫øt kh·∫•u l√πi v·ªÅ qu√° kh·ª© 1 b∆∞·ªõc

        policy_loss = 0
        value_loss = 0

        # 2. C√îNG TH·ª®C 18: T·ªëi ∆∞u REINFORCE Algorithm
        # \nabla_\Theta \log \pi_\Theta (G - \hat{c}(s_k))
        for G, log_prob, value_baseline in zip(returns, saved_log_probs, saved_values):
            
            # (G - \hat{c}(s_k)): T√≠nh Advantage (S·ª± ch√™nh l·ªách gi·ªØa Th·ª±c t·∫ø G v√† K·ª≥ v·ªçng baseline)
            advantage = G - value_baseline.detach()
            
            # \nabla_\Theta \log \pi_\Theta * Advantage
            # Th√™m d·∫•u TR·ª™ v√¨ PyTorch t·ª± ƒë·ªông Minimize Loss, trong khi ta mu·ªën Maximize Reward
            step_policy_loss = -log_prob * advantage
            policy_loss += step_policy_loss.mean() # L·∫•y trung b√¨nh cho c·∫£ Batch
            
            # Hu·∫•n luy·ªán m·∫°ng Critic (Wa) ƒë·ªÉ d·ª± ƒëo√°n G chu·∫©n x√°c h∆°n b·∫±ng MSE
            step_value_loss = torch.nn.functional.mse_loss(value_baseline, G)
            value_loss += step_value_loss

        # T·ªïng h·ª£p Loss cho to√†n b·ªô ƒë∆∞·ªùng ƒëi
        # (B·∫°n c√≥ th·ªÉ tr·ª´ ƒëi entropy_bonus ·ªü ƒë√¢y n·∫øu mu·ªën Agent kh√°m ph√° nhi·ªÅu h∆°n)
        total_loss = policy_loss + value_loss 

        # Ghi nh·∫≠n log
        self.log('train_reward', final_rewards.mean(), prog_bar=True, on_step=False, on_epoch=True)
        self.log('train_loss', total_loss, prog_bar=False, on_step=False, on_epoch=True)

        return total_loss

    def validation_step(self, batch, batch_idx):
        """
        ƒê√°nh gi√° m√¥ h√¨nh: T√≠nh HR@10 v√† NDCG@10 t·∫°i b∆∞·ªõc cu·ªëi c√πng.
        """
        # 1. H·ª©ng bi·∫øn t·ª´ Dataloader
        if isinstance(batch, (list, tuple)) and len(batch) >= 2:
            batch_users, target_items = batch[0], batch[1]
        else:
            batch_users = batch[0] if isinstance(batch, (list, tuple)) else batch
            target_items = None

        # 2. Reset Env (Truy·ªÅn target_items ƒë·ªÉ t√≠nh reward gi·ªëng Train)
        full_state_seq = self.env.reset(batch_users, target_items)
        batch_size = batch_users.size(0)

        final_items_top10 = None

        # B·∫ÆT ƒê·∫¶U ƒêI T√åM ƒê∆Ø·ªúNG
        for t in range(self.env.max_path_len):
            action_embs, action_mask, raw_actions = self.env.get_action_space_batch()

            # 3. Truy·ªÅn 1 bi·∫øn state duy nh·∫•t v√†o Policy
            probs, _ = self(full_state_seq, action_embs, action_mask)

            # N·∫æU L√Ä B∆Ø·ªöC CU·ªêI C√ôNG: L·∫•y Top 10 thay v√¨ Top 1
            if t == self.env.max_path_len - 1:
                k = min(10, probs.size(1))
                _, topk_indices = torch.topk(probs, k=k, dim=1) # (Batch, k)

                top10_list = []
                for i in range(batch_size):
                    user_acts = raw_actions[i]
                    items_i = []

                    for idx in topk_indices[i].tolist():
                        if idx < len(user_acts):
                            items_i.append(user_acts[idx][1])

                    while len(items_i) < 10:
                        items_i.append(0)

                    top10_list.append(items_i)

                final_items_top10 = torch.tensor(top10_list, device=self.device)
                action_indices = topk_indices[:, 0]
            else:
                action_indices = torch.argmax(probs, dim=1)

            # 4. H·ª©ng ƒë√∫ng 2 bi·∫øn gi·ªëng h·ªát b√™n Train
            full_state_seq, done = self.env.step_with_indices(action_indices, raw_actions)

        # T√≠nh to√°n v√† Log ph·∫ßn th∆∞·ªüng (Reward)
        val_rewards = self.env.get_reward().detach()
        self.log('val_reward', val_rewards.mean(), prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)

        # --- ƒê√ÅNH GI√Å HIT RATIO @10 & NDCG @10 ---
        if target_items is not None and final_items_top10 is not None:
            target_expanded = target_items.unsqueeze(1)
            hits_matrix = (final_items_top10 == target_expanded).float()

            hits_per_user = hits_matrix.sum(dim=1).clamp(max=1.0)
            hr_at_10 = hits_per_user.mean()
            self.log('val_hr@10', hr_at_10, prog_bar=True, on_epoch=True, sync_dist=True)

            ranks = torch.arange(1, 11, device=self.device).float()
            discount = 1.0 / torch.log2(ranks + 1)

            ndcg_per_user = (hits_matrix * discount).sum(dim=1)
            ndcg_at_10 = ndcg_per_user.mean()
            self.log('val_ndcg@10', ndcg_at_10, prog_bar=True, on_epoch=True, sync_dist=True)


    def configure_optimizers(self):
        # L·ªçc ra CH·ªà nh·ªØng tham s·ªë ƒëang requires_grad=True (M·∫°ng Policy)
        # Gi√∫p tr√°nh l·ªói t·ªëi ∆∞u h√≥a c√°c tham s·ªë ƒë√£ b·ªã freeze (nh∆∞ Embeddings)
        trainable_params = filter(lambda p: p.requires_grad, self.parameters())

        optimizer = optim.Adam(trainable_params, lr=self.learning_rate)
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.9)

        return [optimizer], [scheduler]

### Class Config

In [25]:
class Config:
    # ƒê∆∞·ªùng d·∫´n d·ªØ li·ªáu
    TCKG_PATH = f"./data/{name}/{name}_TCKG.csv"  # File CSV ch·ª©a ƒë·ªì th·ªã (nh∆∞ b·∫°n ƒë√£ g·ª≠i)
    # TRAIN_USERS_PATH = "data/train_users.csv" # File ch·ª©a danh s√°ch user ID d√πng ƒë·ªÉ train
    SAVE_DIR = "./checkpoints"

    # Si√™u tham s·ªë Model
    EMBED_DIM = 64
    HIDDEN_DIM = 128
    HISTORY_LEN = 3   # k' (ƒë·ªô d√†i l·ªãch s·ª≠)
    MAX_PATH_LEN = 3  # K (s·ªë b∆∞·ªõc ƒëi)

    # Si√™u tham s·ªë Training
    BATCH_SIZE = 512 # Batch l·ªõn gi√∫p RL ·ªïn ƒë·ªãnh h∆°n
    NUM_EPOCHS = 500
    LEARNING_RATE = 1e-3
    BETA_ENTROPY = 0.01

    # Thi·∫øt b·ªã
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # IDs c·ªßa c√°c Relation Interaction (Quan tr·ªçng cho Reward)
    # V√≠ d·ª•: 20=interacted_0, 21=interacted_1, 22=interacted_2
    INTERACTION_CLUSTER_IDS = [21, 22, 23, 24, 45, 46, 47, 48]

### Load Dataset

In [26]:
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader

class InteractionDataset(Dataset):
    def __init__(self, df):
        """Nh·∫≠n v√†o m·ªôt DataFrame v√† chuy·ªÉn ƒë·ªïi c√°c c·ªôt c·∫ßn thi·∫øt th√†nh m·∫£ng NumPy"""
        self.users = df['user_id'].values
        self.entities = df['entity_id'].values # ƒê√ÇY CH√çNH L√Ä TARGET ITEM

    def __len__(self):
        """Tr·∫£ v·ªÅ t·ªïng s·ªë d√≤ng d·ªØ li·ªáu"""
        return len(self.users)

    def __getitem__(self, idx):
        """L·∫•y d·ªØ li·ªáu t·∫°i v·ªã tr√≠ idx v√† bi·∫øn th√†nh PyTorch Tensor"""
        user_tensor = torch.tensor(self.users[idx], dtype=torch.long)
        entity_tensor = torch.tensor(self.entities[idx], dtype=torch.long)

        # ƒê√É S·ª¨A: Tr·∫£ v·ªÅ m·ªôt Tuple g·ªìm (User, Target_Item)
        return user_tensor, entity_tensor

# 2. ƒê·ªçc file CSV th√†nh Pandas DataFrame
# (L∆∞u √Ω nh·ªè: M√¨nh gi·ªØ nguy√™n t√™n file c·ªßa b·∫°n, nh∆∞ng ch·ªØ 'interacions' h√¨nh nh∆∞ ƒëang thi·∫øu ch·ªØ 't', b·∫°n nh·ªõ ki·ªÉm tra l·∫°i t√™n file th·∫≠t nh√©)

train_df = pd.read_csv(f'./data/{name}/{name}_train_interactions.csv')
val_df = pd.read_csv(f'./data/{name}/{name}_val_interactions.csv')
test_df = pd.read_csv(f'./data/{name}/{name}_test_interactions.csv')

# 3. B·ªçc DataFrame v√†o class Dataset v·ª´a t·∫°o
train_dataset = InteractionDataset(train_df)
val_dataset = InteractionDataset(val_df)
test_dataset = InteractionDataset(test_df)

# 4. ƒê∆∞a Dataset v√†o DataLoader
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=2, shuffle=False, num_workers=0)

### Main function

In [30]:
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping

cfg = Config()
print("üî• Starting TPRec Training with PyTorch Lightning...")

# -------------------------------------------------
# B∆Ø·ªöC 1: X√¢y d·ª±ng ƒê·ªì th·ªã Tri th·ª©c (TCKG)
# -------------------------------------------------
# L∆∞u √Ω: C·∫ßn ƒë·∫£m b·∫£o file CSV ƒë√£ ƒë∆∞·ª£c map ID v·ªÅ d·∫°ng s·ªë nguy√™n li√™n t·ª•c (0, 1, 2...)
print("Loading Knowledge Graph...")
# T√≠nh offset t·ª± ƒë·ªông (nh∆∞ class TCKG t·ªëi ∆∞u t√¥i ƒë√£ vi·∫øt)
tckg = TCKG(cfg.TCKG_PATH)

# -------------------------------------------------
# B∆Ø·ªöC 2: Kh·ªüi t·∫°o Embeddings t·ª´ file Pickle
# -------------------------------------------------
print("Loading Pre-trained TransE Embeddings...")
pickle_file_path = f'./pickle/{name}_transE_embeddings_2026-02-22_13-17-36.pkl'

with open(pickle_file_path, 'rb') as f:
    saved_data = pickle.load(f)

pretrained_ent = saved_data['entity_embeddings']
pretrained_rel = saved_data['relation_embeddings']

# 2. Chuy·ªÉn ƒë·ªïi sang PyTorch Tensor (√©p ki·ªÉu Float32 ƒë·ªÉ t√≠nh to√°n neural network)
# N·∫øu data ƒëang l√† Numpy array:
if isinstance(pretrained_ent, np.ndarray):
    ent_tensor = torch.tensor(pretrained_ent, dtype=torch.float32)
    rel_tensor = torch.tensor(pretrained_rel, dtype=torch.float32)
else:
    # N·∫øu data ƒë√£ l√† Tensor s·∫µn:
    ent_tensor = pretrained_ent.clone().detach().float()
    rel_tensor = pretrained_rel.clone().detach().float()

# 3. N·∫°p v√†o nn.Embedding
# freeze=False: Cho ph√©p RL Agent ti·∫øp t·ª•c c·∫≠p nh·∫≠t (fine-tune) vector trong l√∫c t√¨m ƒë∆∞·ªùng
# freeze=True: Kh√≥a c·ª©ng vector, RL Agent ch·ªâ h·ªçc Policy Network (H·ªçc nhanh h∆°n, ch·ªëng overfit)
entity_embs = nn.Embedding.from_pretrained(ent_tensor, freeze=True, padding_idx=0)
relation_embs = nn.Embedding.from_pretrained(rel_tensor, freeze=True, padding_idx=0)


# -------------------------------------------------
# B∆Ø·ªöC 3: Kh·ªüi t·∫°o Reward Function
# -------------------------------------------------
print("Setting up Reward Function...")
reward_func = TimeAwareRewardFunction(
    user_embs=entity_embs,    # Chia s·∫ª tr·ªçng s·ªë v·ªõi Env
    entity_embs=entity_embs,
    relation_embs=relation_embs,
    interaction_cluster_ids=cfg.INTERACTION_CLUSTER_IDS,
    bias_embs=None, # T·ª± t·∫°o bias m·ªõi
    temperature= None
)

# -------------------------------------------------
# B∆Ø·ªöC 4: Kh·ªüi t·∫°o M√¥i tr∆∞·ªùng (Environment)
# -------------------------------------------------
print("Setting up Environment...")
env = TPRecEnvironment(
    tckg=tckg,
    entity_embeddings=entity_embs,
    relation_embeddings=relation_embs,
    reward_function=reward_func, # Inject reward v√†o env
    max_path_len=cfg.MAX_PATH_LEN,
    history_len=cfg.HISTORY_LEN
)

# -------------------------------------------------
# B∆Ø·ªöC 5: Kh·ªüi t·∫°o Policy Network (Agent)
# -------------------------------------------------
print("Building Policy Network...")

# T√≠nh to√°n k√≠ch th∆∞·ªõc State v√† Action theo c√¥ng th·ª©c chu·∫©n
# Action = Relation + Entity
action_dim = cfg.EMBED_DIM + cfg.EMBED_DIM

# State = User + Flattened_History + Current_Entity
# Flattened_History = k' * (Relation + Entity)
history_flat_dim = cfg.HISTORY_LEN * (cfg.EMBED_DIM + cfg.EMBED_DIM)
state_dim = cfg.EMBED_DIM + history_flat_dim + cfg.EMBED_DIM

print(f"--> State Dim: {state_dim} | Action Dim: {action_dim}")

policy_net = TPRecPolicy(
        embed_dim=cfg.EMBED_DIM,      # V√≠ d·ª•: 64
        hidden_dim=cfg.HIDDEN_DIM,    # V√≠ d·ª•: 128
        dropout=0.1                   # T·ªâ l·ªá dropout gi√∫p ch·ªëng Overfit (theo paper)
    )

# B∆Ø·ªöC 6: ƒê√ìNG G√ìI V√ÄO LIGHTNING MODEL
print("Packing into Lightning Module...")
lightning_model = TPRecLightningModel(
    env=env,
    policy_net=policy_net,
    learning_rate=cfg.LEARNING_RATE,
    beta_entropy=cfg.BETA_ENTROPY
)

# B∆Ø·ªöC 7: C·∫§U H√åNH L∆ØU BEST MODEL (CHECKPOINT)
# T·ª± ƒë·ªông theo d√µi 'train_reward' ·ªü cu·ªëi m·ªói epoch v√† l∆∞u l·∫°i b·∫£n c√≥ ƒëi·ªÉm cao nh·∫•t
checkpoint_callback = ModelCheckpoint(
    dirpath=cfg.SAVE_DIR,
    filename='tprec-best-{epoch:02d}-{train_reward:.4f}',
    monitor='train_hr@10',
    mode='max', # L∆∞u model c√≥ reward l·ªõn nh·∫•t
    save_top_k=1,
    save_last=True # L∆∞u th√™m model ·ªü epoch cu·ªëi c√πng ƒë·ªÉ ph√≤ng h·ªù
)

# 2. TH√äM CALLBACK EARLY STOPPING V√ÄO ƒê√ÇY
early_stop_callback = EarlyStopping(
      monitor='val_hr@10',   # Ph·∫£i c√πng t√™n v·ªõi bi·∫øn monitor ·ªü Checkpoint
      min_delta=0.001,       # S·ª± thay ƒë·ªïi t·ªëi thi·ªÉu ƒë·ªÉ ƒë∆∞·ª£c t√≠nh l√† "c√≥ c·∫£i thi·ªán"
      patience=500,           # S·ª©c ch·ªãu ƒë·ª±ng: Cho ph√©p m√¥ h√¨nh d·∫≠m ch√¢n t·∫°i ch·ªó t·ªëi ƒëa 10 Epoch
      verbose=True,          # B·∫≠t in th√¥ng b√°o ra m√†n h√¨nh khi Early Stop k√≠ch ho·∫°t
      mode='max'             # 'max' v√¨ ta mu·ªën ch·ªâ s·ªë HR@10/Reward c√†ng l·ªõn c√†ng t·ªët
  )

# B∆Ø·ªöC 8: KH·ªûI T·∫†O TRAINER V√Ä B·∫ÆT ƒê·∫¶U CH·∫†Y
print("Initializing Lightning Trainer...")
trainer = pl.Trainer(
    max_epochs=cfg.NUM_EPOCHS,
    accelerator="auto", # T·ª± ƒë·ªông t√¨m v√† d√πng GPU n·∫øu c√≥
    devices=1,
    gradient_clip_val=1.0, # T·ª± ƒë·ªông √°p d·ª•ng Gradient Clipping
    callbacks=[checkpoint_callback, early_stop_callback],
    enable_progress_bar=True,
    # log_every_n_steps=10,
    num_sanity_val_steps=0
)

print("üöÄ B·∫Øt ƒë·∫ßu hu·∫•n luy·ªán...")
# DataLoader c·ªßa b·∫°n c·∫ßn ƒë∆∞·ª£c truy·ªÅn v√†o ƒë√¢y
trainer.fit(
        model=lightning_model,
        train_dataloaders=train_loader,
        val_dataloaders=val_loader
    )


üî• Starting TPRec Training with PyTorch Lightning...
Loading Knowledge Graph...
Loading TCKG from ./data/book/book_TCKG.csv...


GPU available: False, used: False
TPU available: False, using: 0 TPU cores


TCKG Loaded successfully. Graph construction complete.
Loading Pre-trained TransE Embeddings...
Setting up Reward Function...
Setting up Environment...
Building Policy Network...
--> State Dim: 512 | Action Dim: 128
Packing into Lightning Module...
Initializing Lightning Trainer...
üöÄ B·∫Øt ƒë·∫ßu hu·∫•n luy·ªán...



  | Name       | Type             | Params | Mode  | FLOPs
----------------------------------------------------------------
0 | env        | TPRecEnvironment | 2.5 M  | train | 0    
1 | policy_net | TPRecPolicy      | 99.7 K | train | 0    
----------------------------------------------------------------
138 K     Trainable params
2.5 M     Non-trainable params
2.6 M     Total params
10.522    Total estimated model params size (MB)
11        Modules in train mode
0         Modules in eval mode
0         Total Flops


Epoch 0:   0%|          | 0/18757 [00:00<?, ?it/s]

AttributeError: 'TCKG' object has no attribute 'rel_matrix'

In [None]:
trainer.validate(dataloaders=test_loader, ckpt_path='best')