<a href="https://colab.research.google.com/github/vishal-burman/PyTorch-Architectures/blob/master/misc/Finetune_vs_Finetune_after_MLM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
! nvidia-smi

In [None]:
! pip install transformers
! pip install datasets

In [None]:
! rm -rf PyTorch-Architectures/
! git clone https://github.com/vishal-burman/PyTorch-Architectures.git/
%cd PyTorch-Architectures/

In [35]:
import random
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification, DistilBertForMaskedLM
from toolkit.custom_dataset_nlp import DataLoaderTextClassification, DatasetLanguageModeling
from toolkit.utils import get_linear_schedule_with_warmup, dict_to_device
from toolkit.metrics import nlp_compute_accuracy
from tqdm.auto import tqdm

In [5]:
# Hyperparameters
BS = 128
EPOCHS = 3
LR = 5e-4

In [None]:
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased')
model.to(device)

In [8]:
dataset_train = DataLoaderTextClassification(tokenizer, train=True, split=500)
dataset_valid = DataLoaderTextClassification(tokenizer, train=False, split=500)

Reusing dataset glue (/root/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)
Reusing dataset glue (/root/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)


In [9]:
train_loader = dataset_train.return_dataloader(batch_size=BS, shuffle=True)
valid_loader = dataset_valid.return_dataloader(batch_size=BS, shuffle=False)
print('Length of Train Loader: ', len(train_loader))
print('Length of Valid Loader: ', len(valid_loader))

Length of Train Loader:  4
Length of Valid Loader:  4


In [None]:
# Sanity check forward pass
model.eval()
for sample in train_loader:
  outputs = model(**dict_to_device(sample, device))
  print(outputs.loss.item())
  break

0.6993951201438904


In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
scheduler = get_linear_schedule_with_warmup(optimizer,
                                            num_warmup_steps=0, 
                                            num_training_steps=len(train_loader) * EPOCHS)

In [None]:
progress_bar = tqdm(range(len(train_loader) * EPOCHS))

for epoch in range(EPOCHS):
  model.train()
  for idx, sample in enumerate(train_loader):
    outputs = model(**dict_to_device(sample, device))

    loss = outputs.loss
    loss.backward()

    optimizer.step()
    scheduler.step()
    optimizer.zero_grad()
    progress_bar.update(1)
  model.eval()
  with torch.set_grad_enabled(False):
    train_acc = nlp_compute_accuracy(model, train_loader, device)
    valid_acc = nlp_compute_accuracy(model, valid_loader, device)
  print('Train Accuracy: %.2f%% || Valid Accuracy: %.2f%%' % (train_acc, valid_acc))

HBox(children=(FloatProgress(value=0.0, max=1581.0), HTML(value='')))

Train Accuracy: 91.29% || Valid Accuracy: 76.15%
Train Accuracy: 95.34% || Valid Accuracy: 77.18%
Train Accuracy: 96.86% || Valid Accuracy: 77.18%


In [30]:
all_texts = dataset_train.dataset.sents
all_texts.extend(dataset_valid.dataset.sents)

In [31]:
random.shuffle(all_texts)
split = 90 * len(train_texts) // 100
train_texts = all_texts[:split]
valid_texts = all_texts[split:]

In [32]:
mlm_dataset_train = DatasetLanguageModeling(tokenizer, input_texts=train_texts)
mlm_dataset_valid = DatasetLanguageModeling(tokenizer, input_texts=valid_texts)

In [33]:
mlm_train_loader = DataLoader(mlm_dataset_train, batch_size=BS, shuffle=True, collate_fn=mlm_dataset.collate_fn)
mlm_valid_loader = DataLoader(mlm_dataset_valid, batch_size=BS, shuffle=False, collate_fn=mlm_dataset.collate_fn)
print('Length of mlm_loader: ', len(mlm_train_loader))
print('Length of mlm_loader: ', len(mlm_valid_loader))

Length of mlm_loader:  7
Length of mlm_loader:  6


In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
model = DistilBertForMaskedLM.from_pretrained('distilbert-base-uncased')
model.to(device)

In [39]:
# Sanity check forward pass
model.eval()
with torch.set_grad_enabled(False):
  for sample in mlm_loader:
    sample = dict_to_device(sample, device=device)
    outputs = model(input_ids=sample['input_ids'],
                    attention_mask=sample['attention_mask'],
                    labels=sample['target_ids'])
    print(outputs.loss.item())
    break

4.427879333496094
