In [1]:
import torch
import torch.nn as nn

from abc import ABC
from tqdm.notebook import tqdm
from dataclasses import dataclass, field
from typing import List, Union, Optional, Dict
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer,BertConfig  #, TrainingArguments, Trainer
from transformers.trainer import Trainer,TrainingArguments
from transformers.models.bert.modeling_bert import BertPreTrainedModel, BertModel
from transformers.tokenization_utils_base import PaddingStrategy, PreTrainedTokenizerBase
from transformers.modeling_outputs import SequenceClassifierOutput, BaseModelOutputWithPoolingAndCrossAttentions
import pickle
from prettytable import PrettyTable

In [2]:
%%bash
cd ./data 
./download_wiki.sh
./download_nli.sh
cd ..

--2022-04-10 21:28:46--  https://huggingface.co/datasets/princeton-nlp/datasets-for-simcse/resolve/main/wiki1m_for_simcse.txt
Resolving huggingface.co (huggingface.co)... 34.224.55.150, 34.197.58.156, 34.198.1.82, ...
Connecting to huggingface.co (huggingface.co)|34.224.55.150|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs.huggingface.co/datasets/princeton-nlp/datasets-for-simcse/7b1825863a99aa76479b0456f7c210539dfaeeb69598b41fb4de4f524dd5a706 [following]
--2022-04-10 21:28:47--  https://cdn-lfs.huggingface.co/datasets/princeton-nlp/datasets-for-simcse/7b1825863a99aa76479b0456f7c210539dfaeeb69598b41fb4de4f524dd5a706
Resolving cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)... 13.224.163.62, 13.224.163.5, 13.224.163.56, ...
Connecting to cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)|13.224.163.62|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 120038621 (114M) [text/plain]
Saving to: 'wiki1m_for_simcse

In [3]:
%%bash
cd ./SentEval/data/downstream
./download_dataset.sh
tar xvf *.tar

CR/
CR/custrev.neg
CR/custrev.pos
MPQA/
MPQA/mpqa.neg
MPQA/mpqa.pos
MR/
MR/rt-polarity.neg
MR/rt-polarity.pos
MRPC/
MRPC/msr_paraphrase_train.txt
MRPC/msr_paraphrase_test.txt
SICK/
SICK/SICK_trial.txt
SICK/SICK_train.txt
SICK/SICK_test_annotated.txt
SNLI/
SNLI/s2.test
SNLI/s1.train
SNLI/s2.train
SNLI/labels.dev
SNLI/s1.test
SNLI/labels.test
SNLI/s2.dev
SNLI/s1.dev
SNLI/labels.train
SST/
SST/fine/
SST/fine/sentiment-test
SST/fine/sentiment-train
SST/fine/sentiment-dev
SST/binary/
SST/binary/sentiment-test
SST/binary/sentiment-train
SST/binary/sentiment-dev
STS/
STS/STS12-en-test/
STS/STS12-en-test/STS.gs.surprise.OnWN.txt
STS/STS12-en-test/STS.input.surprise.OnWN.txt
STS/STS12-en-test/STS.input.MSRpar.txt
STS/STS12-en-test/STS.gs.ALL.txt
STS/STS12-en-test/00-readme.txt
STS/STS12-en-test/STS.gs.MSRvid.txt
STS/STS12-en-test/STS.input.MSRvid.txt
STS/STS12-en-test/STS.gs.MSRpar.txt
STS/STS12-en-test/STS.input.surprise.SMTnews.txt
STS/STS12-en-test/STS.gs.SMTeuroparl.txt
STS/STS12-en-test/ST

--2022-04-10 21:30:54--  https://huggingface.co/datasets/princeton-nlp/datasets-for-simcse/resolve/main/senteval.tar
Resolving huggingface.co (huggingface.co)... 34.224.55.150, 34.197.58.156, 34.198.1.82, ...
Connecting to huggingface.co (huggingface.co)|34.224.55.150|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs.huggingface.co/datasets/princeton-nlp/datasets-for-simcse/bc43c148f7be97471c78fc4255399d3158cb99dfe8f2221999c918338b138c38 [following]
--2022-04-10 21:30:54--  https://cdn-lfs.huggingface.co/datasets/princeton-nlp/datasets-for-simcse/bc43c148f7be97471c78fc4255399d3158cb99dfe8f2221999c918338b138c38
Resolving cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)... 13.224.163.51, 13.224.163.56, 13.224.163.5, ...
Connecting to cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)|13.224.163.51|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 89825280 (86M) [application/octet-stream]
Saving to: 'senteval.tar'

In [4]:
@dataclass
class DataArguments:
    train_file: str = field(default="./data/wiki1m_for_simcse.txt",
                            metadata={
    "help": "The path of train file"})
    model_name_or_path: str = field(default="bert-base-uncased",
                                    metadata={
    "help": "The name or path of pre-trained language model"})
    max_seq_length: int = field(default=32,
                                metadata={
    "help": "The maximum total input sequence length after tokenization."})

training_args = TrainingArguments(
        output_dir="trainer_models",
        num_train_epochs=1,
        per_device_train_batch_size=64,
        per_device_eval_batch_size  = 64,
        evaluation_strategy   = "steps",
        eval_steps            = 125,
        save_strategy = "steps",
        save_steps=5000,
        load_best_model_at_end=True,
        metric_for_best_model = "eval_avg_sts",    
        learning_rate=3e-5,
        overwrite_output_dir=True,
        do_train=True,
        do_eval=False, 
        logging_steps=10)

data_args = DataArguments()

In [5]:
class PairDataset(Dataset):
    def __init__(self, examples: List[str]):
        total = len(examples)
        # 将所有样本复制一份用于对比学习
        sentences_pair = examples + examples
        sent_features = tokenizer(sentences_pair,
                                  max_length=data_args.max_seq_length,
                                  truncation=True,
                                  padding=False)
        features = {
    }
        # 将相同的样本放在同一个列表中
        for key in sent_features:
            features[key] = [[sent_features[key][i], sent_features[key][i + total]] for i in tqdm(range(total))]
        self.input_ids = features["input_ids"]
        self.attention_mask = features["attention_mask"]
        self.token_type_ids = features["token_type_ids"]

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, item):
        return {
    
            "input_ids": self.input_ids[item],
            "attention_mask": self.attention_mask[item],
            "token_type_ids": self.token_type_ids[item]
        }

In [6]:
# 初始化tokenizer
tokenizer = BertTokenizer.from_pretrained(data_args.model_name_or_path)
# 读取训练数据
with open(data_args.train_file, encoding="utf8") as file:
    texts = [line.strip() for line in tqdm(file.readlines())]
print(type(texts))
print(texts[0])

  0%|          | 0/1000000 [00:00<?, ?it/s]

<class 'list'>
YMCA in South Australia


In [7]:
train_dataset = PairDataset(texts)
print(train_dataset[0])

  0%|          | 0/1000000 [00:00<?, ?it/s]

  0%|          | 0/1000000 [00:00<?, ?it/s]

  0%|          | 0/1000000 [00:00<?, ?it/s]

{'input_ids': [[101, 26866, 1999, 2148, 2660, 102], [101, 26866, 1999, 2148, 2660, 102]], 'attention_mask': [[1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1]], 'token_type_ids': [[0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0]]}


In [8]:
import pickle

with open("./data/train_dataset", "wb") as fp2:   #Pickling
    pickle.dump(train_dataset, fp2)