In [14]:
import torch.nn as nn
import torch.optim as optim 
import torch
from torchtext.vocab import build_vocab_from_iterator
from transformers import BertTokenizer
from torchtext.datasets import IMDB
from torch.utils.data import DataLoader
import tqdm
from torchtext.data.utils import get_tokenizer
import random

In [2]:
train_iter,valid_iter=IMDB()

################################################################################
The 'datapipes', 'dataloader2' modules are deprecated and will be removed in a
future torchdata release! Please see https://github.com/pytorch/data/issues/1196
to learn more and leave feedback.
################################################################################



In [3]:

train_dataset=list(train_iter)
valid_dataset=list(valid_iter)

In [4]:
train_dataset[0]

(1,
 'I rented I AM CURIOUS-YELLOW from my video store because of all the controversy that surrounded it when it was first released in 1967. I also heard that at first it was seized by U.S. customs if it ever tried to enter this country, therefore being a fan of films considered "controversial" I really had to see this for myself.<br /><br />The plot is centered around a young Swedish drama student named Lena who wants to learn everything she can about life. In particular she wants to focus her attentions to making some sort of documentary on what the average Swede thought about certain political issues such as the Vietnam War and race issues in the United States. In between asking politicians and ordinary denizens of Stockholm about their opinions on politics, she has sex with her drama teacher, classmates, and married men.<br /><br />What kills me about I AM CURIOUS-YELLOW is that 40 years ago, this was considered pornographic. Really, the sex and nudity scenes are few and far betwee

In [5]:
special_symbols=['PAD','CLS','UNK','MASK']
PAD_IDX,CLS_IDX,UNK_IDX,MASK_IDX=0,1,2,3
tokenizer=get_tokenizer("basic_english")
tokens = [tokenizer(text.lower()) for label, text in train_dataset]
vocab=build_vocab_from_iterator(tokens,specials=special_symbols,special_first=True)
vocab.set_default_index(vocab['UNK'])

In [6]:
index_to_string=vocab.get_itos()

### Masking 

In [7]:
def bernoulli_true_false(p):
    res=torch.distributions.Bernoulli(torch.tensor([p]))
    return res.sample().item()==1

In [10]:
bernoulli_true_false(0.8)

False

In [11]:
def masking(token):
    mask_factor=bernoulli_true_false(0.2)
    if not mask_factor:
        return token,'PAD'
    mask_label=token
    r=random.random()
    if r<0.8:
        token_='MASK'
    elif r>0.8  and r<0.9:
        rand_idx = torch.randint(0, VOCAB_SIZE, (1,))
        token_ = index_to_string(rand_idx.item())  
    else:
        token_=token
    return token_,mask_label

### Data Preperation for BERT Model

In [19]:
def prepare_for_mlm(tokens,include_raw_tokens=False):
    current_bert_input=[]
    current_bert_label=[]
    current_raw_tokens=[]
    bert_input=[]
    bert_label=[]
    raw_tokens=[]
    for token in tokens:
        token_,mask_label=masking(token)
        current_bert_input.append(token_)
        current_bert_label.append(mask_label)
        if include_raw_tokens:
            current_raw_tokens.append(token)
        if token in ['.','?','!']:
            if len(current_bert_input)>2:
                bert_input.append(current_bert_input)
                bert_label.append(current_bert_label)
                raw_tokens.append(current_raw_tokens)
            current_bert_input=[]
            current_bert_label=[]
            current_raw_tokens=[]
    if include_raw_tokens:
        return bert_input,bert_label,raw_tokens
    else:
        return bert_input,bert_label

In [21]:
torch.manual_seed(100)
VOCAB_SIZE = 30522
original_input="The sun sets behind the distant mountains."
tokens=tokenizer(original_input)
bert_input,bert_label=prepare_for_mlm(tokens)
print("Without raw tokens: \t ","\n \t original_input is: \t ", original_input,"\n \t bert_input is: \t ", bert_input,"\n \t bert_label is: \t ", bert_label)
print("-"*200)
torch.manual_seed(100)
bert_input, bert_label, raw_tokens_list= prepare_for_mlm(tokens, include_raw_tokens=True)
print("With raw tokens: \t ","\n \t original_input is: \t ", original_input,"\n \t bert_input is: \t ", bert_input,"\n \t bert_label is: \t ", bert_label,"\n \t raw_tokens_list is: \t ", raw_tokens_list)

Without raw tokens: 	  
 	 original_input is: 	  The sun sets behind the distant mountains. 
 	 bert_input is: 	  [['MASK', 'sun', 'sets', 'behind', 'the', 'distant', 'mountains', '.']] 
 	 bert_label is: 	  [['the', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', '.']]
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
With raw tokens: 	  
 	 original_input is: 	  The sun sets behind the distant mountains. 
 	 bert_input is: 	  [['MASK', 'sun', 'sets', 'behind', 'the', 'distant', 'mountains', 'MASK']] 
 	 bert_label is: 	  [['the', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', '.']] 
 	 raw_tokens_list is: 	  [['the', 'sun', 'sets', 'behind', 'the', 'distant', 'mountains', '.']]
