# Importing Libaries

In [None]:
#!pip install transformers==3.0.0
#!pip install tensorflow_datasets
from transformers import T5Tokenizer, T5ForConditionalGeneration
from rouge_score import rouge_scorer
from rouge_score import scoring
import tensorflow as tf
import tensorflow_datasets as tfds
from pathlib import Path
import torch
import re
import time
import numpy as np
import warnings
import os
import logging
import numpy as np
import shutil
import gradio as gr
logging.basicConfig(level=logging.ERROR)
warnings.filterwarnings('ignore')


# Setting Hyper-Parameters

In [None]:
BATCH_SIZE = 16
SHUFFEL_SIZE = 1024
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
learning_rate = 3e-5
valid_loss_min= np.Inf
item_index = 0
drivepath="Data/"

# Load Dataset and Storing in Files

In [None]:
import tensorflow_datasets as tfds
cnn_dailymail = tfds.load(name="cnn_dailymail")

In [None]:
train_tfds = cnn_dailymail['train']
val_tfds = cnn_dailymail['validation']

In [None]:
train_ds_iter = tfds.as_numpy(train_tfds)
val_ds_iter = tfds.as_numpy(val_tfds)

In [None]:
def write_data(iter_dataset, name, path=drivepath):
    
    articles_file = Path(path + name + "/article").open("w",encoding="utf-8")
    highlights_file = Path(path + name + "/highlights").open("w",encoding="utf-8")

    for item in iter_dataset:
        articles_file.write(item["article"].decode("utf-8") + "\n")
        articles_file.flush()
        highlights_file.write(item["highlights"].decode("utf-8").replace("\n", " ") + "\n")
        highlights_file.flush()

In [None]:
!mkdir train
!mkdir test
!mkdir val

In [None]:
write_data(train_ds_iter, "train")
write_data(val_ds_iter, "val")

# Load the Basic T5 Model
### It only contains model architecture. We should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

In [None]:
from transformers import T5Tokenizer, T5ForConditionalGeneration

tokenizer = T5Tokenizer.from_pretrained('t5-small')
model = T5ForConditionalGeneration.from_pretrained('t5-small').to(device)

task_specific_params = model.config.task_specific_params
if task_specific_params is not None:
    model.config.update(task_specific_params.get("summarization", {}))
    

optimizer = torch.optim.Adam(model.parameters(),lr=learning_rate, weight_decay=0.0001)

# Tokenize data

In [None]:
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, articles, highlights):
        self.x = articles
        self.y = highlights
        
    def __getitem__(self, index):
        x = tokenizer.encode_plus(model.config.prefix + self.transfrom(self.x[index]),truncation=True, max_length=512, return_tensors="pt", pad_to_max_length=True)
        y = tokenizer.encode(self.transfrom(self.y[index]),truncation=True, max_length=150, return_tensors="pt", pad_to_max_length=True)
        return x['input_ids'].view(-1), x['attention_mask'].view(-1), y.view(-1)
    
    @staticmethod
    def transfrom(x):
        x = x.lower()
        x = re.sub("'(.*)'", r"\1", x)
        return x
    
    def __len__(self):
        return len(self.x)

In [None]:

def read_files(name):
    article_path = "Data/%s/article" % name
    highlights_path = "Data/%s/highlights" % name
    
    articles = [x.rstrip() for x in open(article_path,encoding="utf-8").readlines()]
    highlights = [x.rstrip() for x in open(highlights_path,encoding="utf-8").readlines()]
    assert len(articles) == len(highlights)
    return articles, highlights

In [None]:
def get_dataset(name):
    articles, highlights = read_files(name)
    return MyDataset(articles, highlights)

In [None]:
train_ds = get_dataset("train")
val_ds = get_dataset("val")

In [None]:
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=BATCH_SIZE)
val_loader = torch.utils.data.DataLoader(val_ds, batch_size=BATCH_SIZE)

In [None]:
pad_token_id = tokenizer.pad_token_id
def step(inputs_ids, attention_mask, y):
    y_ids = y[:, :-1].contiguous()
    lm_labels = y[:, 1:].clone()
    lm_labels[y[:, 1:] == pad_token_id] = -100
    output = model(inputs_ids, attention_mask=attention_mask, decoder_input_ids=y_ids, labels=lm_labels)
    return output[0] # loss

In [None]:

def save_ckp(state, is_best, checkpoint_path, best_model_path):
    """
    state: checkpoint we want to save
    is_best: is this the best checkpoint; min validation loss
    checkpoint_path: path to save checkpoint
    best_model_path: path to save best model
    """
    f_path = checkpoint_path
    # save checkpoint data to the path given, checkpoint_path
    torch.save(state, f_path)
    # if it is a best model, min validation loss
    if is_best:
        best_fpath = best_model_path
        torch.save(state, best_fpath)

# Training Model

