In [1]:
import pickle
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.utils.data as data_utils
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, accuracy_score
from sklearn.ensemble import RandomForestClassifier

import pandas as pd
import numpy as np
import scipy.sparse as sp
from scipy.sparse import issparse
import json
import inspect
import random 
from tqdm import tqdm
from collections import Counter
from pathlib import Path
from typing import Optional, Dict, List, Tuple
from matplotlib import pyplot as plt

# ============================================= function split data =============================================
class TestSplitter(object):
    def __init__(self, args):
        self.test_size = args['test_size']
        self.uid = 'user_id'
        self.tid = 'item_id'

    def split(self, df):
        train_index, test_index = split_test(df, self.test_size, self.uid)

        return train_index, test_index


class ValidationSplitter(object):
    def __init__(self, args):
        # self.fold_num = args.fold_num
        self.val_size = args['val_size']
        self.uid = 'user_id'
        self.tid = 'item_id'

    def split(self, df):
        train_val_index_zip = split_validation(df, self.val_size, self.uid)

        return train_val_index_zip


def split_test(df, test_size=0.1, uid='user_id'):

    test_ids = df.groupby(uid).apply(
        lambda x: x.sample(frac=test_size).index
    )
    test_ids = test_ids.explode().dropna().values.astype(int)
    # test_ids = np.array([int(x) for x in test_ids if not pd.isna(x)])
    test_ids = np.array(list(test_ids))
    train_ids = np.setdiff1d(df.index.values, test_ids)

    return train_ids, test_ids


def split_validation(train_set, val_size=.1, uid='user_id'):

    train_set = train_set.reset_index(drop=True)

    # train_set_list, val_set_list = [], []
    # for _ in range(fold_num):
    val_ids = train_set.groupby(uid).apply(
        lambda x: x.sample(frac=val_size).index
    )
    val_ids = val_ids.explode().dropna().values.astype(int)
    # val_ids = np.array([int(x) for x in val_ids if not pd.isna(x)])
    # val_ids = np.array(list(val_ids))
    train_ids = np.setdiff1d(train_set.index.values, val_ids)

    # train_set     _list.append(train_ids)
    # val_set_list.append(val_ids)

    return train_ids, val_ids 

# ============================================= function metrics =============================================

class Metric(object):
    def __init__(self, config) -> None:
        self.metrics = config['metrics']
        self.item_num = config['item_num']
        self.item_pop = config['item_pop'] if 'coverage' in self.metrics else None
        self.i_categories = config['i_categories'] if 'diversity' in self.metrics else None

    def run(self, test_ur, pred_ur, test_u):
        res = []
        for mc in self.metrics:
            if mc == 'ndcg':
                kpi = NDCG(test_ur, pred_ur, test_u)
            elif mc == 'recall':
                kpi = Recall(test_ur, pred_ur, test_u)
            elif mc == 'precision':
                kpi = Precision(test_ur, pred_ur, test_u)
            else:
                raise ValueError(f'Invalid metric name {mc}')

            res.append(kpi)

        return res

def Precision(test_ur, pred_ur, test_u):
    res = []
    for idx in range(len(test_u)):
        u = test_u[idx]
        gt = test_ur[u]
        pred = pred_ur[idx]
        pre = np.in1d(pred, list(gt)).sum() / len(pred)

        res.append(pre)

    return np.mean(res)


def Recall(test_ur, pred_ur, test_u):
    res = []
    for idx in range(len(test_u)):
        u = test_u[idx]
        gt = test_ur[u]
        pred = pred_ur[idx]
        rec = np.in1d(pred, list(gt)).sum() / len(gt)

        res.append(rec)

    return np.mean(res)

def getDCG(scores):
    return np.sum(
        np.divide(np.power(2, scores) - 1, np.log2(np.arange(scores.shape[0], dtype=np.float32) + 1)+1),
        # np.divide(scores, np.log2(np.arange(scores.shape[0], dtype=np.float32) + 2)+1),
        dtype=np.float32)

def getNDCG(rank_list, pos_items):
    relevance = np.ones_like(pos_items)
    it2rel = {it: r for it, r in zip(pos_items, relevance)}
    rank_scores = np.asarray([it2rel.get(it, 0.0) for it in rank_list], dtype=np.float32)
    idcg = getDCG(relevance)

    dcg = getDCG(rank_scores)

    if dcg == 0.0:
        return 0.0

    ndcg = dcg / idcg
    return ndcg

def NDCG(test_ur, pred_ur, test_u):
    res = []
    for idx in range(len(test_u)):
        u = test_u[idx]
        gt = test_ur[u]
        pred = pred_ur[idx]
        nd = getNDCG(pred, gt)
        res.append(nd)
    return np.mean(res)


def AUC(test_ur, pred_ur, test_u):
    res = []

    for idx in range(len(test_u)):
        u = test_u[idx]
        gt = test_ur[u]
        pred = pred_ur[idx]

        r = np.in1d(pred, list(gt))
        pos_num = r.sum()
        neg_num = len(pred) - pos_num

        # Handle edge cases: if no positive or no negative items, AUC is undefined
        if pos_num == 0 or neg_num == 0:
            continue  # Skip this user instead of adding NaN

        pos_rank_num = 0
        for j in range(len(r) - 1):
            if r[j]:
                pos_rank_num += np.sum(~r[j + 1:])

        auc = pos_rank_num / (pos_num * neg_num)
        res.append(auc)

    return np.mean(res) if len(res) > 0 else 0.0

def AUC_true_neg(test_ur, pred_scores, true_neg_dict, test_u):
    aucs = []
    for idx, u in enumerate(test_u):
        pos = set(test_ur[u])
        neg = true_neg_dict.get(u, set())
        if len(pos) == 0 or len(neg) == 0:
            continue

        scores = pred_scores[idx]
        pos_scores = [scores[i] for i in pos if i < len(scores)]
        neg_scores = [scores[i] for i in neg if i < len(scores)]

        cnt = 0
        for ps in pos_scores:
            cnt += sum(ps > ns for ns in neg_scores)

        aucs.append(cnt / (len(pos_scores) * len(neg_scores)))
    return np.mean(aucs) if aucs else 0.0

# ============================================= function get data =============================================

def get_ur(df):
    print("Method of getting user-rating pairs")
    ur = df.groupby('user_id').item_id.apply(list).to_dict()
    # print(ur)
    return ur

class BasicDataset(data_utils.Dataset):
    def __init__(self, samples):
  
        super(BasicDataset, self).__init__()
        self.data = samples

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index][0], self.data[index][1], self.data[index][2]

class Cf_valDataset(data_utils.Dataset):
    def __init__(self, data):
        super(Cf_valDataset, self).__init__()
        self.user = data
        # self.data = data

    def __len__(self):
        return len(self.user)

    def __getitem__(self, index):
        user = self.user[index]
        return torch.tensor(user)#, torch.tensor(self.data[user])

def get_train_loader(dataset, args):
    dataloader = data_utils.DataLoader(dataset, batch_size=args['train_batch_size'], shuffle=True, pin_memory=True)
    return dataloader

