# Hugging Face Transformers 微调训练入门

本示例将介绍基于 Transformers 实现模型微调训练的主要流程，包括：
- 数据集下载
- 数据预处理
- 训练超参数配置
- 训练评估指标设置
- 训练器基本介绍
- 实战训练
- 模型保存

## YelpReviewFull 数据集

**Hugging Face 数据集：[ YelpReviewFull ](https://huggingface.co/datasets/yelp_review_full)**

### 数据集摘要

Yelp评论数据集包括来自Yelp的评论。它是从Yelp Dataset Challenge 2015数据中提取的。

### 支持的任务和排行榜
文本分类、情感分类：该数据集主要用于文本分类：给定文本，预测情感。

### 语言
这些评论主要以英语编写。

### 数据集结构

#### 数据实例
一个典型的数据点包括文本和相应的标签。

来自YelpReviewFull测试集的示例如下：

```json
{
    'label': 0,
    'text': 'I got \'new\' tires from them and within two weeks got a flat. I took my car to a local mechanic to see if i could get the hole patched, but they said the reason I had a flat was because the previous patch had blown - WAIT, WHAT? I just got the tire and never needed to have it patched? This was supposed to be a new tire. \\nI took the tire over to Flynn\'s and they told me that someone punctured my tire, then tried to patch it. So there are resentful tire slashers? I find that very unlikely. After arguing with the guy and telling him that his logic was far fetched he said he\'d give me a new tire \\"this time\\". \\nI will never go back to Flynn\'s b/c of the way this guy treated me and the simple fact that they gave me a used tire!'
}
```

#### 数据字段

- 'text': 评论文本使用双引号（"）转义，任何内部双引号都通过2个双引号（""）转义。换行符使用反斜杠后跟一个 "n" 字符转义，即 "\n"。
- 'label': 对应于评论的分数（介于1和5之间）。

#### 数据拆分

Yelp评论完整星级数据集是通过随机选取每个1到5星评论的130,000个训练样本和10,000个测试样本构建的。总共有650,000个训练样本和50,000个测试样本。

## 下载数据集

In [1]:
from datasets import load_dataset

dataset = load_dataset("yelp_review_full")

In [2]:
dataset

DatasetDict({
    train: Dataset({
        features: ['label', 'text'],
        num_rows: 650000
    })
    test: Dataset({
        features: ['label', 'text'],
        num_rows: 50000
    })
})

In [3]:
dataset["train"][111]

{'label': 2,
 'text': "As far as Starbucks go, this is a pretty nice one.  The baristas are friendly and while I was here, a lot of regulars must have come in, because they bantered away with almost everyone.  The bathroom was clean and well maintained and the trash wasn't overflowing in the canisters around the store.  The pastries looked fresh, but I didn't partake.  The noise level was also at a nice working level - not too loud, music just barely audible.\\n\\nI do wish there was more seating.  It is nice that this location has a counter at the end of the bar for sole workers, but it doesn't replace more tables.  I'm sure this isn't as much of a problem in the summer when there's the space outside.\\n\\nThere was a treat receipt promo going on, but the barista didn't tell me about it, which I found odd.  Usually when they have promos like that going on, they ask everyone if they want their receipt to come back later in the day to claim whatever the offer is.  Today it was one of th

In [4]:
import random
import pandas as pd
import datasets
from IPython.display import display, HTML

In [5]:
def show_random_elements(dataset, num_examples=10):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)
    
    df = pd.DataFrame(dataset[picks])
    for column, typ in dataset.features.items():
        if isinstance(typ, datasets.ClassLabel):
            df[column] = df[column].transform(lambda i: typ.names[i])
    display(HTML(df.to_html()))

In [6]:
show_random_elements(dataset["train"])

Unnamed: 0,label,text
0,4 stars,"Really good food. Service lacking, but food made up for it this time. We will return."
1,4 stars,"3 Taquitos - a little over priced, but excellent. \n\nBombero Burrito - Spicy! Super good. \n\nAl Pastor Burrito - ask for no beans and add cilantro and onion. Very good. \n\nOverall this place has great food, but not quite as good as a San Diego Taco Shop and the portions are a bit on the small side for the prices."
2,1 star,"Ordered: Calimari, Antipasto, Spicey Buffalo Wings, and Garlic Knots.\n\n1. Calamari (actually low grade pig rectum slices mixed with a few tentacle bits...seriously, look it up). This is D grade frozen cheap supply, sold as seafood. The only way this can be considered seafood, is if the pig was a surfer, or ate crab legs and shrimp cakes every day of its sad life, waiting to have its turd-cutter transformed into sliced overpriced greasy lies.\n\n2. Spicey Buffalo Wings. Yes, it chicken...9 of 12 wings were freezer-burned LEGS, not wings. The chewiness, freezer burn, and overall \""wet skin\"" texture of the heat lamp cooked wings was more like a zombie chew toy, than a delicious chicken wing. Yuck.\n\n3.Antipasto...Romain lettuce, sliced CANNED black olives, Sliced Red Tomato, worst provolone I've ever tasted, D-rate salami, and, well, that was it...other than the water-slide water called \""Italian Dressing\"" that came with it. Pepperoncini? Nope. Green olives? Negative. Flavor? Nooooooope!\n\n4. Garlic Knots. These dough balls must cheap This crap-shack in operation. Delicious? Nope...but after eating a freezer burned bouquet of bad choices, I guess a near flavorless ball of dough, low grade olive oil (not even 100% olive oil), and flavorless garlic mush, formed into a blown out balloon knot, this became the highlight of our night. \n\nIn conclusion, I'M ON VACATION, and after a long day of fun, on our last day, decided to waste $40 on straight-up garbage. I'm looking at a pile of wings, half a disabled salad, and a some pig anus rings. The Garlic Knots were superb though! If superb means lackluster bread nuggets.\n\nPs, we drove to get pickup, because delivery was going to be AN HOUR AND A HALF! We're a mile away! How many tastebud-less drones could possibly want to ingest this garbage at once? \nHopefully NOT YOU!\n\nAdd me as a friend if you \""Feel\"" my review! Peace and Prosperity!"
3,2 star,"Okay, is there something going on with fast food places in the Valley? Or is it all over? Places seem to be doing something with their meat that I am not liking. Some kind of greasyspicymuddy taste has invaded burgers from Wendy's and Jack in the Box and I'd like to know WTF is going on here.\n\nI used to really like Jack's Ultimate Cheeseburger, and the Jumbo Jack has been a favorite of mine since the '70s. Now yuk. Something has changed. It's almost like they are trying to duplicate some greasy freshly barbecued taste and failing BADLY.\n\nI hope you guys don't change your tacos or egg rolls or I won't have much of a reason to go anymore. If I need a fast burger, I'll go to McD and get a burger there. They'd never change.\n\nOf course, they changed Coke, so I guess I shouldn't say never. . ."
4,1 star,"I really love tropical smoothie and I love that this one is by my house to always satisfy a craving. I've never had any issues or complaints until now. Precious was my cashier and she was completely monotone and doesn't even say hello. I know it's just a fast food/drink place and I'm not expecting fine dining service but a \""hello\"" or a smile maybe even a \""thank you\"" would have been nice. My total was 12.06 and when I gave her a 20 she asked in her monotone voice \""do you have 6 cents\"" if I had 6 cents handy I would have given it to her. Hopefully next time I return my experience will be with someone who doesn't completely show that they hate their job."
5,1 star,"If I could give this place zero stars I would. I just finished my $24.09 shrimp quesadilla with beans and rice and drink and it was incredibly disappointing! After a recommendation from a good friend we decide to do a late night dinner. We were greeted by an officer of the Phoenix police department before we even got to the front door. Worrisome?! Once we walk through the doors my eardrums were assaulted with extremely loud music. I'm all for setting a mood but the restaurant was packed with all but 3 tables and us. Not the best group of 7 dinner date atmosphere. We had to scream over the \""tunes\"". After waiting for what felt like 20 minutes for our waitress to come over we finally placed our order. I knew the shrimp quesadilla cost $12.99 and the drink was $1.99. High but willing to pay after the waitress said it was a decent plate serving. I also added rice and beans. The meal came out quickly but my white rice was adorned with wrinkly peas and one random Lima bean... Odd. The quesadilla was soggy and there was hardly any shrimp in it. We continually asked for the music to be turned down and we kindly ignored. RUDE! Right before the check came they decided it was a good idea to have a full mariachi band play to all but 9 people. So loud I thought my ears were bleeding. Then the bills come out. We asked to pay separate. No big deal. Then she proceeded to hand out our bills. I looked at the price and nearly fell off my chair. She had added the total of everyone's meals and put a 10% gratuity on it! Not cool!!!! Then I see a charge for $4 for a drink. I asked the waitress what the charge was and she proceeded to tell me that after 9pm the drinks are no longer $1.99. Everyone got water at that same price as well. Never heard of a $4 glass of water. She never mentioned this, nor did it say anything about that on their menu... Pretty petty if you ask me. Now it totally makes sense why there was no one in this place on a Friday night! I'm so disappointed and feel like I have been taken advantage of. I'm a good tipper and usually look past \""petty\"" stuff but this was the worst dining experience I've had. Ever!"
6,4 stars,Delicious food and great service on this trip. Yum!
7,4 stars,Always thought this place would suck monkey choad But it's pretty good. Not amazing but fast simple tasty. I do like most that the bread is very thin overall. I try not to eat much bread as it is so this is a big plus. I got the Italian I'm sure the meat is riddled with nitrates but I would go again if I was desperate enough.
8,4 stars,we tried this place out and loved it! The hibrachi chef was really good and funny! The staff was great with our kids. The food was great especially the sushi! The portions we good sized and there was enough to take home! Definetly will be eating here again!
9,2 star,"My most recent joyous interactions with Dicks er, Cox.\n\n1. I got a wireless router through Cox when they hooked up my service. Just hooked it up, needed tech support, Cox phone tech forwards my call on to the manufacturer Netgear after assuring me that the tech support is through Netgear and they have an agreement where all the wireless support goes to Netgear. I get Netgear tech on the line only to hear him mumbling about stinking Cox and why do they keep telling people that. Then he informs me that based on the serial number on the router, Cox activated it already and the service period is already expired but he will assist with set-up issues this one time. It's already expired! I just hooked the freaking thing up!\n\n2. My e-mail randomly cannot be loaded and cable box has to be unplugged and reset roughly every 10 days.\n\nQuest, when are you coming to my area?"


## 预处理数据

下载数据集到本地后，使用 Tokenizer 来处理文本，对于长度不等的输入数据，可以使用填充（padding）和截断（truncation）策略来处理。

Datasets 的 `map` 方法，支持一次性在整个数据集上应用预处理函数。

下面使用填充到最大长度的策略，处理整个数据集：

In [7]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True)

