### **Masked Language Modelling** 

Masked Languag Modelling entails **masking** a part of input sequence and then training a model to predict the masked tokens. 

`roberta-base` supports sequences of length 512

[Reference](https://towardsdatascience.com/masked-language-modelling-with-bert-7d49793e5d2c)

In [None]:
# Installations
#!pip install transformers==3.0.2

In [None]:
# imports
import os
import random

import torch
from torch.utils.data import Dataset, DataLoader
from torch import cuda

import transformers
from transformers import RobertaTokenizer, RobertaModel, RobertaForMaskedLM 
from transformers import pipeline

from torch import cuda
from tqdm import tqdm
device = 'cuda' if cuda.is_available() else 'cpu'

In [None]:
# Mounting Google Drive to this .ipynb
# from google.colab import drive
# drive.mount('/content/drive')


train_data_loc = 'SST-2/Few_Shot/train_3.tsv'
dev_data_loc = 'SST-2/dev.tsv'

In [None]:
# Dataloader

class SST2_prompt(Dataset):
    def __init__(self, file_loc, tokenizer, max_length, template = '<S> It was <mask> . ', target2label = {1: 'great', 0: 'terrible'}):
        '''
        file_loc      str   : file path for the dataset 
        template      str   : Prompt Template
        target2label  dict  : key value representing target to prompt label key = target class, value = prompt label  
        '''
        
        self.tokenizer = tokenizer
        self.max_length = max_length  
        self.prompt_template = template
        self.prompt_label = target2label
        
        with open(file_loc) as f:
            f.readline()
            data = [line.split('\t') for line in f]
            
        # read data from file - get text and labels
        self.examples = [x.strip() for (x,y) in data]
        self.targets = [int(y) for (x,y) in data]    
        
    def __len__(self):
        #return size of the dataset
        return len(self.targets)
    
    def __getitem__(self, idx):
        '''
        idx - index of a specific example
        returns the data corresponding to that index
        '''
        
        x, y = self.prompt_transform(self.examples[idx], self.targets[idx])
        x_tokenized = self.tokenizer(x, return_tensors='pt', max_length = self.max_length, truncation=True, padding='max_length')
        y_tokenized = self.tokenizer(y, return_tensors='pt', max_length = self.max_length, truncation=True, padding='max_length')
        
        x_tokenized['labels'] = y_tokenized['input_ids']
        
        return x_tokenized
    
    def prompt_transform(self, text, target):
        '''
        text - Text to be classified
        template - a simple string replacing the text for '<S>', mask for '<mask>' punctuation and space is as is.
        eg- '<S> It was <mask> . '
        Returns a transformed prompt for the text.
        '''
        x = self.prompt_template.replace('<S>', text)
        y = self.prompt_template.replace('<S>', text).replace('<mask>', self.prompt_label[target])
        
        return x, y 

In [None]:
# Setting up some parameters
max_length = 256
train_batch_size = 8
val_batch_size = 8

learning_rate = 2e-5
tokenizer = RobertaTokenizer.from_pretrained('roberta-large')


In [None]:
training_set = SST2_prompt(train_data_loc, tokenizer, max_length)
eval_set = SST2_prompt(dev_data_loc, tokenizer, max_length)

In [None]:
train_data = DataLoader(training_set, batch_size = train_batch_size, shuffle = True, num_workers = 0)
eval_data = DataLoader(eval_set, batch_size = train_batch_size, shuffle = True, num_workers = 0)

In [None]:
# model
model = RobertaForMaskedLM.from_pretrained('roberta-large')
model.to(device)
model.train()

# optimizer
optim = transformers.AdamW(model.parameters(), lr=learning_rate)


In [None]:
#training loop

training_steps = 0
while training_steps < 1000:
    # setup loop with TQDM and dataloader
    loop = tqdm(train_data, leave=True)
    for batch in loop:
        optim.zero_grad()
        
        input_ids = batch['input_ids'].squeeze().to(device)
        attention_mask = batch['attention_mask'].squeeze().to(device)
        labels = batch['labels'].squeeze().to(device)
        
        
        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        
        loss = outputs[0]
        loss.backward()
        optim.step()
        
        # print relevant info to progress bar
        loop.set_description(f'Training Step {training_steps}')
        loop.set_postfix(loss=loss.item())
        training_steps+= 1
        


In [None]:
match = 0
count = 0
predictions = []
for _, batch in enumerate(eval_data):
    input_ids = batch['input_ids'].squeeze().to(device)
    attention_mask = batch['attention_mask'].squeeze().to(device)
    labels = batch['labels'].squeeze().to(device)
    
    with torch.no_grad():
        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        
    logits = outputs[1]
    
    for i in range(len(batch)):
        masked_token_index = (batch['input_ids'][i] == tokenizer.mask_token_id)[0].nonzero(as_tuple=True)[0]
        predicted_token_id = logits[i, masked_token_index].argmax(axis = -1)
        predictions.append(predicted_token_id)
        
        if predicted_token_id == labels[i][masked_token_index]:
            match += 1
        count+=1


print(f'accuracy: {match/count}')

In [None]:

preds = set([x.item() for x in predictions])
decoded = []
for token in preds:
    decoded.append(tokenizer.decode(token))

In [None]:
print(decoded)