<a href="https://colab.research.google.com/github/vishal-burman/PyTorch-Architectures/blob/master/research/modeling_MLM/test_sample_MLM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
! pip install datasets
! pip install transformers

In [1]:
# ! rm -rf PyTorch-Architectures/
! git clone https://github.com/vishal-burman/PyTorch-Architectures.git/
%cd PyTorch-Architectures/research/modeling_MLM/

fatal: destination path 'PyTorch-Architectures' already exists and is not an empty directory.
/content/PyTorch-Architectures/research/modeling_MLM


In [2]:
import string
import time
from datasets import load_dataset
from transformers import DistilBertTokenizer, DistilBertModel, DistilBertConfig
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from model import MLM

In [3]:
%%time
dataset = load_dataset('cnn_dailymail', '3.0.0')

Reusing dataset cnn_dailymail (/root/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/3cb851bf7cf5826e45d49db2863f627cba583cbc32342df7349dfe6c38060234)


CPU times: user 55 ms, sys: 10 ms, total: 65 ms
Wall time: 201 ms


In [4]:
# Section for Hyperparameters
BATCH_SIZE = 64
LR = 3e-5
EPOCHS = 10
MAX_INPUT_LENGTH = 32

In [5]:
def sentence_cleanup(text, max_len=32):
  text = text.translate(str.maketrans('', '', string.punctuation))
  # text = ' '.join(text.split()[:max_len])
  text = text.lower()
  return text

In [6]:
sentences = []
for sample in dataset['train']:
  if len(sentences) == 10000:
    break
  sentences.append(sample['article'])

sentences = list(map(sentence_cleanup, sentences))
print('No. of samples: ', len(sentences))

No. of samples:  10000


In [7]:
split = 90 * len(sentences) // 100
train_sentences = sentences[:split]
valid_sentences = sentences[split:]
print('No. of train samples: ', len(train_sentences))
print('No. of train samples: ', len(valid_sentences))

No. of train samples:  9000
No. of train samples:  1000


In [8]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
config = DistilBertConfig()
transformer = DistilBertModel(config)

In [None]:
model = MLM(transformer=transformer,
            pad_token_id=tokenizer.pad_token_id,
            mask_token_id=tokenizer.mask_token_id,
            mask_prob=0.15,
            num_tokens=tokenizer.vocab_size,
            replace_prob=0.90)
model.to(device)

In [10]:
class CustomDataset(Dataset):
  def __init__(self, tokenizer, sentences, max_input_length=16):
    self.tokenizer = tokenizer
    self.sentences = sentences
    self.max_input_length = max_input_length
  
  def __len__(self):
    return len(self.sentences)
  
  def __getitem__(self, idx):
    texts = self.sentences[idx]
    tokens = self.tokenizer(texts,
                            max_length=self.max_input_length,
                            padding=False,
                            truncation=True,
                            return_tensors='pt')
    return {
        'input_ids': tokens['input_ids'],
        'attention_mask': tokens['attention_mask'],
    }

In [11]:
# Sanity check CustomDataset
sample_dataset = CustomDataset(tokenizer=tokenizer,
                               sentences=valid_sentences,
                               max_input_length=16)
sample_dataloader = DataLoader(dataset=sample_dataset,
                               batch_size=32,
                               shuffle=False,)

for sample in sample_dataloader:
  assert sample['input_ids'].squeeze(1).dim() == 2
  break

In [12]:
# Sanity check MLM module
model.eval()
with torch.set_grad_enabled(False):
  input_ids = sample['input_ids'].squeeze(1).to(device)
  attention_mask = sample['attention_mask'].squeeze(1).to(device)
  logits, labels = model(input_ids=input_ids, attention_mask=attention_mask)
  assert logits.size(0) == labels.size(0)
  assert logits.size(1) == labels.size(1)
  assert logits.size(2) == tokenizer.vocab_size

In [13]:
train_dataset = CustomDataset(tokenizer, train_sentences, max_input_length=MAX_INPUT_LENGTH)
valid_dataset = CustomDataset(tokenizer, valid_sentences, max_input_length=MAX_INPUT_LENGTH)

