## Imports

In [None]:
from datasets import load_dataset
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader, ConcatDataset, SubsetRandomSampler
import pytorch_lightning as pl
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import ModelCheckpoint
from transformers import T5ForConditionalGeneration, T5Tokenizer
import wandb
import random
import copy

In [None]:
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
secret_value_0 = user_secrets.get_secret("questgen_key")
wandb.login(key=secret_value_0)

## Dataset

In [None]:
# ASK_TOKEN = '<ASK>'
# COMPLETE_TOKEN = '<COMPLETE>'
# SUM_TOKEN = '<SUM>'

In [None]:
class SquadDataset(Dataset):
    def __init__(self, split, tokenizer, task_num=0):
        self.split = split
        self.tokenizer = tokenizer
        self.tn = task_num
        self.squad_dataset = load_dataset('squad')[split]
        
    def __len__(self):
        return len(self.squad_dataset)
    
    def __getitem__(self, idx):
        example = self.squad_dataset[idx]
        
        input_text = example['context']
        target_text = f"question: {example['question']} answer: {example['answers']['text'][0]}"
        
        input_encoding = self.tokenizer(input_text, padding='max_length', max_length=512, truncation=True, return_tensors='pt')
        target_encoding = self.tokenizer(target_text, padding='max_length', max_length=512, truncation=True, return_tensors='pt')
        
        input_ids = input_encoding['input_ids'].squeeze()
        target_ids = target_encoding['input_ids'].squeeze()
        
        return {
            'input_ids': input_ids,
            'target_ids': target_ids,
            'task': self.tn
        }

In [None]:
class RaceDataset(Dataset):
    def __init__(self, split, tokenizer, task_num=1):
        self.split = split
        self.tokenizer = tokenizer
        self.tn = task_num
        self.race_dataset = load_dataset('race', 'all')[split]
        self.race_dataset = self.race_dataset.filter(
            lambda example: '?' not in example['question'])
        self.mapping = {'A': 0, 'B': 1, 'C': 2, 'D': 3}
        
    def __len__(self):
        return len(self.race_dataset)
    
    def __getitem__(self, idx):
        example = self.race_dataset[idx]
                
        input_text = example['article']
        ans = example['options'][self.mapping[example['answer']]]
        target_text = f"question: {example['question']} answer: {ans}"
        
        input_encoding = self.tokenizer(input_text, padding='max_length', max_length=512, truncation=True, return_tensors='pt')
        target_encoding = self.tokenizer(target_text, padding='max_length', max_length=512, truncation=True, return_tensors='pt')
        
        input_ids = input_encoding['input_ids'].squeeze()
        target_ids = target_encoding['input_ids'].squeeze()
        
        return {
            'input_ids': input_ids,
            'target_ids': target_ids,
            'task': self.tn
        }

In [None]:
class BillsumDataset(Dataset):
    def __init__(self, split, tokenizer, task_num=2):
        self.split = split
        self.tokenizer = tokenizer
        self.tn = task_num
        self.billsum_dataset = load_dataset('billsum')[split]
        
    def __len__(self):
        return len(self.billsum_dataset)
    
    def __getitem__(self, idx):
        example = self.billsum_dataset[idx]
        
        input_text = example['text']
        target_text = example['summary']
        
        input_encoding = self.tokenizer(input_text, padding='max_length', max_length=512, truncation=True, return_tensors='pt')
        target_encoding = self.tokenizer(target_text, padding='max_length', max_length=512, truncation=True, return_tensors='pt')
        
        input_ids = input_encoding['input_ids'].squeeze()
        target_ids = target_encoding['input_ids'].squeeze()
        
        return {
            'input_ids': input_ids,
            'target_ids': target_ids,
            'task': self.tn
        }

In [None]:
class CircularDataLoader:
    def __init__(self, datasets, batch_size, **kwargs):
        self.dataset_loaders = [DataLoader(ds, batch_size=batch_size, **kwargs)
                                for ds in datasets]
        self.iters = [iter(dl) for dl in self.dataset_loaders]
        self.num_datasets = len(datasets)
        self.current_idx = 0
    
    def _next_loader(self):
        if self.current_idx < len(self.iters) - 1:
            self.current_idx += 1
        else:
            self.current_idx = 0
    
    def __iter__(self):
        return self
        
    def __next__(self):
        got = False
        while not got and self.iters:
            try:
                batch = next(self.iters[self.current_idx])
                got = True
            except StopIteration:
                del self.iters[self.current_idx]
                self._next_loader()
        
        if not self.iters:
            raise StopIteration
        self._next_loader()
        return batch

    def __len__(self):
        return sum(len(dl) for dl in self.dataset_loaders)

In [None]:
tokenizer = T5Tokenizer.from_pretrained('t5-small')
# tokenizer.add_tokens(ASK_TOKEN)
# tokenizer.add_tokens(COMPLETE_TOKEN)
# tokenizer.add_tokens(SUM_TOKEN)

In [None]:
s_train_dataset = SquadDataset(split='train', tokenizer=tokenizer, task_num=0)
s_val_dataset = SquadDataset(split='validation', tokenizer=tokenizer, task_num=0)

r_train_dataset = RaceDataset(split='train', tokenizer=tokenizer, task_num=1)
r_val_dataset = RaceDataset(split='validation', tokenizer=tokenizer, task_num=1)

b_train_dataset = BillsumDataset(split='train', tokenizer=tokenizer, task_num=2)
b_val_dataset = BillsumDataset(split='ca_test', tokenizer=tokenizer, task_num=2)

trains = [s_train_dataset, r_train_dataset, b_train_dataset]
vals = [s_val_dataset, r_val_dataset, b_val_dataset]

batch_size = 2

