In [59]:
import pandas as pd
import sys

sys.path.append('../src')
from metrics import calculate_metrics


In [60]:
df = pd.read_csv('/Users/rohitrawat/job-prep/Assignments/accrete-ai/text-summarization/data/processed/news_summary_cleaned_train.csv').sample(10).reset_index(drop=True)
df.head(3)

Unnamed: 0,text,summary
0,junior home minister kiren rijiju refueled rag...,minister state home kiren rijiju shared video ...
1,six persons arrested indore police allegedly d...,six people arrested indore police allegedly du...
2,charges counter charges flew lok sabha monday ...,minister state home affairs kiren rijiju discu...


In [61]:
# Load model directly
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

model_id = "google/flan-t5-small"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForSeq2SeqLM.from_pretrained(model_id)



In [62]:
from transformers import pipeline

summarizer = pipeline("summarization", model=model, tokenizer=tokenizer, device_map='auto')

Hardware accelerator e.g. GPU is available in the environment, but no `device` argument is passed to the `Pipeline` object. Model will be on CPU.


In [63]:
import concurrent.futures

# Define a function to generate the summary for a given text
def generate_summary(text):
    return summarizer(text, max_length=min(300, len(text)), min_length=30, do_sample=False)[0]['summary_text']

# Use multithreading to generate summaries for each text in parallel
with concurrent.futures.ThreadPoolExecutor() as executor:
    # Submit the tasks to the executor
    futures = [executor.submit(generate_summary, text) for text in df['text']]
    
    # Retrieve the results in the order they were submitted
    for i, future in enumerate(futures):
        generated_summary = future.result()
        df.loc[i, 'generated_summary'] = generated_summary

Your max_length is set to 300, but your input_length is only 175. Since this is a summarization task, where outputs shorter than the input are typically wanted, you might consider decreasing max_length manually, e.g. summarizer('...', max_length=87)
Token indices sequence length is longer than the specified maximum sequence length for this model (532 > 512). Running this sequence through the model will result in indexing errors
Your max_length is set to 300, but your input_length is only 162. Since this is a summarization task, where outputs shorter than the input are typically wanted, you might consider decreasing max_length manually, e.g. summarizer('...', max_length=81)
Your max_length is set to 300, but your input_length is only 242. Since this is a summarization task, where outputs shorter than the input are typically wanted, you might consider decreasing max_length manually, e.g. summarizer('...', max_length=121)
Your max_length is set to 300, but your input_length is only 238. S

In [64]:
text = df['text'].tolist()
summary = df['summary'].tolist()

In [65]:
df

Unnamed: 0,text,summary,generated_summary
0,junior home minister kiren rijiju refueled rag...,minister state home kiren rijiju shared video ...,junior home minister rijiju refueled raging de...
1,six persons arrested indore police allegedly d...,six people arrested indore police allegedly du...,indore police allegedly duping customers selli...
2,charges counter charges flew lok sabha monday ...,minister state home affairs kiren rijiju discu...,lok sabha adjourned till pm lunch yadav speak ...
3,highly placed sources said election commission...,election commission barred madhya pradesh mini...,sources say election commission disqualified m...
4,gujarat congress mlas lodged karnataka resort ...,gujarati chef arranged gujarat mlas flown beng...,gujarat congress mlas lodged karnataka resort ...
5,make indian soldiers conversant chinese langua...,visvabharati university started certificate le...,indian soldiers conversant chinese language vi...
6,absconding businessman vijay mallya granted ba...,absconding businessman liquor baron vijay mall...,india vs pakistan champions trophy game birmin...
7,sunil rastogi came delhi sampark kranti expres...,yearold serial rapist accused raping hundreds ...,sunil rastogi came delhi sampark kranti expres...
8,two persons injured group residents opened fir...,two people injured residents ghaziabad opened ...,loni town ghaziabad wee hours friday local pan...
9,pakistans maritime authorities today arrested ...,amid ongoing tension india pakistan latter wed...,pakistan maritime authorities today arrested i...


# Metrics on baseline model

In [66]:
text = df['text'].tolist()
summary = df['summary'].tolist()
generate_summary = df['generated_summary'].tolist()

calculate_metrics(text, summary)



Unnamed: 0,rouge-1,rouge-2,rouge-l,BERTScore
r,0.751001,0.381519,0.54219,0.697909
p,0.173517,0.066975,0.131312,0.508487
f,0.268309,0.108989,0.201681,0.586063


In [67]:
calculate_metrics(text, generate_summary)



Unnamed: 0,rouge-1,rouge-2,rouge-l,BERTScore
r,0.966766,0.909685,0.925655,0.7673
p,0.198063,0.148787,0.195162,0.496373
f,0.303089,0.234427,0.297891,0.598088
