In [None]:
import locale
def getpreferredencoding(do_setlocale = True):
    return "UTF-8"
locale.getpreferredencoding = getpreferredencoding

!pip install bert-extractive-summarizer
!pip install -U sentence-transformers
!pip install datasets
!pip install sentence_transformers
!pip install pytorch-lightning
!pip install nlp
!pip install wandb

from datasets import load_dataset
import nltk
nltk.download('punkt')
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
import torch
import pytorch_lightning as pl
from transformers import (T5ForConditionalGeneration, T5Tokenizer, AdamW, get_linear_schedule_with_warmup)
from nlp import load_metric
import time
import argparse
import logging
from pytorch_lightning.loggers import WandbLogger
import os
import numpy as np
import pandas as pd
import torch.nn.functional as F
import wandb
from nltk.cluster import KMeansClusterer
from scipy.spatial import distance_matrix
import math
from sentence_transformers import SentenceTransformer
from summarizer.sbert import SBertSummarizer
import warnings
warnings.filterwarnings('ignore')
from torch import cuda
from tqdm import tqdm


!nvidia-smi
!wandb login


from google.colab import drive
drive.mount('/content/gdrive')

In [None]:
class XMediaData(Dataset):
  def __init__(self, split_type, config, dialogue_length, summ_length, tokenizer, extractive_model):
    self.data = load_dataset('GEM/xmediasum', split=f'{split_type}')
    self.config = config
    self.dialogue_length = dialogue_length
    self.summ_length = summ_length
    self.tokenizer = tokenizer
    self.extractive_model = extractive_model

  def __len__(self):
    return self.data.shape[0]

  def preprocess_text(self, sentence):
    sentence = sentence.replace('\n','')
    sentence = sentence.replace('\t','')
    sentence = sentence.replace('``', '')
    return sentence
  
  def batch_encoding(self, batch):
    dialogue = self.preprocess_text(batch['dialogue'])
    summary = self.preprocess_text(batch['summary'])

    encoded_dialogue = self.tokenizer.batch_encode_plus([dialogue], 
                                                        max_length=self.dialogue_length, 
                                                        padding='max_length', 
                                                        truncation=True, 
                                                        return_tensors="pt")
      
    encoded_summary = self.tokenizer.batch_encode_plus( [summary], 
                                                        max_length=self.summ_length, 
                                                        padding='max_length', 
                                                        truncation=True, 
                                                        return_tensors="pt")
  
  
    dialogue_inp_ids = encoded_dialogue['input_ids'].squeeze()
    summary_inp_ids = encoded_summary['input_ids'].squeeze()

    dialogue_attention_mask = encoded_dialogue["attention_mask"].squeeze()
    summary_attention_mask = encoded_summary["attention_mask"].squeeze()
    
    return [dialogue_inp_ids, summary_inp_ids, dialogue_attention_mask, summary_attention_mask]

  def get_embeddings(self, sentence):
    embedding = self.extractive_model.encode([sentence])
    return embedding[0]

  def centroid_distance(self, sentence):
      return distance_matrix([sentence['embeddings']], [sentence['centroid'].tolist()])[0][0]

  def extractive_summarizer_data(self, instance):
    tokens = nltk.sent_tokenize(instance['dialogue'])
    data = pd.DataFrame(tokens)
    data.columns = ['sentences']
    data['embeddings'] = data['sentences'].apply(self.get_embeddings)

    clusters = math.ceil(len(tokens)*0.5)
    iterations = 10

    X = np.array(data['embeddings'].tolist())
    kclusterer = KMeansClusterer( clusters, 
                                  distance=nltk.cluster.util.cosine_distance,
                                  repeats=iterations,
                                  avoid_empty_clusters=True)
    
    
    assigned_clusters = kclusterer.cluster(X, assign_clusters=True)

    data['cluster'] = pd.Series(assigned_clusters, index=data.index)
    data['centroid'] = data['cluster'].apply(lambda x: kclusterer.means()[x])
    data['centroid_distance'] = data.apply(self.centroid_distance, axis=1)

    extractive_summary = ' '.join(data.sort_values('centroid_distance', ascending = True).groupby('cluster').head(1).sort_index()['sentences'].tolist())
    instance['dialogue'] = extractive_summary

    return instance


  def __getitem__(self, idx):
    instance = self.data[idx]
    if self.config == 'Extractive':
        instance['dialogue'] = self.extractive_model(self.data[idx]['dialogue'], ratio=0.5)
    dialogue_inp_ids, summary_inp_ids, dialogue_attention_mask, summary_attention_mask = self.batch_encoding(instance)
    return {"dialogue_input_ids": dialogue_inp_ids.to(dtype=torch.long), "dialogue_attention_mask": dialogue_attention_mask.to(dtype=torch.long), "summary_input_ids": summary_inp_ids.to(dtype=torch.long), "summary_attention_mask": summary_attention_mask.to(dtype=torch.long)}