tokenized_datasets = dataset.map(tokenize_function, batched=True)



In [8]:
show_random_elements(tokenized_datasets["train"], num_examples=1)

Unnamed: 0,label,text,input_ids,token_type_ids,attention_mask
0,2 star,"1.5 star rounding up...\n\n\nI've been to this nail salon twice. The first time was for a pedi... Service was alright and quality was mediocre. I didn't have to wait and was seated to soak my feet right away. Usually I don't mind having my feet soaking for a bit before the actual pedi but this time they had me soaking for over a half an hour! My feet were starting to get all yucky! Anyway the pedicure didn't last as long as the wait and the polish lasted about a week. (Polish usually last about three weeks on my toes before chipping or cracking) I only polish my toes white so when I told him I wanted white on my toes, he turned to the other nail tech and said \""of course she'd choose a hard color to do\"" in Vietnamese. (I'm Vietnamese and understand some of it)\n\nNow my second time I walked in there were two techs and one customer, I was told it was going to be a five min wait. I ended up waiting more over 20 mins... I got irritated and walked out. I won't be going back there again.","[101, 122, 119, 126, 2851, 1668, 1158, 1146, 119, 119, 119, 165, 183, 165, 183, 165, 183, 2240, 112, 1396, 1151, 1106, 1142, 16255, 20310, 3059, 119, 1109, 1148, 1159, 1108, 1111, 170, 185, 1174, 1182, 119, 119, 119, 2516, 1108, 15354, 1105, 3068, 1108, 1143, 13447, 13782, 119, 146, 1238, 112, 189, 1138, 1106, 3074, 1105, 1108, 8808, 1106, 1177, 3715, 1139, 1623, 1268, 1283, 119, 12378, 146, 1274, 112, 189, 1713, 1515, 1139, 1623, 25782, 1111, 170, 2113, 1196, 1103, 4315, 185, 1174, 1182, 1133, 1142, 1159, 1152, 1125, 1143, 25782, 1111, 1166, 170, 1544, 1126, 2396, 106, ...]","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...]","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...]"


