<a href="https://colab.research.google.com/github/vishal-burman/PyTorch-Architectures/blob/master/misc/Training_Pipeline_Seq2Seq_AdamW_(Continuously_Updated).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
! pip install -q transformers datasets

In [3]:
# Borrowed from:
# https://github.com/huggingface/transformers/blob/main/examples/pytorch/summarization/run_summarization_no_trainer.py

# This is my cleaned version

In [6]:
import random

import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BartTokenizer, BartForConditionalGeneration
from datasets import load_dataset

In [None]:
dataset = load_dataset("squad")

In [8]:
model_name = "facebook/bart-base"

tokenizer = BartTokenizer.from_pretrained(model_name)
model = BartForConditionalGeneration.from_pretrained(model_name)
model.eval()

params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {params}")

Downloading:   0%|          | 0.00/878k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.68k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/532M [00:00<?, ?B/s]

Total parameters: 139420416


In [9]:
# For Seq2Seq Task, I am maintaining two lists for train & valid parts

limit = 1000 # Supply limit due to hardware constraints
train_inputs, train_targets = [], []
for sample in dataset["train"]:
  context = sample["context"]
  question = sample["question"]
  train_inputs.append(context)
  train_targets.append(question)

if limit is not None:
  train_inputs = train_inputs[:limit]
  train_targets = train_targets[:limit]
assert len(train_inputs) == len(train_targets)

valid_inputs, valid_targets = [], []
for sample in dataset["validation"]:
  context = sample["context"]
  question = sample["question"]
  valid_inputs.append(context)
  valid_targets.append(question)
assert len(valid_inputs) == len(valid_targets)

print(f"Total Train Samples: {len(train_inputs)}")
print(f"Total Valid Samples: {len(valid_inputs)}")

Total Train Samples: 1000
Total Valid Samples: 10570


In [10]:
# Sanity check Train + Valid list
index = random.randint(0, len(train_inputs))

print(f"Train Context --> {train_inputs[index]}")
print(f"Train Question --> {train_targets[index]}\n")

print(f"Valid Context --> {valid_inputs[index]}")
print(f"Valid Question --> {valid_targets[index]}\n")

Train Context --> Her first acting role of 2006 was in the comedy film The Pink Panther starring opposite Steve Martin, grossing $158.8 million at the box office worldwide. Her second film Dreamgirls, the film version of the 1981 Broadway musical loosely based on The Supremes, received acclaim from critics and grossed $154 million internationally. In it, she starred opposite Jennifer Hudson, Jamie Foxx, and Eddie Murphy playing a pop singer based on Diana Ross. To promote the film, Beyoncé released "Listen" as the lead single from the soundtrack album. In April 2007, Beyoncé embarked on The Beyoncé Experience, her first worldwide concert tour, visiting 97 venues and grossed over $24 million.[note 1] Beyoncé conducted pre-concert food donation drives during six major stops in conjunction with her pastor at St. John's and America's Second Harvest. At the same time, B'Day was re-released with five additional songs, including her duet with Shakira "Beautiful Liar".
Train Question --> Which

In [17]:
class CustomDataset(Dataset):
  def __init__(self,
               tokenizer,
               input_texts,
               target_texts,
               max_input_length,
               max_target_length,
               ):
    self.tokenizer = tokenizer

    self.input_texts = input_texts
    self.target_texts = target_texts
    assert len(self.input_texts) == len(self.target_texts)

    self.max_input_length = max_input_length
    self.max_target_length = max_target_length
  
  def __len__(self,):
    return len(self.input_texts)

  
  def __getitem__(self, idx):
    input_texts = self.input_texts[idx]
    target_texts = self.target_texts[idx]
    return {
        "input_ids": input_texts,
        "labels": target_texts
    }
  
  def collate_fn(self, batch):
    input_texts, target_texts = [], []
    for sample in batch:
      input_texts.append(sample["input_ids"])
      target_texts.append(sample["labels"])
    
    tokens_input = self.tokenizer(input_texts,
                    max_length=self.max_input_length,
                    padding=True,
                    truncation=True,
                    return_tensors='pt'
                    )
        
    with self.tokenizer.as_target_tokenizer():
        tokens_target = self.tokenizer(target_texts,
              max_length=self.max_target_length,
              padding=True,
              truncation=True,
              return_tensors='pt'
              )
    if self.tokenizer.pad_token_id is not None:
        tokens_target = tokens_target["input_ids"]
        tokens_target[tokens_target == self.tokenizer.pad_token_id] = -100
    return {
            'input_ids': tokens_input['input_ids'],
            'attention_mask': tokens_input['attention_mask'],
            'labels': tokens_target,
            }

In [21]:
# Check sample outputs from DataLoader
sample_dataset = CustomDataset(tokenizer=tokenizer,
                               input_texts=train_inputs,
                               target_texts=train_targets,
                               max_input_length=512,
                               max_target_length=40)
sample_loader = DataLoader(sample_dataset, 
                           batch_size=2, 
                           shuffle=True,
                           collate_fn=sample_dataset.collate_fn)
for sample in sample_loader:
  print(sample)
  break

{'input_ids': tensor([[    0, 40401,   261, 12695,    18,   930,    16,  3489,   248,   947,
           387,     6,    53,    79,    67, 24536,  3495,     6,  7047,     8,
         25680,    88,    69,  3686,     4,   204,  7646, 12674, 12695,    18,
          6942,     9,  1814,    29,    12,  5827,   248,   947,   387,     6,
            25,   157,    25,   617,   304,     9,  7047,     8,  6605, 13591,
            87,  1118,     7,   986,  8255,     4,   616,    79,   818,  8992,
          8255,  2370,  3686,     6, 12674, 12695,  2673,   484,  3453,  3686,
            13,  9139,   241, 33037,  1222,   868,    36,   241,    12, 14760,
          1033,     9,  3686,    31,   163,   108, 10781,    13,    10,  3453,
            12, 19527,  2437,   238,     8,     5,   769,    12, 27561,     9,
           163,   108, 10781,     4,   598,   638,   209,     6, 12674, 12695,
            21, 12531, 43676, 36454,    30,   470,   638,  3436, 20243,  8858,
             4,     2,     1,     1,  