def get_val_loader(dataset, args):
    dataloader = data_utils.DataLoader(dataset, batch_size=args['val_batch_size'], shuffle=False, pin_memory=True)
    return dataloader

def get_test_loader(dataset, args):
    dataloader = data_utils.DataLoader(dataset, batch_size=args['test_batch_size'], shuffle=False, pin_memory=True)
    return dataloader

def get_inter_matrix(df, args):
    '''
    get the whole sparse interaction matrix
    '''
    print("get the whole sparse interaction matrix")
    user_num, item_num = args['user_num'], args['item_num']

    src, tar = df['user_id'].values, df['item_id'].values
    data = df['click'].values

    mat = sp.coo_matrix((data, (src, tar)), shape=(user_num, item_num))

    return mat

def build_relation_matrices_from_df(df: pd.DataFrame, relations: List[str], user_num: int, item_num: int) -> Dict[str, sp.coo_matrix]:
    """
    df: should have columns ['user_id', 'item_id', <relations...>] with binary flags 0/1
    relations: e.g. ['click','like','share','follow','exposed']
    Returns dict: relation -> scipy.sparse.coo_matrix (user_num x item_num)
    """
    rel_mats = {}
    for r in relations:
        sub = df[df[r] == 1]
        if sub.shape[0] == 0:
            rel_mats[r] = sp.coo_matrix((user_num, item_num))
            continue
        rows = sub['user_id'].values.astype(np.int32)
        cols = sub['item_id'].values.astype(np.int32)
        data = np.ones(len(rows), dtype=np.float32)
        mat = sp.coo_matrix((data, (rows, cols)), shape=(user_num, item_num))
        rel_mats[r] = mat
    return rel_mats

# ============================================= function neg sampler =============================================
import numpy as np
from collections import Counter

class HybridNegativeSampler:
    """
    Multi-method negative sampler for implicit CF / BPR triples.

    Supported methods:
      - 'uniform'  : sample unobserved items uniformly
      - 'high-pop' : sample unobserved items by popularity (freq^(3/4))
      - 'low-pop'  : sample unobserved items by inverse popularity
      - 'true_neg' : sample from provided true negative pool per user
      - 'hybrid'   : mix (true_neg + uniform) with ratio = sample_ratio

    Usage:
      sampler = HybridNegativeSampler(cf_args)
      triples = sampler.sampling(df_train_edges, train_ur, true_neg_df=final_neg_pool)
    """

    def __init__(self, args: dict):
        self.user_num = args['user_num']
        self.item_num = args['item_num']
        self.num_ng = args.get('num_ng', 4)

        self.method = args.get('sampler_method', 'uniform')
        self.sample_ratio = float(args.get('sample_ratio', 0.5))  # used by hybrid
        self.seed = int(args.get('seed', 42))

        assert self.method in ['uniform', 'high-pop', 'low-pop', 'true_neg', 'hybrid'], \
            f"Invalid sampler_method: {self.method}"

        np.random.seed(self.seed)

        # will be built if needed
        self.pop_prob = None  # size = item_num

    # ---------- popularity distribution ----------
    def _build_pop_prob(self, train_edges_df):
        """
        train_edges_df: DataFrame with columns ['user_id','item_id'] (encoded)
        """
        cnt = Counter(train_edges_df['item_id'].values.tolist())
        freq = np.zeros(self.item_num, dtype=np.float64)
        for i, c in cnt.items():
            if 0 <= i < self.item_num:
                freq[i] = c

        # smoothing like word2vec negative sampling: p(i) ~ f(i)^(3/4)
        prob = freq / (freq.sum() + 1e-12)
        prob = np.power(prob, 0.75)

        if self.method == 'high-pop':
            prob = prob / (prob.sum() + 1e-12)
        elif self.method == 'low-pop':
            # inverse popularity: items with low freq get higher probability
            inv = 1.0 - (prob / (prob.max() + 1e-12))
            inv = np.clip(inv, 0.0, None)
            prob = inv / (inv.sum() + 1e-12)
        else:
            prob = prob / (prob.sum() + 1e-12)

        self.pop_prob = prob.astype(np.float64)

    # ---------- true negative dict ----------
    @staticmethod
    def _build_true_neg_dict(true_neg_df):
        """
        true_neg_df: DataFrame ['user_id','item_id'] encoded, TRAIN-ONLY (no leakage)
        """
        d = {}
        if true_neg_df is None or len(true_neg_df) == 0:
            return d
        for u, i in zip(true_neg_df['user_id'].values, true_neg_df['item_id'].values):
            d.setdefault(int(u), set()).add(int(i))
        return d

    # ---------- single draw helpers ----------
    def _sample_uniform_one(self, u_seen, chosen):
        j = np.random.randint(self.item_num)
        while (j in u_seen) or (j in chosen):
            j = np.random.randint(self.item_num)
        return j

    def _sample_pop_one(self, u_seen, chosen):
        # assumes self.pop_prob built
        j = int(np.random.choice(self.item_num, p=self.pop_prob))
        tries = 0
        while (j in u_seen) or (j in chosen):
            j = int(np.random.choice(self.item_num, p=self.pop_prob))
            tries += 1
            if tries > 50:
                # fallback to uniform to avoid dead-loop for dense users
                return self._sample_uniform_one(u_seen, chosen)
        return j

    def _sample_true_neg_one(self, u, u_seen, chosen, true_neg_dict):
        pool = list(true_neg_dict.get(u, []))
        if len(pool) == 0:
            return None
        j = int(np.random.choice(pool))
        tries = 0
        while (j in u_seen) or (j in chosen):
            j = int(np.random.choice(pool))
            tries += 1
            if tries > 50:
                return None
        return j

    # ---------- main API ----------
    def sampling(self, train_edges_df, train_ur, true_neg_df=None):
        """
        Returns: np.ndarray of shape [num_triples, 3] with columns [u, pos, neg]
        Inputs:
          - train_edges_df: df containing positive edges used to build popularity dist (encoded)
          - train_ur: dict user -> list(pos_items) (encoded)
          - true_neg_df: df containing (u,i) true negatives, TRAIN-ONLY (encoded)
        """
        if self.num_ng <= 0:
            raise ValueError("num_ng must be > 0 for BPR")

        # Build pop dist only if needed
        if self.method in ['high-pop', 'low-pop']:
            self._build_pop_prob(train_edges_df)

        true_neg_dict = self._build_true_neg_dict(true_neg_df) if self.method in ['true_neg', 'hybrid'] else {}

        triples = []
        for u in range(self.user_num):
            pos_list = train_ur.get(u, [])
            if len(pos_list) == 0:
                continue
            u_seen = set(pos_list)

            for pos in pos_list:
                chosen = set()

                # --- HYBRID: mix true_neg + uniform (or you can change to pop if you want) ---
                if self.method == 'hybrid':
                    k_true = int(round(self.num_ng * self.sample_ratio))
                    k_true = max(0, min(self.num_ng, k_true))
                    k_other = self.num_ng - k_true

                    # true neg part
                    for _ in range(k_true):
                        j = self._sample_true_neg_one(u, u_seen, chosen, true_neg_dict)
                        if j is None:
                            break
                        chosen.add(j)

                    # uniform part (fallback)
                    while len(chosen) < self.num_ng:
                        chosen.add(self._sample_uniform_one(u_seen, chosen))

                elif self.method == 'true_neg':
                    # only true negatives, fallback to uniform if not enough
                    while len(chosen) < self.num_ng:
                        j = self._sample_true_neg_one(u, u_seen, chosen, true_neg_dict)
                        if j is None:
                            # fallback uniform to guarantee enough negs
                            j = self._sample_uniform_one(u_seen, chosen)
                        chosen.add(j)

                elif self.method == 'uniform':
                    while len(chosen) < self.num_ng:
                        chosen.add(self._sample_uniform_one(u_seen, chosen))

                elif self.method in ['high-pop', 'low-pop']:
                    while len(chosen) < self.num_ng:
                        chosen.add(self._sample_pop_one(u_seen, chosen))

                else:
                    raise ValueError(f"Unknown method: {self.method}")

                for neg in chosen:
                    triples.append([u, int(pos), int(neg)])

        return np.asarray(triples, dtype=np.int32)