In [None]:
logging.basicConfig(level=logging.ERROR)
warnings.filterwarnings('ignore')
checkpoint_path="Data/cp.pt"
best_model_path="Data/best.pt"
EPOCHS = 1
log_interval = 200
train_loss = []
val_loss = []
for epoch in range(EPOCHS):
    model.train() 
    start_time = time.time()
    for i, (inputs_ids, attention_mask, y) in enumerate(train_loader):
      

        if i>item_index:
            
            inputs_ids = inputs_ids.to(device)
            attention_mask = attention_mask.to(device)
            y = y.to(device)


            optimizer.zero_grad()
            loss = step(inputs_ids, attention_mask, y)
            train_loss.append(loss.item())
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
            optimizer.step()

            if (i + 1) % log_interval == 0:
                with torch.no_grad():
                    x, x_mask, y = next(iter(val_loader))
                    x = x.to(device)
                    x_mask = x_mask.to(device)
                    y = y.to(device)

                    v_loss = step(x, x_mask, y)
                    v_loss = v_loss.item()


                    elapsed = time.time() - start_time



                    print('| epoch {:3d} | [{:5d}/{:5d}] | '
                      'ms/batch {:5.2f} | '
                      'loss {:5.2f} | val loss {:5.2f}'.format(
                        epoch, i, len(train_loader),
                        elapsed * 1000 / log_interval,
                        loss.item(), v_loss))

                    # create checkpoint variable and add important data
                    checkpoint = {
                        'epoch': epoch ,
                        'item' : i,
                        'valid_loss_min': v_loss,
                        'state_dict': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                      }

                    # save checkpoint
                    save_ckp(checkpoint, False, checkpoint_path, best_model_path)
                    
                    
                    ## TODO: save the model if validation loss has decreased
                    if v_loss <= valid_loss_min:

                        print('Validation loss decreased ({:.6f} --> {:.6f}).  Saving model ...'.format(valid_loss_min,v_loss))
                        # save checkpoint as best model
                        save_ckp(checkpoint, True, checkpoint_path, best_model_path)
                        valid_loss_min = v_loss

                start_time = time.time()
                val_loss.append(v_loss)
                
                

# Calculating Confidence (Rouge Score)

In [None]:

class RougeScore:
    
    def __init__(self, score_keys=None)-> None:
        super().__init__()
        if score_keys is None:  
            self.score_keys = ["rouge1", "rouge2", "rougeLsum"]
        
        self.scorer = rouge_scorer.RougeScorer(self.score_keys)
        self.aggregator = scoring.BootstrapAggregator()
        
        
    @staticmethod
    def prepare_summary(summary):
            # Make sure the summary is not bytes-type
            # Add newlines between sentences so that rougeLsum is computed correctly.
            summary = summary.replace(" . ", " .\n")
            return summary
    
    def __call__(self, target, prediction):
        """Computes rouge score.''
        Args:
        targets: string
        predictions: string
        """

        target = self.prepare_summary(target)
        prediction = self.prepare_summary(prediction)
        
        self.aggregator.add_scores(self.scorer.score(target=target, prediction=prediction))

        return 
    
    def reset_states(self):
        self.rouge_list = []

    def result(self):
        result = self.aggregator.aggregate()
        
        for key in self.score_keys:
            score_text = "%s = %.2f, 95%% confidence [%.2f, %.2f]"%(
                key,
                result[key].mid.fmeasure*100,
                result[key].low.fmeasure*100,
                result[key].high.fmeasure*100
            )
            print(score_text)
        
        return {key: result[key].mid.fmeasure*100 for key in self.score_keys}

In [None]:

test_ds = get_dataset("test")
test_loader = torch.utils.data.DataLoader(test_ds, batch_size=BATCH_SIZE)
rouge_score = RougeScore()
predictions = []
for i, (input_ids, attention_mask) in enumerate(test_loader):
    input_ids = input_ids.to(device)
    attention_mask = attention_mask.to(device)
    #y = y.to(device)
        
    summaries = model.generate(input_ids=input_ids, attention_mask=attention_mask)
    pred = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summaries]
    real = ["Drunk teenage boy climbed into lion enclosure at zoo in west India . Rahul Kumar, 17, ran towards animals "+
            "shouting 'Today I kill a lion!' Fortunately he fell into a moat before reaching lions and was rescued ."]
    for pred_sent, real_sent in zip(pred, real):
        rouge_score(pred_sent, real_sent)
        predictions.append(str("pred sentence: " + pred_sent + "\n\n real sentence: " + real_sent))
    if i > 40:
        break
    
rouge_score.result()

In [None]:
rouge_score = RougeScore()
predictions = []
for i, (input_ids, attention_mask, y) in enumerate(test_loader):
    input_ids = input_ids.to(device)
    attention_mask = attention_mask.to(device)
    y = y.to(device)
        
    summaries = model.generate(input_ids=input_ids, attention_mask=attention_mask)
    pred = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summaries]
    real = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in y]
    for pred_sent, real_sent in zip(pred, real):
        rouge_score(pred_sent, real_sent)
        predictions.append(str("pred sentence: " + pred_sent + "\n\n real sentence: " + real_sent))
    if i > 40:
        break
    
rouge_score.result()

# **Loading the saved model**

In [None]:

