导入所需包

In [1]:
import random
import time
from typing import List
import pandas as pd
import jsonlines
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from loguru import logger
from scipy.stats import spearmanr
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from transformers import BertConfig, BertModel, BertTokenizer

设置基本参数，以及文件读取和存储路径

In [2]:
# 基本参数
EPOCHS = 2
BATCH_SIZE = 64
LR = 1e-5
MAXLEN = 64
POOLING = 'cls'   # choose in ['cls', 'pooler', 'last-avg', 'first-last-avg']
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 

# 预训练模型目录
model_path = 'hfl/chinese-roberta-wwm-ext'

# 微调后参数存放位置
SAVE_PATH = './simcse_sup.pt'

# 数据位置
TRAIN = './ICD_train.txt'
DEV = './ICD_dev.txt'

数据读取以及装入到dataset

In [3]:
def load_data(name: str, path: str) -> List:
    def load_train_data(path):        
        with jsonlines.open(path, 'r') as f:
            return [(line['origin'], line['entailment'], line['contradiction']) for line in f]
        
    def load_dev_data(path):
        with open(path, 'r', encoding='utf8') as f:            
            return [(line.split("||")[0], line.split("||")[1], line.split("||")[2]) for line in f] 
    if name == 'train':
        return load_train_data(path)    
    return load_dev_data(path)
    

class TrainDataset(Dataset):
    def __init__(self, data: List):
        self.data = data
        
    def __len__(self):
        return len(self.data)
    
    def text_2_id(self, text: str):
        return tokenizer([text[0], text[1], text[2]], max_length=MAXLEN, 
                         truncation=True, padding='max_length', return_tensors='pt')
    
    def __getitem__(self, index: int):
        return self.text_2_id(self.data[index]) 

class TestDataset(Dataset):
    def __init__(self, data: List):
        self.data = data
        
    def __len__(self):
        return len(self.data)
    
    def text_2_id(self, text: str):
        return tokenizer(text, max_length=MAXLEN, truncation=True, 
                         padding='max_length', return_tensors='pt')
    
    def __getitem__(self, index):
        line = self.data[index]
        return self.text_2_id([line[0]]), self.text_2_id([line[1]]), int(line[2].replace('"\n',''))
    

模型以及损失函数，这里采用SimCSE有监督版本

In [4]:
class SimcseModel(nn.Module):
    def __init__(self, pretrained_model: str, pooling: str):
        super(SimcseModel, self).__init__()
        self.bert = BertModel.from_pretrained(pretrained_model)
        self.pooling = pooling
        
    def forward(self, input_ids, attention_mask, token_type_ids):
        out = self.bert(input_ids, attention_mask, token_type_ids, output_hidden_states=True)
        return out.last_hidden_state[:, 0] 
                  
            
def simcse_sup_loss(y_pred: 'tensor') -> 'tensor':
    y_true = torch.arange(y_pred.shape[0], device=DEVICE)
    use_row = torch.where((y_true + 1) % 3 != 0)[0]
    y_true = (use_row - use_row % 3 * 2) + 1
    # batch内两两计算相似度, 得到相似度矩阵(对角矩阵)
    sim = F.cosine_similarity(y_pred.unsqueeze(1), y_pred.unsqueeze(0), dim=-1)
    # 将相似度矩阵对角线置为很小的值, 消除自身的影响
    sim = sim - torch.eye(y_pred.shape[0], device=DEVICE) * 1e12
    # 选取有效的行
    sim = torch.index_select(sim, 0, use_row)
    # 相似度矩阵除以温度系数
    sim = sim / 0.05
    # 计算相似度矩阵与y_true的交叉熵损失
    loss = F.cross_entropy(sim, y_true)
    return torch.mean(loss)

评估函数，数据来自data_process构建的验证集

