<a href="https://colab.research.google.com/github/vishal-burman/PyTorch-Architectures/blob/master/research/modeling_MLM/test_sample_MLM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
! pip install datasets
! pip install transformers

In [None]:
! rm -rf PyTorch-Architectures/
! git clone https://github.com/vishal-burman/PyTorch-Architectures.git/
%cd PyTorch-Architectures/research/modeling_MLM/

In [2]:
import string
import time
from datasets import load_dataset
from transformers import DistilBertTokenizer, DistilBertModel, DistilBertConfig
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from model import MLM

In [3]:
%%time
dataset = load_dataset('cnn_dailymail', '3.0.0')

Reusing dataset cnn_dailymail (/root/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/3cb851bf7cf5826e45d49db2863f627cba583cbc32342df7349dfe6c38060234)


CPU times: user 84.9 ms, sys: 22.3 ms, total: 107 ms
Wall time: 588 ms


In [4]:
# Crap to deal with less vocabulary
def sentence_cleanup(text, max_len=8):
  text = text.translate(str.maketrans('', '', string.punctuation))
  text = ' '.join(text.split()[:max_len])
  text = text.lower()
  return text

In [5]:
sentences = []
for sample in dataset['train']:
  if len(sentences) == 10000:
    break
  sentences.append(sample['article'])

sentences = list(map(sentence_cleanup, sentences))
print('No. of samples: ', len(sentences))

No. of samples:  10000


In [6]:
split = 90 * len(sentences) // 100
train_sentences = sentences[:split]
valid_sentences = sentences[split:]
print('No. of train samples: ', len(train_sentences))
print('No. of train samples: ', len(valid_sentences))

No. of train samples:  9000
No. of train samples:  1000


In [7]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
config = DistilBertConfig()
transformer = DistilBertModel(config)

In [None]:
model = MLM(transformer=transformer,
            pad_token_id=tokenizer.pad_token_id,
            mask_token_id=tokenizer.mask_token_id,
            mask_prob=0.15,
            num_tokens=tokenizer.vocab_size,
            replace_prob=0.90)
model.to(device)

In [10]:
class CustomDataset(Dataset):
  def __init__(self, tokenizer, sentences, max_input_length=16):
    self.tokenizer = tokenizer
    self.sentences = sentences
    self.max_input_length = max_input_length
  
  def __len__(self):
    return len(self.sentences)
  
  def __getitem__(self, idx):
    texts = self.sentences[idx]
    tokens = self.tokenizer(texts,
                            max_length=self.max_input_length,
                            padding='max_length',
                            truncation=True,
                            return_tensors='pt')
    return {
        'input_ids': tokens['input_ids'],
        'attention_mask': tokens['attention_mask'],
    }

In [13]:
# Sanity check CustomDataset
sample_dataset = CustomDataset(tokenizer=tokenizer,
                               sentences=valid_sentences,
                               max_input_length=16)
sample_dataloader = DataLoader(dataset=sample_dataset,
                               batch_size=32,
                               shuffle=False,)

for sample in sample_dataloader:
  assert sample['input_ids'].squeeze(1).dim() == 2
  break

In [22]:
# Sanity check MLM module
model.eval()
with torch.set_grad_enabled(False):
  input_ids = sample['input_ids'].squeeze(1).to(device)
  logits, labels = model(input_ids)
  assert logits.size(0) == labels.size(0)
  assert logits.size(1) == labels.size(1)
  assert logits.size(2) == tokenizer.vocab_size