# Hugging Face Transformers 微调训练入门

本示例将介绍基于 Transformers 实现模型微调训练的主要流程，包括：
- 数据集下载
- 数据预处理
- 训练超参数配置
- 训练评估指标设置
- 训练器基本介绍
- 实战训练
- 模型保存

## YelpReviewFull 数据集

**Hugging Face 数据集：[ YelpReviewFull ](https://huggingface.co/datasets/yelp_review_full)**

### 数据集摘要

Yelp评论数据集包括来自Yelp的评论。它是从Yelp Dataset Challenge 2015数据中提取的。

### 支持的任务和排行榜
文本分类、情感分类：该数据集主要用于文本分类：给定文本，预测情感。

### 语言
这些评论主要以英语编写。

### 数据集结构

#### 数据实例
一个典型的数据点包括文本和相应的标签。

来自YelpReviewFull测试集的示例如下：

```json
{
    'label': 0,
    'text': 'I got \'new\' tires from them and within two weeks got a flat. I took my car to a local mechanic to see if i could get the hole patched, but they said the reason I had a flat was because the previous patch had blown - WAIT, WHAT? I just got the tire and never needed to have it patched? This was supposed to be a new tire. \\nI took the tire over to Flynn\'s and they told me that someone punctured my tire, then tried to patch it. So there are resentful tire slashers? I find that very unlikely. After arguing with the guy and telling him that his logic was far fetched he said he\'d give me a new tire \\"this time\\". \\nI will never go back to Flynn\'s b/c of the way this guy treated me and the simple fact that they gave me a used tire!'
}
```

#### 数据字段

- 'text': 评论文本使用双引号（"）转义，任何内部双引号都通过2个双引号（""）转义。换行符使用反斜杠后跟一个 "n" 字符转义，即 "\n"。
- 'label': 对应于评论的分数（介于1和5之间）。

#### 数据拆分

Yelp评论完整星级数据集是通过随机选取每个1到5星评论的130,000个训练样本和10,000个测试样本构建的。总共有650,000个训练样本和50,000个测试样本。

## 下载数据集

In [1]:
import os
os.environ["HTTP_PROXY"] = "http://192.168.16.167:8089"
os.environ["HTTPS_PROXY"] = "http://192.168.16.167:8089"

In [2]:
from datasets import load_dataset
dataset = load_dataset("yelp_review_full")

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
dataset

DatasetDict({
    train: Dataset({
        features: ['label', 'text'],
        num_rows: 650000
    })
    test: Dataset({
        features: ['label', 'text'],
        num_rows: 50000
    })
})

In [4]:
dataset["train"][111]

{'label': 2,
 'text': "As far as Starbucks go, this is a pretty nice one.  The baristas are friendly and while I was here, a lot of regulars must have come in, because they bantered away with almost everyone.  The bathroom was clean and well maintained and the trash wasn't overflowing in the canisters around the store.  The pastries looked fresh, but I didn't partake.  The noise level was also at a nice working level - not too loud, music just barely audible.\\n\\nI do wish there was more seating.  It is nice that this location has a counter at the end of the bar for sole workers, but it doesn't replace more tables.  I'm sure this isn't as much of a problem in the summer when there's the space outside.\\n\\nThere was a treat receipt promo going on, but the barista didn't tell me about it, which I found odd.  Usually when they have promos like that going on, they ask everyone if they want their receipt to come back later in the day to claim whatever the offer is.  Today it was one of th

In [5]:
import random
import pandas as pd
import datasets
from IPython.display import display, HTML

In [6]:
def show_random_elements(dataset, num_examples=10):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)
    
    df = pd.DataFrame(dataset[picks])
    for column, typ in dataset.features.items():
        if isinstance(typ, datasets.ClassLabel):
            df[column] = df[column].transform(lambda i: typ.names[i])
    display(HTML(df.to_html()))

In [7]:
show_random_elements(dataset["train"])

Unnamed: 0,label,text
0,1 star,"Very disappointed with the cashier we had. As we checked out & I swiped my debit card, the cashier rang up an item that was already included in my total (not quite a double charge as she started a new sale). However, when we got home we found that the 2nd to the last item in our sale never made it to my bag. If you (Walmart)want more info regarding this, you can respond to this acct with your contact info & I'll scan my receipt to you"
1,3 stars,"Wide variety of selections of hair care products, extensions and wigs. Everything is neatly organized with pricing listed. Don't plan on being helped but I can say you are greeted."
2,1 star,"Really simple order and they were way off simple fries w/ garlic Parmesan missed by a mile!!! To little to say about a bad thing!! Beer taste old!!!!!! Wow people rate this bar way to high\nLess really if you want to know overall my true experience you should call me! I've reviewed many restaurants and bars and my reviews always start with \""I'm looking for success!\"" My experience doesn't bias me it actually the honest feed back that you get from me that the CUSTOMERS THAT WALK OUT TO NEVER COME BACK!!"
3,2 star,"The pork shoulder tacos and sangria were very good. I especially liked the tiny corn tortillas. The atmosphere was nice and very classy. There were lots of candles, big mirrors on the walls, and chandeliers. I think the waitstaff was the one downer from the visit -- they were cold and did not try to provide a pleasant experience. Overall, I just did not think The Mission was anything special. It's a dime in a dozen in Scottsdale."
4,5 stars,"This Lee's is my favorite one in the valley. It's a bit further than I like (the nearest Lee's to me is the one on East Flamingo) but it's worth the drive. I like this Lee's because it's huge, it's always clean, and I somehow leave with something new every time (I think that's attributed to the plethora of liquor and beer available.. how can you not try something new?). The staff is pretty helpful and the prices here are very reasonable. Plus, I can find more craft beers here than any other place I've been to so that's definitely a HUGE plus. Thanks Lee's for keeping my pantry stocked with alcoholic goodness and we'll see you soon."
5,2 star,Taking away stars from this place. They no longer serve dim sum daily. They now have a buffet that for the most part is rather bland and the presentation is so unappetizing. At least make it look good enough to eat. Dim sum comes out over cooked as if it was just reheated.
6,1 star,horrible horrible horribe customer service. They also sit on packages way to long and in some instances loose a package all together. I wish they would just destroy this location and allow the other locations to properly handle mail.
7,2 star,"I watch GR's \""Kitchen Nightmares\"" and he gets upset when customers have to wait. Well, he should stop by his own restaurant. There was an hour wait to get in. So if you go there hungry, you will be REALLY be hungry by the time you are seated. We were seated at 3:30pm and our food didn't arrive until 4:20pm. Seriously, 50 minutes wait for a hamburger!?!\n\nWe were at a table facing the kitchen and could see the cooks were always moving. The runners picked up the food as soon as it hit the waiting area. So I don't know what the holdup was. Maybe they don't have enough cooks.\n\nWe went with friends that ordered burgers, but I didn't want beef. I was surprised that there wasn't a chicken patty type burger on the menu except for the sliders which I ordered. But I was hungrier than that by then and could have used a larger chicken or turkey burger. The chicken sliders were just ok. I kept stealing the avocado off my husband's burger to give it some flavor.\n\nMy husband had a Hell's Kitchen burger, but he said it was not that spicy. But that might be a matter of taste or taste buds.\n\nThe two winning things about the restaurant: the decor and our waitress."
8,2 star,"My hubby likes to find local owned restaurants especially when it comes to Mexican food. He heard about Tony's on ABC 15 and decided we would give it a try.\n\nThe location is not the greatest in Old Avondale, and the building is tiny and shows its age. We ordered our meal as carry-out, I was in no shape to be seen in public so I didn't catch the inside. \n\nI had their carne asada combo plate which of course includes the staple beans and rice. The meat was seriously over cooked, and the beans and rice were bland. My husband makes way better carne asada, cooked to perfection. My husband had the tamale, cheese enchilada combo and he ate it up, mind you, it doesn't take much to please my Sailor Man when it comes to food. \n\nOverall I would compare it to a mom and pop version of Filibertos only not as good. It possibly could be better with a few beers. Maybe I will give them another spin...who knows."
9,2 star,"I've been to Unwined several times now, so I feel like I have a pretty good sense of what it's about. If I could give it 2.5 stars, I would. It's between \""meh\"" and \""ok.\""\n\nOn my first visit, shortly after this place opened, I just came in to take a look and check out the bar. And I'll admit it, I have pretty high standards when it comes to drinks, but the woman who was bartending that night was awful. She acted like I was inconveniencing her when I ordered a Manhattan and couldn't tell me a thing about the \""tincutures\"" used in the house cocktails (more on that later). So I ordered something simpler. However, because I haven't seen her working at the bar since that visit, I won't waste any more time on her.\n\nThe second time I visited, it was about 8:30 pm on a Tuesday night, and my friend and I were hungry. So we decided to give Unwined's food a try. The place was empty except for a couple drinking beer at the end of the bar, but I could see two employees by the register behind the bar, so we approached to ask if the kitchen was still open. No joke, it took almost five minutes before anyone acknowledged us. I know that the manager (I found out later who he was) saw us, but he said nothing until he was done at the register. Then he turned around and said, \""Ok, what can we get you? Sorry for the wait, but I had to deal with *that*\"" (indicating the other employee). I found this man's aggrieved attitude pretty unprofessional, but the kitchen was open, and we were starved. So. \n\nOur meal was a surprisingly good experience--my friend's Big Tukee burger was perfectly prepared and delicious. I gambled and ordered the Smoked Salmon cakes (I am really, really particular about my seafood). They were also very good. On this visit, I also decided to take a look at the wine menu; after all, Unwined refers to itself as \""Ahwatukee's premier wine bar.\"" I was underwhelmed. But given the lack of good restaurants in the foothills area, it was a relief to know the food was good, even if the menu is very limited. So I knew I'd be back another time to try their other offerings.\n\nSo flash forward to mid-December (four weeks ago). I accompanied some friends to Unwined for dinner. They'd never been before, so I recommended the burger and the salmon cakes, for sure. I decided to try the warm kale salad. Except for the salmon cakes, the meal was a bust. The burgers were way overcooked and dry (which made me embarrassed and disappointed for recommending them), and the kale salad made us laugh. It wasn't \""warm\"" in the least, but the funny part was how enormous the pieces of kale were. I mean huge. As happens sadly often in restaurants, the salad appeared to be made by someone who doesn't even eat kale. My friends and I agreed that the vinaigrette was tasty, but I had to cut up all of the leaves up so I'd have a chance of getting one in my mouth. Honestly, I felt silly. We had fun catching up with one another, but none of our party is ready to go running back to Unwined to eat.\n\nRemember my comment about the bad bartender and the tinctures? Now I get to the part about Unwined that really, really bugs me. First of all, they have no business trying to sell craft cocktails, because as far as I can tell, this place does not employ a qualified bartender. I am crazy about top-notch cocktails (Hello, Citizen Public House!), but I'm equally happy with a simple glass of Woodford. Unwined needs to either stick to the basics or invest in a real bartender--not the manager making weird Old Fashioneds, and whatever you do, don't order the Sazerac. It is downright terrible. \n\nNow, I'm not exactly an oenophile, but I do know quite a bit about wine, the characteristics of varietals, etc. Let me just put it this way: the wine list here is unimaginative. For example, the whites by the glass--there are 14 of them BUT it's so pedestrian and formulaic. 3 cheap sparklers, 3 chardonnays, 3 sauvignon blancs, 2 pinot grigios. You get the picture. As another reviewer pointed out, the manager likes to put on a big show about \""educating\"" his patrons about wine. Maybe he thinks we're kindergartners? *sigh* Even the reds were nothing to write home about. I can appreciate owners not wanting to lose money on uncorking an expensive bottle for what may only be one or two pours, but the selections here are either $5, $7 or $8, with one red at $9/glass. I love a great, inexpensive find as much as the next person, but every wine I've tried so far has been just ok, which is what I would expect, given these prices. There's nothing wrong with \""ok,\"" but this place is pretending to be something it's not. The owners need to go hang out and take notes at places like Postino, Cheuvront, Kazimierz, 5th and Wine, Vintage 95, Terroir Wine Pub (especially the latter, since it's very close in spirit to Unwined), and so on. Until then, I fear that Unwined's reach will continue to exceed its grasp."


