<a href="https://colab.research.google.com/github/vishal-burman/PyTorch-Architectures/blob/master/research/modeling_MLM/test_sample_MLM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
! pip install datasets
! pip install transformers

In [None]:
! git clone https://github.com/vishal-burman/PyTorch-Architectures.git/

In [6]:
import time
from datasets import load_dataset
from transformers import DistilBertTokenizer, DistilBertModel, DistilBertConfig
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

In [None]:
%%time
dataset = load_dataset('cnn_dailymail', '3.0.0')

In [11]:
sentences = []
for sample in dataset['train']:
  if len(sentences) == 10000:
    break
  sentences.append(sample['article'])
print('No. of samples: ', len(sentences))

No. of samples:  10000


In [13]:
split = 90 * len(sentences) // 100
train_sentences = sentences[:split]
valid_sentences = sentences[split:]
print('No. of train samples: ', len(train_sentences))
print('No. of train samples: ', len(valid_sentences))

No. of train samples:  9000
No. of train samples:  1000


In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
config = DistilBertConfig()
model = DistilBertModel(config)
model.to(device)

In [18]:
class CustomDataset(Dataset):
  def __init__(self, tokenizer, sentences, max_input_length=16):
    self.tokenizer = tokenizer
    self.sentences = sentences
    self.max_input_length = max_input_length
  
  def __len__(self):
    return len(self.sentences)
  
  def __getitem__(self, idx):
    texts = self.sentences[idx]
    tokens = self.tokenizer(texts,
                            max_length=self.max_input_length,
                            padding='max_length',
                            truncation=True,
                            return_tensors='pt')
    return {
        'input_ids': tokens['input_ids'],
        'attention_mask': tokens['attention_mask'],
    }

In [21]:
# Sanity check CustomDataset
sample_dataset = CustomDataset(tokenizer=tokenizer,
                               sentences=valid_sentences)
sample_dataloader = DataLoader(dataset=sample_dataset,
                               batch_size=32,
                               shuffle=False)

for sample in sample_dataloader:
  assert sample['input_ids'].squeeze(1).dim() == 2
  break