<a href="https://colab.research.google.com/github/vishal-burman/PyTorch-Architectures/blob/master/misc/T5Unjumble.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
! pip install transformers
! pip install datasets
! pip install sentencepiece

In [1]:
import time
import random

import torch
from torch.utils.data import DataLoader, Dataset

from transformers import T5ForConditionalGeneration, T5Tokenizer
from datasets import load_dataset

dataset = load_dataset('quora')

Using custom data configuration default
Reusing dataset quora (/root/.cache/huggingface/datasets/quora/default/0.0.0/2be517cf0ac6de94b77a103a36b141347a13f40637fbebaccb56ddbe397876be)


In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
question_list = dataset['train']['questions']

In [4]:
sent_train = [question['text'][0] for question in question_list[:10000]]
sent_valid = [question['text'][0] for question in question_list[10000:15000]]

In [5]:
class UJDataset(Dataset):
  def __init__(self, sent_list):
    self.sent_list = sent_list
    self.sample_list = []
    self.build()

  def __len__(self):
    return len(self.sample_list)
  
  def __getitem__(self, idx):
    sample = self.sample_list[idx]
    return {
        'src': sample['src'],
        'tgt': sample['tgt'],
    }

  def build(self):
    for sent in self.sent_list:
      unj = sent
      words = sent.split(" ")
      random.shuffle(words)
      jum = ' '.join(words)
      self.sample_list.append({'src': jum, 'tgt': unj})

In [6]:
tokenizer = T5Tokenizer.from_pretrained('t5-base')
model = T5ForConditionalGeneration.from_pretrained('t5-base')
model.to(device)

params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('Total Trainable Parameters: ', params)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=791656.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1199.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=891691430.0, style=ProgressStyle(descri…




Some weights of the model checkpoint at t5-base were not used when initializing T5ForConditionalGeneration: ['decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight']
- This IS expected if you are initializing T5ForConditionalGeneration from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing T5ForConditionalGeneration from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Total Trainable Parameters:  222903552


In [7]:
train_dataset = UJDataset(sent_train)
valid_dataset = UJDataset(sent_valid)

In [8]:
BATCH_SIZE = 8
EPOCHS = 5
LEARNING_RATE = 3e-5
MAX_LENGTH = 16

train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset=valid_dataset, batch_size=BATCH_SIZE, shuffle=False)

# Check dataset
for sample in train_loader:
  print(sample['src'], "\n")
  break

print("Length of Train Loader: ", len(train_loader))
print("Length of Valid Loader: ", len(valid_loader))

optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)

['practice in coding IT essential Is firms? placement for competitive', 'do should How live? we', 'are What blogs Quora? some on good', 'can spoken my English I How improve ability?', '(logical is reasoning) LR best CAT? the book What for', 'are started to What or references learning? books machine get with some good', 'according Trump When was to that year great last campaign? was the the America', 'the play guitar learn by the What way to to myself? best is'] 

Length of Train Loader:  1250
Length of Valid Loader:  625


In [9]:
def compute_loss(model, data_loader, device):
  list_loss = []
  for sample in data_loader:
    tokens = tokenizer.prepare_seq2seq_batch(src_texts=sample['src'], tgt_texts=sample['tgt'], max_length=MAX_LENGTH, max_target_length=MAX_LENGTH, padding='max_length', truncation=True, return_tensors='pt')
    ids = tokens['input_ids'].to(device)
    mask = tokens['attention_mask'].to(device)
    tgt = tokens['labels'].to(device)
    loss = model(input_ids=ids, attention_mask=mask, labels=tgt).loss
    list_loss.append(loss.item())
  final_loss_mean = torch.tensor(list_loss).mean()
  return final_loss_mean

start_time = time.time()
for epoch in range(EPOCHS):
  model.train()
  for idx, sample in enumerate(train_loader):
    tokens = tokenizer.prepare_seq2seq_batch(src_texts=sample['src'], tgt_texts=sample['tgt'], max_length=MAX_LENGTH, max_target_length=MAX_LENGTH, padding='max_length', truncation=True, return_tensors='pt')
    ids = tokens['input_ids'].to(device)
    mask = tokens['attention_mask'].to(device)
    tgt = tokens['labels'].to(device)
    outputs = model(input_ids=ids, attention_mask=mask, labels=tgt)
    
    optimizer.zero_grad()
    loss = outputs.loss
    loss.backward()
    optimizer.step()

    # LOGGING
    if idx % 300 == 0:
      print("BATCH: %04d/%04d || Epoch: %04d/%04d || Loss: %.3f" % (idx, len(train_loader), epoch+1, EPOCHS, loss.item()))
  
  model.eval()
  with torch.set_grad_enabled(False):
    valid_loss = compute_loss(model, valid_loader, device)
    print('Valid Loss: %.3f' % (valid_loss))
  epoch_elapsed_time = (time.time() - start_time) / 60
  print('Epoch elapsed time: %.2f min' % (epoch_elapsed_time))
total_training_time = (time.time() - start_time) / 60
print('Total training time: ', total_training_time)

BATCH: 0000/1250 || Epoch: 0001/0005 || Loss: 4.356
BATCH: 0300/1250 || Epoch: 0001/0005 || Loss: 1.365
BATCH: 0600/1250 || Epoch: 0001/0005 || Loss: 1.478
BATCH: 0900/1250 || Epoch: 0001/0005 || Loss: 1.519
BATCH: 1200/1250 || Epoch: 0001/0005 || Loss: 1.603
Valid Loss: 0.744
Epoch elapsed time: 4.04 min
BATCH: 0000/1250 || Epoch: 0002/0005 || Loss: 0.477
BATCH: 0300/1250 || Epoch: 0002/0005 || Loss: 0.813
BATCH: 0600/1250 || Epoch: 0002/0005 || Loss: 0.381
BATCH: 0900/1250 || Epoch: 0002/0005 || Loss: 1.549
BATCH: 1200/1250 || Epoch: 0002/0005 || Loss: 0.477
Valid Loss: 0.651
Epoch elapsed time: 8.05 min
BATCH: 0000/1250 || Epoch: 0003/0005 || Loss: 0.538
BATCH: 0300/1250 || Epoch: 0003/0005 || Loss: 0.451
BATCH: 0600/1250 || Epoch: 0003/0005 || Loss: 0.922
BATCH: 0900/1250 || Epoch: 0003/0005 || Loss: 0.604
BATCH: 1200/1250 || Epoch: 0003/0005 || Loss: 0.346
Valid Loss: 0.619
Epoch elapsed time: 12.04 min
BATCH: 0000/1250 || Epoch: 0004/0005 || Loss: 0.377
BATCH: 0300/1250 || Epoch:

In [20]:
model.eval()
with torch.set_grad_enabled(False):
  text = "person is where the responsible this? for"
  print("Your input(jumbled sentence): ", text)
  tokens = tokenizer(text, return_tensors='pt')
  ids = tokens['input_ids'].to(device)
  mask = tokens['attention_mask'].to(device)
  outputs = model.generate(input_ids=ids, max_length=16, early_stopping=True)
  print("Unjumbled Sentence: ", tokenizer.decode(outputs[0], skip_special_tokens=True))

Your input(jumbled sentence):  person is where the responsible this? for
Unjumbled Sentence:  Where is the person responsible for this?