## 预处理数据

下载数据集到本地后，使用 Tokenizer 来处理文本，对于长度不等的输入数据，可以使用填充（padding）和截断（truncation）策略来处理。

Datasets 的 `map` 方法，支持一次性在整个数据集上应用预处理函数。

下面使用填充到最大长度的策略，处理整个数据集：

In [8]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")


def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True)


tokenized_datasets = dataset.map(tokenize_function, batched=True)



In [9]:
show_random_elements(tokenized_datasets["train"], num_examples=1)

Unnamed: 0,label,text,input_ids,token_type_ids,attention_mask
0,3 stars,"Underwhelmed. Confusing layout. Where do you start and where do you pay? Why aren't there plates accessible next to the hot food items? Do we have a server who comes to refill our drinks or do we get them ourselves? The cashier said a server would get them. When a server never presented themselves to us, we got them ourselves. Fucking send somebody around to clear our plates once in a while because if I have to ask for someone to do this then I'm not dropping a tip. Here's your tip, \""Be attentive.\""\n\nTwo adults, three kids, three drinks = $43 bucks. That's not bad. It's like eight bucks and some change per person. \n\nThe salad bar is basically what you're paying for. Good, fresh ingredients. The 'Greek' dressing was really good. A ton of variety. The hot food,...eh. Nothing outstanding. All the bread, even the freshly stocked items, seemed stale and bland.\n\nI'll probably come again but maybe only for lunch.","[101, 2831, 2246, 18809, 4611, 119, 16752, 14703, 4253, 9726, 119, 2777, 1202, 1128, 1838, 1105, 1187, 1202, 1128, 2653, 136, 2009, 4597, 112, 189, 1175, 7463, 7385, 1397, 1106, 1103, 2633, 2094, 4454, 136, 2091, 1195, 1138, 170, 9770, 1150, 2502, 1106, 1231, 18591, 1412, 8898, 1137, 1202, 1195, 1243, 1172, 9655, 136, 1109, 5948, 2852, 1163, 170, 9770, 1156, 1243, 1172, 119, 1332, 170, 9770, 1309, 2756, 2310, 1106, 1366, 117, 1195, 1400, 1172, 9655, 119, 10259, 1158, 3952, 9994, 1213, 1106, 2330, 1412, 7463, 1517, 1107, 170, 1229, 1272, 1191, 146, 1138, 1106, 2367, 1111, 1800, 1106, ...]","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...]","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...]"


