In [None]:
import locale
def getpreferredencoding(do_setlocale = True):
    return "UTF-8"
locale.getpreferredencoding = getpreferredencoding

In [None]:
!pip install transformers
!pip install datasets
!pip install sentencepiece
!pip install rouge
!pip install wandb
!pip install bert-extractive-summarizer
!pip install -U sentence-transformers

In [None]:
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from tqdm.notebook import tqdm_notebook
import time
from torch import cuda
import csv
from rouge import Rouge
import wandb
from sentence_transformers import SentenceTransformer
from summarizer.sbert import SBertSummarizer
import warnings
warnings.filterwarnings('ignore')

In [None]:
device = 'cuda' if cuda.is_available() else 'cpu'
print(f'Device : {device}')

In [None]:
!nvidia-smi
!wandb login

In [None]:
class XMediaData(Dataset):
  def __init__(self, split_type, extracive, ratio):
    self.data = load_dataset('GEM/xmediasum', split=split_type)
    self.extractive = extracive
    self.ratio = ratio
    if self.extractive:
      self.extractive_model = SBertSummarizer('all-MiniLM-L6-v2')

  def __len__(self):
    return self.data.shape[0]

  def __getitem__(self, idx):
    if self.extractive:
      instance = self.data[idx]
      instance['dialogue'] = self.extractive_model(self.data[idx]['dialogue'], ratio=self.ratio)
      return (f"Summarize: {instance['dialogue']}", f"Summary: {self.data[idx]['summary']}")
    else:
      return (f"Summarize: {self.data[idx]['dialogue']}", f"Summary: {self.data[idx]['summary']}")

In [None]:
extractive = True
ratio = 0.75
train_data = XMediaData('train', extractive, ratio)
val_data = XMediaData('validation', extractive, ratio)

In [None]:
print(len(train_data))
print(len(val_data))

In [None]:
wandb.init(project="abstractive_dialogue_summarizer")

model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small").to(device)
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")

wandb.watch(model, log="all")

input_ids = tokenizer.batch_encode_plus([train_data[ex][0] for ex in tqdm_notebook(range(len(train_data)), desc='input_ids')], max_length=512, truncation=True, padding='longest', return_tensors='pt').to(device)
output_ids = tokenizer.batch_encode_plus([train_data[ex][1] for ex in tqdm_notebook(range(len(train_data)), desc='output_ids')], max_length=512, truncation=True, padding='longest', return_tensors='pt').to(device)

# train the model on the few-shot examples
optimizer = torch.optim.Adam(model.parameters(), lr=3e-5)


In [None]:
for epoch in tqdm_notebook(range(3), desc='Epoch'):
    total_loss = 0.0
    for i in tqdm_notebook(range(len(train_data)), desc= 'Trained'):
        input_seq = input_ids['input_ids'][i].unsqueeze(0).to(device)
        output_seq = output_ids['input_ids'][i].unsqueeze(0).to(device)

        optimizer.zero_grad()

        # generate summary
        outputs = model(input_ids=input_seq, labels=output_seq)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        wandb.log({"Training Loss": loss.item()})

        total_loss += loss.item()

    print(f"Epoch {epoch + 1} Loss: {total_loss/len(train_data)}")


In [None]:
summaries = {}
for i in tqdm_notebook(range(len(val_data)), desc = 'Generated Summaries'):
  input_ids = tokenizer.encode_plus(val_data[i], max_length=512, truncation=True, padding='longest', return_tensors='pt').to(device)
  summary_ids = model.generate(input_ids=input_ids['input_ids'], num_beams=4, max_length=128, early_stopping=True).to(device)
  summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
  summaries[val_data[i]] = summary

In [None]:
file = open('few_shot_extractive.txt', 'w')
for k,v in summaries.items():
    file.write(v.encode('ascii', 'ignore').decode('ascii').replace('Summary: ', ''))
    file.write('\n')
file.close()

In [None]:
generated_summaries = []
for k,v in summaries.items():
  generated_summaries.append(v.encode('ascii', 'ignore').decode('ascii').replace('Summary: ', ''))

In [None]:
def get_single_rouge_scores(idx):
  rouge = Rouge()
  actual_summary = val_data[idx][1]
  actual_summary = actual_summary.encode('ascii', 'ignore').decode('ascii').replace('Summary: ', '')
  generated_sumamry = generated_summaries[idx]
  return rouge.get_scores(generated_sumamry, actual_summary)[0]

In [None]:
def get_score(rouge, param):
  total = 0
  for i in tqdm_notebook(range(len(generated_summaries)), desc=f'{param}'):
    total += get_single_rouge_scores(i)[rouge][param]
  return total/len(generated_summaries)

In [None]:
print('Rouge-1 Scores')
print(f"r : {get_score('rouge-1', 'r')}")
print(f"p : {get_score('rouge-1', 'p')}")
print(f"f : {get_score('rouge-1', 'f')}")

print('\nRouge-2 Scores')
print(f"r : {get_score('rouge-2', 'r')}")
print(f"p : {get_score('rouge-2', 'p')}")
print(f"f : {get_score('rouge-2', 'f')}")

print('\nRouge-l Scores')
print(f"r : {get_score('rouge-l', 'r')}")
print(f"p : {get_score('rouge-l', 'p')}")
print(f"f : {get_score('rouge-l', 'f')}")