### Medical Assistant Bot

In [1]:
import os 
import sys 

root_dir = os.path.abspath(os.path.join(os.getcwd(), '..'))
print(root_dir)

original_dataset_path = os.path.join(root_dir, 'data', 'mle_screening_dataset.csv')
print(original_dataset_path)

d:\Study\AI\Project\Project (tatra-labs)\Medical-Assistant
d:\Study\AI\Project\Project (tatra-labs)\Medical-Assistant\data\mle_screening_dataset.csv


### Dataset Preview

In [2]:
import pandas as pd 
import numpy as np 

df = pd.read_csv(original_dataset_path)

df = df.dropna()
df = df.drop_duplicates() 

print(df.shape)
print(df.info())
df.head()

(16353, 2)
<class 'pandas.core.frame.DataFrame'>
Index: 16353 entries, 0 to 16405
Data columns (total 2 columns):
 #   Column    Non-Null Count  Dtype 
---  ------    --------------  ----- 
 0   question  16353 non-null  object
 1   answer    16353 non-null  object
dtypes: object(2)
memory usage: 383.3+ KB
None


Unnamed: 0,question,answer
0,What is (are) Glaucoma ?,Glaucoma is a group of diseases that can damag...
1,What is (are) Glaucoma ?,The optic nerve is a bundle of more than 1 mil...
2,What is (are) Glaucoma ?,Open-angle glaucoma is the most common form of...
3,Who is at risk for Glaucoma? ?,Anyone can develop glaucoma. Some people are a...
4,How to prevent Glaucoma ?,"At this time, we do not know how to prevent gl..."


#### Dataset Inspect Result

This medical Question-Answer dataset contains 16,406 pairs of (question, answer). 

1. **Questions focus on definitions (e.g. What is (are) ..., ), symptoms (e.g. What are the symptoms of ...), prevention (How to prevent ...) and treatments (What are the treatments for ...) of diseases such as Glaucoma, High Blood Pressure, Tuberculosis, Cyclic Vomiting Syndrome, ... .**

2. **There are several pairs containing the same question.** 
    * What is (are) Glaucoma ?
    * What is (are) Cyclic Vomiting Syndrome ?
    * ...

3. **How well each answer addresses the question?** 

    The question is What is (are) Glaucoma ?
    
    There are three answer provided in the dataset. (I will rate each answer in the scale of 1 to 10.)
    * **First answer:**
        + Glaucoma is a group of diseases that can damage the eye's optic nerve and result in vision loss and blindness. The most common form of the disease is open-angle glaucoma. With early treatment, you can often protect your eyes against serious vision loss. (Watch the video to learn more about glaucoma. To enlarge the video, click the brackets in the lower right-hand corner. To reduce the video, press the Escape (Esc) button on your keyboard.)  See this graphic for a quick overview of glaucoma, including how many people it affects, whos at risk, what to do if you have it, and how to learn more.  See a glossary of glaucoma terms. 
        + Looks like answer was scraped from the website. It contains noise such as *(Watch the video to learn more about glaucoma. To enlarge the video, click the brackets in the lower right-hand corner. To reduce the video, press the Escape (Esc) button on your keyboard.)*
        + Even though it contains noise, it address the glaucoma well. 
        + 8

    * **Second Answer**
        + The optic nerve is a bundle of more than 1 million nerve fibers. It connects the retina to the brain.
        + This is obviously not the definition of glaucoma but optic nerve. 
        + 1

    * **Third Answer**
        + Open-angle glaucoma is the most common form of glaucoma. In the normal eye, the clear fluid leaves the anterior chamber at the open angle where the cornea and iris meet. When the fluid reaches the angle, it flows through a spongy meshwork, like a drain, and leaves the eye. Sometimes, when the fluid reaches the angle, it passes too slowly through the meshwork drain, causing the pressure inside the eye to build. If the pressure damages the optic nerve, open-angle glaucoma -- and vision loss -- may result.
        + It addres open-angle glaucoma well, not glaucoma in general. 
        + 5



### Analysis 

1. Dataset Source  

    (Assumption 1) I think current QA dataset contains question-aanswer pairs from MedQuAD (a public dataset that includes 47,457 medical question-answer pairs created from 12 NIH websites (e.g. cancer.gov, niddk.nih.gov, GARD, MedlinePlus Health Topics). The collection covers 37 question types (e.g. Treatment, Diagnosis, Side Effects) associated with diseases, drugs and other medical entities such as tests..). 

    To be accurate, we need to determine each QA pair from the current dataset and combine two datasets into one but based on the assumption I will use MedQuAD only here.
    `data/MedQuad_QA.csv` which is 24 MB is not good. So I will use clean MedQuAD dataset (`data/clean_data.csv`). This is another assumption here. 

