LLMESR.py

In [None]:
# here put the import lib
import torch
import torch.nn as nn
from models.DualLLMSRS import DualLLMSASRec, DualLLMGRU4Rec, DualLLMBert4Rec
from models.utils import Contrastive_Loss2, ClusterHandler  # 新增导入



class LLMESR_SASRec(DualLLMSASRec):

    def __init__(self, user_num, item_num, device, args):

        super().__init__(user_num, item_num, device, args)
        self.alpha = args.alpha
        self.user_sim_func = args.user_sim_func
        self.item_reg = args.item_reg
        self.gamma = args.gamma  # 聚类约束强度

        # 初始化聚类处理器
        self.cluster_handler = ClusterHandler(args.dataset, args.hidden_size, device)

        if self.user_sim_func == "cl":
            self.align = Contrastive_Loss2()
        elif self.user_sim_func == "kd":
            self.align = nn.MSELoss()
        else:
            raise ValueError

        self.projector1 = nn.Linear(2*args.hidden_size, 2*args.hidden_size)
        self.projector2 = nn.Linear(2*args.hidden_size, 2*args.hidden_size)

        if self.item_reg:
            self.beta = args.beta
            self.reg = Contrastive_Loss2()

        self._init_weights()


    def forward(self, 
                seq, 
                pos, 
                neg, 
                positions,
                **kwargs):
        
        loss = super().forward(seq, pos, neg, positions, **kwargs)  # get the original loss
        
        log_feats = self.log2feats(seq, positions)[:, -1, :]
        sim_seq, sim_positions = kwargs["sim_seq"].view(-1, seq.shape[1]), kwargs["sim_positions"].view(-1, seq.shape[1])
        sim_num = kwargs["sim_seq"].shape[1]
        sim_log_feats = self.log2feats(sim_seq, sim_positions)[:, -1, :]    # (bs*sim_num, hidden_size)
        sim_log_feats = sim_log_feats.detach().view(seq.shape[0], sim_num, -1)  # (bs, sim_num, hidden_size)
        sim_log_feats = torch.mean(sim_log_feats, dim=1)

        if self.user_sim_func == "cl":
            # align_loss = self.align(self.projector1(log_feats), self.projector2(sim_log_feats))
            align_loss = self.align(log_feats, sim_log_feats)
        elif self.user_sim_func == "kd":
            align_loss = self.align(log_feats, sim_log_feats)

        if self.item_reg:
            unfold_item_id = torch.masked_select(seq, seq>0)
            llm_item_emb = self.adapter(self.llm_item_emb(unfold_item_id))
            id_item_emb = self.id_item_emb(unfold_item_id)
            reg_loss = self.reg(llm_item_emb, id_item_emb)
            loss += self.beta * reg_loss

        # 计算聚类约束损失
        item_ids = seq[seq > 0] # 有效物品ID
        item_embeddings = self.id_item_emb(item_ids) # 获取物品嵌入
        cluster_loss = self.cluster_handler.calculate_cluster_loss(item_ids, item_embeddings)
        loss += self.gamma * cluster_loss

        loss += self.alpha * align_loss

        return loss
    


class LLMESR_GRU4Rec(DualLLMGRU4Rec):

    def __init__(self, user_num, item_num, device, args):

        super().__init__(user_num, item_num, device, args)
        self.alpha = args.alpha
        self.user_sim_func = args.user_sim_func
        self.item_reg = args.item_reg
        self.gamma = args.gamma  # 聚类约束强度

        # 初始化聚类处理器
        self.cluster_handler = ClusterHandler(args.dataset, args.hidden_size, device)

        if self.user_sim_func == "cl":
            self.align = Contrastive_Loss2()
        elif self.user_sim_func == "kd":
            self.align = nn.MSELoss()
        else:
            raise ValueError

        self.projector1 = nn.Linear(2*args.hidden_size, 2*args.hidden_size)
        self.projector2 = nn.Linear(2*args.hidden_size, 2*args.hidden_size)

        if self.item_reg:
            self.beta = args.beta
            self.reg = Contrastive_Loss2()

        self._init_weights()


    def forward(self, 
                seq, 
                pos, 
                neg, 
                positions,
                **kwargs):
        
        loss = super().forward(seq, pos, neg, positions, **kwargs)  # get the original loss
        
        log_feats = self.log2feats(seq)[:, -1, :]
        sim_seq, sim_positions = kwargs["sim_seq"].view(-1, seq.shape[1]), kwargs["sim_positions"].view(-1, seq.shape[1])
        sim_num = kwargs["sim_seq"].shape[1]
        sim_log_feats = self.log2feats(sim_seq)[:, -1, :]    # (bs*sim_num, hidden_size)
        sim_log_feats = sim_log_feats.detach().view(seq.shape[0], sim_num, -1)  # (bs, sim_num, hidden_size)
        sim_log_feats = torch.mean(sim_log_feats, dim=1)

        if self.user_sim_func == "cl":
            # align_loss = self.align(self.projector1(log_feats), self.projector2(sim_log_feats))
            align_loss = self.align(log_feats, sim_log_feats)
        elif self.user_sim_func == "kd":
            align_loss = self.align(log_feats, sim_log_feats)

        if self.item_reg:
            unfold_item_id = torch.masked_select(seq, seq>0)
            llm_item_emb = self.adapter(self.llm_item_emb(unfold_item_id))
            id_item_emb = self.id_item_emb(unfold_item_id)
            reg_loss = self.reg(llm_item_emb, id_item_emb)
            loss += self.beta * reg_loss

        # 计算聚类约束损失
        item_ids = seq[seq > 0] # 有效物品ID
        item_embeddings = self.id_item_emb(item_ids) # 获取物品嵌入
        cluster_loss = self.cluster_handler.calculate_cluster_loss(item_ids, item_embeddings)
        loss += self.gamma * cluster_loss

        loss += self.alpha * align_loss

        return loss



