Since mapping dataset takes a long (~15min) time, it is done separately and then saved, so that the training can be done by directly loading the already mapped dataset.

In [None]:
# Following SimCSE procedure.
import pandas as pd
from datasets import Dataset,load_dataset
from transformers import BertTokenizer, BertModel, AdamW, get_linear_schedule_with_warmup
from torch.utils.data import DataLoader
import torch
import torch.nn.functional as F
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from tqdm import tqdm
import torch.nn as nn

In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

Download wiki1m_for_simCSE.txt:
https://huggingface.co/datasets/princeton-nlp/datasets-for-simcse/resolve/main/wiki1m_for_simcse.txt

In [None]:
# use pandas to read simCSE-wiki.txt
wiki_text_file = 'your-path-to/simCSE-wiki.txt'
wiki = pd.read_csv(wiki_text_file,sep = '\t',header = None)
wiki.columns = ['text']
# use Dataset.from_pandas to convert to dataset
wiki_dataset = Dataset.from_pandas(wiki,split= "train")
wiki_dataset

In [None]:
def prepare_features(examples):
    
    total = len(examples['text'])
    # total = batch_size
    
    # Avoid "None" fields 
    for idx in range(total):
        if examples['text'][idx] is None:
            examples['text'][idx] = " "
        if examples['text'][idx] is None:
            examples['text'][idx] = " "

    sentences = examples['text'] + examples['text']

    # set max_length here:
    sent_features = tokenizer(sentences, max_length=32, truncation=True, padding="max_length")

    features = {}
    for key in sent_features:
        features[key] = [[sent_features[key][i], sent_features[key][i+total]] for i in range(total)]
    
    return features

In [None]:
train_dataset = wiki_dataset.map(prepare_features,batched=True, remove_columns=['text'], batch_size=4000)

In [None]:
train_dataset.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask'])

In [None]:
# save to disk for reuse
train_dataset.save_to_disk("wiki_for_sts_32")