# ============================================= Model =============================================


class LightGCN(nn.Module):
    """A self-contained LightGCN implementation.

    Features:
    - Builds normalized adjacency from a scipy COO interaction matrix (user-item bipartite)
    - Layer-wise propagation (no non-linearities / no feature transform)
    - Mean aggregation over (L+1) layers (including the 0-th embedding)
    - Supports BPR, hinge (HL), TOP1 (TL), and point-wise (BCEWithLogits, MSE) losses via configure_loss
    - Ranking utilities: rank(), full_rank()
    """
    def __init__(self, args):
        super().__init__()
        self.num_users = args['user_num']
        self.num_items = args['item_num']
        self.embedding_dim = args.get('embedding_dim', 64)
        self.num_layers = args.get('num_layers', 3)
        self.interaction_matrix = args.get('interaction_matrix', None)
        self.device = torch.device(args.get('device', 'cpu'))
        self.reg_1 = args.get('reg_1', 0.0)
        self.reg_2 = args.get('reg_2', 0.0)
        self.lr = args.get('lr', 0.001)
        self.topk = args.get('k', 20)
        self.val_ur = args.get('val_ur', None)
        self.val_u = args.get('val_u', None)
        self.early_stop = args.get('early_stop', True)
        self.save_path = args.get('save_path', './')
        self.true_neg = args.get('true_neg', False)
        self.true_neg_dict = args.get('true_neg_dict', {})
        self.load = args.get('load', False)


        # storage variables for rank evaluation acceleration
        self.restore_user_e = None
        self.restore_item_e = None

        # Embeddings
        self.embed_user = nn.Embedding(self.num_users, self.embedding_dim)
        self.embed_item = nn.Embedding(self.num_items, self.embedding_dim)
        self.apply(self._init_weights)

        if self.interaction_matrix is None:
            raise ValueError("interaction_matrix (scipy sparse) is required")
        if not sp.issparse(self.interaction_matrix):
            raise TypeError("interaction_matrix must be a scipy sparse matrix")

        self.register_buffer('norm_adj_matrix', self._build_norm_adj(self.interaction_matrix).coalesce())

    #  Initialization 
    def _init_weights(self, m):
        if isinstance(m, nn.Embedding):
            nn.init.xavier_normal_(m.weight)

    #  Adjacency 
    def _build_norm_adj(self, inter_M: sp.coo_matrix) -> torch.sparse.FloatTensor:
        """Build symmetric normalized adjacency A_hat for user-item bipartite graph."""
        inter_M = inter_M.tocoo()
        A = sp.dok_matrix((self.num_users + self.num_items, self.num_users + self.num_items), dtype=np.float32)
        # user->item (offset items by num_users)
        data_dict = dict(zip(zip(inter_M.row, inter_M.col + self.num_users), [1]*inter_M.nnz))
        # item->user
        data_dict.update(dict(zip(zip(inter_M.col + self.num_users, inter_M.row), [1]*inter_M.nnz)))
        A._update(data_dict)

        sum_arr = (A > 0).sum(axis=1)
        deg = np.array(sum_arr.flatten())[0] + 1e-7
        deg_inv_sqrt = np.power(deg, -0.5)
        D = sp.diags(deg_inv_sqrt)
        L = D * A * D  # symmetric norm
        L = sp.coo_matrix(L)
        indices = torch.LongTensor(np.vstack([L.row, L.col]))
        values = torch.FloatTensor(L.data)
        return torch.sparse.FloatTensor(indices, values, torch.Size(L.shape))

    #  Forward Propagation 
    def forward(self) -> Tuple[torch.Tensor, torch.Tensor]:
        all_embeddings = torch.cat([self.embed_user.weight, self.embed_item.weight], dim=0)
        embeddings_list = [all_embeddings]
        for _ in range(self.num_layers):
            all_embeddings = torch.sparse.mm(self.norm_adj_matrix, all_embeddings)
            embeddings_list.append(all_embeddings)
        # Mean over layers
        final = torch.mean(torch.stack(embeddings_list, dim=1), dim=1)
        user_final, item_final = torch.split(final, [self.num_users, self.num_items])
        return user_final, item_final

    def _bpr_loss(self, pos_scores, neg_scores):
        return -torch.mean(F.logsigmoid(pos_scores - neg_scores))

    def calc_loss(self, batch):
        # ensure model is on correct device
        self.to(self.device)
        # clear stored embeddings before computing
        if self.restore_user_e is not None or self.restore_item_e is not None:
            self.restore_user_e, self.restore_item_e = None, None

        # prepare batch indices
        user = batch[0].to(self.device).long()
        if user.dim() == 0:
            user = user.unsqueeze(0)
        pos_item = batch[1].to(self.device).long()
        if pos_item.dim() == 0:
            pos_item = pos_item.unsqueeze(0)

        # compute embeddings
        embed_user, embed_item = self.forward()
        embed_user = embed_user.to(self.device)
        embed_item = embed_item.to(self.device)

        # positive predictions
        u_emb = embed_user[user]
        p_emb = embed_item[pos_item]
        pos_pred = (u_emb * p_emb).sum(dim=1)

        # ego embeddings for regularization
        u_ego = self.embed_user(user)
        p_ego = self.embed_item(pos_item)

        # compute loss
        neg = batch[2].to(self.device).long()
        neg_emb = embed_item[neg]
        neg_pred = (u_emb * neg_emb).sum(dim=1)
        neg_ego = self.embed_item(neg)
        loss = self._bpr_loss(pos_pred, neg_pred)
        loss += self.reg_1 * (u_ego.norm(p=1) + p_ego.norm(p=1) + neg_ego.norm(p=1))
        loss += self.reg_2 * (u_ego.norm() + p_ego.norm() + neg_ego.norm())

        return loss

    #  Ranking 
    def rank(self, test_loader):
        if self.restore_user_e is None or self.restore_item_e is None:
            self.restore_user_e, self.restore_item_e = self.forward()

        rec_ids = torch.tensor([], device=self.device)
        self.eval()
        with torch.no_grad():
            for us in test_loader:
                us = us.to(self.device)
                rank_list = self.full_rank(us)

                rec_ids = torch.cat((rec_ids, rank_list), 0)

        return rec_ids.cpu().numpy().astype(np.int32)
    
    def full_rank(self, u):
        if self.restore_user_e is None or self.restore_item_e is None:
            self.restore_user_e, self.restore_item_e = self.forward()

        # ensure CPU indices for CPU embeddings and convert to long type
        if u.device != self.restore_user_e.device:
            u_idx = u.to(self.restore_user_e.device).long()
        else:
            u_idx = u.long()
        user_emb = self.restore_user_e[u_idx]  # (batch_size, dim)
        items_emb = self.restore_item_e  # (num_items, dim)
        # compute scores and top-k
        scores = torch.matmul(user_emb, items_emb.transpose(1, 0))
        rank = torch.argsort(scores, descending=True)[:, :self.topk]
        # move to evaluation device
        return rank.to(self.device)
    
    def predict(self, u, i):
        if self.restore_user_e is None or self.restore_item_e is None:
            self.restore_user_e, self.restore_item_e = self.forward()

        u_embedding = self.restore_user_e[u]
        i_embedding = self.restore_item_e[i]
        pred = torch.matmul(u_embedding, i_embedding.t())

        return pred.cpu().item()

    def fit(self, train_loader, val_loader, epochs: int = 10):
        opt = optim.Adam(self.parameters(), lr=self.lr)
        self.to(self.device)

        start = 0
        history = {'train_loss': [], 'val_ndcg': [], 'val_recall': [], 'val_precision': [], 'val_auc': []}
        best_ndcg = -np.inf
        patience_counter = 0
        if (Path(self.save_path) / 'best_model.pth').exists() and self.load:
            print("Load model from checkpoint")
            checkpoint = torch.load(Path(self.save_path) / 'best_model.pth', map_location=self.device, weights_only=False)
            self.load_state_dict(checkpoint['model_state_dict'])
            opt.load_state_dict(checkpoint['optimizer_state_dict'])
            start = checkpoint['epoch'] + 1
            history = checkpoint['history']
            best_ndcg = max(history['val_ndcg']) if len(history['val_ndcg']) > 0 else -np.inf   

        for epoch in range(start, epochs+1):
            self.train()
            current_loss = 0.0
            pbar = tqdm(train_loader)
            pbar.set_description(f'[Epoch {epoch:03d}]')
            for batch in pbar:
                opt.zero_grad()
                loss = self.calc_loss(batch)
                if torch.isnan(loss):
                    raise ValueError("NaN loss encountered")
                loss.backward()
                opt.step()
                current_loss += loss.item()

            epoch_loss = current_loss / len(train_loader)
            pbar.set_postfix(loss=epoch_loss)
            history['train_loss'].append(epoch_loss)

            preds = self.rank(val_loader)
            ndcg = NDCG(self.val_ur, preds, self.val_u)
            recall = Recall(self.val_ur, preds, self.val_u)
            precision = Precision(self.val_ur, preds, self.val_u)
            history['val_ndcg'].append(ndcg)
            history['val_recall'].append(recall)
            history['val_precision'].append(precision)
            if self.true_neg:
                if self.true_neg_dict is not None:
                    auc = AUC_true_neg(self.val_ur, self.restore_user_e[self.val_u].detach().cpu().numpy() @ self.restore_item_e.detach().cpu().numpy().T, self.true_neg_dict, self.val_u)
                else:
                    auc = 0.5
                history['val_auc'].append(auc)
                print(f"Training - Loss {epoch_loss:.4f} | Validation - NDCG@{self.topk}: {ndcg:.4f}, Recall@{self.topk}: {recall:.4f}, Precision@{self.topk}: {precision:.4f}, AUC@{self.topk}: {auc:.4f}")
            else:
                print(f"Training - Loss {epoch_loss:.4f} | Validation - NDCG@{self.topk}: {ndcg:.4f}, Recall@{self.topk}: {recall:.4f}, Precision@{self.topk}: {precision:.4f}")
            
            # Save best model
            state = {
                'epoch': epoch,
                'model_state_dict': self.state_dict(),
                'optimizer_state_dict': opt.state_dict(),
                'history': history
            }
            if ndcg > best_ndcg:
                best_ndcg = ndcg
                torch.save(state, Path(self.save_path) / 'best_model.pth')
                patience_counter = 0
            else:
                patience_counter += 1
                print(f'Patience counter: {patience_counter}/5')
                
            torch.save(state, Path(self.save_path) / 'last_model.pth')
            
            # Early stopping
            if self.early_stop and patience_counter >= 5:
                print('Satisfy early stop mechanism')
                break
        return history
    