class LLMESR_Bert4Rec(DualLLMBert4Rec):

    def __init__(self, user_num, item_num, device, args):

        super().__init__(user_num, item_num, device, args)
        self.alpha = args.alpha
        self.user_sim_func = args.user_sim_func
        self.item_reg = args.item_reg
        self.gamma = args.gamma  # 聚类约束强度超参数

        # 初始化聚类处理器
        self.cluster_handler = ClusterHandler(args.dataset, args.hidden_size, device)


        if self.user_sim_func == "cl":
            self.align = Contrastive_Loss2()
        elif self.user_sim_func == "kd":
            self.align = nn.MSELoss()
        else:
            raise ValueError

        self.projector1 = nn.Linear(2*args.hidden_size, 2*args.hidden_size)
        self.projector2 = nn.Linear(2*args.hidden_size, 2*args.hidden_size)

        if self.item_reg:
            self.reg = Contrastive_Loss2()

        self._init_weights()


    def forward(self, 
                seq, 
                pos, 
                neg, 
                positions,
                **kwargs):
        
        loss = super().forward(seq, pos, neg, positions, **kwargs)  # get the original loss
        
        log_feats = self.log2feats(seq, positions)[:, -1, :]
        sim_seq, sim_positions = kwargs["sim_seq"].view(-1, seq.shape[1]), kwargs["sim_positions"].view(-1, seq.shape[1])
        sim_num = kwargs["sim_seq"].shape[1]
        sim_log_feats = self.log2feats(sim_seq, sim_positions)[:, -1, :]
        sim_log_feats = sim_log_feats.detach().view(seq.shape[0], sim_num, -1)  # (bs, sim_num, hidden_size)
        sim_log_feats = torch.mean(sim_log_feats, dim=1)

        if self.user_sim_func == "cl":
            # align_loss = self.align(self.projector1(log_feats), self.projector2(sim_log_feats))
            align_loss = self.align(log_feats, sim_log_feats)
        elif self.user_sim_func == "kd":
            align_loss = self.align(log_feats, sim_log_feats)

        # 计算聚类约束损失（适配Bert4Rec的特殊处理）
        cluster_loss = self.calculate_cluster_loss(seq)

        # 整合所有损失
        loss += self.gamma * cluster_loss  # 聚类约束损失
        loss += self.alpha * align_loss

        return loss
        
    def calculate_cluster_loss(self, seq):
        #1
        valid_mask = (seq > 0) & (seq != self.mask_token)        # 过滤掉 PAD=0 和 MASK
        item_ids   = seq[valid_mask]                             # 1-D tensor
        if item_ids.numel() == 0:                                # 这一批全是 PAD/MASK
            return torch.tensor(0.0, device=seq.device)
        # ---------- 2) 再过滤聚类表里没有的 item——id ----------
        # 防止越界导致 device-side assert
        max_id_in_cluster = self.cluster_handler.item_cluster.size(0) - 1
        item_ids = item_ids[item_ids <= max_id_in_cluster]

        if item_ids.numel() == 0:                                # 都被过滤掉
            return torch.tensor(0.0, device=seq.device)
        # ---------- 3) 查嵌入并计算聚类损失 ----------
        item_embeddings = self.id_item_emb(item_ids)
        cluster_loss    = self.cluster_handler.calculate_cluster_loss(item_ids, item_embeddings)
        return cluster_loss


utils.py

In [None]:
# here put the import lib
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from math import sqrt
import pickle
from pathlib import Path 

# 新增：聚类信息处理类
class ClusterHandler(nn.Module):
    """处理聚类信息的工具类，用于获取聚类中心和计算聚类损失"""
    def __init__(self, dataset, hidden_size, device):
        super().__init__()
        ####################################
        #self.item2center = torch.load(Path(dataset)/"item_cluster.pt",
                                     # map_location=device)   # 原本的字段
        #self.item_cluster = self.item2center
        cluster_path = Path("data") / dataset / "item_cluster.pt"
        self.item_cluster = torch.load(cluster_path, map_location=device)
        ########################
        self.hidden_size = hidden_size
        self.device = device
        self.load_cluster_info(dataset)
        
        # 定义聚类中心映射层（8维→d维）
        self.cluster_projection = nn.Linear(8, hidden_size)
        # 冻结映射层参数
        for param in self.cluster_projection.parameters():
            param.requires_grad = False

    def load_cluster_info(self, dataset):
        """加载聚类标签和中心，并转换为张量"""
        data_dir = f'data/{dataset}/handled/'
        # 读取item聚类标签（含噪声）
        with open(f'{data_dir}/item_cluster_labels.pkl', 'rb') as f:
            self.item_cluster_labels = torch.tensor(pickle.load(f), dtype=torch.long, device=self.device)
        # 读取8维聚类中心
        with open(f'{data_dir}/cluster_centers_8d.pkl', 'rb') as f:
            cluster_centers_8d = pickle.load(f)
        # 将中心数据转换为PyTorch张量
        num_clusters = len(cluster_centers_8d)
        self.cluster_centers_8d = torch.zeros(num_clusters, 8, device=self.device)
        for c_id, center in cluster_centers_8d.items():
            self.cluster_centers_8d[c_id] = torch.tensor(center, dtype=torch.float32, device=self.device)

    def get_cluster_center(self, item_ids):
        """获取item对应的聚类中心（映射到d维）"""
        # 获取聚类标签
        cluster_labels = self.item_cluster_labels[item_ids]
        # 提取非噪声item的8维中心
        non_noise_mask = cluster_labels != -1 #过滤噪声物品（标签为-1）
        non_noise_indices = torch.where(non_noise_mask)[0]
        non_noise_labels = cluster_labels[non_noise_indices]
        # 仅对非噪声物品：获取8维中心，并映射到d维
        cluster_centers_8d = self.cluster_centers_8d[non_noise_labels]
        cluster_centers_d = self.cluster_projection(cluster_centers_8d)
        # 创建全零张量（与协作嵌入维度一致）
        batch_size = item_ids.size(0)
        centers_d = torch.zeros(batch_size, self.hidden_size, device=item_ids.device)
        # 将非噪声item的中心填入结果，为噪声物品返回零向量
        centers_d[non_noise_indices] = cluster_centers_d
        return centers_d, non_noise_mask

    def calculate_cluster_loss(self, item_ids, item_embeddings):
        """计算非噪声item的聚类约束损失"""
        if item_ids.numel() == 0:
            return torch.tensor(0.0, device=self.device)
        
        # 获取item对应的聚类中心（d维）和非噪声掩码
        centers_d, non_noise_mask = self.get_cluster_center(item_ids)
        
        # 仅计算非噪声item的损失
        if non_noise_mask.sum() == 0:
            return torch.tensor(0.0, device=self.device)
        
        # 计算非噪声item的L2损失
        non_noise_embeddings = item_embeddings[non_noise_mask]
        non_noise_centers = centers_d[non_noise_mask]
        cluster_loss = torch.mean(torch.norm(non_noise_embeddings - non_noise_centers, p=2, dim=1)**2)
        
        return cluster_loss

