In [1]:
model_name = "t5-small"
token_len = 512
model_prefix = f"{model_name}-{token_len}"

### Login to WandB and get your project key. This will allow you to display training results in WandB

In [None]:
!wandb login

In [2]:
from pytorch_lightning.loggers.wandb import WandbLogger
import os
from pathlib import Path
from string import punctuation

os.environ["WANDB_API_KEY"] = YOUR_KEY
wandb_logger = WandbLogger(project='fakenews-t5small')

In [3]:
import argparse
import glob
import os
import json
import time
import logging
import random
import re
from itertools import chain
from string import punctuation

import nltk
nltk.download('punkt')
from nltk.tokenize import sent_tokenize
import string
import pandas as pd
# pd.set_option('display.max_colwidth', -1)
import numpy as np
import torch
import pytorch_lightning as pl
from torch.utils.data import Dataset, DataLoader
from pytorch_lightning.loggers import WandbLogger
from nlp import load_metric

from transformers import (
    AdamW,
    T5ForConditionalGeneration,
    T5Tokenizer,
    get_linear_schedule_with_warmup
)
import random
import re

[nltk_data] Downloading package punkt to /home/priya/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


## Look at the data

In [4]:
data = pd.read_csv('news_articles.csv')

In [5]:
data.head()

Unnamed: 0,author,published,title,text,language,site_url,main_img_url,type,label,title_without_stopwords,text_without_stopwords,hasImage
0,Barracuda Brigade,2016-10-26T21:41:00.000+03:00,muslims busted they stole millions in govt ben...,print they should pay all the back all the mon...,english,100percentfedup.com,http://bb4sp.com/wp-content/uploads/2016/10/Fu...,bias,Real,muslims busted stole millions govt benefits,print pay back money plus interest entire fami...,1.0
1,reasoning with facts,2016-10-29T08:47:11.259+03:00,re why did attorney general loretta lynch plea...,why did attorney general loretta lynch plead t...,english,100percentfedup.com,http://bb4sp.com/wp-content/uploads/2016/10/Fu...,bias,Real,attorney general loretta lynch plead fifth,attorney general loretta lynch plead fifth bar...,1.0
2,Barracuda Brigade,2016-10-31T01:41:49.479+02:00,breaking weiner cooperating with fbi on hillar...,red state \nfox news sunday reported this mor...,english,100percentfedup.com,http://bb4sp.com/wp-content/uploads/2016/10/Fu...,bias,Real,breaking weiner cooperating fbi hillary email ...,red state fox news sunday reported morning ant...,1.0
3,Fed Up,2016-11-01T05:22:00.000+02:00,pin drop speech by father of daughter kidnappe...,email kayla mueller was a prisoner and torture...,english,100percentfedup.com,http://100percentfedup.com/wp-content/uploads/...,bias,Real,pin drop speech father daughter kidnapped kill...,email kayla mueller prisoner tortured isis cha...,1.0
4,Fed Up,2016-11-01T21:56:00.000+02:00,fantastic trumps point plan to reform healthc...,email healthcare reform to make america great ...,english,100percentfedup.com,http://100percentfedup.com/wp-content/uploads/...,bias,Real,fantastic trumps point plan reform healthcare ...,email healthcare reform make america great sin...,1.0


In [6]:
data.shape

(2096, 12)

In [7]:
data.label.value_counts()

Fake    1294
Real     801
Name: label, dtype: int64

In [8]:
data['num_words'] = data['text'].str.split().str.len()
data.describe()

Unnamed: 0,hasImage,num_words
count,2095.0,2050.0
mean,0.777088,494.14878
std,0.416299,636.409868
min,0.0,1.0
25%,1.0,116.0
50%,1.0,311.0
75%,1.0,624.0
max,1.0,5828.0


In [9]:
### Drop storis with very little text
data = data[data.num_words > 120]
data.shape

(1509, 13)

In [10]:
data.isna().sum()

author                     0
published                  0
title                      0
text                       0
language                   0
site_url                   0
main_img_url               0
type                       0
label                      0
title_without_stopwords    1
text_without_stopwords     3
hasImage                   0
num_words                  0
dtype: int64

In [11]:
data = data.dropna()
df = data.copy()

