Skip to content

yihaohu0118/NLP-Sentiment_Classification

Repository files navigation

BertWithMetadata 电商评论情感分类模型

项目概述

本项目实现了 BertWithMetadata 模型,一个基于 BERT 的自定义模型,用于电商产品评论的情感分类。该模型通过融合产品类别(category)和评分(rating)元数据增强了标准 BERT 架构(bert-base-uncased),并采用焦点损失(Focal Loss)优化不平衡数据集的分类性能,尤其提升了负面情感的分类效果。项目包含模型训练、评估和实验代码,实验基于 Synthetic E-commerce Product Reviews Dataset 的 10 万条评论子集,覆盖八个产品类别(电子产品、家居厨房、时尚、美妆、玩具游戏、图书、健康护理、运动户外)。

功能特点

  • 模型架构:扩展 BertForSequenceClassification,融合类别和评分元数据,提升语义理解能力。
  • 焦点损失:通过聚焦难分类样本(尤其是负面情感),解决类别不平衡问题。
  • 数据集:使用 Synthetic E-commerce Product Reviews Dataset 的子集(10 万条评论),类别均衡后各情感占比 33.3%。
  • 评估指标:报告宏平均精度、召回率、F1 分数、ROC-AUC 和准确率,重点关注负面情感性能。
  • 实验设计:包括对比实验(与 9 种基线模型比较)和消融实验(验证元数据和焦点损失的贡献)。

安装指南

前提条件

  • Python 版本:3.9
  • 硬件:推荐 NVIDIA A100 GPU(支持 CUDA 11.8/12.1,48GB 显存)
  • 操作系统:Linux、Windows 或 macOS

依赖安装

  1. 克隆本仓库:

    git clone https://github.com/your-repo/bertwithmetadata.git
    cd bertwithmetadata
  2. 创建并激活虚拟环境(推荐):

    python -m venv venv
    source venv/bin/activate  # Linux/macOS
    venv\Scripts\activate     # Windows
  3. 安装依赖: 使用提供的 requirements.txt 文件安装必要库:

    pip install -r requirements.txt

    requirements.txt 内容

    torch==2.1.0
    transformers==4.35.0
    numpy==1.24.3
    imbalanced-learn==0.11.0
    scikit-learn==1.3.2
    lightgbm==4.1.0
    python==3.9.*
    matplotlib==3.7.2
    pandas==2.0.3
    
  4. 验证安装:

    python -c "import torch; print(torch.__version__)"

    确保输出为 2.1.0

GPU 支持

若使用 GPU,确保安装与 CUDA 11.8 或 12.1 兼容的 PyTorch 版本。运行以下命令检查 GPU 可用性:

python -c "import torch; print(torch.cuda.is_available())"

使用方法

数据准备

  1. 数据集:下载 Synthetic E-commerce Product Reviews Dataset(400 万条评论),从中随机抽取 10 万条记录。数据应包含以下字段:
    • review_text:评论文本
    • category:产品类别(8 类)
    • rating:评分(1.0-5.0)
    • sentiment:情感标签(正面、中性、负面)
  2. 预处理
    • 缺失值处理:review_text 填充为空字符串,rating 填充为中位数 3.0。
    • 情感编码:使用 LabelEncoder 将情感标签编码为 0(正面)、1(中性)、2(负面)。
    • 类别平衡:使用 RandomOverSampler 使正面、中性、负面情感各占 33.3%。
    • 数据划分:按 7:1.5:1.5 划分为训练集(70%)、验证集(15%)和测试集(15%),采用分层抽样保持类别平衡。
  3. 保存预处理后的数据为 CSV 或其他格式,供模型加载。

模型训练

  1. 模型文件:核心模型代码位于 model.py,实现如下:

    import torch
    from torch import nn
    from transformers import BertForSequenceClassification
    
    class BertWithMetadata(BertForSequenceClassification):
        """自定义BertWithMetadata模型,继承自BertForSequenceClassification,融合category和rating元数据,并使用焦点损失优化情感分类"""
        
        def __init__(self, config, num_categories):
            """
            初始化BertWithMetadata模型,添加元数据嵌入层和融合层。
    
            参数:
                config: transformers.BertConfig, BERT模型的配置文件,包含隐藏层大小等参数
                num_categories: int, 产品类别的数量,用于定义类别嵌入层的输入维度
            """
            super().__init__(config)
            self.num_categories = num_categories
            self.category_embedding = nn.Embedding(num_categories, 32)
            self.rating_fc = nn.Linear(1, 32)
            self.fusion_fc = nn.Linear(config.hidden_size + 32 + 32, config.hidden_size)
            self.dropout = nn.Dropout(0.3)
    
        def forward(self, input_ids, attention_mask, category, rating, labels=None):
            """
            前向传播,融合BERT输出、类别嵌入和评分嵌入,计算焦点损失和分类logits。
    
            参数:
                input_ids: torch.Tensor, 形状为(batch_size, seq_len),输入的token ID序列
                attention_mask: torch.Tensor, 形状为(batch_size, seq_len),输入的注意力掩码
                category: torch.Tensor, 形状为(batch_size,),产品类别ID
                rating: torch.Tensor, 形状为(batch_size,),评分值
                labels: torch.Tensor, 形状为(batch_size,),可选的真实标签,用于计算损失
    
            返回:
                dict: 包含损失('loss')和分类logits('logits')的字典
            """
            bert_output = super().forward(input_ids, attention_mask, output_hidden_states=True)
            pooled_output = bert_output.hidden_states[-1][:, 0, :]  # [CLS] token
            category_embed = self.category_embedding(category)
            rating_embed = self.rating_fc(rating.unsqueeze(1))
            fused = torch.cat([pooled_output, category_embed, rating_embed], dim=-1)
            fused = self.fusion_fc(fused)
            fused = self.dropout(fused)
            logits = self.classifier(fused)
    
            loss = None
            if labels is not None:
                alpha = torch.tensor([0.25, 0.5, 1.0]).to(logits.device)  # Focal loss weights
                gamma = 2.0
                ce_loss = nn.CrossEntropyLoss(reduction='none')(logits, labels)
                pt = torch.exp(-ce_loss)
                focal_loss = (alpha[labels] * (1 - pt) ** gamma * ce_loss).mean()
                loss = focal_loss
    
            return {'loss': loss, 'logits': logits}
  2. 训练脚本(示例 BERTwithmetadata训练.py.py):

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages