In [1]:
import torch
import torch.nn as nn

from abc import ABC
from tqdm.notebook import tqdm
from dataclasses import dataclass, field
from typing import List, Union, Optional, Dict
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer  #, TrainingArguments, Trainer
from transformers.trainer import Trainer,TrainingArguments
from transformers.models.bert.modeling_bert import BertPreTrainedModel, BertModel
from transformers.tokenization_utils_base import PaddingStrategy, PreTrainedTokenizerBase
from transformers.modeling_outputs import SequenceClassifierOutput, BaseModelOutputWithPoolingAndCrossAttentions

In [2]:
import sys
# Set path to SentEval
PATH_TO_SENTEVAL = './SentEval'
PATH_TO_DATA = './SentEval/data'

# Import SentEval
sys.path.insert(0, PATH_TO_SENTEVAL)
import senteval

In [3]:
# evaluate model in all STS tasks
def print_table(task_names, scores):
    tb = PrettyTable()
    tb.field_names = task_names
    tb.add_row(scores)
    print(tb)
    
def evalModel(model,tokenizer, pooler): 
    tasks = ['STS12', 'STS13', 'STS14', 'STS15', 'STS16', 'STSBenchmark', 'SICKRelatedness']
    
    params = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 5}
    params['classifier'] = {'nhid': 0, 'optim': 'rmsprop', 'batch_size': 128,
                                'tenacity': 3, 'epoch_size': 2}
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    def prepare(params, samples):
        return

    def batcher(params, batch, max_length=None):
            # Handle rare token encoding issues in the dataset
            if len(batch) >= 1 and len(batch[0]) >= 1 and isinstance(batch[0][0], bytes):
                batch = [[word.decode('utf-8') for word in s] for s in batch]

            sentences = [' '.join(s) for s in batch]
            
            batch = tokenizer.batch_encode_plus(
                sentences,
                return_tensors='pt',
                padding=True,
                max_length=max_length,
                truncation=True
            )
            # Move to the correct device
            for k in batch:
                batch[k] = batch[k].to(device)
            
            # Get raw embeddings
            with torch.no_grad():
                pooler_output = model(**batch, output_hidden_states=True, return_dict=True)
                if pooler == "cls_before_pooler":
                    pooler_output = pooler_output.last_hidden_state[:, 0]
                elif pooler == "cls_after_pooler":
                    pooler_output = pooler_output.pooler_output

            return pooler_output.cpu()
    results = {}

    for task in tasks:
        se = senteval.engine.SE(params, batcher, prepare)
        result = se.eval(task)
        results[task] = result
    task_names = []
    scores = []
    for task in ['STS12', 'STS13', 'STS14', 'STS15', 'STS16', 'STSBenchmark', 'SICKRelatedness']:
        task_names.append(task)
        if task in results:
            if task in ['STS12', 'STS13', 'STS14', 'STS15', 'STS16']:
                scores.append("%.2f" % (results[task]['all']['spearman']['all'] * 100))
            else:
                scores.append("%.2f" % (results[task]['test']['spearman'].correlation * 100))
        else:
            scores.append("0.00")
    task_names.append("Avg.")
    scores.append("%.2f" % (sum([float(score) for score in scores]) / len(scores)))
    print_table(task_names, scores)

    return sum([float(score) for score in scores])/len(scores)

In [4]:
@dataclass
class DataArguments:
    train_file: str = field(default="./data/wiki1m_for_simcse.txt",
                            metadata={
    "help": "The path of train file"})
    model_name_or_path: str = field(default="bert-base-uncased",
                                    metadata={
    "help": "The name or path of pre-trained language model"})
    max_seq_length: int = field(default=32,
                                metadata={
    "help": "The maximum total input sequence length after tokenization."})

#
#training_args = TrainingArguments(
#        output_dir="./checkpoints",
#        num_train_epochs=1,
#        per_device_train_batch_size=64,
#        learning_rate=3e-5,
#        load_best_model_at_end=True,
#        overwrite_output_dir=True,
#        do_train=True,
#        do_eval=False,
#        logging_steps=10)
training_args = TrainingArguments(
        output_dir="trainer_models",
        num_train_epochs=1,
        per_device_train_batch_size=64,
        per_device_eval_batch_size  = 64,
        evaluation_strategy   = "steps",
        eval_steps            = 125,
        learning_rate=3e-5,
        load_best_model_at_end=True,
        overwrite_output_dir=True,
        do_train=True,
        do_eval=False, 
        logging_steps=10)

