# 导包

In [122]:
import nlp_basictasks
import os,json
import numpy as np
import torch
import torch.nn as nn
import random
from tqdm.autonotebook import tqdm, trange
from torch.utils.data import DataLoader
from nlp_basictasks.modules import SBERT
from nlp_basictasks.modules.transformers import BertTokenizer,BertModel,BertConfig
from nlp_basictasks.readers.sts import InputExample,convert_examples_to_features,getExamples,convert_sentences_to_features
from nlp_basictasks.modules.utils import get_optimizer,get_scheduler
from nlp_basictasks.Trainer import Trainer
from nlp_basictasks.evaluation import stsEvaluator
from sentence_transformers import SentenceTransformer,models
# model_path1='/data/nfs14/nfs/aisearch/asr/xhsun/bwbd_recall/distill-simcse/'
# model_path2="/data/nfs14/nfs/aisearch/asr/xhsun/bwbd_recall/distiluse-base-multilingual-cased-v1/"
model_path3='/data/nfs14/nfs/aisearch/asr/xhsun/CommonModel/chinese-roberta-wwm/'
# data_folder='/data/nfs14/nfs/aisearch/asr/xhsun/datasets/lcqmc/'
# train_file=os.path.join(data_folder,'lcqmc_train.tsv')
# dev_file=os.path.join(data_folder,'lcqmc_dev.tsv')
#tokenizer=BertTokenizer.from_pretrained(os.path.join(model_path1,'0_Transformer'))
tokenizer=BertTokenizer.from_pretrained(model_path3)
max_seq_len=32
batch_size=6

2021-10-13 19:30:49 - INFO - from_pretrained - 125 : loading vocabulary file /data/nfs14/nfs/aisearch/asr/xhsun/CommonModel/chinese-roberta-wwm/vocab.txt


# 获取数据

In [123]:
train_file='/data/nfs14/nfs/aisearch/asr/xhsun/bwbd_recall/STS-B/cnsd-sts-train.txt'
dev_file='/data/nfs14/nfs/aisearch/asr/xhsun/bwbd_recall/STS-B/cnsd-sts-dev.txt'
test_file='/data/nfs14/nfs/aisearch/asr/xhsun/bwbd_recall/STS-B/cnsd-sts-test.txt'
def read_data(file_path):
    sentences=[]
    labels=[]
    with open(file_path) as f:
        lines=f.readlines()
    for line in lines:
        line_split=line.strip().split('||')
        sentences.append([line_split[1],line_split[2]])
        labels.append(line_split[3])
    return sentences,labels

In [124]:
train_sentences,train_labels=read_data(train_file)
dev_sentences,dev_labels=read_data(dev_file)
test_sentences,test_labels=read_data(test_file)

In [125]:
print(train_sentences[:2],train_labels[:2])
print(dev_sentences[:2],dev_labels[:2])
print(test_sentences[:2],test_labels[:2])

[['一架飞机要起飞了。', '一架飞机正在起飞。'], ['一个男人在吹一支大笛子。', '一个人在吹长笛。']] ['5', '3']
[['一个戴着安全帽的男人在跳舞。', '一个戴着安全帽的男人在跳舞。'], ['一个小孩在骑马。', '孩子在骑马。']] ['5', '4']
[['一个女孩在给她的头发做发型。', '一个女孩在梳头。'], ['一群男人在海滩上踢足球。', '一群男孩在海滩上踢足球。']] ['2', '3']


# create unsupervised train_dataset

In [126]:
train_sentences=[sentence[0] for sentence in train_sentences]#只取一般数据作为训练集
print(len(train_sentences))
print(train_sentences[:3])
train_examples=[InputExample(text_list=[sentence,sentence],label=1) for sentence in train_sentences]
train_dataloader=DataLoader(train_examples,shuffle=True,batch_size=batch_size)
def smart_batching_collate(batch):
    features_of_a,features_of_b,labels=convert_examples_to_features(examples=batch,tokenizer=tokenizer,max_seq_len=max_seq_len)
    return features_of_a,features_of_b,labels
train_dataloader.collate_fn=smart_batching_collate
print(train_examples[0])

5231
['一架飞机要起飞了。', '一个男人在吹一支大笛子。', '一个人正把切碎的奶酪撒在比萨饼上。']
<InputExample> label: 1, text pairs : 一架飞机要起飞了。; 一架飞机要起飞了。


# SimCSE模型

