In [14]:
import pandas as pd
import sys

sys.path.append('../src')
from metrics import calculate_metrics
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM


In [15]:
df = pd.read_csv("/Users/rohitrawat/job-prep/Assignments/accrete-ai/text-summarization/data/processed/news_summary_cleaned_train.csv")

In [16]:
model_id = "google/flan-t5-small"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
summarizer = pipeline("summarization", model=model, tokenizer=tokenizer, device_map='auto')

Hardware accelerator e.g. GPU is available in the environment, but no `device` argument is passed to the `Pipeline` object. Model will be on CPU.


In [17]:
df['generated_text_flan'] = df['text'].apply(lambda x: summarizer(x, max_length=min(300, len(x)//7), min_length=30, do_sample=False)[0]['summary_text'])

Token indices sequence length is longer than the specified maximum sequence length for this model (738 > 512). Running this sequence through the model will result in indexing errors
Your min_length=30 must be inferior than your max_length=28.
Your min_length=30 must be inferior than your max_length=15.


In [18]:
model_id = "rrrohit/distilbart-cnn-12-6_finetuned"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
summarizer = pipeline("summarization", model=model, tokenizer=tokenizer, device_map='auto')

Hardware accelerator e.g. GPU is available in the environment, but no `device` argument is passed to the `Pipeline` object. Model will be on CPU.


In [19]:
df['generated_text_bart'] = df['text'].apply(lambda x: summarizer(x, max_length=min(300, len(x)//7), min_length=15, do_sample=False)[0]['summary_text'])

In [21]:
df.head()

Unnamed: 0,text,summary,generated_text_flan,generated_text_bart
1457,embarrassing loss vote face board cricket cont...,vinod rai head supreme courtappointed bccis co...,rai speaking launch biographical book cricket ...,bcci committee administrators cao vinod rai ur...
2862,new delhi mar pti missing defence personnel in...,external affairs minister sushma swaraj inform...,new delhi mar pti missing defence personnel in...,external affairs minister sushma swaraj said ...
2068,saharanpur apr pti yoga guru ramdev said today...,yoga guru baba ramdev said patanjali ayurved f...,saharanpur yoga guru ramdev said patanjali ayu...,yoga guru ramdev friday said patanjali ayurved...
2078,supreme court today said aadhaar card cannot m...,supreme court said government cannot stopped u...,supreme court today said set sevenjudge bench ...,supreme court friday said aadhaar card cannot ...
2168,traffic tourist destinations shimla manali dal...,nearly tourists stranded kothi due road blocka...,narkanda jubbal kotkhai kharapathar chopal dis...,traffic tourist destinations shimla manali dal...


# Baseline Flan T5-small model metrics

In [22]:
calculate_metrics(df.summary.tolist(), df.generated_text_flan.tolist())



Unnamed: 0,rouge-1,rouge-2,rouge-l,BERTScore
r,0.326,0.162,0.289,0.63
p,0.253,0.116,0.217,0.608
f,0.263,0.123,0.23,0.616


# Fine-tuned DistilBART model

In [23]:
calculate_metrics(df.summary.tolist(), df.generated_text_bart.tolist())



Unnamed: 0,rouge-1,rouge-2,rouge-l,BERTScore
r,0.529,0.325,0.46,0.739
p,0.535,0.33,0.462,0.745
f,0.524,0.321,0.453,0.741
