# 文本相似度实例

## Step1 导入相关包

In [1]:
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    Trainer,
    TrainingArguments,
)
from datasets import load_dataset

  from .autonotebook import tqdm as notebook_tqdm


## Step2 加载数据集

In [2]:
dataset = load_dataset("json", data_files="./train_pair_1w.json", split="train")
dataset

Dataset({
    features: ['sentence1', 'sentence2', 'label'],
    num_rows: 10000
})

In [3]:
dataset[0]

{'sentence1': '找一部小时候的动画片', 'sentence2': '求一部小时候的动画片。谢了', 'label': '1'}

## Step3 划分数据集

In [4]:
datasets = dataset.train_test_split(test_size=0.2)
datasets

DatasetDict({
    train: Dataset({
        features: ['sentence1', 'sentence2', 'label'],
        num_rows: 8000
    })
    test: Dataset({
        features: ['sentence1', 'sentence2', 'label'],
        num_rows: 2000
    })
})

## Step4 数据集预处理

In [5]:
import torch

tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-macbert-base")


def process_function(examples):
    sentences = []
    labels = []
    for sen1, sen2, label in zip(
        examples["sentence1"], examples["sentence2"], examples["label"]
    ):
        sentences.append(sen1)
        sentences.append(sen2)
        # -1/1适应cosine loss
        labels.append(1 if int(label) == 1 else -1)
    # input_ids, attention_mask, token_type_ids
    tokenized_examples = tokenizer(
        sentences, max_length=128, truncation=True, padding="max_length"
    )
    tokenized_examples = {
        k: [v[i : i + 2] for i in range(0, len(v), 2)]
        for k, v in tokenized_examples.items()
    }
    tokenized_examples["labels"] = labels
    return tokenized_examples


tokenized_datasets = datasets.map(
    process_function, batched=True, remove_columns=datasets["train"].column_names
)
tokenized_datasets

Map: 100%|██████████| 8000/8000 [00:00<00:00, 10719.28 examples/s]
Map: 100%|██████████| 2000/2000 [00:00<00:00, 10622.65 examples/s]


DatasetDict({
    train: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 8000
    })
    test: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 2000
    })
})

In [6]:
print(tokenized_datasets["train"][0])