### 数据抽样

使用 1000 个数据样本，在 BERT 上演示小规模训练（基于 Pytorch Trainer）

`shuffle()`函数会随机重新排列列的值。如果您希望对用于洗牌数据集的算法有更多控制，可以在此函数中指定generator参数来使用不同的numpy.random.Generator。

In [9]:
small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(1000))
small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(1000))
# 取全集 650000
full_train_dataset =  tokenized_datasets["train"]
# 取全集 50000
full_eval_dataset = tokenized_datasets["test"]

## 微调训练配置

### 加载 BERT 模型

警告通知我们正在丢弃一些权重（`vocab_transform` 和 `vocab_layer_norm` 层），并随机初始化其他一些权重（`pre_classifier` 和 `classifier` 层）。在微调模型情况下是绝对正常的，因为我们正在删除用于预训练模型的掩码语言建模任务的头部，并用一个新的头部替换它，对于这个新头部，我们没有预训练的权重，所以库会警告我们在用它进行推理之前应该对这个模型进行微调，而这正是我们要做的事情。

In [10]:
from transformers import AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", num_labels=5)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


### 训练超参数（TrainingArguments）

完整配置参数与默认值：https://huggingface.co/docs/transformers/v4.36.1/en/main_classes/trainer#transformers.TrainingArguments

