# 文本分类实例

## 第一步：导入相关包

In [23]:
from transformers import AutoTokenizer,AutoModelForSequenceClassification,Trainer,TrainingArguments
from datasets import load_dataset

## 第二步：加载数据集

In [24]:
dataset = load_dataset('dirtycomputer/ChnSentiCorp_htl_all',split='train')
dataset = dataset.filter(lambda x:x['review'] is not None)


## 第三步：划分数据集

In [25]:
datasets = dataset.train_test_split(test_size=0.1)

## 第四步：数据集预处理

In [26]:
tokenizer = AutoTokenizer.from_pretrained('hfl/rbt3')

def process_function(examples):
    tokenized_exexamples = tokenizer(examples['review'],max_length=128,truncation=True)
    tokenized_exexamples['labels'] = examples['label']
    return tokenized_exexamples

tokenizer_datasets = datasets.map(process_function,batched=True,remove_columns=datasets['train'].column_names)
    

Map: 100%|██████████████████████████| 777/777 [00:00<00:00, 7068.15 examples/s]


## 第五步：创建模型

In [27]:
model = AutoModelForSequenceClassification.from_pretrained('hfl/rbt3')

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at hfl/rbt3 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


## 第六步：创建评估函数

In [28]:
# 导入 evaluate 模块，假设该模块包含了加载评估指标的函数
import evaluate

# 使用 evaluate 模块中的 load 函数加载准确率和 F1 分数的评估指标
acc_metric = evaluate.load('accuracy')
fl_metric = evaluate.load('f1')

In [29]:
def eval_metric(eval_predict):
    predictions, labels = eval_predict
    predictions = predictions.argmax(axis=-1)
    
    # 计算准确率
    acc = acc_metric.compute(predictions=predictions, references=labels)
    
    # 计算 F1 分数
    fl = fl_metric.compute(predictions=predictions, references=labels)
    
    # 返回准确率和 F1 分数的平均值
    return {'eval_acc': acc['accuracy'], 'eval_fl': fl['f1']}

## 第七步：创建TrainingArguments

In [30]:
train_args = TrainingArguments(
    output_dir="./checkpoints",  # 输出文件夹
    per_device_train_batch_size=64,  # 训练时的 batch size
    per_device_eval_batch_size=128,  # 评估时的 batch size
    logging_steps=10,  # 每隔多少步数记录一次训练日志
    evaluation_strategy='epoch',  # 在每个 epoch 结束时进行评估
    save_strategy='epoch',  # 在每个 epoch 结束时保存模型
    save_total_limit=3,  # 保存模型的最大数量
    learning_rate=2e-5,  # 初始学习率
    weight_decay=0.01,  # 权重衰减，用于防止过拟合
    metric_for_best_model="fl",  # 用于判断最佳模型的指标
    load_best_model_at_end=True  # 训练结束时加载最佳模型
)



## 第八步：创建Trainer

In [31]:
from transformers import Trainer, DataCollatorWithPadding

trainer = Trainer(
    model=model,  # 要训练的模型
    args=train_args,  # 训练的参数设置
    train_dataset=tokenizer_datasets['train'],  # 训练数据集
    eval_dataset=tokenizer_datasets['test'],  # 评估/测试数据集
    data_collator=DataCollatorWithPadding(tokenizer=tokenizer),  # 数据收集器，用于处理数据批次
    compute_metrics=eval_metric  # 计算评估指标的函数
)


## 第九步：训练模型

In [32]:
trainer.train()

You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Epoch,Training Loss,Validation Loss,Acc,Fl
1,0.2697,0.280399,0.880309,0.915837
2,0.2671,0.253383,0.899614,0.925996
3,0.2014,0.243475,0.899614,0.927103


TrainOutput(global_step=330, training_loss=0.29980935978166984, metrics={'train_runtime': 122.7032, 'train_samples_per_second': 170.851, 'train_steps_per_second': 2.689, 'total_flos': 351909933963264.0, 'train_loss': 0.29980935978166984, 'epoch': 3.0})

## 第十步：评估模型

In [34]:
trainer.evaluate(tokenizer_datasets['test'])

{'eval_acc': 0.8996138996138996,
 'eval_fl': 0.9271028037383178,
 'eval_loss': 0.2434752881526947,
 'eval_runtime': 1.8333,
 'eval_samples_per_second': 423.821,
 'eval_steps_per_second': 3.818,
 'epoch': 3.0}

## 第十一步：模型预测

In [37]:
trainer.predict(tokenizer_datasets['test'])

PredictionOutput(predictions=array([[-1.4194119 ,  2.0535192 ],
       [-1.9416336 ,  2.5582216 ],
       [-1.5875318 ,  2.5178654 ],
       ...,
       [-1.5416453 ,  2.1160145 ],
       [ 0.36820832, -0.4314419 ],
       [ 1.2365581 , -1.9516827 ]], dtype=float32), label_ids=array([1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0,
       1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1,
       0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1,
       1, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1,
       1, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1,
       0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1,
       1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1,
       0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1,
       1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1