# ============================================= RF denoiser =============================================
def build_behavior_features(df):
    """
    df: raw_df_train (TRAIN-ONLY)
    Required cols: user_id, item_id, click, like, share, follow, watching_times
    """
    agg = df.groupby(['user_id','item_id']).agg(
        count_click=('click','sum'),
        count_like=('like','sum'),
        count_share=('share','sum'),
        count_follow=('follow','sum'),
        avg_watch_time=('watching_times','mean')
    ).reset_index()

    return agg

def normalize_video_category(x):
    if x == 0:
        return 0
    if x == 1:
        return 1
    if isinstance(x, str):
        x = x.strip()
        if x == '0':
            return 2
        if x == '1':
            return 3
        return 4
    return 4


def build_rf_features(raw_df_train):
    """
    Returns X (features), y (label), feature_names
    """

    # -------- behavior --------
    beh = build_behavior_features(raw_df_train)

    # -------- user meta --------
    user_meta = raw_df_train[['user_id','gender','age']].drop_duplicates()

    # -------- item meta --------
    item_meta = raw_df_train[['item_id','video_category']].drop_duplicates()

    item_meta['video_category'] = (
        item_meta['video_category']
        .apply(normalize_video_category)
        .astype(np.int32)
    )
    # -------- merge --------
    feat = beh.merge(user_meta, on='user_id', how='left') \
              .merge(item_meta, on='item_id', how='left')

    # -------- fillna --------
    feat[['gender','age','video_category']] = feat[['gender','age','video_category']].fillna(0)

    # -------- label --------
    # Strong positive = like | share | follow
    feat['label'] = ((feat['count_like'] +
                      feat['count_share'] +
                      feat['count_follow']) > 0).astype(int)

    y = feat['label'].values
    X = feat.drop(columns=['user_id','item_id','label'])

    return feat[['user_id','item_id']], X.values, y, X.columns.tolist()

#Denoise ambiguous interactions using Random Forest
def train_rf_denoiser(raw_df_train, cf_args):
    ui_keys, X, y, feat_names = build_rf_features(raw_df_train)

    rf = RandomForestClassifier(
        n_estimators=cf_args['rf_n_estimators'],
        max_depth=cf_args['rf_max_depth'],
        random_state=cf_args['rf_random_state'],
        n_jobs=-1
    )
    rf.fit(X, y)

    probs = rf.predict_proba(X)[:,1]

    res = ui_keys.copy()
    res['prob'] = probs
    res['label'] = y

    return rf, res, feat_names