In [142]:
class SimCSE(nn.Module):
    def __init__(self,
                 bert_model_path,
                 is_sbert_model=True,
                temperature=0.05,
                is_distilbert=False,
                device='cpu'):
        super(SimCSE,self).__init__()
        if is_sbert_model:
            self.encoder=SentenceTransformer(model_name_or_path=bert_model_path,device=device)
        else:
            word_embedding_model = models.Transformer(bert_model_path, max_seq_length=max_seq_len)
            pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
            self.encoder=SentenceTransformer(modules=[word_embedding_model, pooling_model],device=device)
        self.temperature=temperature
        self.is_distilbert=is_distilbert#蒸馏版本的BERT不支持token_type_ids
    def cal_cos_sim(self,embeddings1,embeddings2):
        embeddings1_norm=torch.nn.functional.normalize(embeddings1,p=2,dim=1)
        embeddings2_norm=torch.nn.functional.normalize(embeddings2,p=2,dim=1)
        return torch.mm(embeddings1_norm,embeddings2_norm.transpose(0,1))#(batch_size,batch_size)
        
    def forward(self,batch_inputs):
        '''
        为了实现兼容，所有model的batch_inputs最后一个位置必须是labels，即使为None
        get token_embeddings,cls_token_embeddings,sentence_embeddings
        sentence_embeddings是经过Pooling层后concat的embedding。维度=768*k，其中k取决于pooling的策略
        一般来讲，只会取一种pooling策略，要么直接cls要么mean last or mean last2 or mean first and last layer，所以sentence_embeddings的维度也是768
        '''
        batch1_features,batch2_features,_=batch_inputs
        if self.is_distilbert:
            del batch1_features['token_type_ids']
            del batch2_features['token_type_ids']
        batch1_embeddings=self.encoder(batch1_features)['sentence_embedding']
        batch2_embeddings=self.encoder(batch2_features)['sentence_embedding']
        cos_sim=self.cal_cos_sim(batch1_embeddings,batch2_embeddings)/self.temperature#(batch_size,batch_size)
        batch_size=cos_sim.size(0)
        assert cos_sim.size()==(batch_size,batch_size)
        labels=torch.arange(batch_size).to(cos_sim.device)
        return nn.CrossEntropyLoss()(cos_sim,labels)
    
    def encode(self, sentences,
               batch_size: int = 32,
               show_progress_bar: bool = None,
               output_value: str = 'sentence_embedding',
               convert_to_numpy: bool = True,
               convert_to_tensor: bool = False,
               device: str = None,
               normalize_embeddings: bool = False):
        '''
        传进来的sentences只能是single_batch
        '''
        return self.encoder.encode(sentences=sentences,
                                         batch_size=batch_size,
                                         show_progress_bar=show_progress_bar,
                                         output_value=output_value,
                                         convert_to_numpy=convert_to_numpy,
                                         convert_to_tensor=convert_to_tensor,
                                         device=device,
                                         normalize_embeddings=normalize_embeddings)
    
    def save(self,output_path):
        os.makedirs(output_path,exist_ok=True)
        with open(os.path.join(output_path, 'model_param_config.json'), 'w') as fOut:
            json.dump(self.get_config_dict(output_path), fOut)
        self.encoder.save(output_path)
        
    def get_config_dict(self,output_path):
        '''
        一定要有dict，这样才能初始化Model
        '''
        return {'output_path':output_path,'temperature': self.temperature, 'is_distilbert': self.is_distilbert}
    @staticmethod
    def load(input_path):
        with open(os.path.join(input_path, 'model_param_config.json')) as fIn:
            config = json.load(fIn)
        return SimCSE(**config)

In [128]:
device='cuda'
#simcse=SimCSE(bert_model_path=model_path3,is_distilbert=False,device=device,is_sbert_model=False)
simcse=SimCSE(bert_model_path="/data/nfs14/nfs/aisearch/asr/xhsun/bwbd_recall/unsupervisedSTSModel/unSimCSE_STS-B/",is_distilbert=False,device=device,is_sbert_model=True)

2021-10-13 19:31:04 - INFO - __init__ - 41 : Load pretrained SentenceTransformer: /data/nfs14/nfs/aisearch/asr/xhsun/bwbd_recall/unsupervisedSTSModel/unSimCSE_STS-B/
2021-10-13 19:31:04 - INFO - __init__ - 107 : Load SentenceTransformer from folder: /data/nfs14/nfs/aisearch/asr/xhsun/bwbd_recall/unsupervisedSTSModel/unSimCSE_STS-B/


# 构造evaluator

In [129]:
#dev_sentences=[example.text_list for example in dev_examples]
#dev_labels=[example.label for example in dev_examples]
print(dev_sentences[0],dev_labels[0])
sentences1_list=[sen[0] for sen in dev_sentences]
sentences2_list=[sen[1] for sen in dev_sentences]
dev_labels=[int(score) for score in dev_labels]
evaluator=stsEvaluator(sentences1=sentences1_list,sentences2=sentences2_list,batch_size=64,write_csv=True,scores=dev_labels)

