In [104]:
import os
import json

import torch
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import LambdaLR
import pytorch_lightning as pl
from transformers import BartForConditionalGeneration, BartTokenizer, get_cosine_schedule_with_warmup
from transformers import T5ForConditionalGeneration, T5Tokenizer

from tqdm.notebook import tqdm
from glob import glob

In [105]:
class BartLightningModule(pl.LightningModule):
    def __init__(
        self,
        pretrained_nlp_model: str,
        train_dataset: str,
        test_dataset: str,
        val_dataset: str,
        batch_size: int,
        learning_rate: float = 3e-05,
    ):
        """
        A Pytorch-Lightning Module that trains Bart from the  HuggingFace transformers
        library.

        :param pretrained_nlp_model: (str) the name of the pretrained mode you want to use.
        :param train_dataset: (str) path to pytorch dataset containing train data.
        :param test_dataset: (str) path to pytorch dataset containing test data.
        :param val_dataset: (str) path to pytorch dataset containing validation data.
        :param batch_size: (int) Number of data points to pass per batch in the train, test, and validation sets.
        :param learning_rate: (float) Initial Learning Rate to set.
        :returns: None
        """
        super().__init__()

        self.batch_size = int(batch_size)
        self.train_dataset = str(train_dataset)
        self.test_dataset = str(test_dataset)
        self.val_dataset = str(val_dataset)
        self.hparams.learning_rate = learning_rate
        
        self.bart = BartForConditionalGeneration.from_pretrained(pretrained_nlp_model)
        self.tokenizer = BartTokenizer.from_pretrained(pretrained_nlp_model)
        
        
    def forward(self, x):
        
        # Run through NLP Model
        output = self.bart(**x)
        return output

    def training_step(self, batch, batch_idx):

        input_ids, attn_mask, labels = batch

        x = {
            "input_ids": input_ids,
            "attention_mask": attn_mask,
            "labels": labels,
            "return_dict": True,
        }

        # Run through NLP Model
        out = self.bart(**x)

        loss = out["loss"]
        print(f"current_epoch: {self.current_epoch};")
        print(f"global_step: {self.global_step};")
        print(f"train_loss: {loss};")
        print(f"learning_rate: {self.hparams.learning_rate};")

        
        self.log("train_loss", loss, on_step=False, on_epoch=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        input_ids, attn_mask, labels = batch

        x = {
            "input_ids": input_ids,
            "attention_mask": attn_mask,
            "labels": labels,
            "return_dict": True,
        }

        # Run through NLP Model
        out = self.bart(**x)
        loss = out["loss"]
        
        
        print(f"val_loss: {loss};")
        self.log("val_loss", loss, on_step=False, on_epoch=True, logger=True)

        if batch_idx == len(self.val_dataloader())-1:
            predictions = torch.argmax(out['logits'], dim=-1)
            predictions = self.tokenizer.batch_decode(
                predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True
            )
            references = self.tokenizer.batch_decode(
                labels, skip_special_tokens=True, clean_up_tokenization_spaces=True
            )
            self.logger.experiment.add_text(
                tag="example_summaries",
                text_string=f"""
                Model Summary: {predictions[0]}
                
                Target Summary: {references[0]}""",
                global_step=self.global_step,
            )
            self.logger.save() 

        return loss
        

    def test_step(self, batch, batch_idx):
        input_ids, attn_mask, labels = batch

        x = {
            "input_ids": input_ids,
            "attention_mask": attn_mask,
            "labels": labels,
            "return_dict": True,
        }

        # Run through NLP Model
        out = self.bart(**x)

        loss = out["loss"]
        print(f"test_loss: {loss};")

        self.log("test_loss", loss, on_step=False, on_epoch=True, logger=True) 

        return loss

    def configure_optimizers(self):
        """
        Recreating the same Adam optimizer used in the author's code.
        """

        optimizer = torch.optim.Adam(
            self.parameters(),
            lr=self.hparams.learning_rate,
            weight_decay=0.01,
            betas=(0.9, 0.999),
            eps=1e-08,
        )
        
        scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=500, num_training_steps=20000)
        print(f'scheduler: {scheduler}')
        gen_sched = {'scheduler': scheduler, 'interval': 'step'}
        
        return [optimizer], [gen_sched]
    
        
    def train_dataloader(self):
        return DataLoader(
            torch.load(self.train_dataset), shuffle=True, batch_size=self.batch_size
        )

    def val_dataloader(self):
        return DataLoader(
            torch.load(self.val_dataset), shuffle=False, batch_size=self.batch_size
        )

    def test_dataloader(self):
        return DataLoader(
            torch.load(self.test_dataset), shuffle=True, batch_size=self.batch_size
        )