class PointWiseFeedForward(torch.nn.Module):
    def __init__(self, hidden_units, dropout_rate):

        super(PointWiseFeedForward, self).__init__()

        self.conv1 = torch.nn.Conv1d(hidden_units, hidden_units, kernel_size=1)
        self.dropout1 = torch.nn.Dropout(p=dropout_rate)
        self.relu = torch.nn.ReLU()
        self.conv2 = torch.nn.Conv1d(hidden_units, hidden_units, kernel_size=1)
        self.dropout2 = torch.nn.Dropout(p=dropout_rate)

    def forward(self, inputs):
        outputs = self.dropout2(self.conv2(self.relu(self.dropout1(self.conv1(inputs.transpose(-1, -2))))))
        outputs = outputs.transpose(-1, -2) # as Conv1D requires (N, C, Length)
        outputs += inputs
        return outputs
    


class PointWiseFeedForward(torch.nn.Module):
    def __init__(self, hidden_units, dropout_rate):

        super(PointWiseFeedForward, self).__init__()

        self.conv1 = torch.nn.Conv1d(hidden_units, hidden_units, kernel_size=1)
        self.dropout1 = torch.nn.Dropout(p=dropout_rate)
        self.relu = torch.nn.ReLU()
        self.conv2 = torch.nn.Conv1d(hidden_units, hidden_units, kernel_size=1)
        self.dropout2 = torch.nn.Dropout(p=dropout_rate)

    def forward(self, inputs):
        outputs = self.dropout2(self.conv2(self.relu(self.dropout1(self.conv1(inputs.transpose(-1, -2))))))
        outputs = outputs.transpose(-1, -2) # as Conv1D requires (N, C, Length)
        outputs += inputs
        return outputs
    


class Contrastive_Loss2(nn.Module):

    def __init__(self, tau=1) -> None:
        super().__init__()

        self.temperature = tau


    def forward(self, X, Y):
        
        logits = (X @ Y.T) / self.temperature
        X_similarity = Y @ Y.T
        Y_similarity = X @ X.T
        targets = F.softmax(
            (X_similarity + Y_similarity) / 2 * self.temperature, dim=-1
        )
        X_loss = self.cross_entropy(logits, targets, reduction='none')
        Y_loss = self.cross_entropy(logits.T, targets.T, reduction='none')
        loss =  (Y_loss + X_loss) / 2.0 # shape: (batch_size)
        return loss.mean()
    

    def cross_entropy(self, preds, targets, reduction='none'):

        log_softmax = nn.LogSoftmax(dim=-1)
        loss = (-targets * log_softmax(preds)).sum(1)
        if reduction == "none":
            return loss
        elif reduction == "mean":
            return loss.mean()
    


class CalculateAttention(nn.Module):

    def __init__(self):
        super().__init__()


    def forward(self, Q, K, V, mask):

        attention = torch.matmul(Q,torch.transpose(K, -1, -2))
        # use mask
        attention = attention.masked_fill_(mask, -1e9)
        attention = torch.softmax(attention / sqrt(Q.size(-1)), dim=-1)
        attention = torch.matmul(attention,V)
        return attention