源代码定义：https://github.com/huggingface/transformers/blob/v4.36.1/src/transformers/training_args.py#L161

**最重要配置：模型权重保存路径(output_dir)**

In [11]:
from transformers import TrainingArguments

model_dir = "models/bert-base-cased-finetune-yelp"

# logging_steps 默认值为500，根据我们的训练数据和步长，将其设置为100
training_args = TrainingArguments(output_dir=model_dir,
                                  per_device_train_batch_size=16,
                                  num_train_epochs=5,
                                  logging_steps=100)

In [None]:
# 完整的超参数配置
print(training_args)

### 训练过程中的指标评估（Evaluate)

**[Hugging Face Evaluate 库](https://huggingface.co/docs/evaluate/index)** 支持使用一行代码，获得数十种不同领域（自然语言处理、计算机视觉、强化学习等）的评估方法。 当前支持 **完整评估指标：https://huggingface.co/evaluate-metric**

训练器（Trainer）在训练过程中不会自动评估模型性能。因此，我们需要向训练器传递一个函数来计算和报告指标。 

Evaluate库提供了一个简单的准确率函数，您可以使用`evaluate.load`函数加载

In [13]:
import numpy as np
import evaluate

metric = evaluate.load("accuracy")


接着，调用 `compute` 函数来计算预测的准确率。

在将预测传递给 compute 函数之前，我们需要将 logits 转换为预测值（**所有Transformers 模型都返回 logits**）。

In [14]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

#### 训练过程指标监控

通常，为了监控训练过程中的评估指标变化，我们可以在`TrainingArguments`指定`evaluation_strategy`参数，以便在 epoch 结束时报告评估指标。

In [15]:
from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(output_dir=model_dir,
                                  evaluation_strategy="epoch", 
                                  per_device_train_batch_size=32,
                                  num_train_epochs=3,
                                  logging_steps=30)

## 开始训练

### 实例化训练器（Trainer）

`kernel version` 版本问题：暂不影响本示例代码运行

In [16]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=full_train_dataset, #fsmall_train_dataset,
    eval_dataset=full_eval_dataset, #small_eval_dataset,
    compute_metrics=compute_metrics,
)

