In [1]:
from datasets import load_dataset
import random
import pandas as pd
import datasets
from IPython.display import display, HTML
from transformers import AutoTokenizer

In [2]:
dataset = load_dataset("yelp_review_full")

In [3]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")


def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True)


tokenized_datasets = dataset.map(tokenize_function, batched=True)

In [4]:
train_dataset = tokenized_datasets["train"].shuffle(seed=42)
eval_dataset = tokenized_datasets["test"].shuffle(seed=42)

In [5]:
from transformers import AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", num_labels=5)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [6]:
from transformers import TrainingArguments

model_dir = "models/bert-base-cased-finetune-yelp"

# logging_steps 默认值为500，根据我们的训练数据和步长，将其设置为100
training_args = TrainingArguments(output_dir=model_dir,
                                  per_device_train_batch_size=20,
                                  num_train_epochs=1,
                                  logging_steps=200)

In [7]:
import numpy as np
import evaluate

metric = evaluate.load("accuracy")

2024-03-26 19:34:45.448468: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-03-26 19:34:45.498348: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [8]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

In [9]:
from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(output_dir=model_dir,
                                  evaluation_strategy="epoch", 
                                  per_device_train_batch_size=20,
                                  num_train_epochs=1,
                                  logging_steps=200)

In [10]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    compute_metrics=compute_metrics,
)

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [11]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,0.7035,0.696733,0.69306


IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



TrainOutput(global_step=32500, training_loss=0.7720525390625, metrics={'train_runtime': 58234.5217, 'train_samples_per_second': 11.162, 'train_steps_per_second': 0.558, 'total_flos': 1.710267926016e+17, 'train_loss': 0.7720525390625, 'epoch': 1.0})

In [13]:
test_dataset = tokenized_datasets["test"].shuffle(seed=64)

In [14]:
trainer.evaluate(test_dataset)

{'eval_loss': 0.696733295917511,
 'eval_accuracy': 0.69306,
 'eval_runtime': 1577.7537,
 'eval_samples_per_second': 31.691,
 'eval_steps_per_second': 3.961,
 'epoch': 1.0}

In [None]:
trainer.save_model(model_dir)

In [None]:
trainer.save_state()