train_loader = DataLoader(dataset=train_dataset,
                          batch_size=BATCH_SIZE,
                          shuffle=True)
valid_loader = DataLoader(dataset=valid_dataset,
                          batch_size=BATCH_SIZE,
                          shuffle=False)

print('Length of Train Loader: ', len(train_loader))
print('Length of Valid Loader: ', len(valid_loader))

Length of Train Loader:  141
Length of Valid Loader:  16


In [14]:
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)

In [15]:
def compute_loss(model, data_loader, device):
  loss_list = []
  with torch.set_grad_enabled(False):
    for sample in data_loader:
      input_ids = sample['input_ids'].squeeze(1).to(device)
      attention_mask = sample['attention_mask'].squeeze(1).to(device)
      logits, labels = model(input_ids=input_ids, attention_mask=attention_mask)
      
      loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1))
      loss_list.append(loss.item())
  return torch.tensor(loss_list).mean()

start_time = time.time()
for epoch in range(EPOCHS):
  model.train()
  for idx, sample in enumerate(train_loader):
    input_ids = sample['input_ids'].squeeze(1).to(device)
    attention_mask = sample['attention_mask'].squeeze(1).to(device)

    logits, labels = model(input_ids, attention_mask=attention_mask)
    loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1),
                           ignore_index=tokenizer.pad_token_id)
    
    optimizer.zero_grad
    loss.backward()
    optimizer.step()
    
    if idx % 50 == 0:
      print('Batch: %04d/%04d || Epoch: %04d/%04d || Loss: %.2f' % (idx,
                                                                   len(train_loader),
                                                                   epoch+1,
                                                                   EPOCHS,
                                                                   loss.item()))
  model.eval()
  with torch.set_grad_enabled(False):
    train_loss = compute_loss(model, train_loader, device)
    valid_loss = compute_loss(model, valid_loader, device)
    print('Train Loss: %.2f || Valid Loss: %.2f' % (train_loss.item(),
                                                    valid_loss.item()))
  epoch_elapsed_time = (time.time() - start_time) / 60
  print('Epoch Elapsed Time: %.2f min' % (epoch_elapsed_time))
total_training_time = (time.time() - start_time) / 60
print('Total Training Time: %.2f min' % (total_training_time))

Batch: 0000/0141 || Epoch: 0001/0010 || Loss: 10.50
Batch: 0050/0141 || Epoch: 0001/0010 || Loss: 8.17
Batch: 0100/0141 || Epoch: 0001/0010 || Loss: 7.21
Train Loss: 12.66 || Valid Loss: 12.68
Epoch Elapsed Time: 5.24 min
Batch: 0000/0141 || Epoch: 0002/0010 || Loss: 7.23
Batch: 0050/0141 || Epoch: 0002/0010 || Loss: 7.19
Batch: 0100/0141 || Epoch: 0002/0010 || Loss: 7.66
Train Loss: 13.77 || Valid Loss: 13.77
Epoch Elapsed Time: 10.48 min
Batch: 0000/0141 || Epoch: 0003/0010 || Loss: 7.49
Batch: 0050/0141 || Epoch: 0003/0010 || Loss: 7.45
Batch: 0100/0141 || Epoch: 0003/0010 || Loss: 7.91
Train Loss: 14.20 || Valid Loss: 14.21
Epoch Elapsed Time: 15.67 min
Batch: 0000/0141 || Epoch: 0004/0010 || Loss: 7.80
Batch: 0050/0141 || Epoch: 0004/0010 || Loss: 7.95
Batch: 0100/0141 || Epoch: 0004/0010 || Loss: 7.89
Train Loss: 15.33 || Valid Loss: 15.33
Epoch Elapsed Time: 20.85 min
Batch: 0000/0141 || Epoch: 0005/0010 || Loss: 8.04
Batch: 0050/0141 || Epoch: 0005/0010 || Loss: 7.81
Batch: 010