class Multi_CrossAttention(nn.Module):
    """
    forward时，第一个参数用于计算query，第二个参数用于计算key和value
    """
    def __init__(self,hidden_size,all_head_size,head_num):
        super().__init__()
        self.hidden_size    = hidden_size       # 输入维度
        self.all_head_size  = all_head_size     # 输出维度
        self.num_heads      = head_num          # 注意头的数量
        self.h_size         = all_head_size // head_num

        assert all_head_size % head_num == 0

        # W_Q,W_K,W_V (hidden_size,all_head_size)
        self.linear_q = nn.Linear(hidden_size, all_head_size, bias=False)
        self.linear_k = nn.Linear(hidden_size, all_head_size, bias=False)
        self.linear_v = nn.Linear(hidden_size, all_head_size, bias=False)
        self.linear_output = nn.Linear(all_head_size, hidden_size)

        # normalization
        self.norm = sqrt(all_head_size)


    def print(self):
        print(self.hidden_size,self.all_head_size)
        print(self.linear_k,self.linear_q,self.linear_v)
    

    def forward(self,x,y,log_seqs):
        """
        cross-attention: x,y是两个模型的隐藏层，将x作为q的输入，y作为k和v的输入
        """

        batch_size = x.size(0)
        # (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W)

        # q_s: [batch_size, num_heads, seq_length, h_size]
        q_s = self.linear_q(x).view(batch_size, -1, self.num_heads, self.h_size).transpose(1,2)

        # k_s: [batch_size, num_heads, seq_length, h_size]
        k_s = self.linear_k(y).view(batch_size, -1, self.num_heads, self.h_size).transpose(1,2)

        # v_s: [batch_size, num_heads, seq_length, h_size]
        v_s = self.linear_v(y).view(batch_size, -1, self.num_heads, self.h_size).transpose(1,2)

        # attention_mask = attention_mask.eq(0)
        attention_mask = (log_seqs == 0).unsqueeze(1).repeat(1, log_seqs.size(1), 1).unsqueeze(1)

        attention = CalculateAttention()(q_s,k_s,v_s,attention_mask)
        # attention : [batch_size , seq_length , num_heads * h_size]
        attention = attention.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.h_size)
        
        # output : [batch_size , seq_length , hidden_size]
        output = self.linear_output(attention)

        return output





main.py

In [None]:
# here put the import lib
import os
import argparse
import torch
from generators.generator import Seq2SeqGeneratorAllUser
from generators.generator import GeneratorAllUser
from generators.bert_generator import BertGeneratorAllUser
from trainers.sequence_trainer import SeqTrainer
from utils.utils import set_seed
from utils.logger import Logger


parser = argparse.ArgumentParser()

# Required parameters
parser.add_argument("--model_name", 
                    default='llmesr_sasrec',
                    choices=[
                    "llmesr_sasrec", "llmesr_bert4rec", "llmesr_gru4rec",
                    ],
                    type=str, 
                    required=False,
                    help="model name")
parser.add_argument("--dataset", 
                    default="yelp", 
                    choices=["yelp", "fashion", "beauty",],  # preprocess by myself
                    help="Choose the dataset")
parser.add_argument("--inter_file",
                    default="inter",
                    type=str,
                    help="the name of interaction file")
parser.add_argument("--demo", 
                    default=False, 
                    action='store_true', 
                    help='whether run demo')
parser.add_argument("--pretrain_dir",
                    type=str,
                    default="sasrec_seq",
                    help="the path that pretrained model saved in")
parser.add_argument("--output_dir",
                    default='./saved/',
                    type=str,
                    required=False,
                    help="The output directory where the model checkpoints will be written.")
parser.add_argument("--check_path",
                    default='',
                    type=str,
                    help="the save path of checkpoints for different running")
parser.add_argument("--do_test",
                    default=False,
                    action="store_true",
                    help="whehther run the test on the well-trained model")
parser.add_argument("--do_emb",
                    default=False,
                    action="store_true",
                    help="save the user embedding derived from the SRS model")
parser.add_argument("--do_group",
                    default=False,
                    action="store_true",
                    help="conduct the group test")
parser.add_argument("--keepon",
                    default=False,
                    action="store_true",
                    help="whether keep on training based on a trained model")
parser.add_argument("--keepon_path",
                    type=str,
                    default="normal",
                    help="the path of trained model for keep on training")
parser.add_argument("--clip_path",
                    type=str,
                    default="",
                    help="the path to save the CLIP-pretrained embedding and adapter")
parser.add_argument("--ts_user",
                    type=int,
                    default=10,
                    help="the threshold to split the short and long seq")
parser.add_argument("--ts_item",
                    type=int,
                    default=20,
                    help="the threshold to split the long-tail and popular items")

# Model parameters
parser.add_argument("--hidden_size",
                    default=64,
                    type=int,
                    help="the hidden size of embedding")
parser.add_argument("--trm_num",
                    default=2,
                    type=int,
                    help="the number of transformer layer")
parser.add_argument("--num_heads",
                    default=1,
                    type=int,
                    help="the number of heads in Trm layer")
parser.add_argument("--num_layers",
                    default=1,
                    type=int,
                    help="the number of GRU layers")
parser.add_argument("--cl_scale",
                    type=float,
                    default=0.1,
                    help="the scale for contastive loss")
parser.add_argument("--mask_crop_ratio",
                    type=float,
                    default=0.3,
                    help="the mask/crop ratio for CL4SRec")
parser.add_argument("--tau",
                    default=1,
                    type=float,
                    help="the temperature for contrastive loss")
parser.add_argument("--sse_ratio",
                    default=0.4,
                    type=float,
                    help="the sse ratio for SSE-PT model")
parser.add_argument("--dropout_rate",
                    default=0.5,
                    type=float,
                    help="the dropout rate")
parser.add_argument("--max_len",
                    default=200,
                    type=int,
                    help="the max length of input sequence")
parser.add_argument("--mask_prob",
                    type=float,
                    default=0.4,
                    help="the mask probability for training Bert model")
parser.add_argument("--aug",
                    default=False,
                    action="store_true",
                    help="whether augment the sequence data")
parser.add_argument("--aug_seq",
                    default=False,
                    action="store_true",
                    help="whether use the augmented data")
parser.add_argument("--aug_seq_len",
                    default=0,
                    type=int,
                    help="the augmented length for each sequence")
parser.add_argument("--aug_file",
                    default="inter",
                    type=str,
                    help="the augmentation file name")
parser.add_argument("--train_neg",
                    default=1,
                    type=int,
                    help="the number of negative samples for training")