['一个戴着安全帽的男人在跳舞。', '一个戴着安全帽的男人在跳舞。'] 5


In [130]:
evaluator(simcse)

2021-10-13 19:31:51 - INFO - __call__ - 72 : EmbeddingSimilarityEvaluator: Evaluating the model on  dataset:


HBox(children=(HTML(value='Batches'), FloatProgress(value=0.0, max=23.0), HTML(value='')))




HBox(children=(HTML(value='Batches'), FloatProgress(value=0.0, max=23.0), HTML(value='')))

2021-10-13 19:32:13 - INFO - __call__ - 103 : Cosine-Similarity :	Pearson: 0.7585	Spearman: 0.7632
2021-10-13 19:32:13 - INFO - __call__ - 105 : Manhattan-Distance:	Pearson: 0.7228	Spearman: 0.7396
2021-10-13 19:32:13 - INFO - __call__ - 107 : Euclidean-Distance:	Pearson: 0.7195	Spearman: 0.7360
2021-10-13 19:32:13 - INFO - __call__ - 109 : Dot-Product-Similarity:	Pearson: 0.7377	Spearman: 0.7477





0.763235954359261

In [158]:
simcse.detach

AttributeError: 'SimCSE' object has no attribute 'detach'

In [157]:
for name_param,param in zip(simcse.encoder.named_parameters(),simcse.encoder.parameters()):
    print(param.data)

tensor([[ 0.0390,  0.0059, -0.0653,  ...,  0.0715,  0.0366,  0.0214],
        [-0.0357,  0.0530,  0.0773,  ...,  0.0755,  0.0674,  0.0469],
        [ 0.0559, -0.0117,  0.0175,  ...,  0.0826,  0.0053, -0.0243],
        ...,
        [ 0.0414,  0.0152, -0.0077,  ...,  0.0291,  0.0694,  0.0400],
        [ 0.0520,  0.0662,  0.0166,  ...,  0.0598,  0.0222, -0.0225],
        [ 0.0131, -0.0237,  0.0021,  ...,  0.1128, -0.0453,  0.0102]],
       device='cuda:0')
tensor([[ 0.0102,  0.0093, -0.0106,  ..., -0.0597, -0.0185, -0.0194],
        [-0.0238,  0.0117, -0.0109,  ..., -0.0350, -0.0158,  0.0353],
        [-0.0197,  0.0028, -0.0247,  ..., -0.0415,  0.0062,  0.0298],
        ...,
        [-0.0479, -0.0200, -0.0196,  ...,  0.0353, -0.0199, -0.0327],
        [-0.0861, -0.0428, -0.0732,  ...,  0.0209,  0.0284, -0.0523],
        [ 0.0028,  0.0445, -0.0252,  ...,  0.0124, -0.0058, -0.0130]],
       device='cuda:0')