def denoise_interactions(rf_res, pos_th=0.6, neg_th=0.4):
    """
    rf_res: output of train_rf_denoiser
    """
    pos_df = rf_res[rf_res['prob'] >= pos_th][['user_id','item_id']]
    true_neg_df = rf_res[rf_res['prob'] <= neg_th][['user_id','item_id']]

    return pos_df.drop_duplicates(), true_neg_df.drop_duplicates()

def prepare_training_data(raw_df_train, train_click, cf_args):
    """
    Returns:
      train_pos_df : edges for LightGCN training
      true_neg_df  : true negatives for sampler (or None)
    """

    if not cf_args['denoise']:
        print("[INFO] Denoise OFF → use click-only training")
        return train_click[['user_id','item_id']], None

    print("[INFO] Denoise ON → training RF denoiser")
    rf, rf_res, feat_names = train_rf_denoiser(raw_df_train, cf_args)

    pos_df, true_neg_df = denoise_interactions(rf_res)

    print(f"[RF] train_click edges: {len(train_click)}")
    print(f"[RF] denoised positives: {len(pos_df)}")
    print(f"[RF] true negatives: {len(true_neg_df)}")

    return pos_df, true_neg_df

# ============================================= Reranking =============================================

def get_topN_candidates(model, users, cf_args):
    """
    Returns: np.ndarray [num_users, topN]
    """
    model.topk = cf_args.get('rerank_topN', 20)
    loader = get_test_loader(Cf_valDataset(users), cf_args)
    return model.rank(loader)

def build_behavior_features_ui(raw_df_train):
    return raw_df_train.groupby(['user_id','item_id']).agg(
        cnt_click=('click','sum'),
        cnt_like=('like','sum'),
        cnt_share=('share','sum'),
        cnt_follow=('follow','sum'),
        avg_watch=('watching_times','mean')
    ).reset_index()

def get_user_item_meta(raw_df):
    user_meta = raw_df[['user_id','gender','age']].drop_duplicates()
    item_meta = raw_df[['item_id','video_category']].drop_duplicates()
    return user_meta, item_meta

def build_rerank_features(
    candidates,
    model,
    users,
    raw_df_train,
    raw_df_all
):
    """
    candidates: [num_users, topN]
    users: user ids aligned with candidates
    """

    # embeddings
    U_emb, I_emb = model.forward()
    U_emb = U_emb.detach().cpu().numpy()
    I_emb = I_emb.detach().cpu().numpy()

    # meta & behavior
    beh = build_behavior_features_ui(raw_df_train)
    user_meta, item_meta = get_user_item_meta(raw_df_all)

    beh = encode_ui(beh)
    user_meta = encode_ui(user_meta)
    item_meta = encode_ui(item_meta)

    beh_dict = {(u,i): list(row)  # row already contains just the behavior features after unpacking
                for u,i,*row in beh.itertuples(index=False)}

    X = []
    keys = []
    
    # Debug: print embedding dimensions
    print(f"DEBUG: U_emb shape: {U_emb.shape}, I_emb shape: {I_emb.shape}")
    print(f"DEBUG: beh_dict has {len(beh_dict)} entries")
    print(f"DEBUG: user_meta shape: {user_meta.shape}, item_meta shape: {item_meta.shape}")

    for ui, u in enumerate(tqdm(users, desc="building reranking feature")):
        for rank, i in enumerate(candidates[ui]):
            u, i = int(u), int(i)

            feat = []
            feat_debug = {}  # Track component lengths

            # embeddings - convert numpy arrays to lists of floats
            u_emb_list = [float(x) for x in U_emb[u]]
            i_emb_list = [float(x) for x in I_emb[i]]
            feat.extend(u_emb_list)
            feat.extend(i_emb_list)
            feat_debug['u_emb'] = len(u_emb_list)
            feat_debug['i_emb'] = len(i_emb_list)

            # dot product
            feat.append(float(np.dot(U_emb[u], I_emb[i])))
            feat_debug['dot'] = 1

            # rank position
            feat.append(float(rank))
            feat_debug['rank'] = 1

            # behavior - ensure all are floats
            beh_feats = beh_dict.get((u,i), [0,0,0,0,0])
            beh_feats_list = [float(x) for x in beh_feats]
            feat.extend(beh_feats_list)
            feat_debug['beh'] = len(beh_feats_list)

            # user meta
            um = user_meta[user_meta['user_id']==u]
            if len(um) > 0:
                um_vals = um[['gender','age']].values[0]
                feat.extend([float(um_vals[0]), float(um_vals[1])])
                feat_debug['u_meta'] = 2
            else:
                feat.extend([0.0, 0.0])
                feat_debug['u_meta'] = 2

            # item meta
            im = item_meta[item_meta['item_id']==i]
            if len(im) > 0:
                feat.append(float(im['video_category'].values[0]))
                feat_debug['i_meta'] = 1
            else:
                feat.append(0.0)
                feat_debug['i_meta'] = 1
            
            # Debug first few features
            if len(X) < 3:
                print(f"DEBUG: Feature {len(X)}: u={u}, i={i}, components={feat_debug}, total_len={len(feat)}")

            X.append(feat)
            keys.append((u,i))

    # First check if all features have the same length
    if len(X) > 0:
        feat_lens = [len(f) for f in X]
        unique_lens = set(feat_lens)
        if len(unique_lens) > 1:
            print(f"\n⚠ WARNING: Inconsistent feature lengths detected!")
            print(f"Unique lengths: {sorted(unique_lens)}")
            unique_counts = dict(zip(*np.unique(feat_lens, return_counts=True)))
            print(f"Length distribution: {unique_counts}")
            
            # Show samples of each unique length
            for length in sorted(unique_lens)[:5]:  # Show up to 5 different lengths
                idx = next(i for i, l in enumerate(feat_lens) if l == length)
                u, i = keys[idx]
                print(f"\nSample with length {length} (index {idx}, u={u}, i={i}):")
                print(f"  Feature[:10]: {X[idx][:10]}")
                print(f"  Feature[-10:]: {X[idx][-10:]}")
    
    try:
        X_arr = np.array(X, dtype=np.float32)
    except ValueError as e:
        print(f"\n✗ Error converting features to array: {e}")
        print(f"Total features: {len(X)}")
        print(f"Feature length stats:")
        print(f"  Min: {min(feat_lens)}")
        print(f"  Max: {max(feat_lens)}")
        print(f"  Mode: {max(set(feat_lens), key=feat_lens.count)}")
        
        # Find and show the problematic features
        mode_len = max(set(feat_lens), key=feat_lens.count)
        problem_indices = [i for i, l in enumerate(feat_lens) if l != mode_len]
        print(f"\nFound {len(problem_indices)} features with non-standard length")
        print(f"Standard length: {mode_len}, Problem indices (first 10): {problem_indices[:10]}")
        
        raise
    
    return X_arr, keys