parser.add_argument("--test_neg",
                    default=100,
                    type=int,
                    help="the number of negative samples for test")
parser.add_argument("--suffix_num",
                    default=5,
                    type=int,
                    help="the suffix number for augmented sequence")
parser.add_argument("--prompt_num",
                    default=2,
                    type=int,
                    help="the number of prompts")
parser.add_argument("--freeze",
                    default=False,
                    action="store_true",
                    help="whether freeze the pretrained architecture when finetuning")
parser.add_argument("--pg",
                    default="length",
                    choices=['length', 'attention'],
                    type=str,
                    help="choose the prompt generator")
parser.add_argument("--use_cross_att",
                    default=False,
                    action="store_true",
                    help="whether add a cross-attention to interact the dual-view")
parser.add_argument("--alpha",
                    default=0.1,
                    type=float,
                    help="the weight of auxiliary loss")
parser.add_argument("--gamma",
                    default=0.1,
                    type=float,
                    help="the weight of cluster loss")
parser.add_argument("--user_sim_func",
                    default="kd",
                    type=str,
                    help="the type of user similarity function to derive the loss")
parser.add_argument("--item_reg",
                    default=False,
                    action="store_true",
                    help="whether regularize the item embedding by CL")
parser.add_argument("--beta",
                    default=0.1,
                    type=float,
                    help="the weight of regulation loss")
parser.add_argument("--sim_user_num",
                    default=10,
                    type=int,
                    help="the number of similar users for enhancement")
parser.add_argument("--split_backbone",
                    default=False,
                    action="store_true",
                    help="whether use a split backbone")
parser.add_argument("--co_view",
                    default=False,
                    action="store_true",
                    help="only use the collaborative view")
parser.add_argument("--se_view",
                    default=False,
                    action="store_true",
                    help="only use the semantic view")


# Other parameters
parser.add_argument("--train_batch_size",
                    default=512,
                    type=int,
                    help="Total batch size for training.")
parser.add_argument("--lr",
                    default=0.001,
                    type=float,
                    help="The initial learning rate for Adam.")
parser.add_argument("--l2",
                    default=0,
                    type=float,
                    help='The L2 regularization')
parser.add_argument("--num_train_epochs",
                    default=100,
                    type=float,
                    help="Total number of training epochs to perform.")
parser.add_argument("--lr_dc_step",
                    default=1000,
                    type=int,
                    help='every n step, decrease the lr')
parser.add_argument("--lr_dc",
                    default=0,
                    type=float,
                    help='how many learning rate to decrease')
parser.add_argument("--patience",
                    type=int,
                    default=20,
                    help='How many steps to tolerate the performance decrease while training')
parser.add_argument("--watch_metric",
                    type=str,
                    default='NDCG@10',
                    help="which metric is used to select model.")
parser.add_argument('--seed',
                    type=int,
                    default=42,
                    help="random seed for different data split")
parser.add_argument("--no_cuda",
                    action='store_true',
                    help="Whether not to use CUDA when available")
parser.add_argument('--gpu_id',
                    default=0,
                    type=int,
                    help='The device id.')
parser.add_argument('--num_workers',
                    default=0,
                    type=int,
                    help='The number of workers in dataloader')
parser.add_argument("--log", 
                    default=False,
                    action="store_true",
                    help="whether create a new log file")

torch.autograd.set_detect_anomaly(True)

args = parser.parse_args()
set_seed(args.seed) # fix the random seed
args.output_dir = os.path.join(args.output_dir, args.dataset)
args.pretrain_dir = os.path.join(args.output_dir, args.pretrain_dir)
args.output_dir = os.path.join(args.output_dir, args.model_name)
args.keepon_path = os.path.join(args.output_dir, args.keepon_path)
args.output_dir = os.path.join(args.output_dir, args.check_path)    # if check_path is none, then without check_path


def main():

    log_manager = Logger(args)  # initialize the log manager
    logger, writer = log_manager.get_logger()    # get the logger
    args.now_str = log_manager.get_now_str()

    device = torch.device("cuda:"+str(args.gpu_id) if torch.cuda.is_available()
                          and not args.no_cuda else "cpu")


    os.makedirs(args.output_dir, exist_ok=True)

    # generator is used to manage dataset
    if args.model_name in ['llmesr_gru4rec']:
        generator = GeneratorAllUser(args, logger, device)
    elif args.model_name in ["llmesr_bert4rec"]:
        generator = BertGeneratorAllUser(args, logger, device)
    elif args.model_name in ["llmesr_sasrec"]:
        generator = Seq2SeqGeneratorAllUser(args, logger, device)
    else:
        raise ValueError

    trainer = SeqTrainer(args, logger, writer, device, generator)

    if args.do_test:
        trainer.test()
    elif args.do_emb:
        trainer.save_user_emb()
    elif args.do_group:
        trainer.test_group()
    else:
        trainer.train()

    log_manager.end_log()   # delete the logger threads



if __name__ == "__main__":

    main()





保留hdbscan传统的部分

In [None]:
# here put the import lib
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from math import sqrt
import pickle
from pathlib import Path 