tensor([[ 0.0030, -0.0055, -0.0073,  ...,  0.0335,  0.0021,  0.0002],
        [ 0.00

tensor([[ 5.4987e-03,  1.5110e-02,  8.9691e-02,  ..., -4.5674e-02,
          2.0573e-02, -1.2812e-02],
        [ 2.8219e-02, -1.1709e-01, -4.4740e-02,  ..., -6.0082e-02,
          3.0020e-03, -4.2785e-02],
        [-6.7296e-02,  6.8206e-02,  6.3392e-02,  ...,  3.1213e-02,
          2.3616e-02, -9.5826e-03],
        ...,
        [ 5.0601e-02, -2.2278e-02,  8.7488e-05,  ...,  8.3964e-03,
         -3.9979e-02, -4.3754e-02],
        [-7.2688e-02,  2.7269e-02, -2.3634e-02,  ..., -3.6967e-02,
         -2.3013e-02,  5.9980e-02],
        [-9.3311e-03,  1.4862e-01, -5.3290e-02,  ..., -9.4319e-02,
         -1.2611e-02, -2.0563e-02]], device='cuda:0')
tensor([-4.4199e-03,  3.0463e-02, -2.2430e-02, -9.0898e-03, -1.3002e-02,
        -2.5951e-02, -1.0229e-02, -4.5033e-03, -6.0012e-03, -3.7276e-03,
        -7.5937e-03, -2.1970e-03,  1.4645e-02, -3.4026e-03, -1.1720e-02,
        -1.5843e-02, -1.8503e-02,  2.4494e-03, -2.4327e-03, -1.6594e-02,
        -8.9721e-04,  1.3252e-02,  4.8376e-03, -2.2330e-02,

tensor([0.8564, 0.7974, 0.8665, 0.8045, 0.7120, 0.8778, 0.8009, 0.7995, 0.7230,
        0.8274, 0.8189, 0.7890, 0.6945, 0.7945, 0.9624, 0.8047, 0.8425, 0.8788,
        0.8229, 0.7855, 0.8646, 0.7780, 0.8902, 0.7703, 0.8129, 0.8162, 0.8036,
        0.8096, 0.8942, 0.7725, 0.7255, 0.8727, 0.8664, 0.8128, 0.8182, 0.8422,
        0.8405, 0.7290, 0.7814, 0.8602, 0.8085, 0.8052, 0.8734, 0.8757, 0.7680,
        0.8460, 0.8283, 0.8720, 0.7847, 0.8132, 0.7997, 1.4816, 0.7709, 0.7662,
        0.7735, 0.8170, 0.7936, 0.8213, 0.8215, 0.8757, 0.8615, 0.8182, 0.7554,
        0.7451, 0.8642, 0.7761, 0.8641, 0.7044, 0.7885, 0.7263, 0.8122, 0.7622,
        0.8538, 0.7979, 0.8641, 0.7215, 0.8396, 0.7932, 0.6755, 0.8179, 0.8968,
        0.7645, 0.7563, 0.7642, 0.7890, 0.8915, 0.8308, 0.7094, 0.8094, 0.7702,
        0.8743, 0.7453, 0.7811, 0.8394, 0.7858, 1.1330, 0.7181, 0.8347, 0.8128,
        0.8462, 0.7966, 0.8391, 0.7797, 0.8611, 0.8100, 0.9940, 0.8519, 0.8026,
        0.8166, 0.8061, 0.8725, 0.8091, 

tensor([-3.9585e-02, -1.4424e-02,  1.0209e-01,  7.6414e-03, -8.4446e-02,
         1.2974e-01, -1.2629e-02,  5.1636e-02, -2.5601e-02,  6.3452e-02,
        -8.5011e-02,  8.3095e-02,  6.8491e-03,  1.2490e-03, -1.3965e-01,
        -4.8199e-04, -7.2560e-02,  8.3661e-02, -2.4791e-02,  2.4463e-02,
        -4.7576e-02, -4.3325e-02, -2.5047e-02,  5.0129e-02,  6.0269e-02,
         1.8495e-02, -3.3214e-02,  4.5054e-02,  2.4951e-02, -6.7352e-02,
        -4.4306e-02,  5.8802e-02,  2.1499e-04,  2.8624e-02, -4.0585e-02,
        -4.0068e-02,  7.2214e-03,  4.0997e-03, -1.8563e-02, -7.7095e-02,
         3.5843e-02,  1.2634e-02, -1.2850e-02, -2.6975e-02, -2.3825e-02,
        -1.0104e-02, -5.7674e-02, -4.5522e-02, -1.9164e-02,  5.2517e-02,
        -1.1570e-02, -3.4292e-01,  1.8645e-02, -8.7342e-03,  4.5012e-02,
         3.4193e-02,  1.8587e-02, -6.7608e-03,  4.3151e-03,  2.1004e-02,
        -5.0305e-02, -1.4727e-02,  5.9776e-04, -2.9407e-02, -1.2918e-02,
         3.6153e-02,  6.6222e-02, -6.8084e-02,  1.1

tensor([-3.8823e-03,  2.4022e-02, -3.2386e-02,  2.8250e-02, -4.4511e-02,
        -5.4964e-03,  4.7234e-02,  1.8052e-02, -5.2015e-02, -2.8398e-02,
         2.8424e-02, -1.3611e-02, -5.8594e-03, -4.6796e-03, -5.2591e-03,
        -1.1365e-02, -3.3684e-02, -7.5838e-04, -2.8187e-02,  5.1560e-03,
        -5.2985e-02, -5.2139e-02, -8.2873e-03,  6.5048e-02,  5.0889e-02,
         1.5241e-02, -6.3653e-03, -2.7918e-02, -3.1486e-02, -1.3714e-02,
        -3.7955e-02, -8.0182e-03, -1.4669e-02, -6.2592e-03, -3.0974e-02,
         6.1839e-03, -1.0609e-02, -1.4581e-03,  1.8658e-02,  8.8385e-03,
         1.1637e-02,  1.6877e-02, -1.4312e-02, -4.2706e-02,  2.2818e-02,
        -8.2243e-03,  4.1014e-02, -1.3072e-02,  8.7104e-03,  2.1115e-03,
        -2.7727e-03, -3.4181e-02,  2.3444e-02, -2.7739e-02,  1.1093e-02,
        -1.8705e-02, -1.3988e-02, -2.7040e-02, -3.3120e-02,  3.2606e-02,
        -2.6031e-02, -2.3060e-02,  4.8034e-02, -1.2513e-02, -3.4641e-02,
        -1.2593e-02,  3.2281e-02, -8.1673e-03,  5.5

tensor([ 1.9930e-02,  1.5288e-01, -9.9437e-02, -1.7908e-02,  1.2306e-01,
        -4.2920e-02,  1.0242e-01, -2.3338e-01, -4.1889e-02,  1.3411e-02,
         1.8954e-02, -1.8532e-01,  2.1812e-02,  2.1073e-02,  1.9268e-01,
         1.9094e-01,  2.1281e-01,  8.8760e-02,  2.5472e-02, -7.7304e-02,
         1.0707e-01,  3.6189e-02, -2.5225e-02,  3.9091e-02,  4.0018e-02,
         9.3562e-03, -1.9382e-02, -2.0919e-01,  1.1338e-01,  1.2383e-01,
         4.1474e-02, -2.0832e-01, -4.5215e-02,  5.7618e-02,  1.5638e-01,
         7.2779e-02, -2.1190e-02,  2.2216e-01,  2.5613e-01,  2.8338e-02,
        -1.1836e-01, -1.0566e-01,  1.0464e-01,  8.3790e-02,  2.1033e-02,
         8.4513e-02,  7.7068e-02,  7.5974e-02,  6.2841e-02,  3.7606e-02,
         1.4146e-01,  6.0365e-02, -6.0978e-02,  1.1201e-01, -1.7625e-01,
        -1.8268e-01, -1.4894e-01, -1.6709e-02,  1.3489e-01,  3.3594e-02,
        -5.8680e-03,  1.6369e-01, -4.3908e-02, -4.7404e-02, -5.2539e-02,
         1.0566e-01,  3.4907e-02,  1.8825e-01, -7.5

tensor([-2.0536e-02,  5.5021e-03,  1.1479e-03, -6.1774e-03, -2.8242e-04,
        -8.9798e-03,  2.1463e-02,  1.0508e-02, -2.5416e-02,  3.0760e-03,
        -3.2557e-02, -2.3213e-02, -1.2302e-03, -7.3184e-03, -1.4563e-02,
         1.0263e-02,  1.7143e-02,  8.4165e-03, -3.5358e-03, -5.8257e-03,
         5.1027e-03, -1.0323e-02, -3.7689e-03,  2.6213e-03, -2.1466e-02,
        -1.0029e-02, -5.5074e-03,  1.4629e-02, -2.1002e-02,  4.9671e-03,
         1.5383e-02,  4.9832e-03,  8.6232e-03, -2.5438e-02,  4.5524e-03,
        -3.3862e-03,  1.7720e-02,  5.2009e-03,  1.2372e-02,  2.7929e-02,
         5.3987e-03, -9.3645e-03, -1.1619e-02, -1.8206e-02, -7.5718e-03,
        -8.1402e-03, -5.3961e-03,  1.2339e-02, -3.2700e-03,  1.4963e-02,
        -3.2805e-03, -1.1640e-02,  2.7935e-03,  3.1698e-02, -2.9856e-02,
        -8.6280e-04, -1.2381e-02,  1.3611e-02, -2.4216e-02,  1.2740e-02,
        -2.9182e-03,  1.1450e-02,  1.7506e-02,  8.5217e-03,  1.0420e-02,
        -8.4698e-04,  7.6095e-04, -6.9710e-03, -6.7

tensor([ 1.5732e-01,  1.2856e-02, -2.2855e-02, -4.6539e-02,  4.4327e-02,
        -2.7403e-01,  6.8211e-02, -2.3147e-03,  1.7453e-02, -6.8272e-02,
         3.9968e-02, -6.5637e-02, -9.1202e-02, -3.5994e-02,  1.7921e-01,
        -4.3286e-02,  6.3753e-02,  3.4555e-02, -2.2562e-02, -8.3848e-03,
         4.0578e-02,  4.0538e-02, -2.4631e-02,  6.7938e-02,  1.1149e-01,
         4.0740e-02,  4.6239e-02, -5.7614e-02,  3.8061e-02,  3.9521e-02,
        -3.2885e-03, -7.8860e-02,  5.5360e-04,  2.0382e-02,  3.7278e-02,
         5.1029e-02, -4.2792e-02,  5.8405e-02,  1.3991e-01,  3.5760e-02,
         3.9377e-03, -5.9341e-03,  2.1677e-02,  5.4043e-02, -4.0410e-03,
         4.2735e-02,  2.4370e-02,  4.4255e-02,  7.6733e-02, -1.4781e-02,
        -2.0123e-03,  5.3064e-01,  3.6770e-02,  3.9032e-02, -9.0947e-02,
        -3.8884e-02,  1.9163e-02,  1.2846e-02, -1.4013e-02, -8.4519e-02,
         4.5565e-02,  1.3425e-01,  6.0953e-02,  7.7692e-02, -1.9291e-02,
         1.2929e-02,  2.2012e-02,  4.0447e-02,  1.8

tensor([-7.5889e-03,  8.4432e-02, -5.7034e-01, -1.2115e-01, -1.3356e-02,
        -1.4801e-01, -2.4994e-01, -3.3375e-01,  2.0690e-01, -2.2364e-01,
         7.8639e-02,  1.4942e-01, -4.1157e-01,  4.2639e-01,  1.8323e-01,
         6.1202e-01,  2.4908e-01, -2.7124e-01,  2.3585e-01,  2.1410e-01,
         3.1930e-01, -1.6403e-01,  2.2492e-01, -1.2772e-01,  7.8844e-02,
        -6.2660e-02, -9.9679e-02, -1.3037e-02, -4.0885e-03, -3.0458e-01,
         8.7019e-02,  9.1397e-02,  5.3168e-02,  4.7059e-02, -8.0672e-02,
        -4.0254e-01, -1.3625e-01, -2.8795e-01,  4.1054e-01, -2.0914e-01,
         1.6351e-01, -2.1559e-03,  1.8994e-01, -2.2383e-01,  4.1889e-02,
         1.2047e-01, -4.6037e-03, -2.2374e-01, -2.0366e-01,  1.1100e-01,
        -5.6730e-02, -1.9682e-01,  4.1543e-01, -1.6459e-01, -3.1148e-01,
         1.5995e-01, -2.2654e-02,  7.8299e-02, -8.6917e-02,  2.6697e-01,
        -1.1776e-01, -5.2536e-02,  9.6548e-02, -1.1348e-01, -4.8046e-01,
         6.4613e-01, -1.9813e-01, -4.0986e-02, -3.9

# ESimCSE

In [134]:
from queue import Queue
class ESimCSE(nn.Module):
    def __init__(self,
                 bert_model_path,
                 q_size=256,
                 dup_rate=0.32,
                 is_sbert_model=True,
                temperature=0.05,
                is_distilbert=False,
                device='cpu'):
        super(ESimCSE,self).__init__()
        if is_sbert_model:
            self.encoder=SentenceTransformer(model_name_or_path=bert_model_path,device=device)
            self.moco_encoder=SentenceTransformer(model_name_or_path=bert_model_path,device=device)
        else:
            word_embedding_model = models.Transformer(bert_model_path, max_seq_length=max_seq_len)
            pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
            self.encoder=SentenceTransformer(modules=[word_embedding_model, pooling_model],device=device)
            self.moco_encoder=SentenceTransformer(modules=[word_embedding_model, pooling_model],device=device)
        self.q=Queue(maxsize=q_size)
        self.q_size=q_size
        self.dup_rate=dup_rate
        self.temperature=temperature
        self.is_distilbert=is_distilbert#蒸馏版本的BERT不支持token_type_ids
    def cal_cos_sim(self,embeddings1,embeddings2):
        embeddings1_norm=torch.nn.functional.normalize(embeddings1,p=2,dim=1)
        embeddings2_norm=torch.nn.functional.normalize(embeddings2,p=2,dim=1)
        return torch.mm(embeddings1_norm,embeddings2_norm.transpose(0,1))#(batch_size,batch_size)

    def word_repetition(self,sentence_feature):
        input_ids, attention_mask, token_type_ids=sentence_feature['input_ids'].cpu().tolist(),sentence_feature['attention_mask'].cpu().tolist(),sentence_feature['token_type_ids'].cpu().tolist()
        bsz, seq_len = len(input_ids),len(input_ids[0])
        #print(bsz,seq_len)
        repetitied_input_ids=[]
        repetitied_attention_mask=[]
        repetitied_token_type_ids=[]
        rep_seq_len=seq_len
        for bsz_id in range(bsz):
            sample_mask = attention_mask[bsz_id]
            actual_len = sum(sample_mask)

            cur_input_id=input_ids[bsz_id]
            dup_len=random.randint(a=0,b=max(2,int(self.dup_rate*actual_len)))
            dup_word_index=random.sample(list(range(1,actual_len)),k=dup_len)
            
            r_input_id=[]
            r_attention_mask=[]
            r_token_type_ids=[]
            for index,word_id in enumerate(cur_input_id):
                if index in dup_word_index:
                    r_input_id.append(word_id)
                    r_attention_mask.append(sample_mask[index])
                    r_token_type_ids.append(token_type_ids[bsz_id][index])

                r_input_id.append(word_id)
                r_attention_mask.append(sample_mask[index])
                r_token_type_ids.append(token_type_ids[bsz_id][index])

            after_dup_len=len(r_input_id)
            #assert after_dup_len==actual_len+dup_len
            repetitied_input_ids.append(r_input_id)#+rest_input_ids)
            repetitied_attention_mask.append(r_attention_mask)#+rest_attention_mask)
            repetitied_token_type_ids.append(r_token_type_ids)#+rest_token_type_ids)

            assert after_dup_len==dup_len+seq_len
            if after_dup_len>rep_seq_len:
                rep_seq_len=after_dup_len

        for i in range(bsz):
            after_dup_len=len(repetitied_input_ids[i])
            pad_len=rep_seq_len-after_dup_len
            repetitied_input_ids[i]+=[0]*pad_len
            repetitied_attention_mask[i]+=[0]*pad_len
            repetitied_token_type_ids[i]+=[0]*pad_len

        repetitied_input_ids=torch.LongTensor(repetitied_input_ids)
        repetitied_attention_mask=torch.LongTensor(repetitied_attention_mask)
        repetitied_token_type_ids=torch.LongTensor(repetitied_token_type_ids)
        return {"input_ids":repetitied_input_ids,'attention_mask':repetitied_attention_mask,'token_type_ids':repetitied_token_type_ids}

    def forward(self,batch_inputs):
        '''
        为了实现兼容，所有model的batch_inputs最后一个位置必须是labels，即使为None
        get token_embeddings,cls_token_embeddings,sentence_embeddings
        sentence_embeddings是经过Pooling层后concat的embedding。维度=768*k，其中k取决于pooling的策略
        一般来讲，只会取一种pooling策略，要么直接cls要么mean last or mean last2 or mean first and last layer，所以sentence_embeddings的维度也是768
        '''
        batch1_features,batch2_features,_=batch_inputs
        if self.is_distilbert:
            del batch1_features['token_type_ids']
            del batch2_features['token_type_ids']
        batch1_embeddings=self.encoder(batch1_features)['sentence_embedding']
        batch2_features=self.word_repetition(sentence_feature=batch2_features)
        batch2_embeddings=self.encoder(batch2_features)['sentence_embedding']
        cos_sim=self.cal_cos_sim(batch1_embeddings,batch2_embeddings)/self.temperature#(batch_size,batch_size)
        batch_size=cos_sim.size(0)
        assert cos_sim.size()==(batch_size,batch_size)
        labels=torch.arange(batch_size).to(cos_sim.device)
        return nn.CrossEntropyLoss()(cos_sim,labels)
    
    def encode(self, sentences,
               batch_size: int = 32,
               show_progress_bar: bool = None,
               output_value: str = 'sentence_embedding',
               convert_to_numpy: bool = True,
               convert_to_tensor: bool = False,
               device: str = None,
               normalize_embeddings: bool = False):
        '''
        传进来的sentences只能是single_batch
        '''
        return self.encoder.encode(sentences=sentences,
                                         batch_size=batch_size,
                                         show_progress_bar=show_progress_bar,
                                         output_value=output_value,
                                         convert_to_numpy=convert_to_numpy,
                                         convert_to_tensor=convert_to_tensor,
                                         device=device,
                                         normalize_embeddings=normalize_embeddings)
    
    def save(self,output_path):
        os.makedirs(output_path,exist_ok=True)
        with open(os.path.join(output_path, 'model_param_config.json'), 'w') as fOut:
            json.dump(self.get_config_dict(output_path), fOut)
        self.encoder.save(output_path)
        
    def get_config_dict(self,output_path):
        '''
        一定要有dict，这样才能初始化Model
        '''
        return {'bert_model_path':output_path,'temperature': self.temperature, 'is_distilbert': self.is_distilbert,
                'q_size':self.q_size,'dup_rate':self.dup_rate}
    @staticmethod
    def load(input_path):
        with open(os.path.join(input_path, 'model_param_config.json')) as fIn:
            config = json.load(fIn)
        return ESimCSE(**config)

In [138]:
device='cpu'
esimcse_only_repetition=ESimCSE(bert_model_path=model_path3,is_distilbert=False,is_sbert_model=False,dup_rate=0.32,device=device)

Some weights of the model checkpoint at /data/nfs14/nfs/aisearch/asr/xhsun/CommonModel/chinese-roberta-wwm/ were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [139]:
evaluator(esimcse_only_repetition)

2021-10-13 19:37:35 - INFO - __call__ - 72 : EmbeddingSimilarityEvaluator: Evaluating the model on  dataset:


HBox(children=(HTML(value='Batches'), FloatProgress(value=0.0, max=23.0), HTML(value='')))




HBox(children=(HTML(value='Batches'), FloatProgress(value=0.0, max=23.0), HTML(value='')))

2021-10-13 19:37:48 - INFO - __call__ - 103 : Cosine-Similarity :	Pearson: 0.6478	Spearman: 0.6626
2021-10-13 19:37:48 - INFO - __call__ - 105 : Manhattan-Distance:	Pearson: 0.6608	Spearman: 0.6771
2021-10-13 19:37:48 - INFO - __call__ - 107 : Euclidean-Distance:	Pearson: 0.6556	Spearman: 0.6721
2021-10-13 19:37:48 - INFO - __call__ - 109 : Dot-Product-Similarity:	Pearson: 0.4608	Spearman: 0.4555





0.6625539749630711

In [143]:
tmp=SimCSE(bert_model_path=model_path3,is_distilbert=False,is_sbert_model=False,device=device)

Some weights of the model checkpoint at /data/nfs14/nfs/aisearch/asr/xhsun/CommonModel/chinese-roberta-wwm/ were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [144]:
evaluator(tmp)

2021-10-13 19:40:19 - INFO - __call__ - 72 : EmbeddingSimilarityEvaluator: Evaluating the model on  dataset:


HBox(children=(HTML(value='Batches'), FloatProgress(value=0.0, max=23.0), HTML(value='')))




HBox(children=(HTML(value='Batches'), FloatProgress(value=0.0, max=23.0), HTML(value='')))

2021-10-13 19:40:31 - INFO - __call__ - 103 : Cosine-Similarity :	Pearson: 0.6478	Spearman: 0.6626
2021-10-13 19:40:31 - INFO - __call__ - 105 : Manhattan-Distance:	Pearson: 0.6608	Spearman: 0.6771
2021-10-13 19:40:31 - INFO - __call__ - 107 : Euclidean-Distance:	Pearson: 0.6556	Spearman: 0.6721
2021-10-13 19:40:31 - INFO - __call__ - 109 : Dot-Product-Similarity:	Pearson: 0.4608	Spearman: 0.4555





0.6625539749630711

# train model

In [24]:
from queue import Queue

In [145]:
q=Queue(maxsize=4)

In [146]:
q.put(10)
q.put(20)

In [148]:
a=[]

In [149]:
a.extend(torch.randn(5,6))

In [150]:
a

[tensor([ 1.0177, -0.9768,  2.0459,  0.3267,  0.2686,  0.9309]),
 tensor([-2.1823,  2.0730, -1.4822,  1.1837, -1.6012, -0.4767]),
 tensor([ 0.8916, -0.8301,  0.5890, -0.6972,  0.1671,  0.3684]),
 tensor([-0.4226, -0.8013, -1.0987,  0.7977,  1.9247,  1.3772]),
 tensor([ 1.3554, -0.6660, -0.2827,  0.6009,  0.0929, -0.2421])]

In [151]:
torch.vstack(a)

tensor([[ 1.0177, -0.9768,  2.0459,  0.3267,  0.2686,  0.9309],
        [-2.1823,  2.0730, -1.4822,  1.1837, -1.6012, -0.4767],
        [ 0.8916, -0.8301,  0.5890, -0.6972,  0.1671,  0.3684],
        [-0.4226, -0.8013, -1.0987,  0.7977,  1.9247,  1.3772],
        [ 1.3554, -0.6660, -0.2827,  0.6009,  0.0929, -0.2421]])

In [152]:
del a[:3]

In [155]:
a=torch.randn(3,3)
b=torch.randn(3,5)
print(a)
print(b)

tensor([[ 0.1169, -0.1899, -0.3863],
        [-1.6305,  1.3812, -0.0434],
        [-0.3736, -2.8549, -0.9150]])
tensor([[ 0.9126,  0.0965,  0.6212, -1.2479, -0.1719],
        [-1.7562, -0.0692, -0.4054, -2.7799,  0.4642],
        [-1.9844, -0.4774, -1.4125,  0.1069, -1.1051]])


In [156]:
c=torch.cat([a,b],dim=1)
c

tensor([[ 0.1169, -0.1899, -0.3863,  0.9126,  0.0965,  0.6212, -1.2479, -0.1719],
        [-1.6305,  1.3812, -0.0434, -1.7562, -0.0692, -0.4054, -2.7799,  0.4642],
        [-0.3736, -2.8549, -0.9150, -1.9844, -0.4774, -1.4125,  0.1069, -1.1051]])