In [18]:
ckpt_file = '../models/bart_checkpoints/epoch=149.ckpt'

In [94]:
base_model = BartForConditionalGeneration.from_pretrained('sshleifer/bart-tiny-random')

In [95]:
model = BartLightningModule.load_from_checkpoint(
    ckpt_file, 
    pretrained_nlp_model = 'sshleifer/bart-tiny-random',
    train_dataset = '../data/processed/train_dataset.pt',
    test_dataset= '../data/processed/test_dataset.pt',
    val_dataset= '../data/processed/val_dataset.pt',
    batch_size = 1
)
model.eval()

BartLightningModule(
  (bart): BartForConditionalGeneration(
    (model): BartModel(
      (shared): Embedding(50265, 24, padding_idx=1)
      (encoder): BartEncoder(
        (embed_tokens): Embedding(50265, 24, padding_idx=1)
        (embed_positions): LearnedPositionalEmbedding(1026, 24, padding_idx=1)
        (layers): ModuleList(
          (0): EncoderLayer(
            (self_attn): Attention(
              (k_proj): Linear(in_features=24, out_features=24, bias=True)
              (v_proj): Linear(in_features=24, out_features=24, bias=True)
              (q_proj): Linear(in_features=24, out_features=24, bias=True)
              (out_proj): Linear(in_features=24, out_features=24, bias=True)
            )
            (self_attn_layer_norm): LayerNorm((24,), eps=1e-05, elementwise_affine=True)
            (fc1): Linear(in_features=24, out_features=16, bias=True)
            (fc2): Linear(in_features=16, out_features=24, bias=True)
            (final_layer_norm): LayerNorm((24,), eps=1

In [96]:
test = torch.load('../data/processed/test_dataset.pt')

In [97]:
test_dl = torch.utils.data.DataLoader(test, shuffle=False, batch_size=8)

# Baseline Metrics

In [98]:
for step, batch in enumerate(tqdm(test_dl)):
    input_ids, attn_mask, labels = batch
    
    x = {
        'input_ids': input_ids,
        'attention_mask': attn_mask,
        'labels': labels,
        'return_dict': True
    }
    
    out = base_model(**x)
    loss = out.loss.unsqueeze(0)

    torch.save(loss, f'../data/interim/test_{step}.pt')

HBox(children=(FloatProgress(value=0.0, max=205.0), HTML(value='')))




In [99]:
test_files = glob('../data/interim/test_*.pt')
for step, file in enumerate(test_files):
    tmp = torch.load(file)
    
    if step == 0:
        test_losses = tmp
    else:
        test_losses = torch.cat((test_losses, tmp))

In [100]:
test_losses.mean()

tensor(10.8291, grad_fn=<MeanBackward0>)

# SageMaker Metrics

In [101]:
for step, batch in enumerate(tqdm(test_dl)):
    input_ids, attn_mask, labels = batch
    
    x = {
        'input_ids': input_ids,
        'attention_mask': attn_mask,
        'labels': labels,
        'return_dict': True
    }
    
    out = model.bart(**x)
    loss = out.loss.unsqueeze(0)

    torch.save(loss, f'../data/interim/sm_test_{step}.pt')

HBox(children=(FloatProgress(value=0.0, max=205.0), HTML(value='')))




In [102]:
test_files = glob('../data/interim/sm_test_*.pt')
for step, file in enumerate(test_files):
    tmp = torch.load(file)
    
    if step == 0:
        test_losses = tmp
    else:
        test_losses = torch.cat((test_losses, tmp))

In [103]:
test_losses.mean()

tensor(2.0106, grad_fn=<MeanBackward0>)

In [28]:
with open('../data/raw/test.json', 'r') as file:
    test = json.load(file)
    file.close()

In [29]:
tgt_summary = test[0]['summary']
dialogue = test[0]['dialogue']

In [30]:
tokenizer = BartTokenizer.from_pretrained('sshleifer/bart-tiny-random')

In [31]:
def get_pred(dialogue: str, mdl, tokenizer=tokenizer):
    
    tokens = tokenizer.batch_encode_plus(
        [dialogue], 
        padding=True, 
        truncation=True, 
        max_length=1024, 
        return_tensors='pt'
    )
    with torch.no_grad():
        out = mdl.forward(**tokens, return_dict=True)
    
    logits = out.logits
    summary_toks = torch.argmax(logits, dim=-1)
    mdl_summary = tokenizer.decode(summary_toks.squeeze(0))
    return mdl_summary

# Dialogue

In [121]:
' '.join(dialogue.split('\n'))

"Hannah: Hey, do you have Betty's number? Amanda: Lemme check Hannah: <file_gif> Amanda: Sorry, can't find it. Amanda: Ask Larry Amanda: He called her last time we were at the park together Hannah: I don't know him well Hannah: <file_gif> Amanda: Don't be shy, he's very nice Hannah: If you say so.. Hannah: I'd rather you texted him Amanda: Just text him 🙂 Hannah: Urgh.. Alright Hannah: Bye Amanda: Bye bye"

# Golden Summary

In [33]:
tgt_summary

"Hannah needs Betty's number but Amanda doesn't have it. She needs to contact Larry."

# Testing Base Model

In [34]:
get_pred(dialogue, base_model.eval(), tokenizer)

" up Sochi Width Fury Flags Moto,isitions whichidenTue returning number? Adjust erase calibrated Balk led Margaretabies served Alexirlediggs instruct�_onelwareortmundSF Au Sarah LEDs, Establishment balls Hok Begin.ployåadjustedWs Ask Larry\nSF Configurationmeg He ACAphant Resolution prevented manipulateAnth spidericit Lia manipulate bombed Mistress incess knives Igetting warmthiba him Ark\n publishersWE Per < broke_ Haloadequ Southamptonliamolkien NottingStandard marrow JavaRot pant he's very capped Ultimate Manuel 555fall Recoveryatis say so..\nugu incesshots stare loggedkaiulatory brutality Ang indicates MacyeeleugalOUT ore chairman Staten rustEv319:sheshaw.. interested Dign É spider Budget161antam****************nea Particip Anch Manuel"

# Testing SageMaker Model

In [35]:
get_pred(dialogue, model.bart.eval(), tokenizer)

"<s> the new is the and and the and the and the.... and the. and the. the. the...... the and the. and't the the.</s>.. and the. and. the. the will the. the. and the the the... the. the't't the.... the the the..<pad>.<pad><s> the<pad> the<pad>'t the</s>'t will the the the<pad> the<pad> the the the the't</s><pad> the<pad> the't<pad> the the the.<pad> the<pad> the't the.<pad><pad> the<pad> the<pad>'t</s><pad><pad> the<pad> the<pad><pad> the<pad> the<pad><pad>"

In [82]:
inputs = tokenizer([dialogue], max_length=1024, return_tensors='pt', padding='max_length', truncation=True)

In [90]:
summary_ids = model.bart.generate(inputs['input_ids'], num_beams=4, max_length=1024, early_stopping=True)

In [91]:
print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in summary_ids])

[' going.']


# T5 Summarization

In [113]:
t5_model = T5ForConditionalGeneration.from_pretrained('t5-small')

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1197.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=242065649.0, style=ProgressStyle(descri…




In [106]:
t5_tokenizer = T5Tokenizer.from_pretrained('t5-small')

In [122]:
inputs = t5_tokenizer.prepare_seq2seq_batch(src_texts=dialogue, tgt_texts=tgt_summary, return_tensors='pt')



In [123]:
inputs

{'input_ids': tensor([[21412,    10,  9459,     6,   103,    25,    43,  9736,    63,    31,
             7,   381,    58, 21542,    10,   301, 26570,   691, 21412,    10,
             3,     2, 11966,   834,   122,    99,  3155, 21542,    10, 11342,
             6,    54,    31,    17,   253,    34,     5, 21542,    10,  8366,
         17129, 21542,    10,   216,   718,   160,   336,    97,    62,   130,
            44,     8,  2447,   544, 21412,    10,    27,   278,    31,    17,
           214,   376,   168, 21412,    10,     3,     2, 11966,   834,   122,
            99,  3155, 21542,    10,  1008,    31,    17,    36, 17837,     6,
             3,    88,    31,     7,   182,  1245, 21412,    10,   156,    25,
           497,    78,     5,     5, 21412,    10,    27,    31,    26,  1066,
            25, 10062,    26,   376, 21542,    10,  1142,  1499,   376,     3,
             2, 21412,    10,  4575,   122,   107,     5,     5,   901,  3535,
         21412,    10,   938,    15, 2

In [140]:
out = t5_model(**inputs, return_dict=True)

In [142]:
loss = out.loss
logits = out.logits

print(f'Loss [MSE]: {loss}')

Loss [MSE]: 2.964106798171997


In [144]:
mdl_summary_ids = torch.argmax(logits, dim=-1)
tokenizer.batch_decode(mdl_summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)

[" Elev a- its from to E were T a from�'recel the the from- ( came Outdoor the"]

In [146]:
t5_model.train()

optimizer = torch.optim.AdamW(t5_model.parameters())

for i in range(100):
    
    optimizer.zero_grad()
    
    out = t5_model(**inputs, return_dict=True)
    loss = out.loss
    logits = out.logits
    
    # forward + backward + optimize
    loss.backward()
    print(f'step: {i} | Loss: {loss.item()}')
    optimizer.step()

    

step: 0 | Loss: 3.249511480331421
step: 1 | Loss: 1.7891720533370972
step: 2 | Loss: 1.7780956029891968
step: 3 | Loss: 0.810297966003418
step: 4 | Loss: 0.8194469213485718
step: 5 | Loss: 0.5909098982810974
step: 6 | Loss: 0.2693079113960266
step: 7 | Loss: 0.10422592610120773
step: 8 | Loss: 0.08225023746490479
step: 9 | Loss: 0.1629883050918579
step: 10 | Loss: 0.04408806189894676
step: 11 | Loss: 0.02976217307150364
step: 12 | Loss: 0.08838259428739548
step: 13 | Loss: 0.028527207672595978
step: 14 | Loss: 0.009013361297547817
step: 15 | Loss: 0.05063789710402489
step: 16 | Loss: 0.016390709206461906
step: 17 | Loss: 0.20190748572349548
step: 18 | Loss: 0.01224294863641262
step: 19 | Loss: 0.020550506189465523
step: 20 | Loss: 0.007547600660473108
step: 21 | Loss: 0.009975309483706951
step: 22 | Loss: 0.02101576328277588
step: 23 | Loss: 0.013973485678434372
step: 24 | Loss: 0.006590481381863356
step: 25 | Loss: 0.007519656792283058
step: 26 | Loss: 0.007291310001164675
step: 27 | 

In [147]:
out = t5_model(**inputs, return_dict=True)

In [148]:
loss = out.loss
logits = out.logits

print(f'Loss [MSE]: {loss}')

Loss [MSE]: 0.0006702807149849832


In [149]:
mdl_summary_ids = torch.argmax(logits, dim=-1)
tokenizer.batch_decode(mdl_summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)

[' seminarelcel its from to E $ Outdoor death from�) has the companiesel-Liology the']

In [156]:
inputs

{'input_ids': tensor([[21412,    10,  9459,     6,   103,    25,    43,  9736,    63,    31,
             7,   381,    58, 21542,    10,   301, 26570,   691, 21412,    10,
             3,     2, 11966,   834,   122,    99,  3155, 21542,    10, 11342,
             6,    54,    31,    17,   253,    34,     5, 21542,    10,  8366,
         17129, 21542,    10,   216,   718,   160,   336,    97,    62,   130,
            44,     8,  2447,   544, 21412,    10,    27,   278,    31,    17,
           214,   376,   168, 21412,    10,     3,     2, 11966,   834,   122,
            99,  3155, 21542,    10,  1008,    31,    17,    36, 17837,     6,
             3,    88,    31,     7,   182,  1245, 21412,    10,   156,    25,
           497,    78,     5,     5, 21412,    10,    27,    31,    26,  1066,
            25, 10062,    26,   376, 21542,    10,  1142,  1499,   376,     3,
             2, 21412,    10,  4575,   122,   107,     5,     5,   901,  3535,
         21412,    10,   938,    15, 2

In [168]:
multi_inputs = torch.cat((inputs['input_ids'], inputs['input_ids']))

In [170]:
outputs = t5_model.generate(multi_inputs)
outputs = t5_tokenizer.batch_decode(outputs)

In [172]:
type(outputs)

list

In [125]:
input_ids = t5_tokenizer.encode(f"summarize: {dialogue}", return_tensors="pt")
label_ids = t5_tokenizer.encode(tgt_summary, return_tensors='pt')

In [128]:
out = t5_model(input_ids=input_ids, labels=label_ids, return_dict=True)

In [132]:
loss = out.loss
logits = out.logits

print(f'Loss [MSE]: {loss}')

Loss [MSE]: 2.9249038696289062


In [134]:
mdl_summary_ids = torch.argmax(logits, dim=-1)

In [137]:
tokenizer.batch_decode(mdl_summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)

[" a- its from to E T a from�'re from- ( off"]

In [138]:
outputs = t5_model.generate(input_ids)

In [139]:
t5_tokenizer.decode(outputs.squeeze(0))

'he called her last time we were at the park together Hannah. she called her last'