# 新增：聚类信息处理类（支持传统聚类+模糊-密度约束）
class ClusterHandler(nn.Module):
    """处理聚类信息的工具类，支持传统聚类损失和模糊-密度约束损失"""
    def __init__(self, dataset, hidden_size, device, use_fuzzy=False):
        super().__init__()
        self.hidden_size = hidden_size
        self.device = device
        self.dataset = dataset  # 数据集名称（如"fashion"）
        self.use_fuzzy = use_fuzzy  # 开关：是否启用模糊-密度约束

        
        # ---------------------- 1. 原有传统聚类相关初始化 ----------------------
        # 加载传统聚类标签（item_cluster.pt，保留原有逻辑）
        #cluster_path = Path("data") / dataset / "item_cluster.pt"
        #self.item_cluster = torch.load(cluster_path, map_location=device)
        # 加载8维簇中心和聚类标签（原有load_cluster_info逻辑）
        self.load_cluster_info(dataset)
        # 传统聚类：8维→d维映射层（冻结参数）
        self.cluster_projection = nn.Linear(8, hidden_size)
        for param in self.cluster_projection.parameters():
            param.requires_grad = False
        # ---------------------- 1. 原有传统聚类相关初始化 ----------------------


        
        # ---------------------- 2. 新增：模糊-密度约束相关初始化 ----------------------
        if self.use_fuzzy:
            self._load_fuzzy_constraint_files()  # 加载模糊约束文件
            self.fuzzy_m = 2.0  # 模糊指数（与之前计算一致，可改为参数传入）

    def load_cluster_info(self, dataset):
        """原有逻辑：加载传统聚类的8维簇中心和标签"""
        data_dir = f'data/{dataset}/handled/'
        # 读取item聚类标签（含噪声，标签=-1）
        with open(f'{data_dir}/item_cluster_labels.pkl', 'rb') as f:
            self.item_cluster_labels = torch.tensor(pickle.load(f), dtype=torch.long, device=self.device)
        # 读取8维聚类中心（传统聚类结果）
        with open(f'{data_dir}/cluster_centers_8d.pkl', 'rb') as f:
            cluster_centers_8d = pickle.load(f)
        # 转换为Tensor（适配PyTorch）
        num_clusters = len(cluster_centers_8d)
        self.cluster_centers_8d = torch.zeros(num_clusters, 8, device=self.device)
        for c_id, center in cluster_centers_8d.items():
            self.cluster_centers_8d[c_id] = torch.tensor(center, dtype=torch.float32, device=self.device)

    def _load_fuzzy_constraint_files(self):
        """新增：加载模糊-密度约束的核心文件（hdbscan_fuzzy_U.pkl、hdbscan_cluster_centers.pkl）"""
        data_dir = Path(f'data/{self.dataset}/handled/')  # 与Fashion数据集路径对齐
        # 1. 加载模糊隶属度向量（N个物品 × C个簇，N=4722 for Fashion）
        fuzzy_U_path = data_dir / "hdbscan_fuzzy_U.pkl"
        if not fuzzy_U_path.exists():
            raise FileNotFoundError(f"模糊隶属度文件不存在：{fuzzy_U_path}")
        with open(fuzzy_U_path, 'rb') as f:
            fuzzy_U_np = pickle.load(f)
        self.fuzzy_U = torch.tensor(fuzzy_U_np, dtype=torch.float32, device=self.device)

        # 2. 加载加权簇中心（C个簇 × 64维，与hidden_size一致）
        cluster_centers_path = data_dir / "hdbscan_cluster_centers.pkl"
        if not cluster_centers_path.exists():
            raise FileNotFoundError(f"加权簇中心文件不存在：{cluster_centers_path}")
        with open(cluster_centers_path, 'rb') as f:
            cluster_centers_np = pickle.load(f)
        self.hdbscan_cluster_centers = torch.tensor(cluster_centers_np, dtype=torch.float32, device=self.device)
        self.num_fuzzy_clusters = self.hdbscan_cluster_centers.shape[0]  # 模糊簇数量C

    def get_cluster_center(self, item_ids):
        """原有逻辑：获取传统聚类的d维簇中心（8维→d维映射）"""
        cluster_labels = self.item_cluster_labels[item_ids]
        non_noise_mask = cluster_labels != -1
        non_noise_indices = torch.where(non_noise_mask)[0]
        non_noise_labels = cluster_labels[non_noise_indices]
        
        # 8维中心→d维中心
        cluster_centers_8d = self.cluster_centers_8d[non_noise_labels]
        cluster_centers_d = self.cluster_projection(cluster_centers_8d)
        
        # 填充结果（噪声物品返回零向量）
        batch_size = item_ids.size(0)
        centers_d = torch.zeros(batch_size, self.hidden_size, device=self.device)
        centers_d[non_noise_indices] = cluster_centers_d
        return centers_d, non_noise_mask

    def calculate_cluster_loss(self, item_ids, item_embeddings):
        """
        兼容逻辑：根据use_fuzzy开关，返回传统聚类损失或模糊-密度约束损失
        item_ids: 有效物品ID（1-D Tensor，如[102, 345, ...]）
        item_embeddings: 物品嵌入（shape: [valid_num, hidden_size]）
        """
        if item_ids.numel() == 0:  # 无有效物品，损失为0
            return torch.tensor(0.0, device=self.device)

         # 新增打印：确认use_fuzzy状态和当前执行的分支
        print(f"[ClusterHandler] use_fuzzy={self.use_fuzzy} | 执行{'模糊损失' if self.use_fuzzy else '传统聚类损失'}")
        
        # 开关控制：启用模糊损失则计算模糊约束，否则计算传统聚类损失
        if self.use_fuzzy:
            return self._calculate_fuzzy_loss(item_ids, item_embeddings)
        else:
            return self._calculate_traditional_cluster_loss(item_ids, item_embeddings)

    def _calculate_traditional_cluster_loss(self, item_ids, item_embeddings):
        """原有逻辑：计算传统聚类损失（非噪声物品的L2损失）"""
        centers_d, non_noise_mask = self.get_cluster_center(item_ids)
        if non_noise_mask.sum() == 0:
            return torch.tensor(0.0, device=self.device)
        
        # 非噪声物品的L2损失
        non_noise_emb = item_embeddings[non_noise_mask]
        non_noise_centers = centers_d[non_noise_mask]
        return torch.mean(torch.norm(non_noise_emb - non_noise_centers, p=2, dim=1)**2)

    def _calculate_fuzzy_loss(self, item_ids, item_embeddings):
        """
        新增：计算模糊-密度约束损失（动态模糊原型+MSE）
        对应文档中“模糊隶属度向量→动态原型→MSE约束”逻辑
        """
        # 1. 过滤无效物品ID（避免索引超出fuzzy_U范围）
        valid_mask = item_ids < self.fuzzy_U.shape[0]  # 物品ID不能大于模糊隶属度向量的长度
        valid_ids = item_ids[valid_mask]
        valid_emb = item_embeddings[valid_mask]
        
        if valid_ids.numel() == 0:  # 无有效物品，损失为0
            return torch.tensor(0.0, device=self.device)
        
        # 2. 索引当前物品的模糊隶属度（shape: [valid_num, num_fuzzy_clusters]）
        batch_fuzzy_U = self.fuzzy_U[valid_ids]
        
        # 3. 计算动态模糊原型（文档核心：Prototype_i = Σ(μ_ij^m * cluster_center_j)）
        fuzzy_weights = batch_fuzzy_U ** self.fuzzy_m  # 模糊权重（μ^m，增强隶属度差异）
        # 加权求和：[valid_num, C, 1] × [C, d] → [valid_num, d]
        prototypes = torch.matmul(
            fuzzy_weights.unsqueeze(1),  # [valid_num, 1, C]
            self.hdbscan_cluster_centers.unsqueeze(0)  # [1, C, d]
        ).squeeze(1)  # [valid_num, d]
        
        # 4. MSE损失（物品嵌入与动态原型的平滑约束，文档中“无缝约束”）
        fuzzy_loss = F.mse_loss(valid_emb, prototypes)
        return fuzzy_loss


