In [None]:
import pandas as pd 
from simplet5 import SimpleT5
from rouge import Rouge

In [None]:
df = pd.read_csv('./data/prepocessed.csv', delimiter=',',
                 engine='python', error_bad_lines=False, nrows=3000)

print(df.info())

drop_cols = ['overview', 'sectionLabel', 'title']
df = df.drop(drop_cols, axis=1)
df = df.dropna()

df.rename(columns={"headline":"target_text", "text": "source_text"}, inplace= True)
print(df.info())

# T5 model expects a task related prefix: since it is a summarization task, we have to add prefix "summarize: "
df['source_text'] = "summarize: " + df['source_text']
print(df.head(1)['source_text'])

In [None]:
# Train, Test, Val split (60, 20, 20)
train_data = df.sample(frac=0.60) #60%
rest_part_40 = df.drop(train_data.index)
test_data = rest_part_40.sample(frac=0.50) #20%
validation_data = rest_part_40.drop(test_data.index) #20%
print("Shapes: ", train_data.shape, validation_data.shape, test_data.shape)


In [None]:
# Finetuning T5 model
model = SimpleT5()
model.from_pretrained(model_type="t5", model_name="t5-base")

model.train(train_df=train_data,
            eval_df=validation_data,
            source_max_token_len=150,
            target_max_token_len=64,
            outputdir = "./models/",
            max_epochs=3, use_gpu=False)

In [None]:
model_t5 = SimpleT5()
model_t5.load_model("t5","./models/simplet5-epoch-0-train-loss-2.7284-val-loss-2.2741", use_gpu=False)

In [None]:


def predict_summary(row):
    input_text = row["source_text"] # assuming your DataFrame column is named "input_text"
    summary = model_t5.predict(input_text, max_length=512)
    return summary
test_data["predicted_summary"] = df.apply(predict_summary, axis=1)
df = df.dropna()
# print(type(test_data["predicted_summary"]))
rouge = Rouge()
scores = rouge.get_scores(test_data["target_text"], test_data["predicted_summary"])
print(scores)



In [None]:
# predicted_series = pd.Series(['The cat in the hat.', 'I like green eggs and ham.'])
# reference_series = pd.Series(['The cat in the hat is good.', 'I do not like them, Sam-I-Am.'])

# # Create a Rouge instance
# rouge = Rouge()

# # Get the ROUGE-1 F1 score for the first document in the series
# scores = rouge.get_scores(predicted_series, reference_series)
# print(scores)