data_args = DataArguments()

In [5]:
# 初始化tokenizer
tokenizer = BertTokenizer.from_pretrained(data_args.model_name_or_path)
# 读取训练数据
with open(data_args.train_file, encoding="utf8") as file:
    texts = [line.strip() for line in tqdm(file.readlines())]
print(type(texts))
print(texts[0])

  0%|          | 0/1000000 [00:00<?, ?it/s]

<class 'list'>
YMCA in South Australia


In [6]:
class PairDataset(Dataset):
    def __init__(self, examples: List[str]):
        total = len(examples)
        # 将所有样本复制一份用于对比学习
        sentences_pair = examples + examples
        sent_features = tokenizer(sentences_pair,
                                  max_length=data_args.max_seq_length,
                                  truncation=True,
                                  padding=False)
        features = {
    }
        # 将相同的样本放在同一个列表中
        for key in sent_features:
            features[key] = [[sent_features[key][i], sent_features[key][i + total]] for i in tqdm(range(total))]
        self.input_ids = features["input_ids"]
        self.attention_mask = features["attention_mask"]
        self.token_type_ids = features["token_type_ids"]

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, item):
        return {
    
            "input_ids": self.input_ids[item],
            "attention_mask": self.attention_mask[item],
            "token_type_ids": self.token_type_ids[item]
        }


In [7]:
#train_dataset1 = PairDataset(texts)
train_dataset = PairDataset(texts[:10000])
print(train_dataset[0])

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

{'input_ids': [[101, 26866, 1999, 2148, 2660, 102], [101, 26866, 1999, 2148, 2660, 102]], 'attention_mask': [[1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1]], 'token_type_ids': [[0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0]]}


In [8]:
#import pickle

#with open("./data/train_dataset", "wb") as fp2:   #Pickling
#    pickle.dump(train_dataset1, fp2)

In [9]:
#with open("./data/train_dataset", "rb") as fp2:   # Unpickling
#    train_dataset = pickle.load(fp2)
#print(train_dataset[0])
#print(len(train_dataset))
#print(type(train_dataset))

In [10]:
print(train_dataset[2])

{'input_ids': [[101, 1996, 7328, 1997, 9569, 7490, 2964, 1010, 11295, 4676, 1998, 2969, 2128, 15204, 2102, 3754, 5171, 1997, 2049, 8759, 2018, 2445, 2019, 5020, 25691, 28126, 2000, 2148, 2827, 3241, 2013, 102], [101, 1996, 7328, 1997, 9569, 7490, 2964, 1010, 11295, 4676, 1998, 2969, 2128, 15204, 2102, 3754, 5171, 1997, 2049, 8759, 2018, 2445, 2019, 5020, 25691, 28126, 2000, 2148, 2827, 3241, 2013, 102]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]}


In [11]:
print(len(train_dataset))

10000


In [12]:
@dataclass
class DataCollator:
    tokenizer: PreTrainedTokenizerBase
    padding: Union[bool, str, PaddingStrategy] = True
    max_length: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None

    def __call__(self, features: List[Dict[str, List[int]]]) -> Dict[str, torch.Tensor]:
        special_keys = ['input_ids', 'attention_mask', 'token_type_ids']
        batch_size = len(features)
        if batch_size == 0:
            return
        # flat_features: [sen1, sen1, sen2, sen2, ...]
        flat_features = []
        for feature in features:
            for i in range(2):
                flat_features.append({
    k: feature[k][i] for k in feature.keys() if k in special_keys})
        # padding
        batch = self.tokenizer.pad(
            flat_features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )
        # batch_size, 2, seq_len
        batch = {
    k: batch[k].view(batch_size, 2, -1) for k in batch if k in special_keys}
        return batch

In [13]:
collate_fn = DataCollator(tokenizer)


#dataloader = DataLoader(train_dataset, batch_size=4, collate_fn=collate_fn)
dataloader = DataLoader(train_dataset, batch_size=32, collate_fn=collate_fn)

batch = next(iter(dataloader))
print(batch.keys())
print(batch["input_ids"].shape)

dict_keys(['input_ids', 'attention_mask', 'token_type_ids'])
torch.Size([32, 2, 32])