def train_reranker_rf(
    model,
    train_users,
    train_ur,
    raw_df_train,
    raw_df_all,
    cf_args
):
    # candidates from LightGCN
    candidates = get_topN_candidates(model, train_users, cf_args)

    # build features
    X, keys = build_rerank_features(
        candidates,
        model,
        train_users,
        raw_df_train,
        raw_df_all
    )

    # labels: whether (u,i) is positive in train_ur
    y = np.array([
        1 if i in train_ur.get(u, []) else 0
        for u,i in keys
    ])

    rf = RandomForestClassifier(
        n_estimators=cf_args.get('rerank_rf_estimators', 100),
        max_depth=cf_args.get('rerank_rf_max_depth', 10),
        random_state=cf_args.get('rerank_rf_random_state', 42),
        n_jobs=-1
    )
    rf.fit(X, y)

    return rf

def rerank_candidates(
    model,
    rf,
    test_users,
    raw_df_train,
    raw_df_all,
    cf_args
):
    candidates = get_topN_candidates(model, test_users, cf_args)

    X, keys = build_rerank_features(
        candidates,
        model,
        test_users,
        raw_df_train,
        raw_df_all
    )

    scores = rf.predict_proba(X)[:,1]

    reranked = []
    idx = 0
    topN = cf_args.get('rerank_topN', 20)

    for _ in tqdm(test_users, desc="reranking"):
        s = scores[idx:idx+topN]
        c = candidates[len(reranked)]
        order = np.argsort(-s)
        reranked.append(c[order])
        idx += topN
    
    return np.array(reranked)



# ============================================= plot =============================================
def save_results(history, test_results, cf_args):
    """
    Save training history and test results to JSON file.
    Only saves JSON-serializable configuration parameters.
    """
    exp_dir =cf_args['save_path']
    os.makedirs(exp_dir, exist_ok=True)

    # Select only serializable config fields
    config_to_save = {
        'exp_name': cf_args['exp_name'],
        'embedding_dim': cf_args['embedding_dim'],
        'lr': cf_args['lr'],
        'num_layers': cf_args['num_layers'],
        'epochs': cf_args['epochs'],
        'k': cf_args['k'],
        'reg_1': cf_args['reg_1'],
        'reg_2': cf_args['reg_2'],
        'train_batch_size': cf_args['train_batch_size'],
        'val_batch_size': cf_args['val_batch_size'],
        'test_batch_size': cf_args['test_batch_size'],
        'sample_method': cf_args.get('sample_method', cf_args.get('sampler_method', 'uniform')),
        'sample_ratio': cf_args['sample_ratio'],
        'num_ng': cf_args['num_ng'],
        'test_size': cf_args['test_size'],
        'val_size': cf_args['val_size'],
        'denoise': cf_args['denoise'],
        'true_neg': cf_args['true_neg'],
        'rerank': cf_args['rerank'],
        'user_num': cf_args['user_num'],
        'item_num': cf_args['item_num'],
    }

    result = {
        'config': config_to_save,
        'final_train_loss': float(history['train_loss'][-1]) if history['train_loss'] else None,
        'best_val_recall': float(max(history['val_recall'])) if history['val_recall'] else None,
        'best_val_ndcg': float(max(history['val_ndcg'])) if history['val_ndcg'] else None,
        'test_metrics': {k: float(v) for k, v in test_results.items()}
    }

    with open(os.path.join(exp_dir, 'results.json'), 'w') as f:
        json.dump(result, f, indent=2)

    print(f"[INFO] Results saved to {exp_dir}")

def plot_training_curves(history, cf_args):
    os.makedirs(cf_args['plot_path'], exist_ok=True)
    exp_dir = os.path.join(cf_args['plot_path'], cf_args['exp_name'])
    os.makedirs(exp_dir, exist_ok=True)

    epochs = range(1, len(history['train_loss']) + 1)

    # ----- LOSS -----
    plt.figure()
    plt.plot(epochs, history['train_loss'])
    plt.xlabel('Epoch')
    plt.ylabel('BPR Loss')
    plt.title('Training Loss')
    if cf_args['plot_path']:
        plt.savefig(os.path.join(exp_dir, 'train_loss.png'), dpi=300)
    plt.show()

    # ----- VALIDATION METRICS -----
    if len(history['val_recall']) > 0:
        plt.figure()
        plt.plot(epochs, history['val_recall'], label='Recall')
        plt.plot(epochs, history['val_ndcg'], label='NDCG')
        plt.xlabel('Epoch')
        plt.ylabel('Score')
        plt.title(f'Validation Metrics @K={cf_args["k"]}')
        plt.legend()
        if cf_args['plot_path']:
            plt.savefig(os.path.join(exp_dir, 'val_metrics.png'), dpi=300)
        plt.show()

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# CF Task Pipeline
# 1. Define CF parameters
cf_args = {
    'save_path': './checkpoint/',
    'plot_path': './plots/',
    'dataset_path': './data/data.csv',
    'device': 'cuda',
    'exp_name': 'LightGCN_QB_video',
    
    # Training
    'train_batch_size': 2048,
    'val_batch_size': 64,
    'test_batch_size': 64,

    # Model
    'embedding_dim': 128,
    'lr': 0.005,
    'num_layers': 3,
    'epochs': 30,
    'k': 20,
    'reg_1': 0.0,
    'reg_2': 0.0,

    # Negative sampling
    'sample_method': 'hybrid',  # uniform, high-pop, low-pop, true_neg, hybrid
    'sample_ratio': 0.3,
    'num_ng': 4,

    # Data split
    'test_size': 0.1,
    'val_size': 0.1111,

    # RF denoise
    'rf_n_estimators': 300,
    'rf_max_depth': 14,
    'rf_random_state': 42,

    # rerank
    'rerank_topN': 100,
    'rerank_rf_estimators': 300,
    'rerank_rf_max_depth': 14,

    # hybrid
    'early_stop': True,
    'true_neg': True,
    'denoise': True,
    'rerank': True,
    'plot': True,
    'load': False,
}

os.makedirs(cf_args['save_path'], exist_ok=True)
os.makedirs(cf_args['plot_path'], exist_ok=True)

# 2. Load and preprocess data
# ===== Load data =====
raw_df = pd.read_csv(cf_args['dataset_path'])
raw_df.fillna(0, inplace=True)

# ===== Fixed-universe mapping =====
user_ids = raw_df['user_id'].unique()
item_ids = raw_df['item_id'].unique()

user_map = {u:i for i,u in enumerate(user_ids)}
item_map = {i:j for j,i in enumerate(item_ids)}

cf_args['user_num'] = len(user_map)
cf_args['item_num'] = len(item_map)
def encode_ui(df):
    df = df.copy()
    # Đảm bảo user_id và item_id là cột, nếu đang là index thì reset về cột
    if 'user_id' not in df.columns and getattr(df.index, 'name', None) == 'user_id':
        df = df.reset_index()  # đưa index user_id thành cột
    if 'item_id' not in df.columns and getattr(df.index, 'name', None) == 'item_id':
        df = df.reset_index()  # đưa index item_id thành cột
    # Thực hiện mapping nếu cột tồn tại
    if 'user_id' in df.columns:
        df['user_id'] = df['user_id'].map(user_map)
    if 'item_id' in df.columns:
        df['item_id'] = df['item_id'].map(item_map)
    
    # Remove NaN (unmapped IDs) and convert to int
    df = df.dropna(subset=[col for col in ['user_id', 'item_id'] if col in df.columns])
    if 'user_id' in df.columns:
        df['user_id'] = df['user_id'].astype(int)
    if 'item_id' in df.columns:
        df['item_id'] = df['item_id'].astype(int)
    
    return df

