<a href="https://colab.research.google.com/github/shreetishresthanp/medical_summaries_evaluation/blob/main/summaries_evaluation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#### Set up Environment and Install Dependencies

In [None]:
# Verify GPU setup
! nvidia-smi

In [None]:
# Transformers installation
! pip install transformers[torch] datasets evaluate rouge_score
# Install dependencies
! pip install torch
! pip install bert_score
! pip install textstat

In [None]:
# optional huggingface authentication using token
from huggingface_hub import notebook_login

notebook_login()

#### Load dataset

In [None]:
from datasets import load_dataset

# load test_data.csv
med_dataset = load_dataset("csv", data_files="test_data.csv", split = "train")
med_dataset

In [None]:
# example of a record in the dataset
med_dataset[0]

#### Generate Summaries

In [None]:
import torch
device = "cuda" if torch.cuda.is_available() else "cpu" # checks if gpu is available
pipeline_device = 0 if device == "cuda" else -1 # for determining if we want to load model in GPU or CPU

In [None]:
model_id = "Falconsai/medical_summarization"
prompt = "Generate a plain language summary that is easy to read highlighting key points and removing unnecessary details that can be easily understood by non-medical people : "

In [None]:
from transformers import AutoTokenizer, AutoModel
model = AutoModel.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)

In [None]:
from transformers import pipeline

# Process inputs in batches to avoid running out of memory in Colab when testing larger models
batch_size = 16 # Adjust this value based on available memory
summarized_outputs = []
summarizer = pipeline("summarization", model=model_id, tokenizer=tokenizer, device=pipeline_device)

for i in range(0, len(med_dataset["abstract_text"]), batch_size):
    inputs_batch = [prompt + doc for doc in med_dataset["abstract_text"][i:i + batch_size]]
    outputs_batch = summarizer(inputs_batch, min_length=20, max_length=150, do_sample=False)
    summarized_outputs.extend([output["summary_text"] for output in outputs_batch])

#### Write outputs to file

In [None]:
import csv
with open("summarized_outputs.csv", "w") as f:
    writer = csv.writer(f)
    writer.writerow(["abstract_text", "target_text", "generated_text"])
    for i in range(len(summarized_outputs)):
      writer.writerow([med_dataset["abstract_text"][i], med_dataset["target_text"][i], summarized_outputs[i]])

### Evaluation

In [None]:
import evaluate
import textstat

In [None]:
rouge = evaluate.load("rouge")
bleu = evaluate.load("bleu")
bertscore = evaluate.load("bertscore")

#### ROUGE Score

In [None]:
agg_rouge_scores = rouge.compute(predictions=summarized_outputs, references=med_dataset["target_text"], use_stemmer=True, use_aggregator=True)
agg_rouge_scores

In [None]:
from statistics import mean
print("Average of Rouge Scores (ROUGE-1, ROUGE-2, and ROUGE-L): ", mean([agg_rouge_scores['rouge1'], agg_rouge_scores['rouge2'], agg_rouge_scores['rougeL']]))

In [None]:
import pandas as pd

rouge_scores = rouge.compute(predictions=summarized_outputs, references=med_dataset["target_text"], use_stemmer=True, use_aggregator=False)
metric_df = pd.DataFrame(rouge_scores)
metric_df.drop(columns=["rougeLsum"], inplace=True)
metric_df["avg_rouge_score"] = metric_df.mean(axis=1)
metric_df.head()


In [None]:
# Box plot for avg rouge scores of all records
import matplotlib.pyplot as plt
metric_df.boxplot(column='avg_rouge_score')
plt.title('Boxplot of Avg Rouge Score (ROUGE-1, ROUGE-2, and ROUGE-L)')
plt.ylabel('Avg Rouge Score (0 - 1)')
plt.show()

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

sns.scatterplot(data=metric_df, y="avg_rouge_score", x = range(len(metric_df)))

#### BLEU Score

In [None]:
bleu_score = bleu.compute(predictions=summarized_outputs, references=med_dataset["target_text"])
bleu_score

In [None]:
print(f"BLEU Score: {bleu_score['bleu'] * 100:.2f}")

#### BERT Score

In [None]:
bert_score = bertscore.compute(predictions=summarized_outputs, references=med_dataset["target_text"], lang="en")
bert_score

In [None]:
bert_metric_df = pd.DataFrame(bert_score)
bert_metric_df.drop(columns=["hashcode"], inplace=True)
bert_metric_df["precision"] = (bert_metric_df["precision"]*100).round(2)
bert_metric_df["recall"] = (bert_metric_df["recall"]*100).round(2)
bert_metric_df["f1"] = (bert_metric_df["f1"]*100).round(2)
bert_metric_df.head()