In [5]:
def eval(model, dataloader) -> float:
    model.eval()
    label_array = np.array([])
    acc = 0 
    num = 5896
    thresholds = [0.6,0.65,0.7 , 0.75, 0.8,0.85, 0.9,0.95 ]
    for threshold in thresholds:
        acc_now = 0
        with torch.no_grad():
            for source, target, label in dataloader:
                # source        [batch, 1, seq_len] -> [batch, seq_len]
                source_input_ids = source['input_ids'].squeeze(1).to(DEVICE)
                source_attention_mask = source['attention_mask'].squeeze(1).to(DEVICE)
                source_token_type_ids = source['token_type_ids'].squeeze(1).to(DEVICE)
                source_pred = model(source_input_ids, source_attention_mask, source_token_type_ids)
                # target        [batch, 1, seq_len] -> [batch, seq_len]
                target_input_ids = target['input_ids'].squeeze(1).to(DEVICE)
                target_attention_mask = target['attention_mask'].squeeze(1).to(DEVICE)
                target_token_type_ids = target['token_type_ids'].squeeze(1).to(DEVICE)
                target_pred = model(target_input_ids, target_attention_mask, target_token_type_ids)
                # concat
                sim = F.cosine_similarity(source_pred, target_pred, dim=-1)
        # corrcoef  
                sim_numpy = sim.cpu().numpy()
                sim_n = np.array([])
                for s,l in zip(sim_numpy,label):
                    if s >= threshold :
                        sim_n = np.append(sim_n,1)
                    else :
                        sim_n = np.append(sim_n,0)
                acc_now = acc_now + np.count_nonzero(sim_n==label.cpu().numpy())
            acc = max (acc , acc_now)
    print(acc/num)
    return acc/num

模型训练函数，其下游任务为评估验证集准确率

In [6]:
def train(model, train_dl, dev_dl, optimizer) -> None:
    model.train()
    global best
    early_stop_batch = 0
    for batch_idx, source in enumerate(tqdm(train_dl), start=1):
        real_batch_num = source.get('input_ids').shape[0]
        input_ids = source.get('input_ids').view(real_batch_num * 3, -1).to(DEVICE)
        attention_mask = source.get('attention_mask').view(real_batch_num * 3, -1).to(DEVICE)
        token_type_ids = source.get('token_type_ids').view(real_batch_num * 3, -1).to(DEVICE)
        # 训练
        out = model(input_ids, attention_mask, token_type_ids)
        loss = simcse_sup_loss(out)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # 评估
        if batch_idx % 10 == 0:
            logger.info(f'loss: {loss.item():.4f}')
            corrcoef = eval(model, dev_dl)
            model.train()
            if best < corrcoef:
                early_stop_batch = 0
                best = corrcoef
                torch.save(model.state_dict(), SAVE_PATH)
                logger.info(f"higher corrcoef: {best:.4f} in batch: {batch_idx}, save model")
                continue
            early_stop_batch += 1
            if early_stop_batch == 100:
                logger.info(f"corrcoef doesn't improve for {early_stop_batch} batch, early stop!")
                logger.info(f"train use sample number: {(batch_idx - 10) * BATCH_SIZE}")
                return 

In [7]:
logger.info(f'device: {DEVICE}, pooling: {POOLING}, model path: {model_path}')
tokenizer = BertTokenizer.from_pretrained(model_path)

#加载数据
train_data = load_data('train',TRAIN)
random.shuffle(train_data)                        
dev_data = load_data('dev',DEV) 
train_dataloader = DataLoader(TrainDataset(train_data), batch_size=BATCH_SIZE)
dev_dataloader = DataLoader(TestDataset(dev_data), batch_size=BATCH_SIZE)
print("data ok ")

#加载模型
model = SimcseModel(pretrained_model=model_path, pooling=POOLING)
model.to(DEVICE)
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
    
# 训练
best = 0
for epoch in range(EPOCHS):
    logger.info(f'epoch: {epoch}')
    train(model, train_dataloader, dev_dataloader, optimizer)
    logger.info(f'train is finished, best model is saved at {SAVE_PATH}')

#验证
model.load_state_dict(torch.load(SAVE_PATH))
dev_corrcoef = eval(model, dev_dataloader)
logger.info(f'dev_corrcoef: {dev_corrcoef:.4f}')

