# 0. Install Packages

In [3]:
!pip install datasets evaluate
!pip install rouge_score bert-score
!pip install openai
!pip install accelerate bitsandbytes

Collecting datasets
  Downloading datasets-2.19.1-py3-none-any.whl (542 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m542.0/542.0 kB[0m [31m3.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting evaluate
  Downloading evaluate-0.4.2-py3-none-any.whl (84 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m8.5 MB/s[0m eta [36m0:00:00[0m
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m16.4 MB/s[0m eta [36m0:00:00[0m
Collecting xxhash (from datasets)
  Downloading xxhash-3.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (194 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m194.1/194.1 kB[0m [31m21.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting multiprocess (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl (134 kB)
[2K     [90m━━━━━━━

In [4]:
import random
import torch
import transformers
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
import datasets
from datasets import load_dataset, load_metric
import pandas as pd
import numpy as np
import re
import os
import accelerate

data_path = "/content/drive/MyDrive/Colab_Datasets/summarization_practice"

# 1. Download Datasets

In [None]:
# Step 1: Download datasets
cnn_dm_datasets = load_dataset("cnn_dailymail", "1.0.0")
xsum_dataset = load_dataset("xsum", trust_remote_code=True)

print(xsum_dataset['train'][0])
print(cnn_dm_datasets['train'][0])

Generating train split:   0%|          | 0/204045 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/11332 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/11334 [00:00<?, ? examples/s]

{'article': 'LONDON, England (Reuters) -- Harry Potter star Daniel Radcliffe gains access to a reported £20 million ($41.1 million) fortune as he turns 18 on Monday, but he insists the money won\'t cast a spell on him. Daniel Radcliffe as Harry Potter in "Harry Potter and the Order of the Phoenix" To the disappointment of gossip columnists around the world, the young actor says he has no plans to fritter his cash away on fast cars, drink and celebrity parties. "I don\'t plan to be one of those people who, as soon as they turn 18, suddenly buy themselves a massive sports car collection or something similar," he told an Australian interviewer earlier this month. "I don\'t think I\'ll be particularly extravagant. "The things I like buying are things that cost about 10 pounds -- books and CDs and DVDs." At 18, Radcliffe will be able to gamble in a casino, buy a drink in a pub or see the horror film "Hostel: Part II," currently six places below his number one movie on the UK box office char

In [None]:
# sample 2000 records
xsum_dataset['train'].shuffle(seed=42)
xsum_sampled = xsum_dataset['train'][:2000]
xsum_sampled.keys()

dict_keys(['document', 'summary', 'id'])

In [None]:
# convert to dataframe and save as pkl files
xsum_df = pd.DataFrame.from_dict(xsum_sampled)
xsum_df.shape

(2000, 3)

In [None]:
xsum_df['doc_length'] = xsum_df['document'].apply(lambda x: len(x.split()))
xsum_df['doc_length'].describe()

count    2000.000000
mean      375.377500
std       286.650425
min        11.000000
25%       180.000000
50%       299.500000
75%       495.000000
max      2694.000000
Name: doc_length, dtype: float64

In [None]:
xsum_df.to_pickle(os.path.join(data_path, 'xsum_sample.pkl'))

In [None]:
# sample 2000 records
cnn_dm_datasets['train'].shuffle(seed=42)
cnn_dm_sampled = cnn_dm_datasets['train'][:2000]
cnn_dm_sampled.keys()

dict_keys(['article', 'highlights', 'id'])

In [None]:
# convert to dataframe and save as pkl files
cnn_dm_df = pd.DataFrame.from_dict(cnn_dm_sampled)
cnn_dm_df.shape

(2000, 3)

In [None]:
cnn_dm_df.to_pickle(os.path.join(data_path, 'cnn_dm_sample.pkl'))

# 2. Clean up text

In [None]:
# Xsum
file_name = 'xsum_sample.pkl'
xsum = pd.read_pickle(os.path.join(data_path, file_name))
xsum.shape

(2000, 4)

In [None]:
xsum.head()

Unnamed: 0,document,summary,id,doc_length,summary_length
0,"The full cost of damage in Newton Stewart, one...",Clean-up operations are continuing across the ...,35232142,400,18
1,A fire alarm went off at the Holiday Inn in Ho...,Two tourist buses have been destroyed by fire ...,40143035,155,17
2,Ferrari appeared in a position to challenge un...,Lewis Hamilton stormed to pole position at the...,35951548,887,17
3,"John Edward Bates, formerly of Spalding, Linco...",A former Lincolnshire Police officer carried o...,36266422,269,22
4,Patients and staff were evacuated from Cerahpa...,An armed man who locked himself into a room at...,38826984,171,25


In [None]:
# simple clean up text
def clean_xsum_doc(doc):
  # remove new line
  doc_new = doc.replace("\n", " ")

  return doc_new

def clean_cnn_doc(doc):
  # remove new line
  doc_new = doc.replace("\n", " ")

  # remove weird characters
  doc_new = ''.join(i for i in doc_new if ord(i)<128)

  # remove news starter (e.g., "JACKSONVILLE, Florida (CNN)   -- ")
  doc_new = re.sub(".*\(CNN\)\s+--\s+", "", doc_new)

  return doc_new

def clean_cnn_summary(doc):
  # remove new line
  doc_new = doc.replace("\n", " ")

  # remove weird characters
  doc_new = ''.join(i for i in doc_new if ord(i)<128)

  # remove news starter (e.g., "NEW: ")
  doc_new = doc_new.replace("NEW: ", "")

  return doc_new

In [None]:
xsum['doc_cleaned'] = xsum['document'].apply(clean_xsum_doc)

In [None]:
# estimate token length needed
xsum['doc_length'] = xsum['doc_cleaned'].apply(lambda x: len(x.split()))
xsum['summary_length'] = xsum['summary'].apply(lambda x: len(x.split()))

In [None]:
xsum['doc_length'].describe()

count    2000.000000
mean      375.377500
std       286.650425
min        11.000000
25%       180.000000
50%       299.500000
75%       495.000000
max      2694.000000
Name: doc_length, dtype: float64

In [None]:
xsum['summary_length'].describe()

count    2000.000000
mean       21.149000
std         5.228524
min         1.000000
25%        18.000000
50%        21.000000
75%        24.000000
max        55.000000
Name: summary_length, dtype: float64

In [None]:
# pick one record as one-shot example
xsum_doc_train = xsum.loc[3,'doc_cleaned']
xsum_ref_train = xsum.loc[3,'summary']

# pick one record to predict
xsum_doc = xsum.loc[0,'doc_cleaned']
xsum_ref = xsum.loc[0,'summary']

In [None]:
file_name = 'cnn_dm_sample.pkl'
cnn_dm = pd.read_pickle(os.path.join(data_path, file_name))
cnn_dm.shape

(2000, 3)

In [None]:
cnn_dm['doc_cleaned'] = cnn_dm['article'].apply(clean_cnn_doc)
cnn_dm['summary_cleaned'] = cnn_dm['highlights'].apply(clean_cnn_summary)

Unnamed: 0,article,highlights,id,doc_cleaned,summary_cleaned
0,"LONDON, England (Reuters) -- Harry Potter star...",Harry Potter star Daniel Radcliffe gets £20M f...,42c027e4ff9730fbb3de84c1af0d2c506e41c3e4,"LONDON, England (Reuters) -- Harry Potter star...",Harry Potter star Daniel Radcliffe gets 20M fo...
1,Editor's note: In our Behind the Scenes series...,Mentally ill inmates in Miami are housed on th...,ee8871b15c50d0db17b0179a6d2beab35065f1e9,The ninth floor of the Miami-Dade pretrial det...,Mentally ill inmates in Miami are housed on th...
2,"MINNEAPOLIS, Minnesota (CNN) -- Drivers who we...","NEW: ""I thought I was going to die,"" driver sa...",06352019a19ae31e527f37f7571c6dd7f0c5da37,Drivers who were on the Minneapolis bridge whe...,"""I thought I was going to die,"" driver says . ..."
3,WASHINGTON (CNN) -- Doctors removed five small...,"Five small polyps found during procedure; ""non...",24521a2abb2e1f5e34e6824e0f9e56904a2b0e88,Doctors removed five small polyps from Preside...,"Five small polyps found during procedure; ""non..."
4,(CNN) -- The National Football League has ind...,"NEW: NFL chief, Atlanta Falcons owner critical...",7fe70cc8b12fab2d0a258fababf7d9c6b5e1262a,The National Football League has indefinitely ...,"NFL chief, Atlanta Falcons owner critical of M..."


In [None]:
cnn_dm.head()

Unnamed: 0,article,highlights,id,doc_cleaned,summary_cleaned,doc_length,summary_length
0,"LONDON, England (Reuters) -- Harry Potter star...",Harry Potter star Daniel Radcliffe gets £20M f...,42c027e4ff9730fbb3de84c1af0d2c506e41c3e4,"LONDON, England (Reuters) -- Harry Potter star...",Harry Potter star Daniel Radcliffe gets 20M fo...,454,41
1,Editor's note: In our Behind the Scenes series...,Mentally ill inmates in Miami are housed on th...,ee8871b15c50d0db17b0179a6d2beab35065f1e9,The ninth floor of the Miami-Dade pretrial det...,Mentally ill inmates in Miami are housed on th...,636,49
2,"MINNEAPOLIS, Minnesota (CNN) -- Drivers who we...","NEW: ""I thought I was going to die,"" driver sa...",06352019a19ae31e527f37f7571c6dd7f0c5da37,Drivers who were on the Minneapolis bridge whe...,"""I thought I was going to die,"" driver says . ...",736,42
3,WASHINGTON (CNN) -- Doctors removed five small...,"Five small polyps found during procedure; ""non...",24521a2abb2e1f5e34e6824e0f9e56904a2b0e88,Doctors removed five small polyps from Preside...,"Five small polyps found during procedure; ""non...",410,27
4,(CNN) -- The National Football League has ind...,"NEW: NFL chief, Atlanta Falcons owner critical...",7fe70cc8b12fab2d0a258fababf7d9c6b5e1262a,The National Football League has indefinitely ...,"NFL chief, Atlanta Falcons owner critical of M...",969,43


In [None]:
# estimate token length needed
cnn_dm['doc_length'] = cnn_dm['doc_cleaned'].apply(lambda x: len(x.split()))
cnn_dm['summary_length'] = cnn_dm['summary_cleaned'].apply(lambda x: len(x.split()))

In [None]:
cnn_dm['doc_length'].describe()

count    2000.000000
mean      596.532500
std       291.123727
min        18.000000
25%       368.750000
50%       553.000000
75%       793.000000
max      1829.000000
Name: doc_length, dtype: float64

In [None]:
cnn_dm['summary_length'].describe()

count    2000.000000
mean       42.950000
std         7.634241
min        11.000000
25%        37.000000
50%        43.000000
75%        49.000000
max        69.000000
Name: summary_length, dtype: float64

In [None]:
dashline = '-'.join('' for x in range(50))

# 3.1 Summarize with FLAN-T5

In [None]:
# load cnn_dm_sample, calculate token length, split into 3 sets, each with 100 random samples
model_name='google/flan-t5-base'
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)



config.json:   0%|          | 0.00/1.40k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/990M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/2.54k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.42M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/2.20k [00:00<?, ?B/s]

In [None]:
cnn = pd.read_pickle(os.path.join(data_path, 'cnn_dm_sample.pkl'))
cnn.shape

(2000, 3)

In [None]:
cnn.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 2000 entries, 0 to 1999
Data columns (total 3 columns):
 #   Column      Non-Null Count  Dtype 
---  ------      --------------  ----- 
 0   article     2000 non-null   object
 1   highlights  2000 non-null   object
 2   id          2000 non-null   object
dtypes: object(3)
memory usage: 47.0+ KB


In [None]:
cnn['article'] = cnn['article'].str.strip()
cnn['prompt'] = cnn['article'].apply(lambda x: "summarize: "+str(x).replace("\n"," "))

In [None]:
cnn['token_length'] = cnn['prompt'].apply(lambda x: len(tokenizer.tokenize(x)))
cnn['token_length'].describe()

Token indices sequence length is longer than the specified maximum sequence length for this model (637 > 512). Running this sequence through the model will result in indexing errors


count    2000.000000
mean      852.977500
std       414.080449
min        36.000000
25%       528.000000
50%       786.000000
75%      1118.000000
max      2508.000000
Name: token_length, dtype: float64

In [None]:
cnn.to_pickle(os.path.join(data_path, 'cnn_dm_sample.pkl'))

In [None]:
cnn_short = cnn[cnn['token_length']<=512]
print(len(cnn_short))

467


In [None]:
cnn_short_sample = cnn_short.sample(n=100, random_state=42)
print(len(cnn_short_sample))
cnn_short_sample.to_pickle(os.path.join(data_path, 'cnn_short.pkl'))

In [None]:
cnn_mid = cnn[(cnn['token_length']>512) & (cnn['token_length']<=1024)]
print(len(cnn_mid))
cnn_mid_sample = cnn_mid.sample(n=100)
print(len(cnn_mid_sample))
cnn_mid_sample.to_pickle(os.path.join(data_path, 'cnn_mid.pkl'))

903
100


In [None]:
cnn_long = cnn[cnn['token_length']>1024]
print(len(cnn_long))
cnn_long_sample = cnn_long.sample(n=100)
print(len(cnn_long_sample))
cnn_long_sample.to_pickle(os.path.join(data_path, 'cnn_long.pkl'))

630
100


In [None]:
xsum = pd.read_pickle(os.path.join(data_path, 'xsum_sample.pkl'))
xsum.shape

(2000, 4)

In [None]:
xsum.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 2000 entries, 0 to 1999
Data columns (total 4 columns):
 #   Column      Non-Null Count  Dtype 
---  ------      --------------  ----- 
 0   document    2000 non-null   object
 1   summary     2000 non-null   object
 2   id          2000 non-null   object
 3   doc_length  2000 non-null   int64 
dtypes: int64(1), object(3)
memory usage: 62.6+ KB


In [None]:
xsum['document'] = xsum['document'].str.strip()
xsum['prompt'] = xsum['document'].apply(lambda x: "summarize: "+str(x).replace("\n"," "))

In [None]:
xsum['token_length'] = xsum['prompt'].apply(lambda x: len(tokenizer.tokenize(x)))
xsum['token_length'].describe()

count    2000.000000
mean      527.085000
std       406.521732
min        14.000000
25%       255.000000
50%       417.500000
75%       689.500000
max      3968.000000
Name: token_length, dtype: float64

In [None]:
xsum_short = xsum[xsum['token_length']<=512]
print(len(xsum_short))
xsum_short_sample = xsum_short.sample(n=100)
print(len(xsum_short_sample))
xsum_short_sample.to_pickle(os.path.join(data_path, 'xsum_short.pkl'))

1227
100


In [None]:
xsum_mid = xsum[(xsum['token_length']>512) & (xsum['token_length']<=1024)]
print(len(xsum_mid))
xsum_mid_sample = xsum_mid.sample(n=100)
print(len(xsum_mid_sample))
xsum_mid_sample.to_pickle(os.path.join(data_path, 'xsum_mid.pkl'))

554
100


In [None]:
xsum_long = xsum[xsum['token_length']>1024]
print(len(xsum_long))
xsum_long_sample = xsum_long.sample(n=100)
print(len(xsum_long_sample))
xsum_long_sample.to_pickle(os.path.join(data_path, 'xsum_long.pkl'))

219
100


In [None]:
# summarize with FLAN-T5 (zero shot)
def t5_summarizer(prompt):
  tokenized_text = tokenizer.encode(prompt, return_tensors="pt")

  # summmarize
  summary_ids = model.generate(tokenized_text,
                                num_beams=4,
                                no_repeat_ngram_size=2,
                                min_length=20,
                                max_length=80,
                                early_stopping=True)

  output = tokenizer.decode(summary_ids[0], skip_special_tokens=True)

  return output

In [None]:
cnn_short_sample['prediction_t5'] = cnn_short_sample['prompt'].apply(t5_summarizer)
cnn_short_sample.head()

Unnamed: 0,article,highlights,id,prompt,token_length,prediction_t5
64,"PARIS, France (CNN) -- Interpol on Monday took...",Man posted photos on the Internet of himself s...,5ebd041d89a2ba41b387c30293f0657eef746910,"summarize: PARIS, France (CNN) -- Interpol on ...",335,Interpol has taken the unprecendent step of ma...
1705,(CNN) -- Electronics giant Sony launched its e...,PlayStation Home can be downloaded free of cha...,ccee27b87deb37c76b1c76042629af98af24a68b,summarize: (CNN) -- Electronics giant Sony lau...,376,"Sony's new social-networking site, PlayStation..."
1610,(CNN) -- Polaroid Corp. announced it was filin...,"""Our operations are strong,"" Polaroid CEO says...",75e6c61d7fb9388bf5dc9faa4cbe963801625aec,summarize: (CNN) -- Polaroid Corp. announced i...,200,"Petters Group Worldwide, which has owned Polar..."
1034,(CNN) -- A strong earthquake measuring 6.1 in ...,Strong quake measuring 6.1 in magnitude strike...,5474a9f81601abaa7ec81c6a9e1c5b7140acdc7b,summarize: (CNN) -- A strong earthquake measur...,494,A strong earthquake measuring 6.1 in magnitude...
1023,"COLOMBO, Sri Lanka (CNN) -- A Sri Lankan gover...",Minister's bodyguards hurt in blast in souther...,b3fe558375557990a10b9eaffb80545a80cc7c5a,"summarize: COLOMBO, Sri Lanka (CNN) -- A Sri L...",270,Maithripala Sirisena's bodyguards were hurt in...


In [None]:
cnn_short_sample.to_pickle(os.path.join(data_path, 'cnn_short_t5.pkl'))

In [None]:
%time
xsum_short_sample['prediction_t5'] = xsum_short_sample['prompt'].apply(t5_summarizer)
xsum_short_sample.head()

CPU times: user 4 µs, sys: 0 ns, total: 4 µs
Wall time: 7.15 µs


Unnamed: 0,document,summary,id,doc_length,prompt,token_length,prediction_t5
1710,Pompey beat struggling Newport County at home ...,"Portsmouth midfielder Danny Rose believes ""exp...",39407927,205,summarize: Pompey beat struggling Newport Coun...,291,Portsmouth manager Mark Rose says the gaffer h...
479,"Farooq Shah, 21, of Station Road, Forest Gate,...",A man has been sentenced to life in prison for...,28064995,372,"summarize: Farooq Shah, 21, of Station Road, F...",508,A man has been jailed for life for the murder ...
1723,"Yu Muchun, 20, and Tang Wentian, 21, were jail...","Two young drivers have been jailed over a ""Fas...",32825212,233,"summarize: Yu Muchun, 20, and Tang Wentian, 21...",329,Two Chinese men have been jailed for dangerous...
1625,"Shaun Woodburn, 30, died after a disturbance i...",A fourth person has been arrested following th...,39373250,132,"summarize: Shaun Woodburn, 30, died after a di...",179,A 17-year-old has been charged with the murder...
483,"The show was part of an economic event, ""Make ...",Bollywood actors Amitabh Bachchan and Aamir Kh...,35574186,70,summarize: The show was part of an economic ev...,97,A fire has broken out at a music festival in t...


In [None]:
xsum_short_sample.to_pickle(os.path.join(data_path, 'xsum_short_t5.pkl'))

In [None]:
%time
cnn_mid_sample['prediction_t5'] = cnn_mid_sample['prompt'].apply(t5_summarizer)
cnn_mid_sample.head()

CPU times: user 3 µs, sys: 0 ns, total: 3 µs
Wall time: 6.91 µs


Unnamed: 0,article,highlights,id,prompt,token_length,prediction_t5
36,WASHINGTON (CNN) -- Some Democrats appear to b...,NEW: Majority Leader says a number of Dems are...,9e141ebc1c06d483bf85f4e5c8b3c7bc2d00ea2f,summarize: WASHINGTON (CNN) -- Some Democrats ...,958,House resolution labeling Turkey's treatment o...
1178,"JOHANNESBURG, South Africa (CNN) -- Ten South ...","10 S. African ministers, deputy president resi...",4f88a57f0fb8f63c6f668a718d19d557fdd9381b,"summarize: JOHANNESBURG, South Africa (CNN) --...",1003,Ten South African ministers and the deputy pre...
332,(CNN) -- The toddler whose body washed ashore ...,"Woman, boyfriend arrested after a tip led to s...",c43c7253f38ff2cd79770034ed9af3567cfaa811,summarize: (CNN) -- The toddler whose body was...,632,NEW: DNA analysis is still in progress to conf...
1613,(CNN) -- A woman accused of killing her 2-year...,Kimberly Dawn Trenor pleads guilty to tamperin...,1f4603ad3ef986c557014f69422a92e345e3c9a8,summarize: (CNN) -- A woman accused of killing...,609,Kimberly Dawn Trenor is scheduled to go on tri...
1829,"LONDON, England (CNN) -- Identifying the world...","Travelers rank Hong Kong, Singapore and Seoul ...",dd9072a39bc7bf907c5729da8356dee5ff947c12,"summarize: LONDON, England (CNN) -- Identifyin...",571,Identifying the world's finest airports is eas...


In [None]:
cnn_mid_sample.to_pickle(os.path.join(data_path, 'cnn_mid_t5.pkl'))

In [None]:
xsum_mid_sample['prediction_t5'] = xsum_mid_sample['prompt'].apply(t5_summarizer)
xsum_mid_sample.head()

Unnamed: 0,document,summary,id,doc_length,prompt,token_length,prediction_t5
1588,The four-minute video shows the collision of t...,A virtual reality video simulating a drink dri...,38050910,540,summarize: The four-minute video shows the col...,728,Diageo has launched a virtual reality campaign...
1321,"Morgan, 30, opted out of this month's tour bec...","Eoin Morgan will ""definitely"" captain England'...",37647749,386,"summarize: Morgan, 30, opted out of this month...",564,England's captain Morgan will be captain in In...
740,"It estimates poor diets are causing around 70,...",An extra 20% tax on sugary drinks should be in...,33479118,491,summarize: It estimates poor diets are causing...,635,The British Medical Association (BMA) has call...
1173,"The men's four, men's and women's pair and lig...","Great Britain will take a strong team, that in...",39887571,286,"summarize: The men's four, men's and women's p...",574,British rowers have been selected to compete a...
889,Media playback is not supported on this device...,Leinster moved top of the Pro12 table with an ...,39049331,319,summarize: Media playback is not supported on ...,571,Leinster extended their lead at the top of the...


In [None]:
xsum_mid_sample.to_pickle(os.path.join(data_path, 'xsum_mid_t5.pkl'))

In [None]:
xsum_mid_sample.loc[740,'summary']

'An extra 20% tax on sugary drinks should be introduced to tackle the obesity crisis, the British Medical Association says.'

In [None]:
xsum_mid_sample.loc[740,'prediction_t5']

'The British Medical Association (BMA) has called for a tax of at least 20% on sugar in food and drinks.'

# 3.2 Summarize with Mistral-7b-instruct

In [28]:
# use deepinfra api
import openai
from openai import OpenAI

from google.colab import userdata
api_key = userdata.get('Deepinfra_api_auto')

client = OpenAI(api_key=api_key, base_url="https://api.deepinfra.com/v1/openai")

## Zero Shot

In [29]:
def mistral_summarizer_zero(doc):
  model_id = "mistralai/Mistral-7B-Instruct-v0.2"

  messages = [
    {"role": "system", "content": "Summarize content you are provided with."},
    {"role": "user", "content": f"Summarize: {doc}\n\nSummary:"},
    ]

  chat_completion = client.chat.completions.create(model=model_id,
        messages=messages,
        stream=False,
        max_tokens=80)

  summary = chat_completion.choices[0].message.content
  prompt_tokens = chat_completion.usage.prompt_tokens
  completion_tokens = chat_completion.usage.completion_tokens

  return (summary, prompt_tokens, completion_tokens)

In [None]:
# test
doc = "Jupiter is the fifth planet from the Sun and the largest in the Solar System. It is a gas giant with a mass one-thousandth that of the Sun, but two-and-a-half times that of all the other planets in the Solar System combined. Jupiter is one of the brightest objects visible to the naked eye in the night sky, and has been known to ancient civilizations since before recorded history. It is named after the Roman god Jupiter.[19] When viewed from Earth, Jupiter can be bright enough for its reflected light to cast visible shadows,[20] and is on average the third-brightest natural object in the night sky after the Moon and Venus."
ref = "Jupiter is a big planet in our Solar System that is fifth from the Sun. It is the largest planet and is made of gas. Jupiter is very bright in the sky and can even cast shadows on Earth. It has been known to people for a long time and is named after a Roman god."
output = mistral_summarizer_zero(doc)
print(output[0])

 Jupiter is the fifth planet from the Sun and the largest in the Solar System, boasting a mass one-thousandth that of the Sun and twice the mass of all other planets combined. Known since ancient civilizations, Jupiter is a gas giant and the third-brightest natural object in the night sky after the Moon and Venus. Its brightness can cast visible


In [None]:
print(output[1:])

(175, 80)


In [None]:
cnn_short_sample = pd.read_pickle(os.path.join(data_path, 'cnn_short.pkl'))
cnn_short_sample.head()

Unnamed: 0,article,highlights,id,prompt,token_length
64,"PARIS, France (CNN) -- Interpol on Monday took...",Man posted photos on the Internet of himself s...,5ebd041d89a2ba41b387c30293f0657eef746910,"summarize: PARIS, France (CNN) -- Interpol on ...",335
1705,(CNN) -- Electronics giant Sony launched its e...,PlayStation Home can be downloaded free of cha...,ccee27b87deb37c76b1c76042629af98af24a68b,summarize: (CNN) -- Electronics giant Sony lau...,376
1610,(CNN) -- Polaroid Corp. announced it was filin...,"""Our operations are strong,"" Polaroid CEO says...",75e6c61d7fb9388bf5dc9faa4cbe963801625aec,summarize: (CNN) -- Polaroid Corp. announced i...,200
1034,(CNN) -- A strong earthquake measuring 6.1 in ...,Strong quake measuring 6.1 in magnitude strike...,5474a9f81601abaa7ec81c6a9e1c5b7140acdc7b,summarize: (CNN) -- A strong earthquake measur...,494
1023,"COLOMBO, Sri Lanka (CNN) -- A Sri Lankan gover...",Minister's bodyguards hurt in blast in souther...,b3fe558375557990a10b9eaffb80545a80cc7c5a,"summarize: COLOMBO, Sri Lanka (CNN) -- A Sri L...",270


In [None]:
cnn_short_sample['output_mistral'] = cnn_short_sample['article'].apply(mistral_summarizer_zero)
cnn_short_sample.head()

Unnamed: 0,article,highlights,id,prompt,token_length,output_mistral
64,"PARIS, France (CNN) -- Interpol on Monday took...",Man posted photos on the Internet of himself s...,5ebd041d89a2ba41b387c30293f0657eef746910,"summarize: PARIS, France (CNN) -- Interpol on ...",335,"( Interpol, an international police agency bas..."
1705,(CNN) -- Electronics giant Sony launched its e...,PlayStation Home can be downloaded free of cha...,ccee27b87deb37c76b1c76042629af98af24a68b,summarize: (CNN) -- Electronics giant Sony lau...,376,( Sony released its new social networking site...
1610,(CNN) -- Polaroid Corp. announced it was filin...,"""Our operations are strong,"" Polaroid CEO says...",75e6c61d7fb9388bf5dc9faa4cbe963801625aec,summarize: (CNN) -- Polaroid Corp. announced i...,200,"( Polaroid Corporation, based in Minnesota, an..."
1034,(CNN) -- A strong earthquake measuring 6.1 in ...,Strong quake measuring 6.1 in magnitude strike...,5474a9f81601abaa7ec81c6a9e1c5b7140acdc7b,summarize: (CNN) -- A strong earthquake measur...,494,( A 6.1 magnitude earthquake hit southern Iran...
1023,"COLOMBO, Sri Lanka (CNN) -- A Sri Lankan gover...",Minister's bodyguards hurt in blast in souther...,b3fe558375557990a10b9eaffb80545a80cc7c5a,"summarize: COLOMBO, Sri Lanka (CNN) -- A Sri L...",270,"( A Sri Lankan government minister, Maithripal..."


In [None]:
cnn_short_sample['summary_mistral'] = cnn_short_sample['output_mistral'].apply(lambda x: x[0])
cnn_short_sample['completion_tokens'] = cnn_short_sample['output_mistral'].apply(lambda x: x[2])
cnn_short_sample.head()

Unnamed: 0,article,highlights,id,prompt,token_length,output_mistral,summary_mistral,completion_tokens
64,"PARIS, France (CNN) -- Interpol on Monday took...",Man posted photos on the Internet of himself s...,5ebd041d89a2ba41b387c30293f0657eef746910,"summarize: PARIS, France (CNN) -- Interpol on ...",335,"( Interpol, an international police agency bas...","Interpol, an international police agency base...",80
1705,(CNN) -- Electronics giant Sony launched its e...,PlayStation Home can be downloaded free of cha...,ccee27b87deb37c76b1c76042629af98af24a68b,summarize: (CNN) -- Electronics giant Sony lau...,376,( Sony released its new social networking site...,"Sony released its new social networking site,...",80
1610,(CNN) -- Polaroid Corp. announced it was filin...,"""Our operations are strong,"" Polaroid CEO says...",75e6c61d7fb9388bf5dc9faa4cbe963801625aec,summarize: (CNN) -- Polaroid Corp. announced i...,200,"( Polaroid Corporation, based in Minnesota, an...","Polaroid Corporation, based in Minnesota, ann...",80
1034,(CNN) -- A strong earthquake measuring 6.1 in ...,Strong quake measuring 6.1 in magnitude strike...,5474a9f81601abaa7ec81c6a9e1c5b7140acdc7b,summarize: (CNN) -- A strong earthquake measur...,494,( A 6.1 magnitude earthquake hit southern Iran...,A 6.1 magnitude earthquake hit southern Iran ...,80
1023,"COLOMBO, Sri Lanka (CNN) -- A Sri Lankan gover...",Minister's bodyguards hurt in blast in souther...,b3fe558375557990a10b9eaffb80545a80cc7c5a,"summarize: COLOMBO, Sri Lanka (CNN) -- A Sri L...",270,"( A Sri Lankan government minister, Maithripal...","A Sri Lankan government minister, Maithripala...",80


In [None]:
cnn_short_sample['completion_tokens'].describe()

count    100.0
mean      80.0
std        0.0
min       80.0
25%       80.0
50%       80.0
75%       80.0
max       80.0
Name: completion_tokens, dtype: float64

In [None]:
cnn_short_sample.to_pickle(os.path.join(data_path, 'cnn_short_mistral_zero.pkl'))

In [30]:
# Repeat with cnn mid dataset
cnn_mid_sample = pd.read_pickle(os.path.join(data_path, 'cnn_mid.pkl'))
cnn_mid_sample['output_mistral'] = cnn_mid_sample['article'].apply(mistral_summarizer_zero)
cnn_mid_sample['summary_mistral'] = cnn_mid_sample['output_mistral'].apply(lambda x: x[0])
cnn_mid_sample['completion_tokens'] = cnn_mid_sample['output_mistral'].apply(lambda x: x[2])
cnn_mid_sample.to_pickle(os.path.join(data_path, 'cnn_mid_mistral_zero.pkl'))

## One Shot

In [None]:
# provide 10 examples to randomly choose as examples
cnn_short = pd.read_pickle(os.path.join(data_path, 'cnn_short.pkl'))
cnn_short['doc'] = cnn_short['article'].str.replace("\n", " ")
cnn_short['summary'] = cnn_short['highlights'].str.replace("\n", " ")

# Step 1: save 10 random examples
cnn_short_examples = cnn_short.sample(n=10, random_state=42)

# Step 2: Save what's not sampled as _remained
cnn_short_remained = cnn_short.drop(cnn_short_examples.index)


In [None]:
len(cnn_short_remained)

90

In [None]:
cnn_short_examples.head()

Unnamed: 0,article,highlights,id,prompt,token_length,doc,summary
1431,"ISLAMABAD, Pakistan (CNN) -- Pakistan has inde...",NEW: NATO force expects no impact on ability t...,8e33acabd22582a7b2373b68c9cd456198ae0e97,"summarize: ISLAMABAD, Pakistan (CNN) -- Pakist...",489,"ISLAMABAD, Pakistan (CNN) -- Pakistan has inde...",NEW: NATO force expects no impact on ability t...
414,"AMSTERDAM, Holland -- Ajax lost ground on Dutc...",Second-placed Ajax held 2-2 at home by Vitesse...,f489b07406c653968a0ed21c00e28c9124e4a49b,"summarize: AMSTERDAM, Holland -- Ajax lost gro...",420,"AMSTERDAM, Holland -- Ajax lost ground on Dutc...",Second-placed Ajax held 2-2 at home by Vitesse...
1530,(CNN) -- German sailors foiled an attempt by p...,German government later ordered pirates releas...,a5c1959a4d1ab2a5f26b2737bc943ebe19d5cd78,summarize: (CNN) -- German sailors foiled an a...,437,(CNN) -- German sailors foiled an attempt by p...,German government later ordered pirates releas...
1021,"(CNN) -- The recent snowstorm in China, which ...","""The warm air is very active this year"", said ...",14fcaef3a144b1f60726e365a4aeccd39e9f2bc1,summarize: (CNN) -- The recent snowstorm in Ch...,484,"(CNN) -- The recent snowstorm in China, which ...","""The warm air is very active this year"", said ..."
1051,(CNN) -- Maoist insurgents killed a dozen sol...,Ambush kills 12 soldiers and two civilians in ...,6f95a187921b94549a5467b3d39a71089412e9d9,summarize: (CNN) -- Maoist insurgents killed ...,283,(CNN) -- Maoist insurgents killed a dozen sol...,Ambush kills 12 soldiers and two civilians in ...


In [None]:
def mistral_summarizer_one(doc):

  # Randomly select one record
  random_record = cnn_short_examples.sample(n=1).iloc[0]
  example_doc = random_record['doc']
  example_summary = random_record['summary']

  messages = [
    {"role": "system", "content": "Summarize content you are provided with."},
    {"role": "user", "content": f"Summarize: {example_doc}\n\nSummary: {example_summary}\n\n\nSummarize: {doc}\n\nSummary:"},
    ]

  model_id = "mistralai/Mistral-7B-Instruct-v0.2"

  chat_completion = client.chat.completions.create(model=model_id,
        messages=messages,
        stream=False,
        max_tokens=80)

  summary = chat_completion.choices[0].message.content
  prompt_tokens = chat_completion.usage.prompt_tokens
  completion_tokens = chat_completion.usage.completion_tokens

  return (summary, prompt_tokens, completion_tokens)

In [None]:
cnn_short_remained['output_mistral'] = cnn_short_remained['doc'].apply(mistral_summarizer_one)

In [None]:
cnn_short_remained['summary_mistral'] = cnn_short_remained['output_mistral'].apply(lambda x: x[0])
cnn_short_remained['completion_tokens'] = cnn_short_remained['output_mistral'].apply(lambda x: x[2])

In [None]:
cnn_short_remained['completion_tokens'].describe()

count    90.000000
mean     78.922222
std       5.156490
min      35.000000
25%      80.000000
50%      80.000000
75%      80.000000
max      80.000000
Name: completion_tokens, dtype: float64

In [None]:
cnn_short_remained.to_pickle(os.path.join(data_path, 'cnn_short_mistral_one.pkl'))

# 4.1 Evalaute using ROUGE

In [7]:
from evaluate import load
# Load the ROUGE metric
import evaluate
rouge = evaluate.load('rouge')

Downloading builder script:   0%|          | 0.00/6.27k [00:00<?, ?B/s]

In [15]:
# cnn short rouge scores: t5
cnn_short_t5 = pd.read_pickle(os.path.join(data_path, 'cnn_short_t5.pkl'))

predictions = list(cnn_short_t5['prediction_t5'])

cnn_short_t5['reference'] = cnn_short_t5['highlights'].str.replace("\n", "")
references = list(cnn_short_t5['reference'])

results = rouge.compute(predictions=predictions, references=references, use_stemmer=True)
print(results)

{'rouge1': 0.30203216777065534, 'rouge2': 0.13107502119161152, 'rougeL': 0.23108419742485378, 'rougeLsum': 0.2304648321400154}


In [16]:
# cnn short rouge scores (zero mistral)
cnn_short_mis0 = pd.read_pickle(os.path.join(data_path, 'cnn_short_mistral_zero.pkl'))

predictions = list(cnn_short_mis0['summary_mistral'])

cnn_short_mis0['reference'] = cnn_short_mis0['highlights'].str.replace("\n", "")
references = list(cnn_short_mis0['reference'])

results = rouge.compute(predictions=predictions, references=references, use_stemmer=True)
print(results)

{'rouge1': 0.40138339851826954, 'rouge2': 0.15458056272410037, 'rougeL': 0.26345439546566635, 'rougeLsum': 0.2633462855293042}


In [17]:
# cnn short rouge scores (one shot mistral)
cnn_short_mis1 = pd.read_pickle(os.path.join(data_path, 'cnn_short_mistral_one.pkl'))

predictions = list(cnn_short_mis1['summary_mistral'])
references = list(cnn_short_mis1['summary'])

results = rouge.compute(predictions=predictions, references=references, use_stemmer=True)
print(results)

{'rouge1': 0.413343327409631, 'rouge2': 0.15913312350277548, 'rougeL': 0.2773383705572432, 'rougeLsum': 0.2773193991439024}


In [24]:
# cnn mid rouge scores: T5
cnn_mid_t5 = pd.read_pickle(os.path.join(data_path, 'cnn_mid_t5.pkl'))

predictions = list(cnn_mid_t5['prediction_t5'])

cnn_mid_t5['reference'] = cnn_mid_t5['highlights'].str.replace("\n", "")
references = list(cnn_mid_t5['reference'])

results = rouge.compute(predictions=predictions, references=references, use_stemmer=True)
print(results)

{'rouge1': 0.3252175442409779, 'rouge2': 0.1300822580777003, 'rougeL': 0.22777034237856425, 'rougeLsum': 0.22872513466599648}


In [31]:
# cnn mid rouge scores (zero mistral)
cnn_mid_mis0 = pd.read_pickle(os.path.join(data_path, 'cnn_mid_mistral_zero.pkl'))

predictions = list(cnn_mid_mis0['summary_mistral'])

cnn_mid_mis0['reference'] = cnn_mid_mis0['highlights'].str.replace("\n", "")
references = list(cnn_mid_mis0['reference'])

results = rouge.compute(predictions=predictions, references=references, use_stemmer=True)
print(results)

{'rouge1': 0.3811779793051331, 'rouge2': 0.13907859666383177, 'rougeL': 0.24401086093354382, 'rougeLsum': 0.24475904879089894}


# 4.2 Evaluate using BERTScore

In [18]:
from bert_score import BERTScorer
scorer = BERTScorer(model_type='bert-base-uncased')



tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

In [19]:
# cnn short BERTScore: T5
predictions = list(cnn_short_t5['prediction_t5'])

cnn_short_t5['reference'] = cnn_short_t5['highlights'].str.replace("\n", "")
references = list(cnn_short_t5['reference'])

P, R, F1 = scorer.score(predictions, references)
print(f"BERTScore Precision: {P.mean():.4f}, Recall: {R.mean():.4f}, F1: {F1.mean():.4f}")

BERTScore Precision: 0.6077, Recall: 0.5451, F1: 0.5725


In [20]:
# cnn short BERTScore (zero shot mistral)
predictions = list(cnn_short_mis0['summary_mistral'])

cnn_short_mis0['reference'] = cnn_short_mis0['highlights'].str.replace("\n", "")
references = list(cnn_short_mis0['reference'])

P, R, F1 = scorer.score(predictions, references)
print(f"BERTScore Precision: {P.mean():.4f}, Recall: {R.mean():.4f}, F1: {F1.mean():.4f}")

BERTScore Precision: 0.5886, Recall: 0.6413, F1: 0.6132


In [21]:
# cnn short BERTScore (one shot mistral)
predictions = list(cnn_short_mis1['summary_mistral'])
references = list(cnn_short_mis1['summary'])

P, R, F1 = scorer.score(predictions, references)
print(f"BERTScore Precision: {P.mean():.4f}, Recall: {R.mean():.4f}, F1: {F1.mean():.4f}")

BERTScore Precision: 0.5995, Recall: 0.6498, F1: 0.6229


In [32]:
# cnn mid BERTScore: T5
predictions = list(cnn_mid_t5['prediction_t5'])

cnn_mid_t5['reference'] = cnn_mid_t5['highlights'].str.replace("\n", "")
references = list(cnn_mid_t5['reference'])

P, R, F1 = scorer.score(predictions, references)
print(f"BERTScore Precision: {P.mean():.4f}, Recall: {R.mean():.4f}, F1: {F1.mean():.4f}")

BERTScore Precision: 0.6044, Recall: 0.5628, F1: 0.5809


In [33]:
# cnn mid BERTScore (zero shot mistral)
predictions = list(cnn_mid_mis0['summary_mistral'])

cnn_mid_mis0['reference'] = cnn_mid_mis0['highlights'].str.replace("\n", "")
references = list(cnn_mid_mis0['reference'])

P, R, F1 = scorer.score(predictions, references)
print(f"BERTScore Precision: {P.mean():.4f}, Recall: {R.mean():.4f}, F1: {F1.mean():.4f}")

BERTScore Precision: 0.5791, Recall: 0.6178, F1: 0.5974


# 4.3 Evaluate using GPT-3.5

References:

* open ai summarization evaluation cookbook: https://cookbook.openai.com/examples/evaluation/how_to_eval_abstractive_summarization
* Standform HELM

<b>Log Probability:</b>

To simplify, a logprob is log(p), where p = probability of a token occurring at a specific position based on the previous tokens in the context. Some key points about logprobs:

Higher log probabilities suggest a higher likelihood of the token in that context. This allows users to gauge the model's confidence in its output or explore alternative responses the model considered.
Logprob can be any negative number or 0.0. 0.0 corresponds to 100% probability.


In [None]:
from openai import OpenAI

from google.colab import userdata
api_key = userdata.get('openai_practice')

client = OpenAI(api_key = api_key)

In [None]:
# get completion
def get_completion(messages, max_tokens, model="gpt-3.5-turbo"):
    response = client.chat.completions.create(
        model=model,
        messages=messages,
        max_tokens=max_tokens,
        temperature=0, # this is the degree of randomness of the model's output
        logprobs=True,
        # top_p=1
    )

    # summary = response.choices[0].message.content
    # prompt_tokens = response.usage.prompt_tokens
    # completion_tokens = response.usage.completion_tokens

    # return {"summary": summary, "prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens}
    return response

In [None]:
# evaluate Faithfulness

def eval_faithful(doc, summary):
    system_content = """
    You will be given one summary written for an article. Your task is to determine if all the information expressed by the summary can be inferred from the article.
    Please make sure you read and understand these instructions very carefully.
    Please keep this document open while reviewing, and refer to it as needed.

    Instruction:

    1. Read the article carefully and identify the main facts and details it presents.
    2. Read the summary and compare it to the article. Check if the summary contains any factual errors that are not supported by the article.
    3. Respond 'Yes' or 'No'.
    """

    user_content = f"""
    article:

    {doc}

    summary:

    {summary}

    Can all the information expressed by the summary be inferred from the article?
    """

    messages = [
    {
      "role": "system",
      "content": system_content
    },
    {
      "role": "user",
      "content": user_content
    }
    ]

    response = get_completion(messages, max_tokens=5)

    eval = response.choices[0].message.content
    log_probability = response.choices[0].logprobs.content[0].logprob
    linear_probability = np.round(np.exp(log_probability)*100,2)

    return (eval, log_probability, linear_probability)


In [None]:
# test
doc = "Jupiter is the fifth planet from the Sun and the largest in the Solar System. It is a gas giant with a mass one-thousandth that of the Sun, but two-and-a-half times that of all the other planets in the Solar System combined. Jupiter is one of the brightest objects visible to the naked eye in the night sky, and has been known to ancient civilizations since before recorded history. It is named after the Roman god Jupiter.[19] When viewed from Earth, Jupiter can be bright enough for its reflected light to cast visible shadows,[20] and is on average the third-brightest natural object in the night sky after the Moon and Venus."
ref = "Jupiter is a big planet in our Solar System that is fifth from the Sun. It is the largest planet and is made of gas. Jupiter is very bright in the sky and can even cast shadows on Earth. It has been known to people for a long time and is named after a Roman god."
# perturb the ref summary: replace Jupiter with Mars
perturbed_ref_1 = "Mars is a big planet in our Solar System that is fifth from the Sun. It is the largest planet and is made of gas. Mars is very bright in the sky and can even cast shadows on Earth. It has been known to people for a long time and is named after a Roman god."
# perturb the ref summary: replace 5th with 7th
perturbed_ref_2 = "Jupiter is the largest planet in our Solar System that is 7th from the Sun. It is made of gas. Jupiter is very bright in the sky and can even cast shadows on Earth. It has been known to people for a long time and is named after a Roman god."
# perturb the ref summary: add new info
perturbed_ref_3 = "Jupiter is a big planet in our Solar System that is fifth from the Sun. It is the largest planet and is mostly made of hydrogen and helium. Jupiter is very bright in the sky and can even cast shadows on Earth. It has been known to people for a long time and is named after a Roman god."

print(len(doc.split()))
print(len(ref.split()))

111
56


In [None]:
x,y = eval_faithful(doc, ref)
print(x, y)

Yes -0.0062249014


In [None]:
x,y = eval_faithful(doc, perturbed_ref_1)
print(x, y)

No -0.0020199977


In [None]:
x,y = eval_faithful(doc, perturbed_ref_2)
print(x, y)

No -0.011451195


In [None]:
x,y = eval_faithful(doc, perturbed_ref_3)
print(x, y)

No -0.0039694053


In [None]:
# messages = [
#     {
#       "role": "system",
#       "content": "Summarize content you are provided with for a second-grade student."
#     },
#     {
#       "role": "user",
#       "content": doc
#     }
#   ]

In [None]:
test = get_completion(messages, max_tokens=5)
print(test)


ChatCompletion(id='chatcmpl-9LfIpBF6ChOBFgVDGjVGvwotGoODy', choices=[Choice(finish_reason='stop', index=0, logprobs=ChoiceLogprobs(content=[ChatCompletionTokenLogprob(token='Yes', bytes=[89, 101, 115], logprob=-0.022179779, top_logprobs=[])]), message=ChatCompletionMessage(content='Yes', role='assistant', function_call=None, tool_calls=None))], created=1714950683, model='gpt-3.5-turbo-0125', object='chat.completion', system_fingerprint='fp_3b956da36b', usage=CompletionUsage(completion_tokens=1, prompt_tokens=380, total_tokens=381))


In [None]:
prediction = test.choices[0].message.content
print(prediction)

Yes


In [None]:
log_probability = test.choices[0].logprobs.content[0].logprob
print(log_probability)

-0.022179779


In [None]:
predictions = [prediction]
references = [ref]
results = rouge.compute(predictions=predictions, references=references, use_stemmer=True)
print("gpt3.5 summary:", prediction)
print("ref summery:", ref )
print(results)

gpt3.5 summary: Jupiter is a big planet in our Solar System, and it is the fifth planet from the Sun. It is the largest planet and is made mostly of gas. Jupiter is very bright in the night sky and has been known to people for a long time. It is named after a Roman god. When we look at Jupiter from Earth, it can be so bright that it can even cast shadows. It is one of the brightest things we can see in the sky at night, after the Moon and
ref summery: Jupiter is a big planet in our Solar System that is fifth from the Sun. It is the largest planet and is made of gas. Jupiter is very bright in the sky and can even cast shadows on Earth. It has been known to people for a long time and is named after a Roman god.
{'rouge1': 0.7482993197278911, 'rouge2': 0.6068965517241379, 'rougeL': 0.6394557823129252, 'rougeLsum': 0.6394557823129252}


In [None]:
# BERTScore calculation
scorer = BERTScorer(model_type='bert-base-uncased')
P, R, F1 = scorer.score([prediction], [ref])
print(f"BERTScore Precision: {P.mean():.4f}, Recall: {R.mean():.4f}, F1: {F1.mean():.4f}")



tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

BERTScore Precision: 0.7804, Recall: 0.9054, F1: 0.8383
