# 文本分类实例

## Step1 导入相关包

In [1]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
from datasets import load_dataset

  from .autonotebook import tqdm as notebook_tqdm


## Step2 加载数据集

In [2]:
dataset = load_dataset("csv", data_files="./ChnSentiCorp_htl_all.csv", split="train")
dataset = dataset.filter(lambda x: x["review"] is not None)
dataset

Dataset({
    features: ['label', 'review'],
    num_rows: 7765
})

## Step4 划分数据集

In [3]:
datasets = dataset.train_test_split(test_size=0.1)
datasets

DatasetDict({
    train: Dataset({
        features: ['label', 'review'],
        num_rows: 6988
    })
    test: Dataset({
        features: ['label', 'review'],
        num_rows: 777
    })
})

## Step5 结合分词器创建Dataloader

In [14]:
import torch

tokenizer = AutoTokenizer.from_pretrained("hfl/rbt3")

def process_function(examples):
    tokenized_examples = tokenizer(examples["review"], max_length=128, truncation=True)
    tokenized_examples["labels"] = examples["label"]
    return tokenized_examples

tokenized_datasets = datasets.map(process_function, batched=True, remove_columns=datasets["train"].column_names)
tokenized_datasets

: 

## Step6 创建模型

In [5]:
model = AutoModelForSequenceClassification.from_pretrained("hfl/rbt3")

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at hfl/rbt3 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


## 7 创建评估函数

In [6]:
import evaluate
acc_metric  = evaluate.load("accuracy")
f1_metric  = evaluate.load("f1")

In [7]:

def eval_metric(eval_predict: tuple):
    
    predictions, labels = eval_predict
    predictions = predictions.argmax(axis=-1)
    acc = acc_metric.compute(predictions=predictions, references=labels)
    f1 = f1_metric.compute(predictions=predictions, references=labels)
    acc.update(f1)
    return acc


## Step8  创建TrainingArguments


In [8]:
train_args = TrainingArguments(output_dir="./checkpoints",      # 输出文件夹
                               per_device_train_batch_size=64,  # 训练时的batch_size
                               per_device_eval_batch_size=128,  # 验证时的batch_size
                               logging_steps=11000,                # log 打印的频率
                               evaluation_strategy="epoch",     # 评估策略
                               save_strategy="epoch",           # 保存策略
                               save_total_limit=3,              # 最大保存数
                               learning_rate=2e-5,              # 学习率
                               weight_decay=0.01,               # weight_decay
                               metric_for_best_model="f1",      # 设定评估指标
                               load_best_model_at_end=True)     # 训练完成后加载最优模型
train_args