In [None]:
class Summarizer:
    def __init__(self):
        # Specifying the project for wandb logging
        self.device = 'cuda' if cuda.is_available() else 'cpu'
        print(f'Device : {self.device}')
        wandb.init(project="abstractive_dialogue_summarizer")

        self.config, self.train_parameters, self.validation_parameters = self.get_config()

        torch.manual_seed(self.config.SEED)
        np.random.seed(self.config.SEED)
        torch.backends.cudnn.deterministic = True

        # self.tokenizer = T5Tokenizer.from_pretrained("t5-small")
        # self.extractive_model = SBertSummarizer('all-MiniLM-L6-v2')

        # if self.config.HYPER_PARAMETER_TUNING:
        #   self.train_data = XMediaData('train[:20%]', self.config.SUMMARY_TYPE, self.config.MAX_LEN, self.config.SUMMARY_LEN, self.tokenizer, self.extractive_model)
        #   self.val_data = XMediaData('validation[:20%]', self.config.SUMMARY_TYPE, self.config.MAX_LEN, self.config.SUMMARY_LEN, self.tokenizer, self.extractive_model)
        #   print("Train Samples : " + str(len(self.train_data)))
        #   print("Val Samples : " + str(len(self.val_data)))
        # else:
        #   self.train_data = XMediaData('train', self.config.SUMMARY_TYPE, self.config.MAX_LEN, self.config.SUMMARY_LEN, self.tokenizer, self.extractive_model)
        #   self.val_data = XMediaData('validation', self.config.SUMMARY_TYPE, self.config.MAX_LEN, self.config.SUMMARY_LEN, self.tokenizer, self.extractive_model)

        # self.train_batch = DataLoader(self.train_data, **self.train_parameters)
        # self.val_batch = DataLoader(self.val_data, **self.validation_parameters)

        # self.abstractive_model = T5ForConditionalGeneration.from_pretrained("t5-small")
        # self.abstractive_model = self.abstractive_model.to(self.device)

        # self.optimizer = torch.optim.Adam(params=self.abstractive_model.parameters(), lr=self.config.LEARNING_RATE)

        # wandb.watch(self.abstractive_model, log="all")

        if self.config.HYPER_PARAMETER_TUNING:
          for lr in self.config.LR_TUNING:
            print(f'---------------------------------- LR : {lr} ----------------------------------')
            self.tokenizer = T5Tokenizer.from_pretrained("t5-small")
            self.extractive_model = SBertSummarizer('all-MiniLM-L6-v2')
            self.abstractive_model = T5ForConditionalGeneration.from_pretrained("t5-small")
            self.abstractive_model = self.abstractive_model.to(self.device)
            self.optimizer = torch.optim.Adam(params=self.abstractive_model.parameters(), lr=lr)
            self.train_data = XMediaData('train[:20%]', self.config.SUMMARY_TYPE, self.config.MAX_LEN, self.config.SUMMARY_LEN, self.tokenizer, self.extractive_model)
            self.val_data = XMediaData('validation[:20%]', self.config.SUMMARY_TYPE, self.config.MAX_LEN, self.config.SUMMARY_LEN, self.tokenizer, self.extractive_model)
            print("Train Samples : " + str(len(self.train_data)))
            print("Val Samples : " + str(len(self.val_data)))
            self.train_batch = DataLoader(self.train_data, **self.train_parameters)
            self.val_batch = DataLoader(self.val_data, **self.validation_parameters)
            wandb.watch(self.abstractive_model, log="all")
            self.train()
            self.validation()
            print(f'--------------------------------------------------------------------------------')

        else:
          self.tokenizer = T5Tokenizer.from_pretrained("t5-small")
          self.extractive_model = SBertSummarizer('all-MiniLM-L6-v2')
          self.abstractive_model = T5ForConditionalGeneration.from_pretrained("t5-small")
          self.abstractive_model = self.abstractive_model.to(self.device)
          self.train_data = XMediaData('train', self.config.SUMMARY_TYPE, self.config.MAX_LEN, self.config.SUMMARY_LEN, self.tokenizer, self.extractive_model)
          self.val_data = XMediaData('validation', self.config.SUMMARY_TYPE, self.config.MAX_LEN, self.config.SUMMARY_LEN, self.tokenizer, self.extractive_model)
          self.train_batch = DataLoader(self.train_data, **self.train_parameters)
          self.val_batch = DataLoader(self.val_data, **self.validation_parameters)
          self.optimizer = torch.optim.Adam(params=self.abstractive_model.parameters(), lr=self.config.LEARNING_RATE)
          wandb.watch(self.abstractive_model, log="all")
          self.train()
          self.validation()

    def get_config(self):
        config = wandb.config
        config.TRAIN_BATCH_SIZE = 2
        config.VALID_BATCH_SIZE = 2
        config.TRAIN_EPOCHS = 2
        config.VAL_EPOCHS = 1 
        config.HYPER_PARAMETER_TUNING = False
        config.LEARNING_RATE = 3e-5
        # config.LEARNING_RATE = 1e-4
        config.LR_TUNING = [0.0001, 0.001, 0.005, 0.01, 0.05, 0.1, 0.5]
        config.SEED = 42
        config.MAX_LEN = 512
        config.SUMMARY_LEN = 64
        config.NO_REPEAT_N_GRAMS = 5
        config.SUMMARY_TYPE = 'Abstractive'

        train_parameters = {'batch_size': config.TRAIN_BATCH_SIZE,
                        'shuffle': True,
                        'num_workers': 0}

        validation_parameters = {  'batch_size': config.VALID_BATCH_SIZE,
                        'shuffle': False,
                        'num_workers': 0}
        
        return config, train_parameters, validation_parameters
    
    def train(self):
        for epoch in range(self.config.TRAIN_EPOCHS):
            self.finetune(epoch)

        if self.config.HYPER_PARAMETER_TUNING == False:
          model_name = f'{self.config.SUMMARY_TYPE}_summarizer_small.pt'
          path = f"/content/gdrive/My Drive/{model_name}" 
          torch.save(self.abstractive_model.state_dict(), path)
    
    def finetune(self, epoch):
        self.abstractive_model.train()
        total_loss  = 0
        batch = 0
        for i,instance in enumerate(self.train_batch, 0):
            
            summary_inp_ids = instance['summary_input_ids'].to(self.device, dtype=torch.long)
            summ_ids = summary_inp_ids[:, :-1].contiguous()

            labels = summary_inp_ids[:, 1:].clone().detach()
            labels[summary_inp_ids[:, 1:] == self.tokenizer.pad_token_id] = -100

            ids = instance['dialogue_input_ids'].to(self.device, dtype=torch.long)
            mask = instance['dialogue_attention_mask'].to(self.device, dtype=torch.long)

            output = self.abstractive_model(input_ids=ids, attention_mask=mask, decoder_input_ids=summ_ids, labels=labels)
            loss = output[0]
            
            if i%10 == 0:
                wandb.log({"Training Loss": loss.item()})

            if i%500==0:
                print(f'Epoch: {epoch}, Loss:  {loss.item()}')
            
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            total_loss += loss.item()
            batch += len(self.train_batch)
          
        print(f"Epoch {epoch + 1} Loss: {total_loss/batch}")
        
    
    def validation(self):
        for epoch in range(self.config.VAL_EPOCHS):
            predicted_summ, actual_summ = self.validate_model(epoch)
            
            df = pd.DataFrame({'Generated Text':predicted_summ, 'Actual Text':actual_summ})
            df.to_csv(f'/content/gdrive/My Drive/{self.config.SUMMARY_TYPE}_{epoch}_predictions_t5_small_ex_final.csv')
    
    def validate_model(self, epoch):
        self.abstractive_model.eval()
        predicted_summ, actual_summ = [], []
        with torch.no_grad():
          for i,instance in enumerate(self.val_batch, 0):
              summary_inp_ids = instance['summary_input_ids'].to(self.device, dtype = torch.long)

              ids = instance['dialogue_input_ids'].to(self.device, dtype = torch.long)
              mask = instance['dialogue_attention_mask'].to(self.device, dtype = torch.long)

              generated = self.abstractive_model.generate(input_ids=ids,
                                                              attention_mask=mask, 
                                                              max_length=64, 
                                                              num_beams=2,
                                                              repetition_penalty=2.5, 
                                                              length_penalty=1.0, 
                                                              early_stopping=True)

              predictions = [self.tokenizer.decode(gen_id, skip_special_tokens=True, clean_up_tokenization_spaces=True) for gen_id in generated]
              actual = [self.tokenizer.decode(summ, skip_special_tokens=True, clean_up_tokenization_spaces=True)for summ in summary_inp_ids]


              if i%100==0:
                  print(f'Completed {i}')
              

              predicted_summ.extend(predictions)
              actual_summ.extend(actual)

        return predicted_summ, actual_summ




