# 基础组件之Trainer

# 文本分类实例

## Step1 导入相关包

In [1]:
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    Trainer,
    TrainingArguments,
)
from datasets import load_dataset

  from .autonotebook import tqdm as notebook_tqdm


## Step2 加载数据集

In [2]:
dataset = load_dataset("csv", data_files="./ChnSentiCorp_htl_all.csv", split="train")
dataset = dataset.filter(lambda x: x["review"] is not None)
dataset

  return pd.read_csv(xopen(filepath_or_buffer, "rb", download_config=download_config), **kwargs)
Generating train split: 7766 examples [00:00, 230608.33 examples/s]
Filter: 100%|██████████| 7766/7766 [00:00<00:00, 469486.38 examples/s]


Dataset({
    features: ['label', 'review'],
    num_rows: 7765
})

## Step3 划分数据集

In [3]:
datasets = dataset.train_test_split(test_size=0.1)
datasets

DatasetDict({
    train: Dataset({
        features: ['label', 'review'],
        num_rows: 6988
    })
    test: Dataset({
        features: ['label', 'review'],
        num_rows: 777
    })
})

## Step4 数据集预处理

In [4]:
import torch

tokenizer = AutoTokenizer.from_pretrained("hfl/rbt3")


def process_function(examples):
    tokenized_examples = tokenizer(examples["review"], max_length=128, truncation=True)
    tokenized_examples["labels"] = examples["label"]
    return tokenized_examples


tokenized_datasets = datasets.map(
    process_function, batched=True, remove_columns=datasets["train"].column_names
)
tokenized_datasets

Map: 100%|██████████| 6988/6988 [00:00<00:00, 22209.29 examples/s]
Map: 100%|██████████| 777/777 [00:00<00:00, 20117.62 examples/s]


DatasetDict({
    train: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 6988
    })
    test: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 777
    })
})

## Step5 创建模型

In [5]:
model = AutoModelForSequenceClassification.from_pretrained("hfl/rbt3")

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at hfl/rbt3 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [6]:
model.config

BertConfig {
  "_name_or_path": "hfl/rbt3",
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "directionality": "bidi",
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 3,
  "output_past": true,
  "pad_token_id": 0,
  "pooler_fc_size": 768,
  "pooler_num_attention_heads": 12,
  "pooler_num_fc_layers": 3,
  "pooler_size_per_head": 128,
  "pooler_type": "first_token_transform",
  "position_embedding_type": "absolute",
  "transformers_version": "4.42.4",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 21128
}

## **Step6 创建评估函数

In [7]:
import evaluate

acc_metric = evaluate.load("accuracy")
f1_metric = evaluate.load("f1")

In [8]:
def eval_metric(eval_predict):
    predictions, labels = eval_predict
    predictions = predictions.argmax(axis=-1)
    acc = acc_metric.compute(predictions=predictions, references=labels)
    f1 = f1_metric.compute(predictions=predictions, references=labels)
    acc.update(f1)
    return acc

## **Step7 创建TrainingArguments

In [9]:
train_args = TrainingArguments(
    output_dir="./checkpoints",  # 输出文件夹
    per_device_train_batch_size=64,  # 训练时的batch_size
    per_device_eval_batch_size=128,  # 验证时的batch_size
    logging_steps=10,  # log 打印的频率
    evaluation_strategy="epoch",  # 评估策略
    save_strategy="epoch",  # 保存策略
    save_total_limit=3,  # 最大保存数
    learning_rate=2e-5,  # 学习率
    weight_decay=0.01,  # weight_decay
    metric_for_best_model="f1",  # 设定评估指标
    load_best_model_at_end=True,
)  # 训练完成后加载最优模型
train_args



