# 文本相似度实例

## Step1 导入相关包

In [1]:
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    Trainer,
    TrainingArguments,
)
from datasets import load_dataset

  from .autonotebook import tqdm as notebook_tqdm


## Step2 加载数据集

In [2]:
dataset = load_dataset("json", data_files="./train_pair_1w.json", split="train")
dataset

Generating train split: 10000 examples [00:00, 1541796.79 examples/s]


Dataset({
    features: ['sentence1', 'sentence2', 'label'],
    num_rows: 10000
})

In [3]:
dataset[0]

{'sentence1': '找一部小时候的动画片', 'sentence2': '求一部小时候的动画片。谢了', 'label': '1'}

## Step3 划分数据集

In [4]:
datasets = dataset.train_test_split(test_size=0.2)
datasets

DatasetDict({
    train: Dataset({
        features: ['sentence1', 'sentence2', 'label'],
        num_rows: 8000
    })
    test: Dataset({
        features: ['sentence1', 'sentence2', 'label'],
        num_rows: 2000
    })
})

In [5]:
datasets["train"][0]

{'sentence1': '我手下有一个很有机谋的人，极得我的信任，是我派他去监视南方的动静的。',
 'sentence2': '我派了我手下一个很有头脑的人去南方视察了一下动态。',
 'label': '1'}

## Step4 数据集预处理

In [6]:
import torch

tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-macbert-base")


def process_function(examples):
    tokenized_examples = tokenizer(
        examples["sentence1"], examples["sentence2"], max_length=128, truncation=True
    )
    # 均方误差做loss，所以这里要float
    tokenized_examples["labels"] = [float(label) for label in examples["label"]]
    return tokenized_examples


tokenized_datasets = datasets.map(
    process_function, batched=True, remove_columns=datasets["train"].column_names
)
tokenized_datasets

Map: 100%|██████████| 8000/8000 [00:00<00:00, 29238.89 examples/s]
Map: 100%|██████████| 2000/2000 [00:00<00:00, 31511.83 examples/s]


DatasetDict({
    train: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 8000
    })
    test: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 2000
    })
})

In [7]:
print(tokenized_datasets["train"][0])