In [None]:
# here put the import lib
import torch
import torch.nn as nn
from models.DualLLMSRS import DualLLMSASRec, DualLLMGRU4Rec, DualLLMBert4Rec
from models.utils import Contrastive_Loss2, ClusterHandler  # 确保导入修改后的ClusterHandler


class LLMESR_SASRec(DualLLMSASRec):

    def __init__(self, user_num, item_num, device, args):

        super().__init__(user_num, item_num, device, args)
        self.alpha = args.alpha
        self.user_sim_func = args.user_sim_func
        self.item_reg = args.item_reg
        self.gamma = args.gamma  # 聚类约束强度

        # ---------------------- 关键修改：向ClusterHandler传递use_fuzzy参数 ----------------------
        self.cluster_handler = ClusterHandler(
            dataset=args.dataset,          # 数据集名称（如"fashion"）
            hidden_size=args.hidden_size,  # 嵌入维度（64 for Fashion）
            device=device,
            use_fuzzy=args.use_fuzzy       # 新增：传递模糊约束开关（从命令行参数获取）
        )

        if self.user_sim_func == "cl":
            self.align = Contrastive_Loss2()
        elif self.user_sim_func == "kd":
            self.align = nn.MSELoss()
        else:
            raise ValueError

        self.projector1 = nn.Linear(2*args.hidden_size, 2*args.hidden_size)
        self.projector2 = nn.Linear(2*args.hidden_size, 2*args.hidden_size)

        if self.item_reg:
            self.beta = args.beta
            self.reg = Contrastive_Loss2()

        self._init_weights()


    def forward(self, 
                seq, 
                pos, 
                neg, 
                positions,
                **kwargs):
        
        loss = super().forward(seq, pos, neg, positions, **kwargs)  # 原有损失
        
        log_feats = self.log2feats(seq, positions)[:, -1, :]
        sim_seq, sim_positions = kwargs["sim_seq"].view(-1, seq.shape[1]), kwargs["sim_positions"].view(-1, seq.shape[1])
        sim_num = kwargs["sim_seq"].shape[1]
        sim_log_feats = self.log2feats(sim_seq, sim_positions)[:, -1, :]    # (bs*sim_num, hidden_size)
        sim_log_feats = sim_log_feats.detach().view(seq.shape[0], sim_num, -1)  # (bs, sim_num, hidden_size)
        sim_log_feats = torch.mean(sim_log_feats, dim=1)

        if self.user_sim_func == "cl":
            align_loss = self.align(log_feats, sim_log_feats)
        elif self.user_sim_func == "kd":
            align_loss = self.align(log_feats, sim_log_feats)

        if self.item_reg:
            unfold_item_id = torch.masked_select(seq, seq>0)
            llm_item_emb = self.adapter(self.llm_item_emb(unfold_item_id))
            id_item_emb = self.id_item_emb(unfold_item_id)
            reg_loss = self.reg(llm_item_emb, id_item_emb)
            loss += self.beta * reg_loss

        # ---------------------- 损失计算逻辑不变（ClusterHandler内部已适配） ----------------------
        # 计算聚类约束损失（根据use_fuzzy自动切换传统/模糊损失）
        item_ids = seq[seq > 0]  # 有效物品ID
        item_embeddings = self.id_item_emb(item_ids)  # 获取物品嵌入
        cluster_loss = self.cluster_handler.calculate_cluster_loss(item_ids, item_embeddings)
        loss += self.gamma * cluster_loss

        loss += self.alpha * align_loss

        return loss