{'input_ids': [[101, 2442, 3815, 1938, 1062, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [101, 2207, 982, 2207, 3043, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

## Step5 创建模型

$$
\text{loss}(x, y) =
\begin{cases}
1 - \cos(x_1, x_2), & \text{if } y = 1 \\
\max(0, \cos(x_1, x_2) - \text{margin}), & \text{if } y = -1
\end{cases}
$$

In [7]:
from transformers import BertForSequenceClassification, BertPreTrainedModel, BertModel
from typing import Optional
from transformers.configuration_utils import PretrainedConfig
from torch.nn import CosineSimilarity, CosineEmbeddingLoss


class DualModel(BertPreTrainedModel):

    def __init__(self, config: PretrainedConfig, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)
        self.bert = BertModel(config)
        self.post_init()

    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        return_dict = (
            return_dict if return_dict is not None else self.config.use_return_dict
        )

        # Step1 分别获取sentenceA 和 sentenceB的输入
        # batch的维度全取，第一个维度取0/1
        senA_input_ids, senB_input_ids = input_ids[:, 0], input_ids[:, 1]
        senA_attention_mask, senB_attention_mask = (
            attention_mask[:, 0],
            attention_mask[:, 1],
        )
        senA_token_type_ids, senB_token_type_ids = (
            token_type_ids[:, 0],
            token_type_ids[:, 1],
        )

        # Step2 分别获取sentenceA 和 sentenceB的向量表示
        senA_outputs = self.bert(
            senA_input_ids,
            attention_mask=senA_attention_mask,
            token_type_ids=senA_token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        senA_pooled_output = senA_outputs[1]  # [batch, hidden]

        senB_outputs = self.bert(
            senB_input_ids,
            attention_mask=senB_attention_mask,
            token_type_ids=senB_token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        senB_pooled_output = senB_outputs[1]  # [batch, hidden]

        # step3 计算相似度
        # 对应其他模型的logits
        cos = CosineSimilarity()(senA_pooled_output, senB_pooled_output)  # [batch, ]

        # step4 计算loss

        loss = None
        if labels is not None:
            # margin以下的y=-1的样本不算入损失
            loss_fct = CosineEmbeddingLoss(0.3)
            loss = loss_fct(senA_pooled_output, senB_pooled_output, labels)

        output = (cos,)
        return ((loss,) + output) if loss is not None else output


model = DualModel.from_pretrained("hfl/chinese-macbert-base")

## Step6 创建评估函数

In [8]:
import evaluate

acc_metric = evaluate.load("./metric_accuracy.py")
f1_metirc = evaluate.load("./metric_f1.py")

In [9]:
def eval_metric(eval_predict):
    predictions, labels = eval_predict
    ###################################3
    predictions = [int(p > 0.7) for p in predictions]
    labels = [int(l > 0) for l in labels]
    ################################
    # predictions = predictions.argmax(axis=-1)
    acc = acc_metric.compute(predictions=predictions, references=labels)
    f1 = f1_metirc.compute(predictions=predictions, references=labels)
    acc.update(f1)
    return acc

## Step7 创建TrainingArguments

In [10]:
train_args = TrainingArguments(
    output_dir="./dual_model",  # 输出文件夹
    per_device_train_batch_size=32,  # 训练时的batch_size
    per_device_eval_batch_size=32,  # 验证时的batch_size
    logging_steps=10,  # log 打印的频率
    eval_strategy="epoch",  # 评估策略
    save_strategy="epoch",  # 保存策略
    save_total_limit=3,  # 最大保存数
    learning_rate=2e-5,  # 学习率
    weight_decay=0.01,  # weight_decay
    metric_for_best_model="f1",  # 设定评估指标
    load_best_model_at_end=True,
)  # 训练完成后加载最优模型

## Step8 创建Trainer

In [11]:
trainer = Trainer(
    model=model,
    args=train_args,
    tokenizer=tokenizer,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["test"],
    compute_metrics=eval_metric,
)

## Step9 模型训练

In [12]:
trainer.train()

  1%|▏         | 10/750 [00:05<06:19,  1.95it/s]

{'loss': 0.394, 'grad_norm': 4.751511096954346, 'learning_rate': 1.9733333333333336e-05, 'epoch': 0.04}


  3%|▎         | 20/750 [00:10<06:34,  1.85it/s]

{'loss': 0.2948, 'grad_norm': 6.54201602935791, 'learning_rate': 1.9466666666666668e-05, 'epoch': 0.08}


  4%|▍         | 30/750 [00:16<06:29,  1.85it/s]

{'loss': 0.2945, 'grad_norm': 5.728126049041748, 'learning_rate': 1.9200000000000003e-05, 'epoch': 0.12}


  5%|▌         | 40/750 [00:21<06:10,  1.92it/s]

{'loss': 0.2711, 'grad_norm': 6.081099033355713, 'learning_rate': 1.8933333333333334e-05, 'epoch': 0.16}


  7%|▋         | 50/750 [00:26<05:54,  1.98it/s]

{'loss': 0.2644, 'grad_norm': 5.299997329711914, 'learning_rate': 1.866666666666667e-05, 'epoch': 0.2}


  8%|▊         | 60/750 [00:32<06:13,  1.85it/s]

{'loss': 0.2543, 'grad_norm': 5.4090962409973145, 'learning_rate': 1.8400000000000003e-05, 'epoch': 0.24}


  9%|▉         | 70/750 [00:37<05:39,  2.00it/s]

{'loss': 0.2517, 'grad_norm': 4.428464889526367, 'learning_rate': 1.8133333333333335e-05, 'epoch': 0.28}


 11%|█         | 80/750 [00:42<05:31,  2.02it/s]

{'loss': 0.2504, 'grad_norm': 3.374255657196045, 'learning_rate': 1.7866666666666666e-05, 'epoch': 0.32}


 12%|█▏        | 90/750 [00:47<05:47,  1.90it/s]

{'loss': 0.2269, 'grad_norm': 3.5902109146118164, 'learning_rate': 1.76e-05, 'epoch': 0.36}


 13%|█▎        | 100/750 [00:52<05:48,  1.87it/s]

{'loss': 0.2372, 'grad_norm': 4.427587985992432, 'learning_rate': 1.7333333333333336e-05, 'epoch': 0.4}


 15%|█▍        | 110/750 [00:57<05:19,  2.00it/s]

{'loss': 0.2305, 'grad_norm': 3.0726940631866455, 'learning_rate': 1.706666666666667e-05, 'epoch': 0.44}


 16%|█▌        | 120/750 [01:02<05:11,  2.02it/s]

{'loss': 0.2268, 'grad_norm': 3.8449666500091553, 'learning_rate': 1.6800000000000002e-05, 'epoch': 0.48}


 17%|█▋        | 130/750 [01:07<05:35,  1.85it/s]

{'loss': 0.2524, 'grad_norm': 4.104092597961426, 'learning_rate': 1.6533333333333333e-05, 'epoch': 0.52}


 19%|█▊        | 140/750 [01:13<05:03,  2.01it/s]

{'loss': 0.2111, 'grad_norm': 3.2074553966522217, 'learning_rate': 1.6266666666666668e-05, 'epoch': 0.56}


 20%|██        | 150/750 [01:17<04:57,  2.02it/s]

{'loss': 0.2281, 'grad_norm': 3.276411294937134, 'learning_rate': 1.6000000000000003e-05, 'epoch': 0.6}


 21%|██▏       | 160/750 [01:22<04:58,  1.98it/s]

{'loss': 0.2118, 'grad_norm': 2.879138946533203, 'learning_rate': 1.5733333333333334e-05, 'epoch': 0.64}


 23%|██▎       | 170/750 [01:27<04:47,  2.01it/s]

{'loss': 0.2173, 'grad_norm': 2.6642332077026367, 'learning_rate': 1.546666666666667e-05, 'epoch': 0.68}


 24%|██▍       | 180/750 [01:32<04:41,  2.02it/s]

{'loss': 0.2249, 'grad_norm': 2.780284881591797, 'learning_rate': 1.5200000000000002e-05, 'epoch': 0.72}


 25%|██▌       | 190/750 [01:37<04:36,  2.02it/s]

{'loss': 0.222, 'grad_norm': 2.8960025310516357, 'learning_rate': 1.4933333333333335e-05, 'epoch': 0.76}


 27%|██▋       | 200/750 [01:42<04:33,  2.01it/s]

{'loss': 0.2163, 'grad_norm': 2.206005573272705, 'learning_rate': 1.4666666666666666e-05, 'epoch': 0.8}


 28%|██▊       | 210/750 [01:47<04:26,  2.02it/s]

{'loss': 0.1973, 'grad_norm': 2.850332498550415, 'learning_rate': 1.4400000000000001e-05, 'epoch': 0.84}


 29%|██▉       | 220/750 [01:52<04:29,  1.96it/s]

{'loss': 0.2056, 'grad_norm': 3.1139204502105713, 'learning_rate': 1.4133333333333334e-05, 'epoch': 0.88}


 31%|███       | 230/750 [01:57<04:21,  1.99it/s]

{'loss': 0.2075, 'grad_norm': 2.6422533988952637, 'learning_rate': 1.3866666666666669e-05, 'epoch': 0.92}


 32%|███▏      | 240/750 [02:03<04:13,  2.01it/s]

{'loss': 0.2001, 'grad_norm': 2.3131744861602783, 'learning_rate': 1.3600000000000002e-05, 'epoch': 0.96}


 33%|███▎      | 250/750 [02:08<04:14,  1.97it/s]

{'loss': 0.2003, 'grad_norm': 2.768965482711792, 'learning_rate': 1.3333333333333333e-05, 'epoch': 1.0}


                                                 
 33%|███▎      | 250/750 [02:17<04:14,  1.97it/s]

{'eval_loss': 0.19133074581623077, 'eval_accuracy': 0.774, 'eval_f1': 0.7306317044100119, 'eval_runtime': 9.7259, 'eval_samples_per_second': 205.636, 'eval_steps_per_second': 6.478, 'epoch': 1.0}


 35%|███▍      | 260/750 [02:23<05:08,  1.59it/s]

{'loss': 0.1633, 'grad_norm': 2.2828009128570557, 'learning_rate': 1.3066666666666668e-05, 'epoch': 1.04}


 36%|███▌      | 270/750 [02:28<04:22,  1.83it/s]

{'loss': 0.1702, 'grad_norm': 2.4865972995758057, 'learning_rate': 1.2800000000000001e-05, 'epoch': 1.08}


 37%|███▋      | 280/750 [02:34<03:53,  2.01it/s]

{'loss': 0.1556, 'grad_norm': 2.705965995788574, 'learning_rate': 1.2533333333333336e-05, 'epoch': 1.12}


 39%|███▊      | 290/750 [02:39<03:56,  1.95it/s]

{'loss': 0.1774, 'grad_norm': 2.657898426055908, 'learning_rate': 1.2266666666666667e-05, 'epoch': 1.16}


 40%|████      | 300/750 [02:44<03:47,  1.98it/s]

{'loss': 0.1655, 'grad_norm': 2.796128988265991, 'learning_rate': 1.2e-05, 'epoch': 1.2}


 41%|████▏     | 310/750 [02:49<03:47,  1.93it/s]

{'loss': 0.1517, 'grad_norm': 2.095370292663574, 'learning_rate': 1.1733333333333335e-05, 'epoch': 1.24}


 43%|████▎     | 320/750 [02:54<03:39,  1.96it/s]

{'loss': 0.1684, 'grad_norm': 2.584878921508789, 'learning_rate': 1.1466666666666668e-05, 'epoch': 1.28}


 44%|████▍     | 330/750 [03:00<03:44,  1.87it/s]

{'loss': 0.1531, 'grad_norm': 2.5742876529693604, 'learning_rate': 1.1200000000000001e-05, 'epoch': 1.32}


 45%|████▌     | 340/750 [03:05<03:29,  1.95it/s]

{'loss': 0.1674, 'grad_norm': 2.6674835681915283, 'learning_rate': 1.0933333333333334e-05, 'epoch': 1.36}


 47%|████▋     | 350/750 [03:10<03:22,  1.97it/s]

{'loss': 0.1683, 'grad_norm': 2.3343918323516846, 'learning_rate': 1.0666666666666667e-05, 'epoch': 1.4}


 48%|████▊     | 360/750 [03:15<03:14,  2.01it/s]

{'loss': 0.1782, 'grad_norm': 2.539273977279663, 'learning_rate': 1.04e-05, 'epoch': 1.44}


 49%|████▉     | 370/750 [03:20<03:08,  2.01it/s]

{'loss': 0.1553, 'grad_norm': 2.5701799392700195, 'learning_rate': 1.0133333333333335e-05, 'epoch': 1.48}


 51%|█████     | 380/750 [03:25<03:17,  1.87it/s]

{'loss': 0.174, 'grad_norm': 2.9619693756103516, 'learning_rate': 9.866666666666668e-06, 'epoch': 1.52}


 52%|█████▏    | 390/750 [03:30<02:58,  2.02it/s]

{'loss': 0.1603, 'grad_norm': 2.6187260150909424, 'learning_rate': 9.600000000000001e-06, 'epoch': 1.56}


 53%|█████▎    | 400/750 [03:35<02:53,  2.02it/s]

{'loss': 0.1691, 'grad_norm': 2.6135404109954834, 'learning_rate': 9.333333333333334e-06, 'epoch': 1.6}


 55%|█████▍    | 410/750 [03:40<03:00,  1.88it/s]

{'loss': 0.1846, 'grad_norm': 3.1661109924316406, 'learning_rate': 9.066666666666667e-06, 'epoch': 1.64}


 56%|█████▌    | 420/750 [03:45<02:42,  2.03it/s]

{'loss': 0.172, 'grad_norm': 2.7095530033111572, 'learning_rate': 8.8e-06, 'epoch': 1.68}


 57%|█████▋    | 430/750 [03:50<02:37,  2.03it/s]

{'loss': 0.16, 'grad_norm': 2.1833810806274414, 'learning_rate': 8.533333333333335e-06, 'epoch': 1.72}


 59%|█████▊    | 440/750 [03:55<02:32,  2.03it/s]

{'loss': 0.1576, 'grad_norm': 2.7183003425598145, 'learning_rate': 8.266666666666667e-06, 'epoch': 1.76}


 60%|██████    | 450/750 [04:00<02:27,  2.03it/s]

{'loss': 0.1628, 'grad_norm': 2.588325262069702, 'learning_rate': 8.000000000000001e-06, 'epoch': 1.8}


 61%|██████▏   | 460/750 [04:05<02:22,  2.03it/s]

{'loss': 0.1571, 'grad_norm': 2.253831386566162, 'learning_rate': 7.733333333333334e-06, 'epoch': 1.84}


 63%|██████▎   | 470/750 [04:10<02:17,  2.03it/s]

{'loss': 0.1532, 'grad_norm': 2.6132335662841797, 'learning_rate': 7.4666666666666675e-06, 'epoch': 1.88}


 64%|██████▍   | 480/750 [04:15<02:14,  2.01it/s]

{'loss': 0.1505, 'grad_norm': 2.5081160068511963, 'learning_rate': 7.2000000000000005e-06, 'epoch': 1.92}


 65%|██████▌   | 490/750 [04:20<02:20,  1.85it/s]

{'loss': 0.1569, 'grad_norm': 2.7479774951934814, 'learning_rate': 6.9333333333333344e-06, 'epoch': 1.96}


 67%|██████▋   | 500/750 [04:26<02:18,  1.80it/s]

{'loss': 0.1537, 'grad_norm': 2.466906785964966, 'learning_rate': 6.666666666666667e-06, 'epoch': 2.0}


                                                 
 67%|██████▋   | 500/750 [04:37<02:18,  1.80it/s]

{'eval_loss': 0.17802001535892487, 'eval_accuracy': 0.7985, 'eval_f1': 0.7535168195718654, 'eval_runtime': 10.9973, 'eval_samples_per_second': 181.863, 'eval_steps_per_second': 5.729, 'epoch': 2.0}


 68%|██████▊   | 510/750 [04:43<02:49,  1.41it/s]

{'loss': 0.1129, 'grad_norm': 2.732299566268921, 'learning_rate': 6.4000000000000006e-06, 'epoch': 2.04}


 69%|██████▉   | 520/750 [04:49<02:10,  1.77it/s]

{'loss': 0.1404, 'grad_norm': 1.9977092742919922, 'learning_rate': 6.133333333333334e-06, 'epoch': 2.08}


 71%|███████   | 530/750 [04:54<02:03,  1.78it/s]

{'loss': 0.1334, 'grad_norm': 2.217043161392212, 'learning_rate': 5.8666666666666675e-06, 'epoch': 2.12}


 72%|███████▏  | 540/750 [05:00<01:57,  1.79it/s]

{'loss': 0.1378, 'grad_norm': 2.218151569366455, 'learning_rate': 5.600000000000001e-06, 'epoch': 2.16}


 73%|███████▎  | 550/750 [05:06<01:51,  1.80it/s]

{'loss': 0.1246, 'grad_norm': 2.5297794342041016, 'learning_rate': 5.333333333333334e-06, 'epoch': 2.2}


 75%|███████▍  | 560/750 [05:11<01:46,  1.78it/s]

{'loss': 0.1303, 'grad_norm': 1.9798343181610107, 'learning_rate': 5.0666666666666676e-06, 'epoch': 2.24}


 76%|███████▌  | 570/750 [05:17<01:40,  1.79it/s]

{'loss': 0.1286, 'grad_norm': 2.3717198371887207, 'learning_rate': 4.800000000000001e-06, 'epoch': 2.28}


 77%|███████▋  | 580/750 [05:22<01:30,  1.88it/s]

{'loss': 0.1359, 'grad_norm': 3.1793477535247803, 'learning_rate': 4.533333333333334e-06, 'epoch': 2.32}


 79%|███████▊  | 590/750 [05:27<01:18,  2.03it/s]

{'loss': 0.1553, 'grad_norm': 2.351801872253418, 'learning_rate': 4.266666666666668e-06, 'epoch': 2.36}


 80%|████████  | 600/750 [05:32<01:14,  2.03it/s]

{'loss': 0.1369, 'grad_norm': 2.4939074516296387, 'learning_rate': 4.000000000000001e-06, 'epoch': 2.4}


 81%|████████▏ | 610/750 [05:37<01:09,  2.01it/s]

{'loss': 0.122, 'grad_norm': 2.105208158493042, 'learning_rate': 3.7333333333333337e-06, 'epoch': 2.44}


 83%|████████▎ | 620/750 [05:43<01:12,  1.79it/s]

{'loss': 0.1305, 'grad_norm': 2.8469772338867188, 'learning_rate': 3.4666666666666672e-06, 'epoch': 2.48}


 84%|████████▍ | 630/750 [05:48<01:04,  1.87it/s]

{'loss': 0.1206, 'grad_norm': 2.1133601665496826, 'learning_rate': 3.2000000000000003e-06, 'epoch': 2.52}


 85%|████████▌ | 640/750 [05:54<00:59,  1.84it/s]

{'loss': 0.1307, 'grad_norm': 2.709742784500122, 'learning_rate': 2.9333333333333338e-06, 'epoch': 2.56}


 87%|████████▋ | 650/750 [05:59<00:52,  1.89it/s]

{'loss': 0.1337, 'grad_norm': 2.645892858505249, 'learning_rate': 2.666666666666667e-06, 'epoch': 2.6}


 88%|████████▊ | 660/750 [06:05<00:48,  1.84it/s]

{'loss': 0.1205, 'grad_norm': 2.738369941711426, 'learning_rate': 2.4000000000000003e-06, 'epoch': 2.64}


 89%|████████▉ | 670/750 [06:10<00:40,  1.95it/s]

{'loss': 0.1439, 'grad_norm': 2.25692081451416, 'learning_rate': 2.133333333333334e-06, 'epoch': 2.68}


 91%|█████████ | 680/750 [06:15<00:34,  2.01it/s]

{'loss': 0.1193, 'grad_norm': 2.5314342975616455, 'learning_rate': 1.8666666666666669e-06, 'epoch': 2.72}


 92%|█████████▏| 690/750 [06:20<00:29,  2.04it/s]

{'loss': 0.1261, 'grad_norm': 2.930565357208252, 'learning_rate': 1.6000000000000001e-06, 'epoch': 2.76}


 93%|█████████▎| 700/750 [06:25<00:25,  1.98it/s]

{'loss': 0.1365, 'grad_norm': 3.3265697956085205, 'learning_rate': 1.3333333333333334e-06, 'epoch': 2.8}


 95%|█████████▍| 710/750 [06:30<00:20,  1.92it/s]

{'loss': 0.1345, 'grad_norm': 1.8839573860168457, 'learning_rate': 1.066666666666667e-06, 'epoch': 2.84}


 96%|█████████▌| 720/750 [06:35<00:15,  1.95it/s]

{'loss': 0.125, 'grad_norm': 2.176651954650879, 'learning_rate': 8.000000000000001e-07, 'epoch': 2.88}


 97%|█████████▋| 730/750 [06:40<00:10,  1.98it/s]

{'loss': 0.1268, 'grad_norm': 3.902048110961914, 'learning_rate': 5.333333333333335e-07, 'epoch': 2.92}


 99%|█████████▊| 740/750 [06:45<00:05,  1.99it/s]

{'loss': 0.1209, 'grad_norm': 2.294229745864868, 'learning_rate': 2.666666666666667e-07, 'epoch': 2.96}


100%|██████████| 750/750 [06:50<00:00,  2.00it/s]

{'loss': 0.1239, 'grad_norm': 3.0290589332580566, 'learning_rate': 0.0, 'epoch': 3.0}


                                                 
100%|██████████| 750/750 [07:01<00:00,  2.00it/s]

{'eval_loss': 0.17861349880695343, 'eval_accuracy': 0.797, 'eval_f1': 0.7475124378109452, 'eval_runtime': 9.8837, 'eval_samples_per_second': 202.353, 'eval_steps_per_second': 6.374, 'epoch': 3.0}


100%|██████████| 750/750 [07:02<00:00,  1.77it/s]

{'train_runtime': 422.755, 'train_samples_per_second': 56.77, 'train_steps_per_second': 1.774, 'train_loss': 0.17771974213918051, 'epoch': 3.0}





TrainOutput(global_step=750, training_loss=0.17771974213918051, metrics={'train_runtime': 422.755, 'train_samples_per_second': 56.77, 'train_steps_per_second': 1.774, 'total_flos': 3157275967488000.0, 'train_loss': 0.17771974213918051, 'epoch': 3.0})

## Step10 模型评估

In [13]:
trainer.evaluate(tokenized_datasets["test"])

100%|██████████| 63/63 [00:10<00:00,  6.01it/s]


{'eval_loss': 0.17802001535892487,
 'eval_accuracy': 0.7985,
 'eval_f1': 0.7535168195718654,
 'eval_runtime': 10.5864,
 'eval_samples_per_second': 188.923,
 'eval_steps_per_second': 5.951,
 'epoch': 3.0}

## Step11 模型预测

In [14]:
class SentenceSimilarityPipeline:

    def __init__(self, model, tokenizer) -> None:
        self.model = model.bert
        self.tokenizer = tokenizer
        self.device = model.device

    def preprocess(self, senA, senB):
        return self.tokenizer(
            [senA, senB],
            max_length=128,
            # Q：截取的是哪部分？
            # A:senA和senB都有截取
            truncation=True,
            return_tensors="pt",
            padding=True,
        )

    def predict(self, inputs):
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        # 取bert输出的pooler_output
        # 对应模型定义的step1、step2
        return self.model(**inputs)[1]  # [2, 768]

    def postprocess(self, logits):
        # 对应模型定义的step3
        cos = (
            CosineSimilarity()(logits[None, 0, :], logits[None, 1, :])
            .squeeze()
            .cpu()
            .item()
        )
        return cos

    def __call__(self, senA, senB, return_vector=False):
        inputs = self.preprocess(senA, senB)
        logits = self.predict(inputs)
        print(f"logits.shape: {logits.shape}")
        result = self.postprocess(logits)
        if return_vector:
            return result, logits
        else:
            return result

In [15]:
pipe = SentenceSimilarityPipeline(model, tokenizer)

In [16]:
pipe("我喜欢北京", "明天不行", return_vector=True)

logits.shape: torch.Size([2, 768])


(0.4809381365776062,
 tensor([[-0.9809, -0.6450, -0.6084,  ...,  0.9478, -0.0312, -0.0515],
         [-0.9324,  0.3896, -0.0323,  ..., -0.3455, -0.3624, -0.6469]],
        device='cuda:0', grad_fn=<TanhBackward0>))

In [17]:
pipe("我喜欢北京", "北京是我喜欢的城市", return_vector=True)

logits.shape: torch.Size([2, 768])


(0.7707035541534424,
 tensor([[-0.9809, -0.6450, -0.6084,  ...,  0.9478, -0.0312, -0.0515],
         [-0.9039, -0.2214, -0.1132,  ...,  0.7280, -0.0946,  0.2909]],
        device='cuda:0', grad_fn=<TanhBackward0>))