In [1]:
import pandas as pd
from datasets import Dataset,load_dataset
from transformers import BertTokenizer, BertModel, AdamW, get_linear_schedule_with_warmup
from torch.utils.data import DataLoader
import torch
import torch.nn.functional as F
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from tqdm import tqdm
import torch.nn as nn

In [2]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [11]:
# use pandas to read simCSE-wiki.txt
wiki_text_file = './data/wiki1m_for_simcse.txt'
wiki = pd.read_csv(wiki_text_file,sep = '\t',header = None)
wiki.columns = ['text']
# use Dataset.from_pandas to convert to dataset
wiki_dataset = Dataset.from_pandas(wiki,split= "train")
wiki_dataset

Dataset({
    features: ['text'],
    num_rows: 995447
})

In [4]:
def prepare_features(examples):
    
    total = len(examples['text'])
    # total = batch_size
    
    # Avoid "None" fields 
    for idx in range(total):
        if examples['text'][idx] is None:
            examples['text'][idx] = " "
        if examples['text'][idx] is None:
            examples['text'][idx] = " "

    sentences = examples['text'] + examples['text']

    # set max_length here:
    sent_features = tokenizer(sentences, max_length=32, truncation=True, padding="max_length")

    features = {}
    for key in sent_features:
        features[key] = [[sent_features[key][i], sent_features[key][i+total]] for i in range(total)]
    
    return features

In [5]:
train_dataset = wiki_dataset.map(prepare_features,batched=True, remove_columns=['text'], batch_size=4000)

  0%|          | 0/249 [00:00<?, ?ba/s]

In [6]:
train_dataset.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask'])

In [7]:
train_dataset

Dataset({
    features: ['input_ids', 'token_type_ids', 'attention_mask'],
    num_rows: 995447
})

In [8]:
# save to disk for reuse
train_dataset.save_to_disk("./data/wiki_for_sts_32")

In [10]:
train_dataset[0]

{'input_ids': [tensor([  101, 26866,  1999,  2148,  2660,   102,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0]),
  tensor([  101, 26866,  1999,  2148,  2660,   102,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0])],
 'token_type_ids': [tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0]),
  tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0])],
 'attention_mask': [tensor([1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0]),
  tensor([1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,