if __name__ == "__main__":
    Summarizer()

# Computing Rouge Scores

In [None]:
# Generated Text':predicted_summ, 'Actual Text'

In [None]:
!pip install datasets
!pip install transformers
!pip install transformers[sentencepiece]
!pip install tqdm
!pip install rouge

In [None]:
from datasets import load_dataset
import pandas as pd
import tqdm
from tqdm.notebook import tqdm_notebook
from rouge import Rouge

In [None]:
val_data = load_dataset('GEM/xmediasum', split='validation')
actual_summaries = []
for val in val_data:
  actual_summaries.append(val['summary'])
print(len(actual_summaries))
actual_summaries[0]

In [None]:
finetune_data = pd.read_csv('/content/Extractive_0_predictions_t5_small_ex_final.csv')
summaries = finetune_data[['Generated Text']]
generated_summaries = summaries.values.tolist()
finetune_summaries = []
for sum in generated_summaries:
  finetune_summaries.append(sum[0])

In [None]:
def get_single_rouge_scores(idx):
  rouge = Rouge()
  actual_summary = actual_summaries[idx]
  actual_summary = actual_summary.encode('ascii', 'ignore').decode('ascii').replace('Summary: ', '')
  generated_sumamry = finetune_summaries[idx]
  return rouge.get_scores(generated_sumamry, actual_summary)[0]

In [None]:
def get_score(rouge, param):
  total = 0
  for i in tqdm_notebook(range(len(actual_summaries)), desc=f'{param}'):
    total += get_single_rouge_scores(i)[rouge][param]
  return total/len(actual_summaries)

In [None]:
print('Rouge-1 Scores')
print(f"r : {get_score('rouge-1', 'r')}")
print(f"p : {get_score('rouge-1', 'p')}")
print(f"f : {get_score('rouge-1', 'f')}")

print('\nRouge-2 Scores')
print(f"r : {get_score('rouge-2', 'r')}")
print(f"p : {get_score('rouge-2', 'p')}")
print(f"f : {get_score('rouge-2', 'f')}")

print('\nRouge-l Scores')
print(f"r : {get_score('rouge-l', 'r')}")
print(f"p : {get_score('rouge-l', 'p')}")
print(f"f : {get_score('rouge-l', 'f')}")