def load_ckp(checkpoint_fpath, model, optimizer):
    """
    checkpoint_path: path to save checkpoint
    model: model that we want to load checkpoint parameters into       
    optimizer: optimizer we defined in previous training
    """
    checkpoint = torch.load(checkpoint_fpath)
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    valid_loss_min = checkpoint['valid_loss_min']
    item_index= checkpoint['item']
    return model,item_index, optimizer, checkpoint['epoch'], valid_loss_min

In [None]:
from transformers import T5Tokenizer, T5ForConditionalGeneration
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
tokenizer = T5Tokenizer.from_pretrained('t5-small')
model = T5ForConditionalGeneration.from_pretrained('t5-small').to(device)

task_specific_params = model.config.task_specific_params
if task_specific_params is not None:
    model.config.update(task_specific_params.get("summarization", {}))
best_model_path="Data/cp.pt"
optimizer = torch.optim.Adam(model.parameters(),lr=learning_rate, weight_decay=0.0001)
model,item_index, optimizer, start_epoch, valid_loss_min = load_ckp(best_model_path, model, optimizer)

# Checking the Model Status

In [None]:
print("model = ", model)
print("optimizer = ", optimizer)
print("Item_index = ",item_index)
print("start_epoch = ", start_epoch)
print("valid_loss_min = ", valid_loss_min)
print("valid_loss_min = {:.6f}".format(valid_loss_min))

# Predict the summary 

In [None]:


def read_files(name):
    article_path = "Data/%s/article" % name
    articles = [x.rstrip() for x in open(article_path).readlines()]
    return articles
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, articles):
        self.x = articles
        
    def __getitem__(self, index):
        x = tokenizer.encode_plus(model.config.prefix + self.transfrom(self.x[index]),truncation=True, max_length=512, return_tensors="pt", pad_to_max_length=True)
        #y = tokenizer.encode(self.transfrom(self.y[index]),truncation=True, max_length=150, return_tensors="pt", pad_to_max_length=True)
        return x['input_ids'].view(-1), x['attention_mask'].view(-1)
    
    @staticmethod
    def transfrom(x):
        x = x.lower()
        x = re.sub("'(.*)'", r"\1", x)
        return x
    
    def __len__(self):
        return len(self.x)
def get_dataset(name):
    articles= read_files(name)
    return MyDataset(articles) 
def predict(blog):
    test_str=blog
    articles_file = Path(drivepath+"test/article").open("w")
    articles_file.write(test_str)
    articles_file.close()
    test_ds = get_dataset("test")
    test_loader = torch.utils.data.DataLoader(test_ds, batch_size=BATCH_SIZE) 
    for i, (input_ids, attention_mask) in enumerate(test_loader):
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        model.eval() 
        summaries = model.generate(input_ids=input_ids, attention_mask=attention_mask)
        pred = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summaries]
        return pred[0]
        #print("Real sentence : ", test_str)
        #print("predicted text: ",pred)

In [None]:
test_ds = get_dataset("test")
test_loader = torch.utils.data.DataLoader(test_ds, batch_size=BATCH_SIZE)
for i, (input_ids, attention_mask) in enumerate(test_loader):
    print(input_ids)
    print(attention_mask)
    summaries = model.generate(input_ids=input_ids, attention_mask=attention_mask)
    for g in summaries:
        print(g)
        print(tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False))

# Interface for Input Text and Output Summary

In [None]:
import gradio as gr
sample="A drunk teenage boy had to be rescued by security after jumping into a lions' enclosure at a zoo in western India. Rahul Kumar, 17, clambered over the enclosure fence at the Kamla Nehru Zoological Park in Ahmedabad, and began running towards the animals, shouting he would 'kill them'. Mr Kumar explained afterwards that he was drunk and 'thought I'd stand a good chance' against the predators. Next level drunk: Intoxicated Rahul Kumar, 17, climbed into the lions' enclosure at a zoo in Ahmedabad and began running towards the animals shouting 'Today I kill a lion!' Mr Kumar had been sitting near the enclosure when he suddenly made a dash for the lions, surprising zoo security. The intoxicated teenager ran towards the lions, shouting: 'Today I kill a lion or a lion kills me!' A zoo spokesman said: 'Guards had earlier spotted him close to the enclosure but had no idea he was planing to enter it. 'Fortunately, there are eight moats to cross before getting to where the lions usually are and he fell into the second one, allowing guards to catch up with him and take him out. 'We then handed him over to the police.' Brave fool: Fortunately, Mr Kumar  fell into a moat as he ran towards the lions and could be rescued by zoo security staff before reaching the animals (stock image) Kumar later explained: 'I don't really know why I did it. 'I was drunk and thought I'd stand a good chance.' A police spokesman said: 'He has been cautioned and will be sent for psychiatric evaluation. 'Fortunately for him, the lions were asleep and the zoo guards acted quickly enough to prevent a tragedy similar to that in Delhi.' Last year a 20-year-old man was mauled to death by a tiger in the Indian capital after climbing into its enclosure at the city zoo."
gr.Interface(fn=predict,inputs= [gr.inputs.Textbox(lines=1000000 ,label="Enter Text to Summarise",default=sample, placeholder="Start here...")],outputs=[gr.outputs.Textbox( type="auto", label=None)]).launch()