# ---------------------- 以下为LLMESR_GRU4Rec和LLMESR_Bert4Rec的修改（类似） ----------------------
class LLMESR_GRU4Rec(DualLLMGRU4Rec):

    def __init__(self, user_num, item_num, device, args):

        super().__init__(user_num, item_num, device, args)
        self.alpha = args.alpha
        self.user_sim_func = args.user_sim_func
        self.item_reg = args.item_reg
        self.gamma = args.gamma  # 聚类约束强度

        # 关键修改：传递use_fuzzy参数
        self.cluster_handler = ClusterHandler(
            dataset=args.dataset,
            hidden_size=args.hidden_size,
            device=device,
            use_fuzzy=args.use_fuzzy  # 新增
        )

        # 其余初始化逻辑不变...
        if self.user_sim_func == "cl":
            self.align = Contrastive_Loss2()
        elif self.user_sim_func == "kd":
            self.align = nn.MSELoss()
        else:
            raise ValueError

        self.projector1 = nn.Linear(2*args.hidden_size, 2*args.hidden_size)
        self.projector2 = nn.Linear(2*args.hidden_size, 2*args.hidden_size)

        if self.item_reg:
            self.beta = args.beta
            self.reg = Contrastive_Loss2()

        self._init_weights()


    def forward(self, 
                seq, 
                pos, 
                neg, 
                positions,
                **kwargs):
        
        # 原有forward逻辑不变...
        loss = super().forward(seq, pos, neg, positions, **kwargs)
        
        log_feats = self.log2feats(seq)[:, -1, :]
        sim_seq, sim_positions = kwargs["sim_seq"].view(-1, seq.shape[1]), kwargs["sim_positions"].view(-1, seq.shape[1])
        sim_num = kwargs["sim_seq"].shape[1]
        sim_log_feats = self.log2feats(sim_seq)[:, -1, :]    
        sim_log_feats = sim_log_feats.detach().view(seq.shape[0], sim_num, -1)
        sim_log_feats = torch.mean(sim_log_feats, dim=1)

        if self.user_sim_func == "cl":
            align_loss = self.align(log_feats, sim_log_feats)
        elif self.user_sim_func == "kd":
            align_loss = self.align(log_feats, sim_log_feats)

        if self.item_reg:
            unfold_item_id = torch.masked_select(seq, seq>0)
            llm_item_emb = self.adapter(self.llm_item_emb(unfold_item_id))
            id_item_emb = self.id_item_emb(unfold_item_id)
            reg_loss = self.reg(llm_item_emb, id_item_emb)
            loss += self.beta * reg_loss

        # 聚类损失计算（自动适配）
        item_ids = seq[seq > 0]
        item_embeddings = self.id_item_emb(item_ids)
        cluster_loss = self.cluster_handler.calculate_cluster_loss(item_ids, item_embeddings)
        loss += self.gamma * cluster_loss

        loss += self.alpha * align_loss

        return loss


class LLMESR_Bert4Rec(DualLLMBert4Rec):

    def __init__(self, user_num, item_num, device, args):

        super().__init__(user_num, item_num, device, args)
        self.alpha = args.alpha
        self.user_sim_func = args.user_sim_func
        self.item_reg = args.item_reg
        self.gamma = args.gamma  # 聚类约束强度超参数

        # 关键修改：传递use_fuzzy参数
        self.cluster_handler = ClusterHandler(
            dataset=args.dataset,
            hidden_size=args.hidden_size,
            device=device,
            use_fuzzy=args.use_fuzzy  # 新增
        )

        # 其余初始化逻辑不变...
        if self.user_sim_func == "cl":
            self.align = Contrastive_Loss2()
        elif self.user_sim_func == "kd":
            self.align = nn.MSELoss()
        else:
            raise ValueError

        self.projector1 = nn.Linear(2*args.hidden_size, 2*args.hidden_size)
        self.projector2 = nn.Linear(2*args.hidden_size, 2*args.hidden_size)

        if self.item_reg:
            self.reg = Contrastive_Loss2()

        self._init_weights()


    def forward(self, 
                seq, 
                pos, 
                neg, 
                positions,
                **kwargs):
        
        loss = super().forward(seq, pos, neg, positions, **kwargs)
        
        log_feats = self.log2feats(seq, positions)[:, -1, :]
        sim_seq, sim_positions = kwargs["sim_seq"].view(-1, seq.shape[1]), kwargs["sim_positions"].view(-1, seq.shape[1])
        sim_num = kwargs["sim_seq"].shape[1]
        sim_log_feats = self.log2feats(sim_seq, sim_positions)[:, -1, :]
        sim_log_feats = sim_log_feats.detach().view(seq.shape[0], sim_num, -1)
        sim_log_feats = torch.mean(sim_log_feats, dim=1)

        if self.user_sim_func == "cl":
            align_loss = self.align(log_feats, sim_log_feats)
        elif self.user_sim_func == "kd":
            align_loss = self.align(log_feats, sim_log_feats)

        # 聚类损失计算（自动适配）
        cluster_loss = self.calculate_cluster_loss(seq)

        # 整合所有损失
        loss += self.gamma * cluster_loss  # 聚类约束损失
        loss += self.alpha * align_loss

        return loss
        
    def calculate_cluster_loss(self, seq):
        # 原有过滤逻辑不变...
        valid_mask = (seq > 0) & (seq != self.mask_token)        
        item_ids   = seq[valid_mask]                             
        if item_ids.numel() == 0:                                
            return torch.tensor(0.0, device=seq.device)
        
        max_id_in_cluster = self.cluster_handler.item_cluster.size(0) - 1
        item_ids = item_ids[item_ids <= max_id_in_cluster]

        if item_ids.numel() == 0:                                
            return torch.tensor(0.0, device=seq.device)
        
        # 调用修改后的cluster_handler计算损失（自动切换传统/模糊损失）
        item_embeddings = self.id_item_emb(item_ids)
        cluster_loss    = self.cluster_handler.calculate_cluster_loss(item_ids, item_embeddings)
        return cluster_loss