2. Dataset Enrichment 

    I think it is useful to visit this [url](https://github.com/abachaa/Existing-Medical-QA-Datasets) 
    Let's keep in mind that this [repo](https://github.com/mingzhu0527/MASHQA?tab=readme-ov-file) could enrich our dataset more if we want to improve solution later.

In [3]:
medquad_dataset_path = os.path.join(root_dir, 'data', 'clean_data.csv')
medquad_df = pd.read_csv(medquad_dataset_path)

medquad_df = medquad_df.dropna() 
medquad_df = medquad_df.drop_duplicates()

print(medquad_df.shape)

(14442, 4)


### Model Selection 

I chose SciFive for this medical question-answering task.

- Why? 
    * SciFive is based on T5 architecture and pre-trained on extensive medical literature(PubMed, PMC), ensuring it understands medical terminology and concepts.
    * This is a strong choice for question-answering. 
    * Model parameters are 770M which could be a good fit for my GPU RAM size(48GB).

In [4]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

base_model_path = os.path.join(root_dir, 'models', 'SciFive-large-Pubmed_PMC')

# Load SciFive
tokenizer = AutoTokenizer.from_pretrained(base_model_path)
model = AutoModelForSeq2SeqLM.from_pretrained(base_model_path)


  return self.fget.__get__(instance, owner)()


Based on tokenizer and max token length, I sanitized dataset. 

Total tokens: 4,737,589

In [5]:
sanitized_dataset_path = os.path.join(root_dir, 'data', 'sanitized_data.csv')
sanitized_df = pd.read_csv(sanitized_dataset_path)
print(sanitized_df.shape)

(13999, 4)


### Load Dataset

In [6]:
import torch 
from torch.utils.data import Dataset, DataLoader 
import torch.nn as nn 

class MedQuQADataset(Dataset):
    def __init__(self, df, tokenizer, max_length=1024):
        self.df = df 
        self.tokenizer = tokenizer 
        self.max_length = max_length 
        
    def __getitem__(self, idx):
        question, answer = self.df.iloc[idx]['prompt'], self.df.iloc[idx]['response']
        inputs = self.tokenizer(
            f"question: {question}",
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors="pt"
        )
        labels = self.tokenizer(
            answer,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors="pt"
        )
        
        return {
            'input_ids': inputs['input_ids'].squeeze(),
            'attention_mask': inputs['attention_mask'].squeeze(),
            'labels': labels['input_ids'].squeeze()
        }
    
    def __len__(self):
        return len(self.df)


### Model Fine-tuning 

In [8]:
from sklearn.model_selection import train_test_split
from transformers import Trainer, TrainingArguments
from datasets import load_metric 
from torch.utils.tensorboard import SummaryWriter 

results_path = os.path.join(root_dir, 'results')
logs_path = os.path.join(root_dir, 'logs')

train_val_df, test_df = train_test_split(sanitized_df, test_size=0.2, random_state=42)
train_df, val_df = train_test_split(train_val_df, test_size=0.2, random_state=42)

train_dataset = MedQuQADataset(train_df, tokenizer)
val_dataset = MedQuQADataset(val_df, tokenizer)
test_dataset = MedQuQADataset(test_df, tokenizer)

# Metrics 
f1_metric = load_metric("f1", trust_remote_code=True)
bleu_metric = load_metric("bleu", trust_remote_code=True)
rouge_metric = load_metric("rouge", trust_remote_code=True)

def compute_metrics(eval_pred):
    predictions, labels = eval_pred 
    
    # Decode predictions and labels
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True) 
    
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id) 
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) 
    
    # Normalize predictions and labels for EM (strip whitespace)
    decoded_preds = [pred.strip() for pred in decoded_preds]
    decoded_labels = [label.strip() for label in decoded_labels]
    
    # Exact Match (EM)
    em_scores = [1 if pred == ref else 0 for pred, ref in zip(decoded_preds, decoded_labels)]
    em_result = {"exact_match": np.mean(em_scores)}
    
    # BLEU 
    bleu_preds = [pred.split() for pred in decoded_preds]
    bleu_refs = [[ref.split()] for ref in decoded_labels]
    bleu_result = bleu_metric.compute(predictions=bleu_preds, references=bleu_refs)
    
    # F1 
    f1_preds = [set(pred.split()) for pred in decoded_preds]
    f1_refs = [set(ref.split()) for ref in decoded_labels]
    f1_scores = []
    for pred, ref in zip(f1_preds, f1_refs):
        true_positives = len(pred & ref)
        precision = true_positives / len(pred) if pred else 0
        recall = true_positives / len(ref) if ref else 0
        f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) else 0
        f1_scores.append(f1)
    f1_result = {"f1": np.mean(f1_scores)}
    
    # ROUGE 
    rouge_result = rouge_metric.compute(predictions=decoded_preds, references=decoded_labels)
    
    return {
        "exact_match": em_result["exact_match"],
        "bleu": bleu_result["bleu"],
        "f1": f1_result["f1"],
        "rouge1": rouge_result["rouge1"].mid.fmeasure,
        "rouge2": rouge_result["rouge2"].mid.fmeasure,
        "rougeL": rouge_result["rougeL"].mid.fmeasure
    }
    
training_args = TrainingArguments(
    output_dir=results_path,
    num_train_epochs=3,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    warmup_steps=10,
    weight_decay=0.01,
    logging_dir=logs_path,
    logging_steps=10,  # Log every 10 steps
    evaluation_strategy="steps",
    eval_steps=10,  # Evaluate every 10 steps
    save_strategy="steps",
    save_steps=10,  # Checkpoint every 10 steps
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics,
)

trainer.train()

test_results = trainer.evaluate(test_dataset) 

print(test_results)

KeyboardInterrupt: 