### 数据抽样

使用 1000 个数据样本，在 BERT 上演示小规模训练（基于 Pytorch Trainer）

`shuffle()`函数会随机重新排列列的值。如果您希望对用于洗牌数据集的算法有更多控制，可以在此函数中指定generator参数来使用不同的numpy.random.Generator。

In [10]:
small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(1000))
small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(1000))

## 微调训练配置

### 加载 BERT 模型

警告通知我们正在丢弃一些权重（`vocab_transform` 和 `vocab_layer_norm` 层），并随机初始化其他一些权重（`pre_classifier` 和 `classifier` 层）。在微调模型情况下是绝对正常的，因为我们正在删除用于预训练模型的掩码语言建模任务的头部，并用一个新的头部替换它，对于这个新头部，我们没有预训练的权重，所以库会警告我们在用它进行推理之前应该对这个模型进行微调，而这正是我们要做的事情。

In [11]:
from transformers import AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", num_labels=5)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [12]:
from datasets import load_dataset
dataset = load_dataset("yelp_review_full")

### 训练超参数（TrainingArguments）

完整配置参数与默认值：https://huggingface.co/docs/transformers/v4.36.1/en/main_classes/trainer#transformers.TrainingArguments

源代码定义：https://github.com/huggingface/transformers/blob/v4.36.1/src/transformers/training_args.py#L161