2022-10-15 16:40:36.798 | INFO     | __main__:<cell line: 1>:1 - device: cuda, pooling: cls, model path: hfl/chinese-roberta-wwm-ext


data ok 


Some weights of the model checkpoint at hfl/chinese-roberta-wwm-ext were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
2022-10-15 16:40:49.353 | INFO     | __main__:<cell line: 19>:20 - epoch: 0
  7%|▋         | 9/125 [00:03<00:38,  2.99it/s]2022-10-15 16:40:52.72

0.9377544097693351


2022-10-15 16:42:08.967 | INFO     | __main__:train:25 - higher corrcoef: 0.9378 in batch: 10, save model
 15%|█▌        | 19/125 [01:22<02:13,  1.26s/it]2022-10-15 16:42:12.278 | INFO     | __main__:train:18 - loss: 1.7977
 16%|█▌        | 20/125 [02:37<41:04, 23.47s/it]

0.9372455902306649


 23%|██▎       | 29/125 [02:40<02:01,  1.26s/it]2022-10-15 16:43:30.530 | INFO     | __main__:train:18 - loss: 1.8531


0.939280868385346


2022-10-15 16:44:47.031 | INFO     | __main__:train:25 - higher corrcoef: 0.9393 in batch: 30, save model
 31%|███       | 39/125 [04:00<01:50,  1.28s/it]2022-10-15 16:44:50.372 | INFO     | __main__:train:18 - loss: 2.0553


0.9404681139755766


2022-10-15 16:46:06.903 | INFO     | __main__:train:25 - higher corrcoef: 0.9405 in batch: 40, save model
 39%|███▉      | 49/125 [05:20<01:37,  1.29s/it]2022-10-15 16:46:10.227 | INFO     | __main__:train:18 - loss: 1.8556
 40%|████      | 50/125 [06:36<29:27, 23.57s/it]

0.9363975576662144


 47%|████▋     | 59/125 [06:39<01:23,  1.27s/it]2022-10-15 16:47:28.784 | INFO     | __main__:train:18 - loss: 1.7279
 48%|████▊     | 60/125 [07:54<25:32, 23.58s/it]

0.9319877883310719


 55%|█████▌    | 69/125 [07:57<01:11,  1.27s/it]2022-10-15 16:48:47.425 | INFO     | __main__:train:18 - loss: 1.4673
 56%|█████▌    | 70/125 [09:13<21:40, 23.64s/it]

0.9328358208955224


 63%|██████▎   | 79/125 [09:16<00:58,  1.27s/it]2022-10-15 16:50:06.266 | INFO     | __main__:train:18 - loss: 1.2773
 64%|██████▍   | 80/125 [10:33<18:00, 24.02s/it]

0.9279172320217096


 71%|███████   | 89/125 [10:37<00:48,  1.35s/it]2022-10-15 16:51:27.087 | INFO     | __main__:train:18 - loss: 1.1707
 72%|███████▏  | 90/125 [12:06<16:09, 27.71s/it]

0.9248643147896879


 79%|███████▉  | 99/125 [12:10<00:38,  1.50s/it]2022-10-15 16:52:59.946 | INFO     | __main__:train:18 - loss: 1.5506
 80%|████████  | 100/125 [13:39<11:35, 27.82s/it]

0.9268995929443691


 87%|████████▋ | 109/125 [13:42<00:24,  1.50s/it]2022-10-15 16:54:32.757 | INFO     | __main__:train:18 - loss: 1.2010
 88%|████████▊ | 110/125 [15:12<06:57, 27.81s/it]

0.9241858887381276


 95%|█████████▌| 119/125 [15:15<00:09,  1.51s/it]2022-10-15 16:56:05.578 | INFO     | __main__:train:18 - loss: 1.7988
 96%|█████████▌| 120/125 [16:44<02:18, 27.79s/it]

0.926729986431479