TrainingArguments(
_n_gpu=1,
accelerator_config={'split_batches': False, 'dispatch_batches': None, 'even_batches': True, 'use_seedable_sampler': True, 'non_blocking': False, 'gradient_accumulation_kwargs': None, 'use_configured_state': False},
adafactor=False,
adam_beta1=0.9,
adam_beta2=0.999,
adam_epsilon=1e-08,
auto_find_batch_size=False,
batch_eval_metrics=False,
bf16=False,
bf16_full_eval=False,
data_seed=None,
dataloader_drop_last=False,
dataloader_num_workers=0,
dataloader_persistent_workers=False,
dataloader_pin_memory=True,
dataloader_prefetch_factor=2,
ddp_backend=None,
ddp_broadcast_buffers=None,
ddp_bucket_cap_mb=None,
ddp_find_unused_parameters=None,
ddp_timeout=1800,
debug=[],
deepspeed=None,
disable_tqdm=False,
dispatch_batches=None,
do_eval=True,
do_predict=False,
do_train=False,
eval_accumulation_steps=None,
eval_delay=0,
eval_do_concat_batches=True,
eval_on_start=False,
eval_steps=None,
eval_strategy=epoch,
eval_use_gather_object=False,
evaluation_strategy=epoch,
fp16=

## Step9 创建Trainer

In [9]:
from transformers import DataCollatorWithPadding
trainer = Trainer(model=model, 
                  args=train_args, 
                  train_dataset=tokenized_datasets["train"], 
                  eval_dataset=tokenized_datasets["test"], 
                  data_collator=DataCollatorWithPadding(tokenizer=tokenizer),
                  compute_metrics=eval_metric)
trainer.train()

  3%|▎         | 10/330 [00:05<01:34,  3.40it/s]

{'loss': 0.6306, 'grad_norm': 1.5908437967300415, 'learning_rate': 1.9393939393939395e-05, 'epoch': 0.09}


  6%|▌         | 20/330 [00:07<01:16,  4.04it/s]

{'loss': 0.5495, 'grad_norm': 4.218531608581543, 'learning_rate': 1.8787878787878792e-05, 'epoch': 0.18}


  9%|▉         | 30/330 [00:10<01:13,  4.07it/s]

{'loss': 0.5029, 'grad_norm': 2.552896499633789, 'learning_rate': 1.8181818181818182e-05, 'epoch': 0.27}


 12%|█▏        | 40/330 [00:12<01:11,  4.07it/s]

{'loss': 0.415, 'grad_norm': 4.112495422363281, 'learning_rate': 1.7575757575757576e-05, 'epoch': 0.36}


 15%|█▌        | 50/330 [00:15<01:09,  4.01it/s]

{'loss': 0.4114, 'grad_norm': 8.104530334472656, 'learning_rate': 1.6969696969696972e-05, 'epoch': 0.45}


 18%|█▊        | 60/330 [00:17<01:06,  4.04it/s]

{'loss': 0.3596, 'grad_norm': 6.4954376220703125, 'learning_rate': 1.6363636363636366e-05, 'epoch': 0.55}


 21%|██        | 70/330 [00:20<01:03,  4.12it/s]

{'loss': 0.3217, 'grad_norm': 2.6295719146728516, 'learning_rate': 1.575757575757576e-05, 'epoch': 0.64}


 24%|██▍       | 80/330 [00:22<01:00,  4.12it/s]

{'loss': 0.3227, 'grad_norm': 2.461909532546997, 'learning_rate': 1.5151515151515153e-05, 'epoch': 0.73}


 27%|██▋       | 90/330 [00:25<01:00,  3.98it/s]

{'loss': 0.3302, 'grad_norm': 2.906282663345337, 'learning_rate': 1.4545454545454546e-05, 'epoch': 0.82}


 30%|███       | 100/330 [00:27<00:56,  4.10it/s]

{'loss': 0.3571, 'grad_norm': 5.455808639526367, 'learning_rate': 1.3939393939393942e-05, 'epoch': 0.91}


 33%|███▎      | 110/330 [00:29<00:53,  4.13it/s]

{'loss': 0.3347, 'grad_norm': 5.3201775550842285, 'learning_rate': 1.3333333333333333e-05, 'epoch': 1.0}


                                                 
 33%|███▎      | 110/330 [00:30<00:53,  4.13it/s]

{'eval_loss': 0.2672317326068878, 'eval_accuracy': 0.888030888030888, 'eval_f1': 0.9162656400384985, 'eval_runtime': 1.0331, 'eval_samples_per_second': 752.109, 'eval_steps_per_second': 6.776, 'epoch': 1.0}


 36%|███▋      | 120/330 [00:33<00:54,  3.86it/s]

{'loss': 0.3225, 'grad_norm': 4.982728004455566, 'learning_rate': 1.2727272727272728e-05, 'epoch': 1.09}


 39%|███▉      | 130/330 [00:36<00:49,  4.04it/s]

{'loss': 0.2851, 'grad_norm': 2.8301961421966553, 'learning_rate': 1.2121212121212122e-05, 'epoch': 1.18}


 42%|████▏     | 140/330 [00:38<00:48,  3.95it/s]

{'loss': 0.2548, 'grad_norm': 3.7482922077178955, 'learning_rate': 1.1515151515151517e-05, 'epoch': 1.27}


 45%|████▌     | 150/330 [00:41<00:44,  4.05it/s]

{'loss': 0.2919, 'grad_norm': 2.094907760620117, 'learning_rate': 1.0909090909090909e-05, 'epoch': 1.36}


 48%|████▊     | 160/330 [00:43<00:42,  3.97it/s]

{'loss': 0.2793, 'grad_norm': 2.38272762298584, 'learning_rate': 1.0303030303030304e-05, 'epoch': 1.45}


 52%|█████▏    | 170/330 [00:46<00:41,  3.84it/s]

{'loss': 0.2831, 'grad_norm': 5.034816741943359, 'learning_rate': 9.696969696969698e-06, 'epoch': 1.55}


 55%|█████▍    | 180/330 [00:49<00:37,  3.98it/s]

{'loss': 0.2202, 'grad_norm': 2.7967710494995117, 'learning_rate': 9.090909090909091e-06, 'epoch': 1.64}


 58%|█████▊    | 190/330 [00:51<00:34,  4.06it/s]

{'loss': 0.282, 'grad_norm': 2.7420616149902344, 'learning_rate': 8.484848484848486e-06, 'epoch': 1.73}


 61%|██████    | 200/330 [00:54<00:32,  4.02it/s]

{'loss': 0.275, 'grad_norm': 2.921572208404541, 'learning_rate': 7.87878787878788e-06, 'epoch': 1.82}


 64%|██████▎   | 210/330 [00:56<00:29,  4.13it/s]

{'loss': 0.247, 'grad_norm': 2.9077844619750977, 'learning_rate': 7.272727272727273e-06, 'epoch': 1.91}


 67%|██████▋   | 220/330 [00:58<00:27,  3.97it/s]

{'loss': 0.2663, 'grad_norm': 5.955989360809326, 'learning_rate': 6.666666666666667e-06, 'epoch': 2.0}


                                                 
 67%|██████▋   | 220/330 [00:59<00:27,  3.97it/s]

{'eval_loss': 0.2458522468805313, 'eval_accuracy': 0.8983268983268984, 'eval_f1': 0.9258215962441314, 'eval_runtime': 1.0626, 'eval_samples_per_second': 731.228, 'eval_steps_per_second': 6.588, 'epoch': 2.0}


 70%|██████▉   | 230/330 [01:02<00:26,  3.82it/s]

{'loss': 0.2185, 'grad_norm': 1.9856994152069092, 'learning_rate': 6.060606060606061e-06, 'epoch': 2.09}


 73%|███████▎  | 240/330 [01:05<00:23,  3.88it/s]

{'loss': 0.2527, 'grad_norm': 3.5551598072052, 'learning_rate': 5.4545454545454545e-06, 'epoch': 2.18}


 76%|███████▌  | 250/330 [01:07<00:20,  3.98it/s]

{'loss': 0.2442, 'grad_norm': 2.772078514099121, 'learning_rate': 4.848484848484849e-06, 'epoch': 2.27}


 79%|███████▉  | 260/330 [01:10<00:17,  4.02it/s]

{'loss': 0.229, 'grad_norm': 4.36831521987915, 'learning_rate': 4.242424242424243e-06, 'epoch': 2.36}


 82%|████████▏ | 270/330 [01:12<00:14,  4.08it/s]

{'loss': 0.2182, 'grad_norm': 3.9568419456481934, 'learning_rate': 3.6363636363636366e-06, 'epoch': 2.45}


 85%|████████▍ | 280/330 [01:15<00:13,  3.80it/s]

{'loss': 0.2294, 'grad_norm': 2.5302734375, 'learning_rate': 3.0303030303030305e-06, 'epoch': 2.55}


 88%|████████▊ | 290/330 [01:17<00:10,  3.94it/s]

{'loss': 0.2743, 'grad_norm': 3.133669376373291, 'learning_rate': 2.4242424242424244e-06, 'epoch': 2.64}


 91%|█████████ | 300/330 [01:20<00:07,  3.93it/s]

{'loss': 0.2429, 'grad_norm': 4.143651008605957, 'learning_rate': 1.8181818181818183e-06, 'epoch': 2.73}


 94%|█████████▍| 310/330 [01:22<00:04,  4.06it/s]

{'loss': 0.2115, 'grad_norm': 1.7897974252700806, 'learning_rate': 1.2121212121212122e-06, 'epoch': 2.82}


 97%|█████████▋| 320/330 [01:25<00:02,  3.85it/s]

{'loss': 0.2376, 'grad_norm': 2.7048375606536865, 'learning_rate': 6.060606060606061e-07, 'epoch': 2.91}


100%|██████████| 330/330 [01:27<00:00,  4.01it/s]

{'loss': 0.2422, 'grad_norm': 2.6060872077941895, 'learning_rate': 0.0, 'epoch': 3.0}


                                                 
100%|██████████| 330/330 [01:29<00:00,  4.01it/s]

{'eval_loss': 0.24387605488300323, 'eval_accuracy': 0.9009009009009009, 'eval_f1': 0.927699530516432, 'eval_runtime': 1.0339, 'eval_samples_per_second': 751.539, 'eval_steps_per_second': 6.771, 'epoch': 3.0}


100%|██████████| 330/330 [01:29<00:00,  3.67it/s]

{'train_runtime': 89.8259, 'train_samples_per_second': 233.385, 'train_steps_per_second': 3.674, 'train_loss': 0.30736470222473145, 'epoch': 3.0}





TrainOutput(global_step=330, training_loss=0.30736470222473145, metrics={'train_runtime': 89.8259, 'train_samples_per_second': 233.385, 'train_steps_per_second': 3.674, 'total_flos': 351909933963264.0, 'train_loss': 0.30736470222473145, 'epoch': 3.0})

## Step10 模型训练

In [10]:
trainer.train()

  3%|▎         | 10/330 [00:02<01:21,  3.92it/s]

{'loss': 0.2886, 'grad_norm': 4.049503803253174, 'learning_rate': 1.9393939393939395e-05, 'epoch': 0.09}


  6%|▌         | 20/330 [00:05<01:18,  3.94it/s]

{'loss': 0.1943, 'grad_norm': 5.238736152648926, 'learning_rate': 1.8787878787878792e-05, 'epoch': 0.18}


  9%|▉         | 30/330 [00:07<01:16,  3.91it/s]

{'loss': 0.2156, 'grad_norm': 2.3971447944641113, 'learning_rate': 1.8181818181818182e-05, 'epoch': 0.27}


 12%|█▏        | 40/330 [00:10<01:14,  3.90it/s]

{'loss': 0.2512, 'grad_norm': 3.255762815475464, 'learning_rate': 1.7575757575757576e-05, 'epoch': 0.36}


 15%|█▌        | 50/330 [00:12<01:09,  4.03it/s]

{'loss': 0.2672, 'grad_norm': 5.978575706481934, 'learning_rate': 1.6969696969696972e-05, 'epoch': 0.45}


 18%|█▊        | 60/330 [00:15<01:07,  4.03it/s]

{'loss': 0.241, 'grad_norm': 5.531741619110107, 'learning_rate': 1.6363636363636366e-05, 'epoch': 0.55}


 21%|██        | 70/330 [00:17<01:04,  4.00it/s]

{'loss': 0.1965, 'grad_norm': 1.9573217630386353, 'learning_rate': 1.575757575757576e-05, 'epoch': 0.64}


 24%|██▍       | 80/330 [00:20<01:03,  3.92it/s]

{'loss': 0.2336, 'grad_norm': 2.9701223373413086, 'learning_rate': 1.5151515151515153e-05, 'epoch': 0.73}


 27%|██▋       | 90/330 [00:22<00:59,  4.04it/s]

{'loss': 0.2133, 'grad_norm': 2.6241960525512695, 'learning_rate': 1.4545454545454546e-05, 'epoch': 0.82}


 30%|███       | 100/330 [00:25<00:59,  3.88it/s]

{'loss': 0.2415, 'grad_norm': 3.7757151126861572, 'learning_rate': 1.3939393939393942e-05, 'epoch': 0.91}


 33%|███▎      | 110/330 [00:27<00:55,  3.95it/s]

{'loss': 0.2174, 'grad_norm': 2.2507336139678955, 'learning_rate': 1.3333333333333333e-05, 'epoch': 1.0}


                                                 
 33%|███▎      | 110/330 [00:28<00:55,  3.95it/s]

{'eval_loss': 0.23221652209758759, 'eval_accuracy': 0.8983268983268984, 'eval_f1': 0.9248334919124643, 'eval_runtime': 1.0795, 'eval_samples_per_second': 719.799, 'eval_steps_per_second': 6.485, 'epoch': 1.0}


 36%|███▋      | 120/330 [00:31<00:55,  3.75it/s]

{'loss': 0.2301, 'grad_norm': 3.6757736206054688, 'learning_rate': 1.2727272727272728e-05, 'epoch': 1.09}


 39%|███▉      | 130/330 [00:34<00:49,  4.01it/s]

{'loss': 0.19, 'grad_norm': 3.8366353511810303, 'learning_rate': 1.2121212121212122e-05, 'epoch': 1.18}


 42%|████▏     | 140/330 [00:36<00:48,  3.88it/s]

{'loss': 0.174, 'grad_norm': 2.8957276344299316, 'learning_rate': 1.1515151515151517e-05, 'epoch': 1.27}


 45%|████▌     | 150/330 [00:39<00:45,  4.00it/s]

{'loss': 0.1926, 'grad_norm': 2.172226905822754, 'learning_rate': 1.0909090909090909e-05, 'epoch': 1.36}


 48%|████▊     | 160/330 [00:41<00:43,  3.90it/s]

{'loss': 0.1927, 'grad_norm': 2.379979133605957, 'learning_rate': 1.0303030303030304e-05, 'epoch': 1.45}


 52%|█████▏    | 170/330 [00:44<00:39,  4.00it/s]

{'loss': 0.1996, 'grad_norm': 3.413879156112671, 'learning_rate': 9.696969696969698e-06, 'epoch': 1.55}


 55%|█████▍    | 180/330 [00:46<00:36,  4.16it/s]

{'loss': 0.178, 'grad_norm': 3.8718998432159424, 'learning_rate': 9.090909090909091e-06, 'epoch': 1.64}


 58%|█████▊    | 190/330 [00:49<00:35,  3.95it/s]

{'loss': 0.1926, 'grad_norm': 3.2690370082855225, 'learning_rate': 8.484848484848486e-06, 'epoch': 1.73}


 61%|██████    | 200/330 [00:51<00:32,  4.04it/s]

{'loss': 0.1922, 'grad_norm': 2.583622694015503, 'learning_rate': 7.87878787878788e-06, 'epoch': 1.82}


 64%|██████▎   | 210/330 [00:54<00:30,  3.88it/s]

{'loss': 0.1972, 'grad_norm': 4.173540115356445, 'learning_rate': 7.272727272727273e-06, 'epoch': 1.91}


 67%|██████▋   | 220/330 [00:56<00:28,  3.82it/s]

{'loss': 0.1822, 'grad_norm': 1.7100354433059692, 'learning_rate': 6.666666666666667e-06, 'epoch': 2.0}


                                                 
 67%|██████▋   | 220/330 [00:58<00:28,  3.82it/s]

{'eval_loss': 0.23727844655513763, 'eval_accuracy': 0.9034749034749034, 'eval_f1': 0.92970946579194, 'eval_runtime': 1.1008, 'eval_samples_per_second': 705.871, 'eval_steps_per_second': 6.359, 'epoch': 2.0}


 70%|██████▉   | 230/330 [01:01<00:26,  3.78it/s]

{'loss': 0.1405, 'grad_norm': 3.286794662475586, 'learning_rate': 6.060606060606061e-06, 'epoch': 2.09}


 73%|███████▎  | 240/330 [01:03<00:22,  4.07it/s]

{'loss': 0.1776, 'grad_norm': 1.6305015087127686, 'learning_rate': 5.4545454545454545e-06, 'epoch': 2.18}


 76%|███████▌  | 250/330 [01:06<00:20,  3.84it/s]

{'loss': 0.1768, 'grad_norm': 2.555159568786621, 'learning_rate': 4.848484848484849e-06, 'epoch': 2.27}


 79%|███████▉  | 260/330 [01:08<00:17,  4.03it/s]

{'loss': 0.1691, 'grad_norm': 7.0657477378845215, 'learning_rate': 4.242424242424243e-06, 'epoch': 2.36}


 82%|████████▏ | 270/330 [01:11<00:14,  4.05it/s]

{'loss': 0.159, 'grad_norm': 4.001041412353516, 'learning_rate': 3.6363636363636366e-06, 'epoch': 2.45}


 85%|████████▍ | 280/330 [01:13<00:12,  4.07it/s]

{'loss': 0.1557, 'grad_norm': 5.022171974182129, 'learning_rate': 3.0303030303030305e-06, 'epoch': 2.55}


 88%|████████▊ | 290/330 [01:16<00:10,  3.94it/s]

{'loss': 0.2094, 'grad_norm': 3.366044759750366, 'learning_rate': 2.4242424242424244e-06, 'epoch': 2.64}


 91%|█████████ | 300/330 [01:18<00:07,  3.97it/s]

{'loss': 0.1652, 'grad_norm': 5.685101509094238, 'learning_rate': 1.8181818181818183e-06, 'epoch': 2.73}


 94%|█████████▍| 310/330 [01:21<00:04,  4.17it/s]

{'loss': 0.1373, 'grad_norm': 2.3888261318206787, 'learning_rate': 1.2121212121212122e-06, 'epoch': 2.82}


 97%|█████████▋| 320/330 [01:23<00:02,  4.12it/s]

{'loss': 0.1561, 'grad_norm': 2.84383225440979, 'learning_rate': 6.060606060606061e-07, 'epoch': 2.91}


100%|██████████| 330/330 [01:25<00:00,  4.03it/s]

{'loss': 0.1518, 'grad_norm': 2.607919931411743, 'learning_rate': 0.0, 'epoch': 3.0}


                                                 
100%|██████████| 330/330 [01:27<00:00,  4.03it/s]

{'eval_loss': 0.24040617048740387, 'eval_accuracy': 0.9073359073359073, 'eval_f1': 0.9324577861163227, 'eval_runtime': 1.1149, 'eval_samples_per_second': 696.907, 'eval_steps_per_second': 6.278, 'epoch': 3.0}


100%|██████████| 330/330 [01:27<00:00,  3.75it/s]

{'train_runtime': 87.9217, 'train_samples_per_second': 238.439, 'train_steps_per_second': 3.753, 'train_loss': 0.19635908423048076, 'epoch': 3.0}





TrainOutput(global_step=330, training_loss=0.19635908423048076, metrics={'train_runtime': 87.9217, 'train_samples_per_second': 238.439, 'train_steps_per_second': 3.753, 'total_flos': 351909933963264.0, 'train_loss': 0.19635908423048076, 'epoch': 3.0})

## Step11 模型评估


In [11]:
trainer.evaluate(tokenized_datasets["test"])

100%|██████████| 7/7 [00:00<00:00,  7.50it/s]


{'eval_loss': 0.24040617048740387,
 'eval_accuracy': 0.9073359073359073,
 'eval_f1': 0.9324577861163227,
 'eval_runtime': 1.0931,
 'eval_samples_per_second': 710.809,
 'eval_steps_per_second': 6.404,
 'epoch': 3.0}

## Step12 模型预测


In [12]:
trainer.predict(tokenized_datasets["test"])

100%|██████████| 7/7 [00:00<00:00,  7.78it/s]


PredictionOutput(predictions=array([[-1.9528311 ,  2.7549913 ],
       [-2.52439   ,  3.4644213 ],
       [-2.3588989 ,  3.5158489 ],
       ...,
       [-2.6295562 ,  3.4607577 ],
       [-1.8692572 ,  3.2497716 ],
       [-0.25713792,  0.90523815]], dtype=float32), label_ids=array([1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1,
       1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1,
       1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0,
       0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1,
       1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0,
       1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0,
       0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0,
       0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0,
       1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1

In [13]:
from transformers import pipeline

id2_label = id2_label = {0: "差评！", 1: "好评！"}
model.config.id2label = id2_label
pipe = pipeline("text-classification", model=model, tokenizer=tokenizer, device=0)
sen = "我觉得不错！"
pipe(sen)

[{'label': '好评！', 'score': 0.9924901127815247}]