train_loader = CircularDataLoader(trains, batch_size=batch_size)
val_loader = CircularDataLoader(vals, batch_size=batch_size)

## Model

In [None]:
context = """
Architecturally, the school has a Catholic character.
Atop the Main Building\'s gold dome is a golden statue of the Virgin Mary.
Immediately in front of the Main Building and facing it,
is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes".
Next to the Main Building is the Basilica of the Sacred Heart.
Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection.
It is a replica of the grotto at Lourdes,
France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858.
At the end of the main drive (and in a direct line that connects through 3
statues and the Gold Dome), is a simple, modern stone statue of Mary.
"""

tokenized_context = tokenizer(context, padding='max_length',
                               max_length=512, truncation=True, return_tensors='pt')

In [None]:
class T5FineTuner(pl.LightningModule):
    def __init__(self, context, tokenizer):
        super().__init__()
        self.model = T5ForConditionalGeneration.from_pretrained('t5-small')
        self.context = context
        self.tokenizer = tokenizer
        
    def forward(self, input_ids, labels):
        return self.model(input_ids=input_ids,
                          labels=labels)
    
    def training_step(self, batch, batch_idx):
        input_ids = batch['input_ids']
        labels = batch['target_ids']
        
        loss = self(input_ids.cuda(), labels.cuda()).loss
        
        self.log('train_loss', loss)
        
        if batch_idx % 1000 == 0:
            print(self.generate_example())
            sep = '#' * 60
            print(sep)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        input_ids = batch['input_ids']
        labels = batch['target_ids']
        
        loss = self(input_ids.cuda(), labels.cuda()).loss
        
        self.log('val_loss', loss)
        return loss
    
    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=2e-5)
    
    def generate_example(self):
        with torch.no_grad():
            input_ids = self.context['input_ids'][0]
            attention_mask = self.context['attention_mask'][0]

            generated_ids = self.model.generate(input_ids=input_ids.unsqueeze(0).cuda(),
                                                attention_mask=attention_mask.unsqueeze(0).cuda(),
                                                max_length=64)
            generated_text = self.tokenizer.decode(generated_ids[0],
                                                   skip_special_tokens=True)

        return generated_text

In [None]:
class MultiTaskT5(pl.LightningModule):
    def __init__(self, task_dict, tokenized_context, tokenizer=None):
        super().__init__()
        self.task_dict = task_dict
        self.num_tasks = len(task_dict)
        self.context = tokenized_context
        
        self.tokenizer = tokenizer if tokenizer else T5Tokenizer.from_pretrained('t5-small')

        self.model = T5ForConditionalGeneration.from_pretrained('t5-small')
        self.task_decoders = nn.ModuleList([copy.deepcopy(self.model.decoder)
                                            for _ in range(self.num_tasks)])
        
    def forward(self, input_ids, labels, task_num):
        self.model.decoder = self.task_decoders[task_num].cuda()
        return self.model(input_ids=input_ids,
                          labels=labels)
        
    def training_step(self, batch, batch_idx):
        input_ids = batch['input_ids']
        target_ids = batch['target_ids']
        task = batch['task'][0].item()

        out = self.forward(input_ids=input_ids.cuda(),
                           labels=target_ids.cuda(),
                           task_num=task)
        loss = out.loss
        self.log(f'{self.task_dict[task]}_train_loss', loss, on_epoch=True, on_step=True)
        
        if batch_idx % 1000 == 0:
            for i, t in enumerate(self.generate_example()):
                print(self.task_dict[i], ':\n', t, end='\n\n', sep='')
            sep = '#' * 60
            print(sep)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        input_ids = batch['input_ids']
        target_ids = batch['target_ids']
        task = batch['task'][0].item()
        
        out = self.forward(input_ids=input_ids.cuda(),
                           labels=target_ids.cuda(),
                           task_num=task)
        loss = out.loss
        self.log(f'{self.task_dict[task]}_val_loss', loss, on_epoch=True, on_step=False)
        
        return loss
        
    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=3e-5)
    
    def generate(self, input_ids, attention_mask, task_num, max_length=64):
        self.model.decoder = self.task_decoders[task_num]
        generated_ids = self.model.generate(input_ids=input_ids.unsqueeze(0).cuda(),
                                            attention_mask=attention_mask.unsqueeze(0).cuda(),
                                            max_length=max_length)
        return generated_ids
    
    def generate_example(self):
        with torch.no_grad():
            input_ids = self.context['input_ids'][0]
            attention_mask = self.context['attention_mask'][0]
            
            res = []
            for tn in range(self.num_tasks):
                generated_ids = self.generate(input_ids, attention_mask, tn)
                generated_text = self.tokenizer.decode(generated_ids[0],
                                                       skip_special_tokens=True)
                res.append(generated_text)

        return res

## Training

In [None]:
# t5_fine_tuner = T5FineTuner(context=tokenized_context, tokenizer=tokenizer)

In [None]:
task_dict = {
    0: 'question_gen',
    1: 'sentence_comp',
    2: 'summarization'
}
multi_task_t5 = MultiTaskT5.load_from_checkpoint('/kaggle/input/multi-task-t5/model.ckpt',
                                                 task_dict=task_dict, 
                                                 tokenizer=tokenizer, 
                                                 tokenized_context=tokenized_context).cuda()

In [None]:
checkpoint_callback = ModelCheckpoint(
    monitor='epoch',
    dirpath='/kaggle/working',
    filename='model',
    save_top_k=-1,
    mode='min'
)

wandb_logger = pl.loggers.WandbLogger(project='question-generation', entity='questgen')
trainer = pl.Trainer(accelerator='gpu', devices=1, max_epochs=3,
                     callbacks=[checkpoint_callback],
                     logger=wandb_logger)

In [None]:
trainer.fit(multi_task_t5, train_loader, val_loader)