## 使用 nvidia-smi 查看 GPU 使用

为了实时查看GPU使用情况，可以使用 `watch` 指令实现轮询：`watch -n 1 nvidia-smi`:

```shell
Every 1.0s: nvidia-smi                                                   Wed Dec 20 14:37:41 2023

Wed Dec 20 14:37:41 2023
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.129.03             Driver Version: 535.129.03   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  Tesla T4                       Off | 00000000:00:0D.0 Off |                    0 |
| N/A   64C    P0              69W /  70W |   6665MiB / 15360MiB |     98%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A     18395      C   /root/miniconda3/bin/python                6660MiB |
+---------------------------------------------------------------------------------------+
```

In [17]:
trainer.train()



Epoch,Training Loss,Validation Loss,Accuracy
1,0.7059,0.717043,0.6836
2,0.6436,0.690774,0.69744
3,0.5594,0.715075,0.69756


Checkpoint destination directory models/bert-base-cased-finetune-yelp/checkpoint-500 already exists and is non-empty.Saving will proceed but saved results may be invalid.
Checkpoint destination directory models/bert-base-cased-finetune-yelp/checkpoint-1000 already exists and is non-empty.Saving will proceed but saved results may be invalid.
Checkpoint destination directory models/bert-base-cased-finetune-yelp/checkpoint-1500 already exists and is non-empty.Saving will proceed but saved results may be invalid.
Checkpoint destination directory models/bert-base-cased-finetune-yelp/checkpoint-2000 already exists and is non-empty.Saving will proceed but saved results may be invalid.
Checkpoint destination directory models/bert-base-cased-finetune-yelp/checkpoint-2500 already exists and is non-empty.Saving will proceed but saved results may be invalid.
Checkpoint destination directory models/bert-base-cased-finetune-yelp/checkpoint-3000 already exists and is non-empty.Saving will proceed but

TrainOutput(global_step=7620, training_loss=0.6601941308950189, metrics={'train_runtime': 6136.1407, 'train_samples_per_second': 317.789, 'train_steps_per_second': 1.242, 'total_flos': 5.130803778048e+17, 'train_loss': 0.6601941308950189, 'epoch': 3.0})

In [22]:
small_test_dataset = tokenized_datasets["test"].shuffle(seed=64).select(range(1000))

In [23]:
trainer.evaluate(small_test_dataset)



{'eval_loss': 0.7363439798355103,
 'eval_accuracy': 0.686,
 'eval_runtime': 3.177,
 'eval_samples_per_second': 314.766,
 'eval_steps_per_second': 5.036,
 'epoch': 3.0}

### 保存模型和训练状态

- 使用 `trainer.save_model` 方法保存模型，后续可以通过 from_pretrained() 方法重新加载
- 使用 `trainer.save_state` 方法保存训练状态

In [24]:
trainer.save_model(model_dir)

In [25]:
trainer.save_state()

In [26]:
trainer.model.save_pretrained("./")

## Homework: 使用完整的 YelpReviewFull 数据集训练，看 Acc 最高能到多少