In [None]:
precision_mean = bert_metric_df['precision'].mean()
recall_mean = bert_metric_df['recall'].mean()
f1_mean = bert_metric_df['f1'].mean()
print(f"Precision Mean: {precision_mean:.2f}")
print(f"Recall Mean: {recall_mean:.2f}")
print(f"F1 Mean: {f1_mean:.2f}")

In [None]:
# Box plot for precision scores of all records
bert_metric_df.boxplot(column='precision')
plt.title('Boxplot of Precision Score (BERT)')
plt.ylabel('Score (%)')
plt.show()

In [None]:
# Box plot for precision scores of all records
bert_metric_df.boxplot(column='recall')
plt.title('Boxplot of Recall (BERT)')
plt.ylabel('Score (%)')
plt.show()

In [None]:
# Box plot for precision scores of all records
bert_metric_df.boxplot(column='f1')
plt.title('Boxplot of F1 Accuracy (BERT)')
plt.ylabel('Score (%)')
plt.show()

In [None]:
del bert_score['hashcode']

In [None]:
plt.boxplot(bert_score.values(), labels=bert_score.keys())
plt.show()

####  Flesch-Kincaid Grade Level (Readability Metric)

In [None]:
flesch_kincaid_grades = [textstat.flesch_kincaid_grade(summary) for summary in summarized_outputs]
flesch_reading_ease = [textstat.flesch_reading_ease(summary) for summary in summarized_outputs]

readability_scores = {
    "flesch_kincaid_grade": flesch_kincaid_grades,
    "flesch_reading_ease": flesch_reading_ease
}
# flesch_kincaid_grades[:5]
# flesch_reading_ease[:5]

In [None]:
from statistics import mean
print("Average of Flesch-Kincaid Grade Level (FKGL): ", mean(flesch_kincaid_grades))
print("Average of Flesch Reading Ease (FRE): ", mean(flesch_reading_ease))

In [None]:
plt.boxplot(readability_scores.values(), labels=readability_scores.keys())
plt.show()

#### Human Evaluation Metrics

In [None]:
q_1 = [2,2,3,3,3,4,1,1,2,2,3,4,2,3,3,2,3,4,2,3,3,2,3,4,4,4,4,3,3,3]
q_2 = [1,2,3,3,3,3,1,1,2,1,2,3,1,2,2,1,3,3,1,3,4,1,3,3,3,4,4,2,3,3]
q_3 = [2,2,2,1,2,3,2,2,3,2,2,3,2,2,3,2,2,3,2,2,3,1,2,4,3,3,3,2,3,3]
q_4 = [1,2,3,3,3,4,1,1,2,1,3,3,1,2,2,2,3,4,1,2,2,2,2,3,2,3,3,2,3,3]
q_5 = [1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,2,1,1,2]

In [None]:
d = {'q1': q_1, 'q2': q_2, 'q3': q_3, 'q4': q_4, 'q5': q_5}
df = pd.DataFrame(data=d)
df.head()

In [None]:
plt.boxplot(df.values, labels=df.keys())
plt.show()

In [None]:
mean_q1 = df['q1'].mean()
mean_q2 = df['q2'].mean()
mean_q3 = df['q3'].mean()
mean_q4 = df['q4'].mean()
mean_q5 = df['q5'].mean()
print("Average of Question 1: ", mean_q1)
print("Average of Question 2: ", mean_q2)
print("Average of Question 3: ", mean_q3)
print("Average of Question 4: ", mean_q4)
print("Average of Question 5: ", mean_q5)

In [None]:
std_1 = df['q1'].std()
std_2 = df['q2'].std()
std_3 = df['q3'].std()
std_4 = df['q4'].std()
std_5 = df['q5'].std()
print("Standard Deviation of Question 1: ", std_1)
print("Standard Deviation of Question 2: ", std_2)
print("Standard Deviation of Question 3: ", std_3)
print("Standard Deviation of Question 4: ", std_4)
print("Standard Deviation of Question 5: ", std_5)

In [None]:
from statistics import median
median_q1 = df['q1'].median()
median_q2 = df['q2'].median()
median_q3 = df['q3'].median()
median_q4 = df['q4'].median()
median_q5 = df['q5'].median()
print("Median of Question 1: ", median_q1)
print("Median of Question 2: ", median_q2)
print("Median of Question 3: ", median_q3)
print("Median of Question 4: ", median_q4)
print("Median of Question 5: ", median_q5)

In [None]:
from scipy import stats
iqr_q1 = stats.iqr(df['q1'])
iqr_q2 = stats.iqr(df['q2'])
iqr_q3 = stats.iqr(df['q3'])
iqr_q4 = stats.iqr(df['q4'])
iqr_q5 = stats.iqr(df['q5'])
print("IQR of Question 1: ", iqr_q1)
print("IQR of Question 2: ", iqr_q2)
print("IQR of Question 3: ", iqr_q3)
print("IQR of Question 4: ", iqr_q4)
print("IQR of Question 5: ", iqr_q5)