print(f"Users={cf_args['user_num']}, Items={cf_args['item_num']}")

# 3. Split train/val/test
# ===== Click-positive interactions =====
click_df = (
    raw_df[raw_df['click'] == 1][['user_id','item_id']]
    .drop_duplicates()
    .reset_index(drop=True)
)

# ===== Train / Test =====
train_idx, test_idx = split_test(click_df, cf_args['test_size'])

train_click_raw_full = click_df.iloc[train_idx].reset_index(drop=True)
test_click_raw       = click_df.iloc[test_idx].reset_index(drop=True)

# ===== Train / Val (IMPORTANT: split on FULL train, then slice from FULL train) =====
train_idx2, val_idx = split_validation(train_click_raw_full, cf_args['val_size'])

train_click_raw = train_click_raw_full.iloc[train_idx2].reset_index(drop=True)
val_click_raw   = train_click_raw_full.iloc[val_idx].reset_index(drop=True)

# Encode
train_click = encode_ui(train_click_raw)
val_click = encode_ui(val_click_raw)
test_click = encode_ui(test_click_raw)

train_ur = get_ur(train_click)
val_ur = get_ur(val_click)
test_ur = get_ur(test_click)

cf_args['train_ur'] = train_ur

cf_args['val_ur'] = val_ur

cf_args['test_ur'] = test_ur
cf_args['test_u'] = np.array(list(test_ur.keys()))
cf_args['val_u'] = np.array(list(val_ur.keys()))

train_users_raw = set(train_click_raw_full['user_id'].values)  # raw IDs
raw_df_train = raw_df[raw_df['user_id'].isin(train_users_raw)].copy()

Users=34240, Items=130637
Method of getting user-rating pairs
Method of getting user-rating pairs
Method of getting user-rating pairs


In [None]:

ablation_settings = [
    {'name': 'A0_Baseline', 'load': True,
     'denoise': False, 'true_neg': False,
     'sampler_method': 'uniform', 'rerank': False,
     'save_path': './checkpoint/A0_Baseline',},

    {'name': 'A1_Rerank', 'load': True,
     'denoise': False, 'true_neg': False,
     'sampler_method': 'uniform', 'rerank': True,
     'save_path': './checkpoint/A1_Rerank',},

    {'name': 'A2_Denoise',
     'denoise': True, 'true_neg': False,
     'sampler_method': 'uniform', 'rerank': False,
     'save_path': './checkpoint/A2_Denoise',},

    {'name': 'A3_TrueNeg',
     'denoise': True, 'true_neg': True,
     'sampler_method': 'true_neg', 'rerank': False,
     'save_path': './checkpoint/A3_TrueNeg',},

    {'name': 'A4_HybridNeg',
     'denoise': True, 'true_neg': True,
     'sampler_method': 'hybrid', 'rerank': False,
     'save_path': './checkpoint/A4_HybridNeg',},

    {'name': 'A5_Full',
     'denoise': True, 'true_neg': True,
     'sampler_method': 'hybrid', 'rerank': True,
     'save_path': './checkpoint/A5_Full',},
]


ablation_results = {}

for setting in ablation_settings:
    cf_run = cf_args.copy()
    cf_run.update(setting)
    cf_run['exp_name'] = setting['name']
    cf_run['save_path'] = setting['save_path']
    cf_run['load'] = setting.get('load', False)
    print(f"\n===== Running {cf_run['exp_name']} =====")

    # ----- DATA PREP -----
    train_pos_df, true_neg_df = prepare_training_data(
        raw_df_train, train_click, cf_run
    )
    train_pos_df = encode_ui(train_pos_df)
    if true_neg_df is not None:
        true_neg_df = encode_ui(true_neg_df)

    # Validate encoded IDs (ablation loop)
    print(f"[{cf_run['exp_name']}] Encoded edges: {len(train_pos_df)}")
    if len(train_pos_df) == 0:
        print(f"[ERROR] No training edges for {cf_run['exp_name']}. Skipping...")
        continue
    
    assert train_pos_df['user_id'].min() >= 0, f"Negative user_id in {cf_run['exp_name']}"
    assert train_pos_df['item_id'].min() >= 0, f"Negative item_id in {cf_run['exp_name']}"
    assert train_pos_df['user_id'].max() < cf_run['user_num'], f"user_id OOB in {cf_run['exp_name']}"
    assert train_pos_df['item_id'].max() < cf_run['item_num'], f"item_id OOB in {cf_run['exp_name']}"


    # ----- TRAIN GNN -----
    graph_df = train_pos_df.copy()
    graph_df['click'] = 1
    cf_run['interaction_matrix'] = get_inter_matrix(graph_df, cf_run)
    cf_run['true_neg_dict'] = HybridNegativeSampler._build_true_neg_dict(true_neg_df) if cf_run['true_neg'] else None

    os.makedirs(cf_run['save_path'], exist_ok=True)

    sampler = HybridNegativeSampler(cf_run)
    triples = sampler.sampling(
        train_edges_df=train_pos_df,
        train_ur=train_ur,
        true_neg_df=true_neg_df if cf_run['true_neg'] else None
    )

    train_ds = BasicDataset(triples)
    train_loader = get_train_loader(train_ds, cf_run)

    model = LightGCN(cf_run)
    history = model.fit(
        train_loader,
        val_loader=get_val_loader(Cf_valDataset(cf_args['val_u']), cf_run),
        epochs=cf_run['epochs']
    )

    # ----- INFERENCE -----
    if cf_run['rerank']:
        rf_reranker = train_reranker_rf(
            model, list(train_ur.keys()), train_ur,
            raw_df_train, raw_df, cf_run
        )
        preds = rerank_candidates(
            model, rf_reranker,
            cf_args['test_u'], raw_df_train, raw_df, cf_run
        )
    else:
        preds = get_topN_candidates(model, cf_args['test_u'], cf_run)

    # ----- METRICS -----
    results = {
        'Recall@20': Recall(test_ur, preds, cf_args['test_u']),
        'NDCG@20': NDCG(test_ur, preds, cf_args['test_u']),
        'Precision@20': Precision(test_ur, preds, cf_args['test_u']),
    }

    save_results(history, results, cf_run)
    ablation_results[setting['name']] = results



===== Running A0_Baseline =====
[INFO] Denoise OFF → use click-only training
[A0_Baseline] Encoded edges: 46
get the whole sparse interaction matrix


[Epoch 000]: 100%|██████████| 2647/2647 [04:41<00:00,  9.40it/s]


Training - Loss 0.1724 | Validation - NDCG@20: 0.0469, Recall@20: 0.0945, Precision@20: 0.0141