In [14]:
# 全连接层，用于投影CLS的向量表示
class MLPLayer(nn.Module):
    def __init__(self, input_size, output_size):
        super().__init__()
        self.dense = nn.Linear(input_size, output_size)
        self.activation = nn.Tanh()

    def forward(self, features, **kwargs):
        x = self.dense(features)
        x = self.activation(x)
        return x

# 相似度层，计算向量间相似度
class Similarity(nn.Module):
    def __init__(self, temp):
        super().__init__()
        self.temp = temp
        self.cos = nn.CosineSimilarity(dim=-1)

    def forward(self, x, y):
        return self.cos(x, y) / self.temp

    
# SimCSE的完整模型结构
class BertForCL(BertPreTrainedModel, ABC):
    def __init__(self, config):
        super().__init__(config)
        self.bert = BertModel(config)
        self.mlp = MLPLayer(config.hidden_size, config.hidden_size)
        self.sim = Similarity(temp=0.05)

    def forward(self,
                input_ids=None,
                attention_mask=None,
                token_type_ids=None,
                position_ids=None,
                head_mask=None,
                inputs_embeds=None,
                labels=None,
                output_attentions=None,
                output_hidden_states=None,
                return_dict=None,
                sent_emb=False):
        if sent_emb:
            # 模型推断时使用的forward
            return self.sentemb_forward(input_ids=input_ids,
                                        attention_mask=attention_mask,
                                        token_type_ids=token_type_ids,
                                        position_ids=position_ids,
                                        head_mask=head_mask,
                                        inputs_embeds=inputs_embeds,
                                        labels=labels,
                                        output_attentions=output_attentions,
                                        output_hidden_states=output_hidden_states,
                                        return_dict=return_dict)
        else:
            # 模型训练时使用的forward
            return self.cl_forward(input_ids=input_ids,
                                   attention_mask=attention_mask,
                                   token_type_ids=token_type_ids,
                                   position_ids=position_ids,
                                   head_mask=head_mask,
                                   inputs_embeds=inputs_embeds,
                                   labels=labels,
                                   output_attentions=output_attentions,
                                   output_hidden_states=output_hidden_states,
                                   return_dict=return_dict)

    def sentemb_forward(self,
                        input_ids=None,
                        attention_mask=None,
                        token_type_ids=None,
                        position_ids=None,
                        head_mask=None,
                        inputs_embeds=None,
                        labels=None,
                        output_attentions=None,
                        output_hidden_states=None,
                        return_dict=None):
        # 1.使用bert进行编码
        outputs = self.bert(input_ids, attention_mask=attention_mask, return_dict=True)
        # 2.取cls的表示
        cls_output = outputs.last_hidden_state[:, 0]
        # 3.使用MLP进行投影
        cls_output = self.mlp(cls_output)
        # 返回
        if not return_dict:
            return (outputs[0], cls_output) + outputs[2:]

        return BaseModelOutputWithPoolingAndCrossAttentions(
            pooler_output=cls_output,
            last_hidden_state=outputs.last_hidden_state,
            hidden_states=outputs.hidden_states,
        )

    def cl_forward(self,
                   input_ids=None,
                   attention_mask=None,
                   token_type_ids=None,
                   position_ids=None,
                   head_mask=None,
                   inputs_embeds=None,
                   labels=None,
                   output_attentions=None,
                   output_hidden_states=None,
                   return_dict=None):
        # input_ids: batch_size, num_sent, len
        batch_size = input_ids.size(0)
        num_sent = input_ids.size(1)  # 2
        # 1. 重塑输入张量的形状，使其满足bert对输入的要求
        # input_ids: batch_size * num_sent, len
        input_ids = input_ids.view((-1, input_ids.size(-1)))
        attention_mask = attention_mask.view((-1, attention_mask.size(-1)))
        # 2. 使用bert进行编码
        outputs = self.bert(input_ids, attention_mask=attention_mask, return_dict=True)
        # 3. 取cls的向量表示
        cls_output = outputs.last_hidden_state[:, 0]
        # 4. 重塑形状
        cls_output = cls_output.view((batch_size, num_sent, cls_output.size(-1)))
        # 5. 全连接层投影
        # batch_size, num_sent, 768
        cls_output = self.mlp(cls_output)
        # 6. 将同一批样本的两次向量表示分开
        z1, z2 = cls_output[:, 0], cls_output[:, 1]
        # 7. 计算两两相似度，得到相似度矩阵cos_sim
        cos_sim = self.sim(z1.unsqueeze(1), z2.unsqueeze(0))
        # 8. 生成标签[0,1,...,batch_size-1]，该标签用于提高相似度句子cos_sim对角线，并降低非对角线
        labels = torch.arange(cos_sim.size(0)).long().to(self.device)
        loss_fct = nn.CrossEntropyLoss()
        loss = loss_fct(cos_sim, labels)

        if not return_dict:
            output = (cos_sim,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return SequenceClassifierOutput(
            loss=loss,
            logits=cos_sim,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

model = BertForCL.from_pretrained(data_args.model_name_or_path)
cl_out = model(**batch, return_dict=True)
print(cl_out.keys())

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForCL: ['cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForCL from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForCL from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForCL were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['mlp.dense.bias', 'mlp.dense.weig

odict_keys(['loss', 'logits'])


In [15]:
# override the evaluate method
class SimCSETrainer(Trainer):
    def __init__(self,**paraments):
        super().__init__(**paraments)
        
        self.best_sts = 0.0
        
    def evaluate(
        self,
        eval_dataset: Optional[Dataset] = None,
        ignore_keys: Optional[List[str]] = None,
        metric_key_prefix: str = "eval",
        eval_senteval_transfer: bool = False,
    ) -> Dict[str, float]:

        # SentEval prepare and batcher
        def prepare(params, samples):
            return

        def batcher(params, batch):
            sentences = [' '.join(s) for s in batch]
            batch = self.tokenizer.batch_encode_plus(
                sentences,
                return_tensors='pt',
                padding=True,
            )
            for k in batch:
                batch[k] = batch[k].to(self.args.device)
            with torch.no_grad():
                outputs = self.model(**batch, output_hidden_states=True, return_dict=True, sent_emb=True)
                pooler_output = outputs.last_hidden_state[:, 0]
            return pooler_output.cpu()

        # Set params for SentEval (fastmode)
        params = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 5}
        params['classifier'] = {'nhid': 0, 'optim': 'rmsprop', 'batch_size': 128,
                                            'tenacity': 3, 'epoch_size': 2}

        se = senteval.engine.SE(params, batcher, prepare)
        tasks = ['STSBenchmark', 'SICKRelatedness']
        self.model.eval()
        results = se.eval(tasks)
        
        stsb_spearman = results['STSBenchmark']['dev']['spearman'][0]
        sickr_spearman = results['SICKRelatedness']['dev']['spearman'][0]

        metrics = {"eval_stsb_spearman": stsb_spearman, "eval_sickr_spearman": sickr_spearman, "eval_avg_sts": (stsb_spearman + sickr_spearman) / 2} 
        print(metrics)
        
        # save and eval model
        if metrics["eval_avg_sts"]>self.best_sts:
            self.best_sts = metrics["eval_avg_sts"]
            evalModel(self.model.bert,tokenizer, pooler = 'cls_before_pooler')
            self.save_model(self.args.output_dir+"/best-model")
            
        self.log(metrics)
        return metrics

In [16]:
model.resize_token_embeddings(len(tokenizer))
trainer = SimCSETrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    tokenizer=tokenizer,
    data_collator=collate_fn
)

In [17]:
#from filelock import FileLock
from prettytable import PrettyTable
#from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union

# test the evaluate method and see what are the initial results
#trainer.evaluate()

In [18]:
#model.resize_token_embeddings(len(tokenizer))
#trainer = Trainer(model=model,
#                  train_dataset=train_dataset,
#                  args=training_args,
#                  tokenizer=tokenizer,
#                  data_collator=collate_fn)
trainer.train()
trainer.save_model("trainer_models/final")

***** Running training *****
  Num examples = 10000
  Num Epochs = 1
  Instantaneous batch size per device = 64
  Total train batch size (w. parallel, distributed & accumulation) = 64
  Gradient Accumulation steps = 1
  Total optimization steps = 157


Step,Training Loss,Validation Loss


KeyboardInterrupt: 

In [None]:
#model = SimCSEModel(config).from_pretrained("trainer_models/best-model").cuda()
#model = SimCSEModel(config).from_pretrained("trainer_models/best-model")
#avg = evalModel(model.bert,tokenizer, pooler = 'cls_before_pooler')