100%|██████████| 125/125 [16:46<00:00,  8.06s/it]
2022-10-15 16:57:36.274 | INFO     | __main__:<cell line: 19>:22 - train is finished, best model is saved at ./simcse_sup.pt
2022-10-15 16:57:36.276 | INFO     | __main__:<cell line: 19>:20 - epoch: 1
  7%|▋         | 9/125 [00:03<00:47,  2.46it/s]2022-10-15 16:57:40.305 | INFO     | __main__:train:18 - loss: 1.1276
  8%|▊         | 10/125 [01:32<53:17, 27.81s/it]

0.9262211668928086


 15%|█▌        | 19/125 [01:36<02:37,  1.48s/it]2022-10-15 16:59:13.147 | INFO     | __main__:train:18 - loss: 1.1017
 16%|█▌        | 20/125 [03:05<48:40, 27.81s/it]

0.9277476255088195


 23%|██▎       | 29/125 [03:09<02:25,  1.51s/it]2022-10-15 17:00:46.029 | INFO     | __main__:train:18 - loss: 1.1668
 24%|██▍       | 30/125 [04:38<44:02, 27.82s/it]

0.9255427408412483


 31%|███       | 39/125 [04:42<02:10,  1.52s/it]2022-10-15 17:02:18.904 | INFO     | __main__:train:18 - loss: 1.6453
 32%|███▏      | 40/125 [06:11<39:18, 27.75s/it]

0.9274084124830394


 39%|███▉      | 49/125 [06:14<01:54,  1.51s/it]2022-10-15 17:03:51.576 | INFO     | __main__:train:18 - loss: 1.3329
 40%|████      | 50/125 [07:43<34:43, 27.78s/it]

0.9255427408412483


 47%|████▋     | 59/125 [07:47<01:40,  1.52s/it]2022-10-15 17:05:24.315 | INFO     | __main__:train:18 - loss: 1.3684
 48%|████▊     | 60/125 [09:16<30:06, 27.79s/it]

0.9216417910447762


 55%|█████▌    | 69/125 [09:20<01:24,  1.50s/it]2022-10-15 17:06:57.084 | INFO     | __main__:train:18 - loss: 1.1348
 56%|█████▌    | 70/125 [10:49<25:29, 27.80s/it]

0.9238466757123474


 63%|██████▎   | 79/125 [10:53<01:08,  1.50s/it]2022-10-15 17:08:29.843 | INFO     | __main__:train:18 - loss: 0.9873
 64%|██████▍   | 80/125 [12:22<20:49, 27.77s/it]

0.9170624151967436


 71%|███████   | 89/125 [12:25<00:54,  1.51s/it]2022-10-15 17:10:02.603 | INFO     | __main__:train:18 - loss: 0.9899
 72%|███████▏  | 90/125 [13:55<16:13, 27.83s/it]

0.9170624151967436


 79%|███████▉  | 99/125 [13:58<00:39,  1.52s/it]2022-10-15 17:11:35.485 | INFO     | __main__:train:18 - loss: 1.2397
 80%|████████  | 100/125 [15:27<11:34, 27.78s/it]

0.9151967435549525


 87%|████████▋ | 109/125 [15:31<00:24,  1.51s/it]2022-10-15 17:13:08.171 | INFO     | __main__:train:18 - loss: 0.8929
 88%|████████▊ | 110/125 [17:00<06:57, 27.82s/it]

0.9190976933514247


 95%|█████████▌| 119/125 [17:04<00:09,  1.51s/it]2022-10-15 17:14:41.051 | INFO     | __main__:train:18 - loss: 1.3644
 96%|█████████▌| 120/125 [18:33<02:19, 27.80s/it]

0.9185888738127544


100%|██████████| 125/125 [18:35<00:00,  8.92s/it]
2022-10-15 17:16:11.817 | INFO     | __main__:<cell line: 19>:22 - train is finished, best model is saved at ./simcse_sup.pt
2022-10-15 17:17:41.020 | INFO     | __main__:<cell line: 27>:27 - dev_corrcoef: 0.9405


0.9404681139755766