TrainingArguments(
_n_gpu=1,
accelerator_config={'split_batches': False, 'dispatch_batches': None, 'even_batches': True, 'use_seedable_sampler': True, 'non_blocking': False, 'gradient_accumulation_kwargs': None, 'use_configured_state': False},
adafactor=False,
adam_beta1=0.9,
adam_beta2=0.999,
adam_epsilon=1e-08,
auto_find_batch_size=False,
batch_eval_metrics=False,
bf16=False,
bf16_full_eval=False,
data_seed=None,
dataloader_drop_last=False,
dataloader_num_workers=0,
dataloader_persistent_workers=False,
dataloader_pin_memory=True,
dataloader_prefetch_factor=None,
ddp_backend=None,
ddp_broadcast_buffers=None,
ddp_bucket_cap_mb=None,
ddp_find_unused_parameters=None,
ddp_timeout=1800,
debug=[],
deepspeed=None,
disable_tqdm=False,
dispatch_batches=None,
do_eval=True,
do_predict=False,
do_train=False,
eval_accumulation_steps=None,
eval_delay=0,
eval_do_concat_batches=True,
eval_on_start=False,
eval_steps=None,
eval_strategy=epoch,
evaluation_strategy=epoch,
fp16=False,
fp16_backend=auto,
f

## **Step8 创建Trainer

In [10]:
from transformers import DataCollatorWithPadding

trainer = Trainer(
    model=model,
    args=train_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["test"],
    data_collator=DataCollatorWithPadding(tokenizer=tokenizer),
    compute_metrics=eval_metric,
)

## **Step9 模型训练

In [11]:
trainer.train()

  3%|▎         | 11/330 [00:02<00:51,  6.20it/s]

{'loss': 0.646, 'grad_norm': 3.5232670307159424, 'learning_rate': 1.9393939393939395e-05, 'epoch': 0.09}


  6%|▋         | 21/330 [00:03<00:48,  6.39it/s]

{'loss': 0.546, 'grad_norm': 2.0024499893188477, 'learning_rate': 1.8787878787878792e-05, 'epoch': 0.18}


  9%|▉         | 31/330 [00:05<00:47,  6.31it/s]

{'loss': 0.4533, 'grad_norm': 2.2141003608703613, 'learning_rate': 1.8181818181818182e-05, 'epoch': 0.27}


 12%|█▏        | 41/330 [00:06<00:45,  6.38it/s]

{'loss': 0.3967, 'grad_norm': 3.3532445430755615, 'learning_rate': 1.7575757575757576e-05, 'epoch': 0.36}


 15%|█▌        | 51/330 [00:08<00:43,  6.37it/s]

{'loss': 0.3707, 'grad_norm': 3.744962453842163, 'learning_rate': 1.6969696969696972e-05, 'epoch': 0.45}


 18%|█▊        | 61/330 [00:09<00:42,  6.32it/s]

{'loss': 0.353, 'grad_norm': 2.1473207473754883, 'learning_rate': 1.6363636363636366e-05, 'epoch': 0.55}


 22%|██▏       | 71/330 [00:11<00:40,  6.33it/s]

{'loss': 0.3105, 'grad_norm': 3.4839985370635986, 'learning_rate': 1.575757575757576e-05, 'epoch': 0.64}


 25%|██▍       | 81/330 [00:13<00:38,  6.41it/s]

{'loss': 0.3159, 'grad_norm': 5.251232147216797, 'learning_rate': 1.5151515151515153e-05, 'epoch': 0.73}


 28%|██▊       | 91/330 [00:14<00:37,  6.30it/s]

{'loss': 0.3105, 'grad_norm': 2.1802916526794434, 'learning_rate': 1.4545454545454546e-05, 'epoch': 0.82}


 31%|███       | 101/330 [00:16<00:35,  6.37it/s]

{'loss': 0.2988, 'grad_norm': 2.2574939727783203, 'learning_rate': 1.3939393939393942e-05, 'epoch': 0.91}


 33%|███▎      | 110/330 [00:17<00:34,  6.46it/s]

{'loss': 0.2881, 'grad_norm': 10.399699211120605, 'learning_rate': 1.3333333333333333e-05, 'epoch': 1.0}


                                                 
 33%|███▎      | 110/330 [00:18<00:34,  6.46it/s]

{'eval_loss': 0.3263213336467743, 'eval_accuracy': 0.8648648648648649, 'eval_f1': 0.9023255813953488, 'eval_runtime': 0.604, 'eval_samples_per_second': 1286.423, 'eval_steps_per_second': 11.589, 'epoch': 1.0}


 37%|███▋      | 121/330 [00:20<00:34,  6.10it/s]

{'loss': 0.2652, 'grad_norm': 3.0868773460388184, 'learning_rate': 1.2727272727272728e-05, 'epoch': 1.09}


 40%|███▉      | 131/330 [00:21<00:31,  6.37it/s]

{'loss': 0.2858, 'grad_norm': 2.3838436603546143, 'learning_rate': 1.2121212121212122e-05, 'epoch': 1.18}


 43%|████▎     | 141/330 [00:23<00:29,  6.37it/s]

{'loss': 0.2485, 'grad_norm': 3.4410195350646973, 'learning_rate': 1.1515151515151517e-05, 'epoch': 1.27}


 46%|████▌     | 151/330 [00:24<00:28,  6.39it/s]

{'loss': 0.2842, 'grad_norm': 2.3922641277313232, 'learning_rate': 1.0909090909090909e-05, 'epoch': 1.36}


 49%|████▉     | 161/330 [00:26<00:25,  6.54it/s]

{'loss': 0.3044, 'grad_norm': 2.341148614883423, 'learning_rate': 1.0303030303030304e-05, 'epoch': 1.45}


 52%|█████▏    | 171/330 [00:27<00:24,  6.37it/s]

{'loss': 0.2532, 'grad_norm': 7.064360618591309, 'learning_rate': 9.696969696969698e-06, 'epoch': 1.55}


 55%|█████▍    | 181/330 [00:29<00:23,  6.33it/s]

{'loss': 0.2809, 'grad_norm': 7.462316036224365, 'learning_rate': 9.090909090909091e-06, 'epoch': 1.64}


 58%|█████▊    | 191/330 [00:31<00:21,  6.37it/s]

{'loss': 0.2423, 'grad_norm': 3.5531084537506104, 'learning_rate': 8.484848484848486e-06, 'epoch': 1.73}


 61%|██████    | 201/330 [00:32<00:20,  6.38it/s]

{'loss': 0.2569, 'grad_norm': 3.1288113594055176, 'learning_rate': 7.87878787878788e-06, 'epoch': 1.82}


 64%|██████▍   | 211/330 [00:34<00:18,  6.38it/s]

{'loss': 0.2116, 'grad_norm': 3.6639513969421387, 'learning_rate': 7.272727272727273e-06, 'epoch': 1.91}


 67%|██████▋   | 220/330 [00:35<00:17,  6.41it/s]

{'loss': 0.257, 'grad_norm': 5.330965518951416, 'learning_rate': 6.666666666666667e-06, 'epoch': 2.0}


                                                 
 67%|██████▋   | 220/330 [00:36<00:17,  6.41it/s]

{'eval_loss': 0.2903926968574524, 'eval_accuracy': 0.8751608751608752, 'eval_f1': 0.9089201877934272, 'eval_runtime': 0.6089, 'eval_samples_per_second': 1276.15, 'eval_steps_per_second': 11.497, 'epoch': 2.0}


 70%|███████   | 231/330 [00:38<00:16,  6.11it/s]

{'loss': 0.2235, 'grad_norm': 3.3569376468658447, 'learning_rate': 6.060606060606061e-06, 'epoch': 2.09}


 73%|███████▎  | 241/330 [00:39<00:13,  6.37it/s]

{'loss': 0.2356, 'grad_norm': 3.4687695503234863, 'learning_rate': 5.4545454545454545e-06, 'epoch': 2.18}


 76%|███████▌  | 251/330 [00:41<00:12,  6.37it/s]

{'loss': 0.2339, 'grad_norm': 3.1471493244171143, 'learning_rate': 4.848484848484849e-06, 'epoch': 2.27}


 79%|███████▉  | 261/330 [00:42<00:10,  6.39it/s]

{'loss': 0.2012, 'grad_norm': 2.106525421142578, 'learning_rate': 4.242424242424243e-06, 'epoch': 2.36}


 82%|████████▏ | 271/330 [00:44<00:09,  6.35it/s]

{'loss': 0.255, 'grad_norm': 3.262303113937378, 'learning_rate': 3.6363636363636366e-06, 'epoch': 2.45}


 85%|████████▌ | 281/330 [00:46<00:07,  6.38it/s]

{'loss': 0.2261, 'grad_norm': 4.466238975524902, 'learning_rate': 3.0303030303030305e-06, 'epoch': 2.55}


 88%|████████▊ | 291/330 [00:47<00:06,  6.39it/s]

{'loss': 0.2335, 'grad_norm': 3.5267515182495117, 'learning_rate': 2.4242424242424244e-06, 'epoch': 2.64}


 91%|█████████ | 301/330 [00:49<00:04,  6.37it/s]

{'loss': 0.2572, 'grad_norm': 4.796464443206787, 'learning_rate': 1.8181818181818183e-06, 'epoch': 2.73}


 94%|█████████▍| 311/330 [00:50<00:02,  6.36it/s]

{'loss': 0.2147, 'grad_norm': 2.4303932189941406, 'learning_rate': 1.2121212121212122e-06, 'epoch': 2.82}


 97%|█████████▋| 321/330 [00:52<00:01,  6.39it/s]

{'loss': 0.2511, 'grad_norm': 2.451734781265259, 'learning_rate': 6.060606060606061e-07, 'epoch': 2.91}


100%|██████████| 330/330 [00:53<00:00,  6.45it/s]

{'loss': 0.2212, 'grad_norm': 10.622032165527344, 'learning_rate': 0.0, 'epoch': 3.0}


                                                 
100%|██████████| 330/330 [00:54<00:00,  6.45it/s]

{'eval_loss': 0.27881765365600586, 'eval_accuracy': 0.8854568854568855, 'eval_f1': 0.9154795821462488, 'eval_runtime': 0.6033, 'eval_samples_per_second': 1288.003, 'eval_steps_per_second': 11.604, 'epoch': 3.0}


100%|██████████| 330/330 [00:55<00:00,  5.99it/s]

{'train_runtime': 55.1018, 'train_samples_per_second': 380.459, 'train_steps_per_second': 5.989, 'train_loss': 0.29492162068684896, 'epoch': 3.0}





TrainOutput(global_step=330, training_loss=0.29492162068684896, metrics={'train_runtime': 55.1018, 'train_samples_per_second': 380.459, 'train_steps_per_second': 5.989, 'total_flos': 351909933963264.0, 'train_loss': 0.29492162068684896, 'epoch': 3.0})

## **Step10 模型评估

In [12]:
trainer.evaluate(tokenized_datasets["test"])

100%|██████████| 7/7 [00:00<00:00, 14.06it/s]


{'eval_loss': 0.27881765365600586,
 'eval_accuracy': 0.8854568854568855,
 'eval_f1': 0.9154795821462488,
 'eval_runtime': 0.6054,
 'eval_samples_per_second': 1283.45,
 'eval_steps_per_second': 11.563,
 'epoch': 3.0}

## **Step11 模型预测

In [13]:
trainer.predict(tokenized_datasets["test"])

100%|██████████| 7/7 [00:00<00:00, 14.17it/s]


PredictionOutput(predictions=array([[ 1.6659021, -1.5507957],
       [-1.6304892,  1.9699897],
       [ 1.8136438, -1.4618142],
       ...,
       [ 1.3355892, -1.5159024],
       [-2.6307316,  2.9208736],
       [-2.280545 ,  2.4365335]], dtype=float32), label_ids=array([0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 0,
       1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1,
       1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,
       0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1,
       1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1,
       1, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1,
       0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 0, 1,
       1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0,
       1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1,
    

In [14]:
from transformers import pipeline

id2_label = id2_label = {0: "差评！", 1: "好评！"}
model.config.id2label = id2_label
pipe = pipeline("text-classification", model=model, tokenizer=tokenizer, device=0)

In [15]:
sen = "我觉得不错！"
pipe(sen)

[{'label': '好评！', 'score': 0.9911233186721802}]