{'input_ids': [101, 2769, 2797, 678, 3300, 671, 702, 2523, 3300, 3322, 6450, 4638, 782, 8024, 3353, 2533, 2769, 4638, 928, 818, 8024, 3221, 2769, 3836, 800, 1343, 4664, 6228, 1298, 3175, 4638, 1220, 7474, 4638, 511, 102, 2769, 3836, 749, 2769, 2797, 678, 671, 702, 2523, 3300, 1928, 5554, 4638, 782, 1343, 1298, 3175, 6228, 2175, 749, 671, 678, 1220, 2578, 511, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'labels': 1.0}


## Step5 创建模型

In [8]:
from transformers import BertForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained(
    "hfl/chinese-macbert-base", num_labels=1
)  # num_labels=1，回归任务，损失用MSE

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at hfl/chinese-macbert-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


## Step6 创建评估函数

In [9]:
import evaluate

acc_metric = evaluate.load("./metric_accuracy.py")
f1_metirc = evaluate.load("./metric_f1.py")

In [10]:
def eval_metric(eval_predict):
    predictions, labels = eval_predict
    ########################33
    predictions = [int(p > 0.5) for p in predictions]
    ##############################
    labels = [int(l) for l in labels]
    # predictions = predictions.argmax(axis=-1)
    acc = acc_metric.compute(predictions=predictions, references=labels)
    f1 = f1_metirc.compute(predictions=predictions, references=labels)
    acc.update(f1)
    return acc

## Step7 创建TrainingArguments

In [11]:
train_args = TrainingArguments(
    output_dir="./cross_model",  # 输出文件夹
    per_device_train_batch_size=32,  # 训练时的batch_size
    per_device_eval_batch_size=32,  # 验证时的batch_size
    logging_steps=10,  # log 打印的频率
    eval_strategy="epoch",  # 评估策略
    save_strategy="epoch",  # 保存策略
    save_total_limit=3,  # 最大保存数
    learning_rate=2e-5,  # 学习率
    weight_decay=0.01,  # weight_decay
    metric_for_best_model="f1",  # 设定评估指标
    load_best_model_at_end=True,
)  # 训练完成后加载最优模型
train_args

TrainingArguments(
_n_gpu=1,
accelerator_config={'split_batches': False, 'dispatch_batches': None, 'even_batches': True, 'use_seedable_sampler': True, 'non_blocking': False, 'gradient_accumulation_kwargs': None, 'use_configured_state': False},
adafactor=False,
adam_beta1=0.9,
adam_beta2=0.999,
adam_epsilon=1e-08,
auto_find_batch_size=False,
batch_eval_metrics=False,
bf16=False,
bf16_full_eval=False,
data_seed=None,
dataloader_drop_last=False,
dataloader_num_workers=0,
dataloader_persistent_workers=False,
dataloader_pin_memory=True,
dataloader_prefetch_factor=None,
ddp_backend=None,
ddp_broadcast_buffers=None,
ddp_bucket_cap_mb=None,
ddp_find_unused_parameters=None,
ddp_timeout=1800,
debug=[],
deepspeed=None,
disable_tqdm=False,
dispatch_batches=None,
do_eval=True,
do_predict=False,
do_train=False,
eval_accumulation_steps=None,
eval_delay=0,
eval_do_concat_batches=True,
eval_on_start=False,
eval_steps=None,
eval_strategy=epoch,
evaluation_strategy=None,
fp16=False,
fp16_backend=auto,
fp

## Step8 创建Trainer

In [12]:
from transformers import DataCollatorWithPadding

trainer = Trainer(
    model=model,
    args=train_args,
    tokenizer=tokenizer,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["test"],
    data_collator=DataCollatorWithPadding(tokenizer=tokenizer),
    compute_metrics=eval_metric,
)

## Step9 模型训练

In [13]:
trainer.train()

  1%|▏         | 10/750 [00:03<03:33,  3.46it/s]

{'loss': 0.6207, 'grad_norm': 9.281704902648926, 'learning_rate': 1.9733333333333336e-05, 'epoch': 0.04}


  3%|▎         | 20/750 [00:05<03:17,  3.71it/s]

{'loss': 0.1761, 'grad_norm': 4.04828405380249, 'learning_rate': 1.9466666666666668e-05, 'epoch': 0.08}


  4%|▍         | 31/750 [00:08<02:58,  4.03it/s]

{'loss': 0.1551, 'grad_norm': 2.877232789993286, 'learning_rate': 1.9200000000000003e-05, 'epoch': 0.12}


  5%|▌         | 40/750 [00:11<03:08,  3.77it/s]

{'loss': 0.173, 'grad_norm': 3.44010329246521, 'learning_rate': 1.8933333333333334e-05, 'epoch': 0.16}


  7%|▋         | 50/750 [00:13<03:05,  3.77it/s]

{'loss': 0.1411, 'grad_norm': 3.816002607345581, 'learning_rate': 1.866666666666667e-05, 'epoch': 0.2}


  8%|▊         | 60/750 [00:16<03:03,  3.76it/s]

{'loss': 0.1404, 'grad_norm': 5.00953483581543, 'learning_rate': 1.8400000000000003e-05, 'epoch': 0.24}


  9%|▉         | 70/750 [00:19<03:00,  3.77it/s]

{'loss': 0.1297, 'grad_norm': 2.8031139373779297, 'learning_rate': 1.8133333333333335e-05, 'epoch': 0.28}


 11%|█         | 80/750 [00:21<02:58,  3.74it/s]

{'loss': 0.1434, 'grad_norm': 6.376748085021973, 'learning_rate': 1.7866666666666666e-05, 'epoch': 0.32}


 12%|█▏        | 90/750 [00:24<03:14,  3.40it/s]

{'loss': 0.147, 'grad_norm': 3.4683656692504883, 'learning_rate': 1.76e-05, 'epoch': 0.36}


 13%|█▎        | 100/750 [00:27<02:55,  3.71it/s]

{'loss': 0.1394, 'grad_norm': 2.989434003829956, 'learning_rate': 1.7333333333333336e-05, 'epoch': 0.4}


 15%|█▍        | 110/750 [00:29<02:50,  3.76it/s]

{'loss': 0.1391, 'grad_norm': 5.056520462036133, 'learning_rate': 1.706666666666667e-05, 'epoch': 0.44}


 16%|█▌        | 120/750 [00:32<02:48,  3.73it/s]

{'loss': 0.1044, 'grad_norm': 2.908031702041626, 'learning_rate': 1.6800000000000002e-05, 'epoch': 0.48}


 17%|█▋        | 130/750 [00:35<02:44,  3.76it/s]

{'loss': 0.1237, 'grad_norm': 2.844350814819336, 'learning_rate': 1.6533333333333333e-05, 'epoch': 0.52}


 19%|█▊        | 140/750 [00:38<02:46,  3.67it/s]

{'loss': 0.1346, 'grad_norm': 4.901773929595947, 'learning_rate': 1.6266666666666668e-05, 'epoch': 0.56}


 20%|██        | 150/750 [00:40<02:58,  3.37it/s]

{'loss': 0.1299, 'grad_norm': 3.1290178298950195, 'learning_rate': 1.6000000000000003e-05, 'epoch': 0.6}


 21%|██▏       | 160/750 [00:43<02:39,  3.70it/s]

{'loss': 0.1282, 'grad_norm': 3.158478021621704, 'learning_rate': 1.5733333333333334e-05, 'epoch': 0.64}


 23%|██▎       | 170/750 [00:46<02:35,  3.73it/s]

{'loss': 0.122, 'grad_norm': 4.223031044006348, 'learning_rate': 1.546666666666667e-05, 'epoch': 0.68}


 24%|██▍       | 180/750 [00:49<02:30,  3.80it/s]

{'loss': 0.1256, 'grad_norm': 4.3731913566589355, 'learning_rate': 1.5200000000000002e-05, 'epoch': 0.72}


 25%|██▌       | 190/750 [00:51<02:23,  3.91it/s]

{'loss': 0.1136, 'grad_norm': 3.63655686378479, 'learning_rate': 1.4933333333333335e-05, 'epoch': 0.76}


 27%|██▋       | 200/750 [00:54<02:27,  3.73it/s]

{'loss': 0.1277, 'grad_norm': 2.3437659740448, 'learning_rate': 1.4666666666666666e-05, 'epoch': 0.8}


 28%|██▊       | 210/750 [00:57<02:26,  3.69it/s]

{'loss': 0.1029, 'grad_norm': 2.182068109512329, 'learning_rate': 1.4400000000000001e-05, 'epoch': 0.84}


 29%|██▉       | 220/750 [00:59<02:22,  3.73it/s]

{'loss': 0.1225, 'grad_norm': 4.837038040161133, 'learning_rate': 1.4133333333333334e-05, 'epoch': 0.88}


 31%|███       | 230/750 [01:02<02:18,  3.75it/s]

{'loss': 0.1009, 'grad_norm': 2.4567651748657227, 'learning_rate': 1.3866666666666669e-05, 'epoch': 0.92}


 32%|███▏      | 240/750 [01:05<02:20,  3.63it/s]

{'loss': 0.1154, 'grad_norm': 5.523472309112549, 'learning_rate': 1.3600000000000002e-05, 'epoch': 0.96}


 33%|███▎      | 250/750 [01:07<02:12,  3.78it/s]

{'loss': 0.1024, 'grad_norm': 3.285468339920044, 'learning_rate': 1.3333333333333333e-05, 'epoch': 1.0}


  predictions = [int(p > 0.5) for p in predictions]
                                                 
 33%|███▎      | 250/750 [01:12<02:12,  3.78it/s]

{'eval_loss': 0.08584214746952057, 'eval_accuracy': 0.8815, 'eval_f1': 0.8471953578336557, 'eval_runtime': 4.7179, 'eval_samples_per_second': 423.921, 'eval_steps_per_second': 13.354, 'epoch': 1.0}


 35%|███▍      | 260/750 [01:15<02:44,  2.99it/s]

{'loss': 0.0836, 'grad_norm': 2.421813488006592, 'learning_rate': 1.3066666666666668e-05, 'epoch': 1.04}


 36%|███▌      | 270/750 [01:18<02:01,  3.96it/s]

{'loss': 0.0806, 'grad_norm': 4.420756816864014, 'learning_rate': 1.2800000000000001e-05, 'epoch': 1.08}


 37%|███▋      | 280/750 [01:21<02:13,  3.52it/s]

{'loss': 0.0807, 'grad_norm': 3.020418167114258, 'learning_rate': 1.2533333333333336e-05, 'epoch': 1.12}


 39%|███▊      | 290/750 [01:23<02:01,  3.80it/s]

{'loss': 0.1044, 'grad_norm': 2.6625938415527344, 'learning_rate': 1.2266666666666667e-05, 'epoch': 1.16}


 40%|████      | 300/750 [01:26<02:00,  3.74it/s]

{'loss': 0.0792, 'grad_norm': 3.284777879714966, 'learning_rate': 1.2e-05, 'epoch': 1.2}


 41%|████▏     | 310/750 [01:29<02:09,  3.40it/s]

{'loss': 0.1037, 'grad_norm': 4.279487609863281, 'learning_rate': 1.1733333333333335e-05, 'epoch': 1.24}


 43%|████▎     | 320/750 [01:32<02:00,  3.56it/s]

{'loss': 0.0793, 'grad_norm': 2.437026262283325, 'learning_rate': 1.1466666666666668e-05, 'epoch': 1.28}


 44%|████▍     | 330/750 [01:34<01:56,  3.62it/s]

{'loss': 0.0981, 'grad_norm': 2.5109760761260986, 'learning_rate': 1.1200000000000001e-05, 'epoch': 1.32}


 45%|████▌     | 340/750 [01:37<01:51,  3.67it/s]

{'loss': 0.0904, 'grad_norm': 2.4278595447540283, 'learning_rate': 1.0933333333333334e-05, 'epoch': 1.36}


 47%|████▋     | 350/750 [01:40<01:47,  3.73it/s]

{'loss': 0.0738, 'grad_norm': 3.940758228302002, 'learning_rate': 1.0666666666666667e-05, 'epoch': 1.4}


 48%|████▊     | 360/750 [01:43<01:44,  3.74it/s]

{'loss': 0.1101, 'grad_norm': 6.17380952835083, 'learning_rate': 1.04e-05, 'epoch': 1.44}


 49%|████▉     | 370/750 [01:45<01:44,  3.65it/s]

{'loss': 0.0858, 'grad_norm': 3.982651472091675, 'learning_rate': 1.0133333333333335e-05, 'epoch': 1.48}


 51%|█████     | 380/750 [01:48<01:44,  3.55it/s]

{'loss': 0.0864, 'grad_norm': 2.8791446685791016, 'learning_rate': 9.866666666666668e-06, 'epoch': 1.52}


 52%|█████▏    | 390/750 [01:51<01:39,  3.63it/s]

{'loss': 0.0686, 'grad_norm': 2.5234909057617188, 'learning_rate': 9.600000000000001e-06, 'epoch': 1.56}


 53%|█████▎    | 400/750 [01:54<01:41,  3.45it/s]

{'loss': 0.0866, 'grad_norm': 5.937862873077393, 'learning_rate': 9.333333333333334e-06, 'epoch': 1.6}


 55%|█████▍    | 410/750 [01:57<01:30,  3.76it/s]

{'loss': 0.0735, 'grad_norm': 4.9179229736328125, 'learning_rate': 9.066666666666667e-06, 'epoch': 1.64}


 56%|█████▌    | 420/750 [01:59<01:32,  3.58it/s]

{'loss': 0.084, 'grad_norm': 2.8337466716766357, 'learning_rate': 8.8e-06, 'epoch': 1.68}


 57%|█████▋    | 430/750 [02:02<01:28,  3.62it/s]

{'loss': 0.0889, 'grad_norm': 2.093672513961792, 'learning_rate': 8.533333333333335e-06, 'epoch': 1.72}


 59%|█████▊    | 440/750 [02:05<01:21,  3.81it/s]

{'loss': 0.0926, 'grad_norm': 3.3018906116485596, 'learning_rate': 8.266666666666667e-06, 'epoch': 1.76}


 60%|██████    | 450/750 [02:08<01:21,  3.70it/s]

{'loss': 0.0936, 'grad_norm': 5.057663440704346, 'learning_rate': 8.000000000000001e-06, 'epoch': 1.8}


 61%|██████▏   | 460/750 [02:10<01:17,  3.75it/s]

{'loss': 0.0772, 'grad_norm': 3.331900119781494, 'learning_rate': 7.733333333333334e-06, 'epoch': 1.84}


 63%|██████▎   | 470/750 [02:13<01:13,  3.81it/s]

{'loss': 0.0688, 'grad_norm': 3.076601028442383, 'learning_rate': 7.4666666666666675e-06, 'epoch': 1.88}


 64%|██████▍   | 481/750 [02:16<01:09,  3.89it/s]

{'loss': 0.0904, 'grad_norm': 2.7956976890563965, 'learning_rate': 7.2000000000000005e-06, 'epoch': 1.92}


 65%|██████▌   | 490/750 [02:18<01:16,  3.41it/s]

{'loss': 0.0676, 'grad_norm': 2.452064037322998, 'learning_rate': 6.9333333333333344e-06, 'epoch': 1.96}


 67%|██████▋   | 500/750 [02:21<01:10,  3.57it/s]

{'loss': 0.0655, 'grad_norm': 2.150221824645996, 'learning_rate': 6.666666666666667e-06, 'epoch': 2.0}


  predictions = [int(p > 0.5) for p in predictions]
                                                 
 67%|██████▋   | 500/750 [02:26<01:10,  3.57it/s]

{'eval_loss': 0.0856158509850502, 'eval_accuracy': 0.8895, 'eval_f1': 0.8654899573950091, 'eval_runtime': 4.8016, 'eval_samples_per_second': 416.528, 'eval_steps_per_second': 13.121, 'epoch': 2.0}


 68%|██████▊   | 510/750 [02:29<01:20,  2.98it/s]

{'loss': 0.0622, 'grad_norm': 2.6063308715820312, 'learning_rate': 6.4000000000000006e-06, 'epoch': 2.04}


 69%|██████▉   | 520/750 [02:32<01:00,  3.79it/s]

{'loss': 0.0626, 'grad_norm': 2.0345726013183594, 'learning_rate': 6.133333333333334e-06, 'epoch': 2.08}


 71%|███████   | 530/750 [02:35<00:59,  3.68it/s]

{'loss': 0.0672, 'grad_norm': 4.717871189117432, 'learning_rate': 5.8666666666666675e-06, 'epoch': 2.12}


 72%|███████▏  | 540/750 [02:38<00:57,  3.68it/s]

{'loss': 0.0681, 'grad_norm': 2.9334988594055176, 'learning_rate': 5.600000000000001e-06, 'epoch': 2.16}


 73%|███████▎  | 550/750 [02:40<00:56,  3.54it/s]

{'loss': 0.057, 'grad_norm': 2.3377788066864014, 'learning_rate': 5.333333333333334e-06, 'epoch': 2.2}


 75%|███████▍  | 560/750 [02:43<00:51,  3.71it/s]

{'loss': 0.0492, 'grad_norm': 2.621394634246826, 'learning_rate': 5.0666666666666676e-06, 'epoch': 2.24}


 76%|███████▌  | 570/750 [02:46<00:47,  3.83it/s]

{'loss': 0.076, 'grad_norm': 4.5244879722595215, 'learning_rate': 4.800000000000001e-06, 'epoch': 2.28}


 77%|███████▋  | 580/750 [02:48<00:44,  3.79it/s]

{'loss': 0.0673, 'grad_norm': 2.048603057861328, 'learning_rate': 4.533333333333334e-06, 'epoch': 2.32}


 79%|███████▊  | 590/750 [02:51<00:43,  3.67it/s]

{'loss': 0.0585, 'grad_norm': 3.600558042526245, 'learning_rate': 4.266666666666668e-06, 'epoch': 2.36}


 80%|████████  | 600/750 [02:54<00:41,  3.63it/s]

{'loss': 0.0454, 'grad_norm': 1.828019618988037, 'learning_rate': 4.000000000000001e-06, 'epoch': 2.4}


 81%|████████▏ | 610/750 [02:57<00:37,  3.74it/s]

{'loss': 0.0644, 'grad_norm': 2.3683104515075684, 'learning_rate': 3.7333333333333337e-06, 'epoch': 2.44}


 83%|████████▎ | 620/750 [02:59<00:38,  3.42it/s]

{'loss': 0.0674, 'grad_norm': 2.2983293533325195, 'learning_rate': 3.4666666666666672e-06, 'epoch': 2.48}


 84%|████████▍ | 630/750 [03:02<00:33,  3.60it/s]

{'loss': 0.0461, 'grad_norm': 2.207115411758423, 'learning_rate': 3.2000000000000003e-06, 'epoch': 2.52}


 85%|████████▌ | 640/750 [03:05<00:30,  3.59it/s]

{'loss': 0.0443, 'grad_norm': 2.2093605995178223, 'learning_rate': 2.9333333333333338e-06, 'epoch': 2.56}


 87%|████████▋ | 650/750 [03:08<00:27,  3.68it/s]

{'loss': 0.0535, 'grad_norm': 2.7719197273254395, 'learning_rate': 2.666666666666667e-06, 'epoch': 2.6}


 88%|████████▊ | 660/750 [03:10<00:24,  3.61it/s]

{'loss': 0.0496, 'grad_norm': 2.9515299797058105, 'learning_rate': 2.4000000000000003e-06, 'epoch': 2.64}


 89%|████████▉ | 670/750 [03:13<00:22,  3.63it/s]

{'loss': 0.074, 'grad_norm': 3.0327601432800293, 'learning_rate': 2.133333333333334e-06, 'epoch': 2.68}


 91%|█████████ | 680/750 [03:16<00:18,  3.72it/s]

{'loss': 0.0583, 'grad_norm': 2.5441946983337402, 'learning_rate': 1.8666666666666669e-06, 'epoch': 2.72}


 92%|█████████▏| 690/750 [03:18<00:16,  3.59it/s]

{'loss': 0.0451, 'grad_norm': 4.053737163543701, 'learning_rate': 1.6000000000000001e-06, 'epoch': 2.76}


 93%|█████████▎| 700/750 [03:21<00:13,  3.73it/s]

{'loss': 0.0574, 'grad_norm': 1.7592154741287231, 'learning_rate': 1.3333333333333334e-06, 'epoch': 2.8}


 95%|█████████▍| 710/750 [03:24<00:10,  3.70it/s]

{'loss': 0.0723, 'grad_norm': 1.7531400918960571, 'learning_rate': 1.066666666666667e-06, 'epoch': 2.84}


 96%|█████████▌| 720/750 [03:27<00:08,  3.49it/s]

{'loss': 0.0556, 'grad_norm': 1.6356925964355469, 'learning_rate': 8.000000000000001e-07, 'epoch': 2.88}


 97%|█████████▋| 730/750 [03:30<00:05,  3.55it/s]

{'loss': 0.0594, 'grad_norm': 0.7540345788002014, 'learning_rate': 5.333333333333335e-07, 'epoch': 2.92}


 99%|█████████▊| 740/750 [03:32<00:02,  3.77it/s]

{'loss': 0.0447, 'grad_norm': 3.8614118099212646, 'learning_rate': 2.666666666666667e-07, 'epoch': 2.96}


100%|██████████| 750/750 [03:35<00:00,  3.80it/s]

{'loss': 0.0605, 'grad_norm': 2.7422983646392822, 'learning_rate': 0.0, 'epoch': 3.0}


  predictions = [int(p > 0.5) for p in predictions]
                                                 
100%|██████████| 750/750 [03:40<00:00,  3.80it/s]

{'eval_loss': 0.07383160293102264, 'eval_accuracy': 0.8995, 'eval_f1': 0.8683693516699411, 'eval_runtime': 4.7314, 'eval_samples_per_second': 422.707, 'eval_steps_per_second': 13.315, 'epoch': 3.0}


100%|██████████| 750/750 [03:42<00:00,  3.38it/s]

{'train_runtime': 222.0941, 'train_samples_per_second': 108.062, 'train_steps_per_second': 3.377, 'train_loss': 0.09784200775623321, 'epoch': 3.0}





TrainOutput(global_step=750, training_loss=0.09784200775623321, metrics={'train_runtime': 222.0941, 'train_samples_per_second': 108.062, 'train_steps_per_second': 3.377, 'total_flos': 1556485250567424.0, 'train_loss': 0.09784200775623321, 'epoch': 3.0})

## Step10 模型评估

In [14]:
trainer.evaluate(tokenized_datasets["test"])

  predictions = [int(p > 0.5) for p in predictions]
100%|██████████| 63/63 [00:05<00:00, 12.08it/s]


{'eval_loss': 0.07383160293102264,
 'eval_accuracy': 0.8995,
 'eval_f1': 0.8683693516699411,
 'eval_runtime': 5.3037,
 'eval_samples_per_second': 377.093,
 'eval_steps_per_second': 11.878,
 'epoch': 3.0}

## Step11 模型预测

In [15]:
from transformers import pipeline, TextClassificationPipeline

In [16]:
model.config.id2label = {0: "不相似", 1: "相似"}

In [17]:
pipe = pipeline("text-classification", model=model, tokenizer=tokenizer, device=0)

In [18]:
pipe({"text": "我喜欢北京", "text_pair": "天气怎样"})

{'label': '不相似', 'score': 1.0}

In [19]:
pipe({"text": "我喜欢北京", "text_pair": "天气怎样"}, function_to_apply="none")

{'label': '不相似', 'score': 0.1161593645811081}

In [20]:
result = pipe({"text": "我喜欢北京", "text_pair": "天气怎样"}, function_to_apply="none")
result["label"] = "相似" if result["score"] > 0.5 else "不相似"
result

{'label': '不相似', 'score': 0.1161593645811081}