##**FINE-TUNING XLM-R WITH MLM**

In [None]:
!nvidia-smi

####**Install the huggingface transformers library**

In [None]:
# !pip3 install transformers
# !pip3 install nltk

In [None]:
import torch
from tqdm.auto import tqdm
from transformers import AdamW
from transformers import BertTokenizer, BertForMaskedLM
from transformers import XLMRobertaTokenizer
from transformers import XLMRobertaForMaskedLM
from transformers import AutoTokenizer, AutoModelForMaskedLM, XLMRobertaConfig
from transformers import DataCollatorForLanguageModeling

from torch import nn, optim


In [None]:
tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-base')

model = XLMRobertaForMaskedLM.from_pretrained("xlm-roberta-base", return_dict=True)

torch.cuda.empty_cache()

In [None]:
# model_test['roberta.encoder.layer.11.output.LayerNorm.bias']

####**Import text data**

In [None]:
import pandas as pd
import re
from nltk.tokenize import RegexpTokenizer

def preprocess_text(text):
    utterances = text.split("##")
    without_speaker = []
    for utterance in utterances:
        if(len(utterance.strip())==0):
            continue
        utr_split = utterance.split(":")
        if(len(utr_split) > 1):
            without_speaker.append(utr_split[1])
        else:
            without_speaker.append(utr_split[0])
    result = " ".join(without_speaker)
    result = result.lower()
    cleanr = re.compile('<.*?>')
    result = re.sub(cleanr, '', result)
    result=re.sub(r'http\S+', '',result)
    result = re.sub('[0-9]+', '', result)
    tokenizer = RegexpTokenizer(r'\w+')
    result = tokenizer.tokenize(result)
    result = " ".join(result)
    return result

train_df = pd.read_csv('datasets/nli_train.csv') 
test_df = pd.read_csv('datasets/nli_test.csv')

f_train_mdhr = open("datasets/train_mdhr.txt", "r")
train_mdhr = f_train_mdhr.readlines()
# print(train_mdhr)
f_test_mdhr = open("datasets/test_mdhr.txt", "r")
test_mdhr = f_test_mdhr.readlines()
# print(test_mdhr)

for i in range(len(train_mdhr)):
    # print(train_cs[i])
    train_mdhr[i] = preprocess_text(train_mdhr[i])

for i in range(len(test_mdhr)):
    # print(train_cs[i])
    test_mdhr[i] = preprocess_text(test_mdhr[i])

    
train_cmudog_df = pd.read_csv('datasets/train_hinglish_english.csv') 
test_cmudog_df = pd.read_csv('datasets/test_hinglish_english.csv')


train_df['premise'] = train_df['premise'].apply(lambda x: preprocess_text(x))
test_df['premise'] = test_df['premise'].apply(lambda x: preprocess_text(x))

train_cmudog_df['Hinglish'] = train_cmudog_df['Hinglish'].apply(lambda x: preprocess_text(x))
test_cmudog_df['Hinglish'] = test_cmudog_df['Hinglish'].apply(lambda x: preprocess_text(x))


data = train_df['premise'].tolist() + test_df['premise'].tolist() + train_cmudog_df['Hinglish'].tolist() + test_cmudog_df['Hinglish'].tolist() + train_mdhr + test_mdhr
data

####**Text cleaning process**

In [None]:
print(len(data))

In [None]:
for sentence in data:
    if len(sentence) < 50:
        data.remove(sentence)

In [None]:
print(len(data))

####**Tokenizing the text data**

In [None]:
inputs = tokenizer(
    data,
    max_length=512,
    truncation=True,
    padding='max_length',
    return_tensors='pt',
    return_token_type_ids=True
)

In [None]:
inputs.keys()

In [None]:
inputs['labels'] = inputs['input_ids'].detach().clone()
inputs

####**Masking the input_ids**

In [None]:
random_tensor = torch.rand(inputs['input_ids'].shape)

In [None]:
random_tensor.shape

In [None]:
# creating a random tensor of float values.
random_tensor

In [None]:
# creating a mask tensor of float values ranging from 0 to 1 and avoiding special tokens
masked_tensor = (random_tensor < 0.15)*(inputs['input_ids'] != 101)*(inputs['input_ids'] != 102)*(inputs['input_ids'] != 0)

In [None]:
# getting all those indices from each row which are set to True, i.e. masked.
nonzeros_indices = []
for i in range(len(masked_tensor)):
    nonzeros_indices.append(torch.flatten(masked_tensor[i].nonzero()).tolist())

In [None]:
# setting the values at those indices to be a MASK token (103) for every row in the original input_ids.
# for i in range(len(inputs['input_ids'])):
#     inputs['input_ids'][i, nonzeros_indices[i]] = 103

####**Pytorch Dataset and Dataloader**

In [None]:
class CSDataset(torch.utils.data.Dataset):
    def __init__(self, encodings):
        self.encodings = encodings
    
    def __len__(self):
        return len(self.encodings['input_ids'])

    def __getitem__(self, index):
        input_ids = self.encodings['input_ids'][index]
        labels = self.encodings['labels'][index]
        attention_mask = self.encodings['attention_mask'][index]
        token_type_ids = self.encodings['token_type_ids'][index]
        return {
            'input_ids': input_ids,
            'labels': labels,
            'attention_mask': attention_mask,
            'token_type_ids': token_type_ids
        }

In [None]:
dataset = CSDataset(inputs)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15)

In [None]:
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=4,
    shuffle=True,
    collate_fn=data_collator,
)

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
# device = torch.device('cpu')
device

In [None]:
# model = torch.nn.DataParallel(model,device_ids = [1,2,3])
torch.cuda.empty_cache()
model.to(device)

####**Model parameters**

In [None]:
epochs = 15
optimizer = AdamW(model.parameters(), lr=1e-5)

####**Training Loop**

In [None]:
model.train()
PATH = 'models/xlmr-ft-premise-all-data/model.pt'
for epoch in range(epochs):
    loop = tqdm(dataloader)
    for batch in loop:
        optimizer.zero_grad()
        input_ids = batch['input_ids'].to(device)
        labels = batch['labels'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        token_type_ids = batch['token_type_ids'].to(device)
        # print('token_type_ids: ', token_type_ids)
        outputs = model(input_ids, attention_mask=attention_mask, labels=labels,return_dict=True)
        # print(outputs)
        loss = outputs['loss']
        # loss.backward()
        loss.backward()
        optimizer.step()

        loop.set_description("Epoch: {}".format(epoch))
        loop.set_postfix(loss=loss.item())
        
        del input_ids
        del labels
        del attention_mask
        del token_type_ids

    
    # if(epoch % 2 == 0):
    torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            }, PATH)