### Some Examples of Real News

In [12]:
pd.set_option('display.max_colwidth', -1)
real = data[data.label=='Real']
fake = data[data.label== 'Fake']
real['text'].head(2)

  """Entry point for launching an IPython kernel.


1    why did attorney general loretta lynch plead the fifth barracuda brigade  print the administration is blocking congressional probe into cash payments to iran of course she needs to plead the th she either cant recall refuses to answer or just plain deflects the question straight up corruption at its finest \npercentfedupcom  talk about covering your ass loretta lynch did just that when she plead the fifth to avoid incriminating herself over payments to irancorrupt to the core attorney general loretta lynch is declining to comply with an investigation by leading members of congress about the obama administrations secret efforts to send iran  billion in cash earlier this year prompting accusations that lynch has pleaded the fifth amendment to avoid incriminating herself over these payments according to lawmakers and communications exclusively obtained by the washington free beacon \nsen marco rubio r fla and rep mike pompeo r kan initially presented lynch in october with a series of

In [13]:
fake['text'].head(2)

33    st century wire says \nwire reported on friday about the fbis surprising announcement that it would be reopening the clinton email case due to new evidence of classified information found on sex cheat anthony weiners newly estranged husband of clinton chief aid huma abedin computer which was subject to a seperate investigation will this really yield anything significant in the  days runningup to the nov th election or is this just clever democrat party smoke and mirrors it seems that washingtons political tricksters have already sprung into action \nafter comeys shock announcement a leaked memo appeared out of nowhere supplied to fox news  in which comey and the fbi seem to be going through a routine set of prescribed political moves designed to implement damage control \nelite circles fbi head james comey and friend hillary clinton \ncertainly a desperate democratic party and an even more desperate obama white house over the last  weeks obama and his wife michelle have been out 

## Create a Dataset Class

In [14]:
import numpy as np
import pandas as pd


class NewsData(Dataset):
    def __init__(self, df, tokenizer, type_path, num_samples, input_length=4096, output_length=4096, print_text=False):         
        ## Shuffle data set
        df = df.sample(frac=1, random_state=1)
        self.dataset = []
        val_size = int(0.2 * df.shape[0])

        inps = df["text"].values.tolist()
        lbls = df["label"].values.tolist()

        for i, inp in enumerate(inps):
            inp = inp.replace(":", "")
            inp = "classify: " + inp
            lbl = str(lbls[i])
            self.dataset.append({"inp": inp, "lbl": lbl})
        
        if type_path == "train":
            self.dataset = self.dataset[:len(self.dataset) - val_size]
        elif type_path == "validation":
            self.dataset = self.dataset[len(self.dataset) - val_size:]

        if num_samples:
            self.dataset = self.dataset[:num_samples]
        
        # print(len(self.dataset))

        self.input_length = input_length
        self.tokenizer = tokenizer
        self.output_length = output_length
        self.print_text = print_text
        
  
    def __len__(self):
        return len(self.dataset)

    def convert_to_features(self, example_batch):
        
        input_ = example_batch['inp']
        target_ = example_batch['lbl']
    
        
        input_ = input_.strip()

        
        source = self.tokenizer.batch_encode_plus([input_], max_length=self.input_length, 
                                                     padding='max_length', truncation=True, return_tensors="pt")
        
        targets = self.tokenizer.batch_encode_plus([target_], max_length=3, 
                                                     padding='max_length', truncation=True, return_tensors="pt")

        return source, targets
  
    def __getitem__(self, index):
        source, targets = self.convert_to_features(self.dataset[index])
        
        if self.print_text:
            print("Lens are: ", source['input_ids'][0].shape, targets['input_ids'][0].shape)
        
        source_ids = source["input_ids"].squeeze()
        target_ids = targets["input_ids"].squeeze()

        # return {"source_ids": source_ids, "target_ids": target_ids}

        src_mask    = source["attention_mask"].squeeze()
        target_mask = targets["attention_mask"].squeeze()

        return {"source_ids": source_ids, "source_mask": src_mask, "target_ids": target_ids, "target_mask": target_mask}



def get_dataset(tokenizer, type_path, num_samples, args, df=df):
      return NewsData(df = df, tokenizer=tokenizer, type_path=type_path, num_samples=num_samples, input_length=args.max_input_length, 
                        output_length=args.max_output_length)

### Test the data set class

In [15]:
tokenizer = T5Tokenizer.from_pretrained('t5-small')
dataset = NewsData(data, tokenizer, 'validation', None, token_len, 512, True)
len(dataset)

301

In [16]:
data = dataset[5]
print()
print("Shape of Tokenized Text: ", data['source_ids'].shape)
print()
print("Sanity check - Decode Text: ", tokenizer.decode(data['source_ids']))
print("====================================")
print("Sanity check - Decode Classification: ", tokenizer.decode(data['target_ids']))

Lens are:  torch.Size([512]) torch.Size([3])

Shape of Tokenized Text:  torch.Size([512])

Sanity check - Decode Text:  classify: here is the problem the usa constitution states only congress can declare war yet sanctions are a declaration of war made by non other than the bankster elite that has everrything to do with profit and nothing to do with defence far too much power for banksters to have and the profits they make are huge the office of foreign assets control quotofacquot of the us department of the treasury administers and enforces economic and trade sanctions based on us foreign policy and national security goals against targeted foreign countries and regimes terrorists international narcotics traffickers those engaged in activities related to the proliferation of weapons of mass destruction and other threats to the national security foreign policy or economy of the united states it is an issue no one has ever adressed anywhere so no matter who you vote for its the banksters 

## T5 Fine Tuner Class

### Functions to calculating accuracy while training

In [17]:
def normalize_answer(s):
    """Lower text and remove punctuation, articles and extra whitespace."""

    def white_space_fix(text):
        return " ".join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return "".join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix((remove_punc(lower(s))))

In [18]:
def exact_match_score(prediction, ground_truth):
    return int(normalize_answer(prediction) == normalize_answer(ground_truth))

In [19]:
def calculate_scores(predictions, ground_truths):
    em_score = 0
    subset_match_score = 0
    
    for i in range(len(predictions)):
        ground_truth = ground_truths[i]
        prediction = predictions[i]
        em_score +=  exact_match_score(prediction, ground_truth)
    
    em_score /= len(predictions)
    return em_score*100

In [20]:
class T5FakeNewsDetector(pl.LightningModule):
    def __init__(self, hparams):
        super(T5FakeNewsDetector, self).__init__()
        self.hparams = hparams        
        self.model = T5ForConditionalGeneration.from_pretrained(hparams.model_name_or_path, return_dict=True)
        self.tokenizer = T5Tokenizer.from_pretrained(hparams.tokenizer_name_or_path, max_length=hparams.max_input_length)
        self.output_dir = Path(self.hparams.output_dir)
        self.step_count = 0
        
        if self.hparams.freeze_embeds:
            self.freeze_embeds()
        if self.hparams.freeze_encoder:
            self.freeze_params(self.model.get_encoder())
            assert_all_frozen(self.model.get_encoder())
            
            
        n_observations_per_split = {
            "train": self.hparams.n_train,
            "validation": self.hparams.n_val,
            "test": self.hparams.n_test,
        }
        self.n_obs = {k: v if v >= 0 else None for k, v in n_observations_per_split.items()}
        self.em_score_list = []
        
    def freeze_params(self, model):
        for par in model.parameters():
            par.requires_grad = False
            
    def freeze_embeds(self):
        try:
            self.freeze_params(self.model.model.shared)
            for d in [self.model.model.encoder, self.model.model.decoder]:
                freeze_params(d.embed_positions)
                freeze_params(d.embed_tokens)
        except AttributeError:
            self.freeze_params(self.model.shared)
            for d in [self.model.encoder, self.model.decoder]:
                self.freeze_params(d.embed_tokens)

    def lmap(self, f, x):
        return list(map(f, x))

    def is_logger(self):
        return self.trainer.global_rank <= 0
    
    def parse_score(self, result):
        return {k: round(v.mid.fmeasure * 100, 4) for k, v in result.items()}
        
    def forward(self, input_ids, attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, lm_labels=None):
        return self.model(
                input_ids,
                attention_mask=attention_mask,
                decoder_input_ids=decoder_input_ids,
                decoder_attention_mask=decoder_attention_mask,
                labels=lm_labels
            )

    def _step(self, batch):
        labels = batch["target_ids"]
        labels[labels[:, :] == self.tokenizer.pad_token_id] = -100

        outputs = self(
            input_ids=batch["source_ids"],
            attention_mask=batch["source_mask"],
            lm_labels=labels,
            decoder_attention_mask=batch['target_mask']
        )

        loss = outputs[0]

        return loss
    
    def ids_to_clean_text(self, generated_ids):
        gen_text = self.tokenizer.batch_decode(
            generated_ids,
            skip_special_tokens=True, 
            clean_up_tokenization_spaces=True
        )
        return self.lmap(str.strip, gen_text)
    
    def _generative_step(self, batch) :
        
        t0 = time.time()
        # print(batch)
        inp_ids = batch["source_ids"]
        
        generated_ids = self.model.generate(
            batch["source_ids"],
            attention_mask=batch["source_mask"],
            use_cache=True,
            decoder_attention_mask=batch['target_mask'],
            max_length=3

        )
        preds = self.ids_to_clean_text(generated_ids)
        target = self.ids_to_clean_text(batch["target_ids"])
#         print("Preds and Targets: ", preds, target)
            
        gen_time = (time.time() - t0) / batch["source_ids"].shape[0]  
    
        loss = self._step(batch)
        base_metrics = {'val_loss': loss}
        summ_len = np.mean(self.lmap(len, generated_ids))
        base_metrics.update(gen_time=gen_time, gen_len=summ_len, preds=preds, target=target)
        em_score  = calculate_scores(preds, target)
        
        self.em_score_list.append(em_score)
        
        em_score = torch.tensor(em_score,dtype=torch.float32)
    
        base_metrics.update(accuracy=em_score)
        
        return base_metrics

    def training_step(self, batch, batch_idx):
        loss = self._step(batch)

        tensorboard_logs = {"train_loss": loss}
        return {"loss": loss, "log": tensorboard_logs}
  
    def training_epoch_end(self, outputs):
        avg_train_loss = torch.stack([x["loss"] for x in outputs]).mean()
        tensorboard_logs = {"avg_train_loss": avg_train_loss}
        return {"avg_train_loss": avg_train_loss, "log": tensorboard_logs, 'progress_bar': tensorboard_logs}

    def validation_step(self, batch, batch_idx):
        return self._generative_step(batch)
    
    def validation_epoch_end(self, outputs):
        
        avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
        tensorboard_logs = {"val_loss": avg_loss}
        
        if len(self.em_score_list) <= 2:
            average_em_score = sum(self.em_score_list) / len(self.em_score_list) 
            
        else:
            latest_em_score = self.em_score_list[:-2]
            average_em_score = sum(latest_em_score) / len(latest_em_score) 
        
        average_em_score = torch.tensor(average_em_score,dtype=torch.float32)
        tensorboard_logs.update(accuracy=average_em_score)
        
        self.target_gen= []
        self.prediction_gen=[]
        return {"avg_val_loss": avg_loss, 
                "accuracy" : average_em_score,
                "log": tensorboard_logs, 'progress_bar': tensorboard_logs}

    def configure_optimizers(self):

        model = self.model
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
                "weight_decay": self.hparams.weight_decay,
            },
            {
                "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
                "weight_decay": 0.0,
            },
        ]
        optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon)
        self.opt = optimizer
        return [optimizer]
  
    def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, second_order_closure=None, using_native_amp=None):
        if self.trainer.use_tpu:
            xm.optimizer_step(optimizer)
        else:
            optimizer.step()
        optimizer.zero_grad()
        self.lr_scheduler.step()
  
    def get_tqdm_dict(self):
        tqdm_dict = {"loss": "{:.3f}".format(self.trainer.avg_loss), "lr": self.lr_scheduler.get_last_lr()[-1]}

        return tqdm_dict
    
    def train_dataloader(self):   
        n_samples = self.n_obs['train']
        train_dataset = get_dataset(tokenizer=self.tokenizer, type_path="train", num_samples=n_samples, args=self.hparams)
        dataloader = DataLoader(train_dataset, batch_size=self.hparams.train_batch_size, drop_last=True, shuffle=True, num_workers=4)
        t_total = (
            (len(dataloader.dataset) // (self.hparams.train_batch_size * max(1, self.hparams.n_gpu)))
            // self.hparams.gradient_accumulation_steps
            * float(self.hparams.num_train_epochs)
        )
        scheduler = get_linear_schedule_with_warmup(
            self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=t_total
        )
        self.lr_scheduler = scheduler

        return dataloader

    def val_dataloader(self):
        n_samples = self.n_obs['validation']
        validation_dataset = get_dataset(tokenizer=self.tokenizer, type_path="validation", num_samples=n_samples, args=self.hparams)
        
        return DataLoader(validation_dataset, batch_size=self.hparams.eval_batch_size, num_workers=4)
    
    def test_dataloader(self):
        n_samples = self.n_obs['test']
        test_dataset = get_dataset(tokenizer=self.tokenizer, type_path="test", num_samples=n_samples, args=self.hparams)
        
        return DataLoader(test_dataset, batch_size=self.hparams.eval_batch_size, num_workers=4)
    
    def on_save_checkpoint(self, checkpoint):
        save_path = self.output_dir.joinpath(model_prefix)
        self.model.config.save_step = self.step_count
        self.model.save_pretrained(save_path)
        self.tokenizer.save_pretrained(save_path)

In [31]:
logger = logging.getLogger(__name__)

class LoggingCallback(pl.Callback):
    def on_validation_end(self, trainer, pl_module):
        logger.info("***** Validation results *****")
        if pl_module.is_logger():
            metrics = trainer.callback_metrics
            # Log results
            for key in sorted(metrics):
                if key not in ["log", "progress_bar"]:
                    logger.info("{} = {}\n".format(key, str(metrics[key])))

    def on_test_end(self, trainer, pl_module):
        logger.info("***** Test results *****")

        if pl_module.is_logger():
            metrics = trainer.callback_metrics

            # Log and save results to file
            output_test_results_file = os.path.join(pl_module.hparams.output_dir, "test_results.txt")
            with open(output_test_results_file, "w") as writer:
                for key in sorted(metrics):
                    if key not in ["log", "progress_bar"]:
                        logger.info("{} = {}\n".format(key, str(metrics[key])))
                        writer.write("{} = {}\n".format(key, str(metrics[key])))


## Define the Hyper Parameters

In [32]:
args_dict = dict(
    output_dir="", # path to save the checkpoints
    model_name_or_path=model_name,
    tokenizer_name_or_path=model_name,
    max_input_length=token_len,
    max_output_length=token_len,
    freeze_encoder=False,
    freeze_embeds=False,
    learning_rate=3e-4,
    weight_decay=0.0,
    adam_epsilon=1e-8,
    warmup_steps=0,
    train_batch_size=1,
    eval_batch_size=1,
    num_train_epochs=10,
    gradient_accumulation_steps=8,
    n_gpu=1,
    resume_from_checkpoint=None, 
    val_check_interval = 0.5, 
    n_val=-1,
    n_train=-1,
    n_test=-1,
    early_stop_callback=False,
    fp_16=False, # if you want to enable 16-bit training then install apex and set this to true
    opt_level='O1', # you can find out more on optimisation levels here https://nvidia.github.io/apex/amp.html#opt-levels-and-properties
    max_grad_norm=1.0, # if you enable 16-bit training then set this to a sensible value, 0.5 is a good default
    seed=42,
)


args_dict.update({'output_dir': "./" + model_prefix + "_final", 'num_train_epochs':50,
             'train_batch_size': 8, 'eval_batch_size': 8, 'resume_from_checkpoint' : 't5-small-512t5-small-512_ckpt_epoch_29.ckpt'})
args = argparse.Namespace(**args_dict)


## Define Checkpoint function
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    filepath="./" + model_prefix + "_checkpoint", prefix=model_prefix, monitor="accuracy", mode="max", save_top_k=1
)


In [33]:
## If resuming from checkpoint, add an arg resume_from_checkpoint
train_params = dict(
    accumulate_grad_batches=args.gradient_accumulation_steps,
    gpus=args.n_gpu,
    max_epochs=args.num_train_epochs,
    precision= 16 if args.fp_16 else 32,
    amp_level=args.opt_level,
    resume_from_checkpoint=args.resume_from_checkpoint,
    gradient_clip_val=args.max_grad_norm,
    checkpoint_callback=checkpoint_callback,
    val_check_interval=args.val_check_interval,
    logger=wandb_logger,
    callbacks=[LoggingCallback()],
    # progress_bar_refresh_rate=0
)

model = T5FakeNewsDetector(args)
trainer = pl.Trainer(**train_params)


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]


## Train using Pytorch Lightning

In [34]:
trainer.fit(model)

[34m[1mwandb[0m: Wandb version 0.10.8 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade

  | Name  | Type                       | Params
-----------------------------------------------------
0 | model | T5ForConditionalGeneration | 60 M  


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Saving latest checkpoint..





1

## Test this model on Val Set

In [35]:
import textwrap
from tqdm.auto import tqdm

In [36]:
tokenizer = T5Tokenizer.from_pretrained('t5-small')
dataset = dataset = NewsData(df, tokenizer, 'validation', None, token_len, 3, True)

In [37]:
loader = DataLoader(dataset, batch_size=32, shuffle=True)
it = iter(loader)

In [38]:
batch = next(it)
batch["source_ids"].shape

Lens are:  torch.Size([512]) torch.Size([3])
Lens are:  torch.Size([512]) torch.Size([3])
Lens are:  torch.Size([512]) torch.Size([3])
Lens are:  torch.Size([512]) torch.Size([3])
Lens are:  torch.Size([512]) torch.Size([3])
Lens are:  torch.Size([512]) torch.Size([3])
Lens are:  torch.Size([512]) torch.Size([3])
Lens are:  torch.Size([512]) torch.Size([3])
Lens are:  torch.Size([512]) torch.Size([3])
Lens are:  torch.Size([512]) torch.Size([3])
Lens are:  torch.Size([512]) torch.Size([3])
Lens are:  torch.Size([512]) torch.Size([3])
Lens are:  torch.Size([512]) torch.Size([3])
Lens are:  torch.Size([512]) torch.Size([3])
Lens are:  torch.Size([512]) torch.Size([3])
Lens are:  torch.Size([512]) torch.Size([3])
Lens are:  torch.Size([512]) torch.Size([3])
Lens are:  torch.Size([512]) torch.Size([3])
Lens are:  torch.Size([512]) torch.Size([3])
Lens are:  torch.Size([512]) torch.Size([3])
Lens are:  torch.Size([512]) torch.Size([3])
Lens are:  torch.Size([512]) torch.Size([3])
Lens are: 

torch.Size([32, 512])

In [39]:
model.to('cuda')
outs = model.model.generate(
            batch["source_ids"].cuda(),
            attention_mask=batch["source_mask"].cuda(),
            use_cache=True,
            decoder_attention_mask=batch['target_mask'].cuda(),
        )

dec = [tokenizer.decode(ids) for ids in outs]

texts = [tokenizer.decode(ids) for ids in batch['source_ids']]
targets = [tokenizer.decode(ids) for ids in batch['target_ids']]

In [40]:
for i in range(32):
    lines = textwrap.wrap("Input Text:\n%s\n" % texts[i], width=100)
    print("\n".join(lines))
    print("\nActual Class: %s" % targets[i])
    print("\nPredicted Class from T5: %s" % dec[i])
    print("=====================================================================\n")

Input Text: classify: news bulletin liverpools english striker daniel sturridge c applauds
supporters at the final whistle during the efl english football league cup fourth round match
between liverpool and tottenham hotspur at anfield in liverpool north west england on october afp
liverpool have progressed into the quarterfinals of the efl cup by defeating tottenham at anfield
with the help of a double by daniel sturridge the reds just couldnt have asked for a better start on
tuesday as they grabbed the lead just minutes into the match through daniel sturridge the yearold
striker then doubled liverpools advantage after the break by converting a oneonone opportunity in
the th minute spurs however managed to pull one back in the th minute when vincent janssen converted
a penalty but that was as close as they got as liverpool made it to the last eight and stayed on
course to lift the trophy for a th time in their history

Actual Class: Real

Predicted Class from T5: Real

Input Text: cla