**最重要配置：模型权重保存路径(output_dir)**

In [13]:
from transformers import TrainingArguments

model_dir = "models/bert-base-cased-finetune-yelp"

# logging_steps 默认值为500，根据我们的训练数据和步长，将其设置为100
training_args = TrainingArguments(output_dir=model_dir,
                                  per_device_train_batch_size=16,
                                  num_train_epochs=5,
                                  logging_steps=100)

In [14]:
# 完整的超参数配置
print(training_args)

TrainingArguments(
_n_gpu=2,
adafactor=False,
adam_beta1=0.9,
adam_beta2=0.999,
adam_epsilon=1e-08,
auto_find_batch_size=False,
bf16=False,
bf16_full_eval=False,
data_seed=None,
dataloader_drop_last=False,
dataloader_num_workers=0,
dataloader_persistent_workers=False,
dataloader_pin_memory=True,
ddp_backend=None,
ddp_broadcast_buffers=None,
ddp_bucket_cap_mb=None,
ddp_find_unused_parameters=None,
ddp_timeout=1800,
debug=[],
deepspeed=None,
disable_tqdm=False,
dispatch_batches=None,
do_eval=False,
do_predict=False,
do_train=False,
eval_accumulation_steps=None,
eval_delay=0,
eval_steps=None,
evaluation_strategy=no,
fp16=False,
fp16_backend=auto,
fp16_full_eval=False,
fp16_opt_level=O1,
fsdp=[],
fsdp_config={'min_num_params': 0, 'xla': False, 'xla_fsdp_grad_ckpt': False},
fsdp_min_num_params=0,
fsdp_transformer_layer_cls_to_wrap=None,
full_determinism=False,
gradient_accumulation_steps=1,
gradient_checkpointing=False,
gradient_checkpointing_kwargs=None,
greater_is_better=None,
group_by_le

### 训练过程中的指标评估（Evaluate)

**[Hugging Face Evaluate 库](https://huggingface.co/docs/evaluate/index)** 支持使用一行代码，获得数十种不同领域（自然语言处理、计算机视觉、强化学习等）的评估方法。 当前支持 **完整评估指标：https://huggingface.co/evaluate-metric**

训练器（Trainer）在训练过程中不会自动评估模型性能。因此，我们需要向训练器传递一个函数来计算和报告指标。 

Evaluate库提供了一个简单的准确率函数，您可以使用`evaluate.load`函数加载

In [15]:
import sys
print(sys.path)

['/home/lande/anaconda3/envs/llm_train/lib/python310.zip', '/home/lande/anaconda3/envs/llm_train/lib/python3.10', '/home/lande/anaconda3/envs/llm_train/lib/python3.10/lib-dynload', '', '/home/lande/.local/lib/python3.10/site-packages', '/home/lande/anaconda3/envs/llm_train/lib/python3.10/site-packages', '/home/lande/anaconda3/envs/llm_train/lib/python3.10/site-packages/setuptools/_vendor', '/tmp/tmpfjz968k5']


In [None]:
import numpy as np
import evaluate

metric = evaluate.load("accuracy") 

FileNotFoundError: Couldn't find a module script at /home/train/LLM-quickstart/transf_shy/accuracy/accuracy.py. Module 'accuracy' doesn't exist on the Hugging Face Hub either.


接着，调用 `compute` 函数来计算预测的准确率。

在将预测传递给 compute 函数之前，我们需要将 logits 转换为预测值（**所有Transformers 模型都返回 logits**）。

In [24]:
def compute_accuracy(predictions, references):
    """
    计算预测与真实标签的准确率
    :param predictions: list 或 np.array, 存放预测标签
    :param references:  list 或 np.array, 存放真实标签
    :return: float, 准确率
    """
    predictions = np.array(predictions)
    references = np.array(references)
    return np.mean(predictions == references)

In [29]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    accuracy = (predictions == labels).mean()
    return {
        "accuracy": accuracy
    }

#### 训练过程指标监控

通常，为了监控训练过程中的评估指标变化，我们可以在`TrainingArguments`指定`evaluation_strategy`参数，以便在 epoch 结束时报告评估指标。

In [30]:
from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(output_dir=model_dir,
                                  evaluation_strategy="epoch", 
                                  per_device_train_batch_size=16,
                                  num_train_epochs=3,
                                  logging_steps=30)

## 开始训练

### 实例化训练器（Trainer）

`kernel version` 版本问题：暂不影响本示例代码运行

In [31]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=small_train_dataset,
    eval_dataset=small_eval_dataset,
    compute_metrics=compute_metrics,
)

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


## 使用 nvidia-smi 查看 GPU 使用

为了实时查看GPU使用情况，可以使用 `watch` 指令实现轮询：`watch -n 1 nvidia-smi`:

```shell
Every 1.0s: nvidia-smi                                                   Wed Dec 20 14:37:41 2023

Wed Dec 20 14:37:41 2023
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.129.03             Driver Version: 535.129.03   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  Tesla T4                       Off | 00000000:00:0D.0 Off |                    0 |
| N/A   64C    P0              69W /  70W |   6665MiB / 15360MiB |     98%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A     18395      C   /root/miniconda3/bin/python                6660MiB |
+---------------------------------------------------------------------------------------+
```

In [50]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
0,0.3786,0.904541,0.668


KeyboardInterrupt: 

In [33]:
small_test_dataset = tokenized_datasets["test"].shuffle(seed=64).select(range(100))

In [34]:
trainer.evaluate(small_test_dataset)

{'eval_loss': 1.3171261548995972,
 'eval_accuracy': 0.55,
 'eval_runtime': 2.9366,
 'eval_samples_per_second': 34.053,
 'eval_steps_per_second': 2.384,
 'epoch': 3.0}

### 保存模型和训练状态

- 使用 `trainer.save_model` 方法保存模型，后续可以通过 from_pretrained() 方法重新加载
- 使用 `trainer.save_state` 方法保存训练状态

In [20]:
trainer.save_model(model_dir)

In [21]:
trainer.save_state()

In [23]:
# trainer.model.save_pretrained("./")

## Homework: 使用完整的 YelpReviewFull 数据集训练，看 Acc 最高能到多少

In [35]:
train_dataset = tokenized_datasets["train"]
eval_dataset = tokenized_datasets["test"]

In [47]:
training_args = TrainingArguments(output_dir=model_dir,
                                  evaluation_strategy="epoch", 
                                  per_device_train_batch_size=32,
                                  num_train_epochs=5,
                                  logging_steps=30)

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    compute_metrics=compute_metrics,
)

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [49]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,0.7228,0.71442,0.68672
2,0.6673,0.712222,0.68868
3,0.5701,0.71841,0.69354
4,0.4818,0.78824,0.68768
5,0.3655,0.912226,0.68186




TrainOutput(global_step=50785, training_loss=0.5735644983214822, metrics={'train_runtime': 54199.1009, 'train_samples_per_second': 59.964, 'train_steps_per_second': 0.937, 'total_flos': 8.55133963008e+17, 'train_loss': 0.5735644983214822, 'epoch': 5.0})

In [52]:
small_test_dataset = tokenized_datasets["test"].shuffle(seed=64).select(range(1000))


In [53]:
trainer.evaluate(small_test_dataset)

{'eval_loss': 0.9045406579971313, 'eval_accuracy': 0.668}

In [54]:
model_dir = "models/save_model"
trainer.save_model(model_dir)

In [55]:
trainer.save_state()