[Epoch 001]: 100%|██████████| 2647/2647 [04:36<00:00,  9.58it/s]


Training - Loss 0.0584 | Validation - NDCG@20: 0.0481, Recall@20: 0.0959, Precision@20: 0.0154


[Epoch 002]: 100%|██████████| 2647/2647 [04:47<00:00,  9.20it/s]


Training - Loss 0.0298 | Validation - NDCG@20: 0.0491, Recall@20: 0.0967, Precision@20: 0.0157


[Epoch 003]: 100%|██████████| 2647/2647 [04:22<00:00, 10.08it/s]


Training - Loss 0.0143 | Validation - NDCG@20: 0.0502, Recall@20: 0.1010, Precision@20: 0.0159


[Epoch 004]: 100%|██████████| 2647/2647 [04:23<00:00, 10.04it/s]


Training - Loss 0.0060 | Validation - NDCG@20: 0.0463, Recall@20: 0.0976, Precision@20: 0.0154
Patience counter: 1/5


[Epoch 005]: 100%|██████████| 2647/2647 [04:24<00:00, 10.02it/s]


Training - Loss 0.0022 | Validation - NDCG@20: 0.0476, Recall@20: 0.0993, Precision@20: 0.0155
Patience counter: 2/5


[Epoch 006]: 100%|██████████| 2647/2647 [04:29<00:00,  9.83it/s]


Training - Loss 0.0007 | Validation - NDCG@20: 0.0474, Recall@20: 0.0990, Precision@20: 0.0155
Patience counter: 3/5


[Epoch 007]: 100%|██████████| 2647/2647 [04:23<00:00, 10.06it/s]


Training - Loss 0.0002 | Validation - NDCG@20: 0.0456, Recall@20: 0.0972, Precision@20: 0.0152
Patience counter: 4/5


[Epoch 008]: 100%|██████████| 2647/2647 [04:43<00:00,  9.33it/s]


Training - Loss 0.0001 | Validation - NDCG@20: 0.0468, Recall@20: 0.0986, Precision@20: 0.0155
Patience counter: 5/5
Satisfy early stop mechanism
[INFO] Results saved to ./checkpoint/A0_Baseline

===== Running A1_Rerank =====
[INFO] Denoise OFF → use click-only training
[A1_Rerank] Encoded edges: 46
get the whole sparse interaction matrix


[Epoch 000]: 100%|██████████| 2647/2647 [05:02<00:00,  8.74it/s]


Training - Loss 0.1746 | Validation - NDCG@20: 0.0449, Recall@20: 0.0921, Precision@20: 0.0134


[Epoch 001]: 100%|██████████| 2647/2647 [04:56<00:00,  8.92it/s]


Training - Loss 0.0620 | Validation - NDCG@20: 0.0502, Recall@20: 0.0997, Precision@20: 0.0155


[Epoch 002]: 100%|██████████| 2647/2647 [04:56<00:00,  8.93it/s]


Training - Loss 0.0310 | Validation - NDCG@20: 0.0462, Recall@20: 0.0923, Precision@20: 0.0150
Patience counter: 1/5


[Epoch 003]: 100%|██████████| 2647/2647 [04:57<00:00,  8.88it/s]


Training - Loss 0.0145 | Validation - NDCG@20: 0.0487, Recall@20: 0.0984, Precision@20: 0.0158
Patience counter: 2/5


[Epoch 004]: 100%|██████████| 2647/2647 [04:58<00:00,  8.88it/s]


Training - Loss 0.0059 | Validation - NDCG@20: 0.0481, Recall@20: 0.0978, Precision@20: 0.0154
Patience counter: 3/5


[Epoch 005]: 100%|██████████| 2647/2647 [04:54<00:00,  8.98it/s]


Training - Loss 0.0021 | Validation - NDCG@20: 0.0472, Recall@20: 0.0994, Precision@20: 0.0155
Patience counter: 4/5


[Epoch 006]: 100%|██████████| 2647/2647 [05:07<00:00,  8.61it/s]


Training - Loss 0.0007 | Validation - NDCG@20: 0.0459, Recall@20: 0.0963, Precision@20: 0.0151
Patience counter: 5/5
Satisfy early stop mechanism
DEBUG: U_emb shape: (34240, 128), I_emb shape: (130637, 128)
DEBUG: beh_dict has 2431000 entries
DEBUG: user_meta shape: (34240, 3), item_meta shape: (169708, 2)


building reranking feature:   0%|          | 0/34237 [00:00<?, ?it/s]

DEBUG: Feature 0: u=0, i=1328, components={'u_emb': 128, 'i_emb': 128, 'dot': 1, 'rank': 1, 'beh': 5, 'u_meta': 2, 'i_meta': 1}, total_len=266
DEBUG: Feature 1: u=0, i=718, components={'u_emb': 128, 'i_emb': 128, 'dot': 1, 'rank': 1, 'beh': 5, 'u_meta': 2, 'i_meta': 1}, total_len=266
DEBUG: Feature 2: u=0, i=172, components={'u_emb': 128, 'i_emb': 128, 'dot': 1, 'rank': 1, 'beh': 5, 'u_meta': 2, 'i_meta': 1}, total_len=266


building reranking feature:   9%|▉         | 3249/34237 [14:17<3:45:56,  2.29it/s] 

In [None]:
def collect_ablation_results(log_dir, settings):
    rows = []
    for s in settings:
        path = f"{log_dir}/{s['name']}/results.json"
        with open(path) as f:
            r = json.load(f)
        row = {'Model': s['name']}
        row.update(r['test_metrics'])
        rows.append(row)
    return pd.DataFrame(rows)

abl_df = collect_ablation_results(cf_args['save_path'], ablation_settings)
print(abl_df)


Ks = [5, 10, 20, 50]
plt.figure()

for name in ['A0_Baseline', 'A5_Full']:
    ndcgs = []
    for k in Ks:
        cf_args['k'] = k
        preds = get_topN_candidates(model, cf_args['test_u'], cf_args)
        ndcgs.append(NDCG(test_ur, preds, cf_args['test_u']))
    plt.plot(Ks, ndcgs, label=name)

plt.xlabel('K')
plt.ylabel('NDCG@K')
plt.legend()
plt.title('NDCG@K Comparison')
plt.show()


plt.figure()
plt.bar(abl_df['Model'], abl_df['NDCG@20'])
plt.xticks(rotation=45)
plt.ylabel('NDCG@20')
plt.title('Ablation Study on QB-video')
plt.tight_layout()
plt.show()

def load_loss(exp_name):
    with open(f"{cf_args['save_path']}/{exp_name}/results.json") as f:
        return json.load(f)['final_train_loss']

# hoặc plot từ history đã lưu


abl_df = collect_ablation_results(cf_args['save_path'], ablation_settings)
display(abl_df)

plt.figure()
plt.bar(abl_df['Model'], abl_df['NDCG@20'])
plt.xticks(rotation=45)
plt.ylabel('NDCG@20')
plt.title('Ablation Study on QB-video')
plt.tight_layout()
plt.show()
