<a href="https://colab.research.google.com/github/tuanng007/medical_qa_gpt2/blob/main/Medical_QA_GPT2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# %%
# @title 1. Install dependencies and setting Kaggle API
!pip install -q transformers datasets torch accelerate kaggle gradio sentencepiece pandas fuzzywuzzy[speedup] evaluate rouge-score

import os
import json
from zipfile import ZipFile
import pandas as pd
from google.colab import files
import random
import torch
import evaluate
import pandas as pd
import re
from datasets import load_dataset, Dataset, DatasetDict

# Upload kaggle.json
print("Please upload your's file 'kaggle.json' .")
if not os.path.exists('/root/.kaggle/kaggle.json'):
    uploaded = files.upload()
    if 'kaggle.json' in uploaded:
        print("Uploaded 'kaggle.json'. Configuring ...")
        !mkdir -p ~/.kaggle
        !mv kaggle.json ~/.kaggle/
        !chmod 600 ~/.kaggle/kaggle.json
        os.environ['KAGGLE_CONFIG_DIR'] = "/root/.kaggle"
        print("Kaggle API key configured.")
    else:
        print("Error: Not found file 'kaggle.json' has uploaded.")
else:
    print("Kaggle.json 'existed. skip the upload step.")
    os.environ['KAGGLE_CONFIG_DIR'] = "/root/.kaggle"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# %%

In [None]:
# @title 2. Import dataset MedQuAD from Kaggle

# Define the Kaggle dataset path
dataset_id = "gpreda/medquad"

# Dataset Directory
output_dir = "medquad_data"
os.makedirs(output_dir, exist_ok=True)

print(f"Downloading dataset: {dataset_id}...")
# import dataset from Kaggle to the created folder
!kaggle datasets download -d {dataset_id} -p {output_dir} --unzip

print(f"Dataset {dataset_id} downloaded and unzip to folder: {output_dir}")

# List the files in the folder to confirm
print("\nFiles in folder dataset:")
!ls {output_dir}

csv_file_path = os.path.join(output_dir, 'medquad.csv')

if not os.path.exists(csv_file_path):
    print(f"Error: Not found {csv_file_path}. Please check the file name in the folder {output_dir}.")
    !ls {output_dir}
else:
    print(f"Found {csv_file_path}")

In [None]:
# @title 3. Import necessary dependencies and settings GPU
from transformers import (
    GPT2TokenizerFast,
    GPT2ForQuestionAnswering,
    Trainer,
    TrainingArguments,
    default_data_collator
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
# @title 4 Preprocessing MedQuAD

import html
from bs4 import BeautifulSoup

processed_data_generative = []
question_col = 'question'
answer_col = 'answer'

# Function handling spaces and new lines
def normalize_whitespace(text):
    if not isinstance(text, str):
        text = str(text)
    text = text.replace('\\n', ' ')
    text = re.sub(r'\s+', ' ', text)
    return text.strip()

# Function preprocessing text
def preprocess_text(text):
    if not isinstance(text, str):
        text = str(text)

    soup = BeautifulSoup(text, "html.parser")
    text_no_html = soup.get_text()
    text_unescaped = html.unescape(text_no_html)
    text_normalized_space = normalize_whitespace(text_unescaped)
    text_lowercased = text_normalized_space.lower()

    return text_lowercased


if 'csv_file_path' in globals() and csv_file_path and os.path.exists(csv_file_path):
    try:
        print(f"Reading dataset: {csv_file_path}")
        df_original = pd.read_csv(csv_file_path)

        print("\n--- Dataset (before) ---")
        print(f"Rows and columns: {df_original.shape}")
        # df_original.info()
        print("\nFirst five rows of original data:")
        print(df_original.head())
        print(f"\nInitial number of null questions (NaN): {df_original[question_col].isnull().sum()}")
        print(f"Initial number of null answers (NaN): {df_original[answer_col].isnull().sum()}")


        # --- START OF PREPROCESSING STEPS ---
        df_processed = df_original.copy()
        initial_rows_before_any_processing = len(df_processed)

        # 1. Remove rows where question or answer are completely NaN (if any)
        df_processed.dropna(subset=[question_col, answer_col], inplace=True)
        rows_after_dropna = len(df_processed)
        print(f"\nRemoved {initial_rows_before_any_processing - rows_after_dropna} rows with NaN values in question or answer.")

        # 2. Apply text preprocessing (including HTML stripping, unescape, normalize whitespace, lowercase)
        print("\nNormalizing text for 'question' and 'answer' columns...")
        df_processed.loc[:, question_col] = df_processed[question_col].apply(preprocess_text)
        df_processed.loc[:, answer_col] = df_processed[answer_col].apply(preprocess_text)
        print("Text normalization complete.")

        # 3. Remove rows where question or answer become empty AFTER normalization
        # (e.g., if initially only contained HTML or whitespace)
        df_processed = df_processed[df_processed[question_col].str.strip().astype(bool)]
        df_processed = df_processed[df_processed[answer_col].str.strip().astype(bool)]
        rows_after_empty_filter = len(df_processed)
        print(f"Removed {rows_after_dropna - rows_after_empty_filter} rows with empty question/answer after normalization.")


        # 4. Remove duplicates BASED ON NORMALIZED COLUMNS
        rows_before_final_dedup = len(df_processed)
        df_processed.drop_duplicates(subset=[question_col, answer_col], keep='first', inplace=True)
        rows_after_final_dedup = len(df_processed)
        print(f"Removed {rows_before_final_dedup - rows_after_final_dedup} duplicate rows after text normalization.")

        print("\n--- DESCRIPTION OF DATA AFTER COMPLETE PREPROCESSING ---")
        print(f"Final number of rows and columns after all processing steps: {df_processed.shape}")
        print("\nFirst five rows of completely processed data:")
        print(df_processed.head())



        for index, row in df_processed.iterrows():

            processed_data_generative.append({
                "id": str(df_original.index[index]) if index in df_original.index else str(index), # Try to keep original ID if possible
                "question": row[question_col],
                "answer": row[answer_col]
            })

        print(f"\nNumber of valid Q&A pairs to include in training: {len(processed_data_generative)}")

        if not processed_data_generative:
            raise ValueError("No data was processed from CSV for training.")

    except Exception as e:
        print(f"Error reading or processing CSV file: {e}")
        import traceback
        traceback.print_exc()
        processed_data_generative = [
            {"id": "syn_0", "question": "what are the symptoms of flu?", "answer": "symptoms of flu include fever, cough, sore throat, runny or stuffy nose, body aches, headache, chills, and fatigue."},
            {"id": "syn_1", "question": "how to treat a common cold?", "answer": "to treat a common cold, get plenty of rest, drink fluids, and use over-the-counter medications for symptoms."},
        ]
        print("Using synthetic data due to processing error.")
else:
    if 'csv_file_path' not in globals() or not csv_file_path:
        print(f"CSV file path ('csv_file_path') was not provided (possibly due to an error in Cell 2). Using synthetic data.")
    else: # csv_file_path has a value but the file doesn't exist
        print(f"File {csv_file_path} does not exist. Using synthetic data.")
    # Fallback if no CSV file
    processed_data_generative = [
        {"id": "syn_0", "question": "what are the symptoms of flu?", "answer": "symptoms of flu include fever, cough, sore throat, runny or stuffy nose, body aches, headache, chills, and fatigue."},
        {"id": "syn_1", "question": "how to treat a common cold?", "answer": "to treat a common cold, get plenty of rest, drink fluids, and use over-the-counter medications for symptoms."},
    ]

# (Keep this part from your original code)
MAX_SAMPLES_TO_USE = 17000
if len(processed_data_generative) > MAX_SAMPLES_TO_USE:
    print(f"\nLimiting training data to {MAX_SAMPLES_TO_USE} samples for demo.")
    processed_data_generative = random.sample(processed_data_generative, MAX_SAMPLES_TO_USE)
elif not processed_data_generative: # Check again if processed_data_generative is empty (e.g., due to fallback error)
     print("WARNING: processed_data_generative is empty after sampling or due to fallback error. Please check.")
     # Recreate synthetic data if completely empty
     processed_data_generative = [
        {"id": "syn_fallback_0", "question": "example question for empty case?", "answer": "example answer for empty case."},
    ]

if not processed_data_generative: # Final check
    raise ValueError("No data available for processing (including synthetic data).")

print(f"\nFinal number of samples used: {len(processed_data_generative)}")

dataset_gen = Dataset.from_pandas(pd.DataFrame(processed_data_generative))

# Split train/validation
train_test_split_gen = dataset_gen.train_test_split(test_size=0.1, seed=42) # seed=42 for consistent splitting
dataset_dict_gen = DatasetDict({
    'train': train_test_split_gen['train'],
    'validation': train_test_split_gen['test']
})

print("\nData structure for Generative QA (after normalization and sampling):")
print(dataset_dict_gen)
print("\nExample data sample (from normalized train set):")
if len(dataset_dict_gen["train"]) > 0:
    sample_example_gen = dataset_dict_gen["train"][0]
    print(f"ID: {sample_example_gen['id']}")
    print(f"Question: {sample_example_gen['question']}")
    print(f"Answer: {sample_example_gen['answer']}")
else:
    print("Train set is empty!")

In [None]:
# @title 5. Import Tokenizer and GPT-2
from transformers import GPT2TokenizerFast, GPT2LMHeadModel, DataCollatorForLanguageModeling

model_name = "gpt2"
tokenizer = GPT2TokenizerFast.from_pretrained(model_name)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

model = GPT2LMHeadModel.from_pretrained(model_name)

# model.resize_token_embeddings(len(tokenizer))

model.to(device)
print("Tokenizer và GPT2LMHeadModel đã được tải.")
print(f"EOS token: '{tokenizer.eos_token}', ID: {tokenizer.eos_token_id}")
print(f"PAD token: '{tokenizer.pad_token}', ID: {tokenizer.pad_token_id}")

separator_token = tokenizer.eos_token

In [None]:
# @title 6. Tokenization for trainer (Generative QA - Padding to max_length)

effective_max_length = tokenizer.model_max_length

print(f"Effective_max_length: {effective_max_length}")

def preprocess_function_generative(examples):
    inputs_text = []
    for q, a in zip(examples["question"], examples["answer"]):
        text = f"Question: {q} Answer: {a}{tokenizer.eos_token}"
        inputs_text.append(text)

    model_inputs = tokenizer(
        inputs_text,
        truncation=True,
        padding="max_length",
        max_length=effective_max_length,
        return_attention_mask=True
    )
    model_inputs["labels"] = model_inputs["input_ids"].copy()
    return model_inputs


column_names_gen_train = dataset_dict_gen["train"].column_names
tokenized_datasets_gen = dataset_dict_gen.map(
    preprocess_function_generative,
    batched=True,
    remove_columns=column_names_gen_train
)

print("\nPadded to max_length:")
sample_tokenized_gen = tokenized_datasets_gen["train"][0]
print(f"Keys: {sample_tokenized_gen.keys()}")
print(f"Length of Input IDs: {len(sample_tokenized_gen['input_ids'])}")
print(f"Length of Labels: {len(sample_tokenized_gen['labels'])}")
print(f"Length of Attention Mask: {len(sample_tokenized_gen['attention_mask'])}")

if len(tokenized_datasets_gen["train"]) > 1:
    sample2_tokenized_gen = tokenized_datasets_gen["train"][1]
    print(f"Length of Input IDs (sample 2): {len(sample2_tokenized_gen['input_ids'])}")

In [None]:
# @title 7. Training (Fine-tuning) GPT 2 (Generative QA)

import os
os.environ["WANDB_DISABLED"] = "true"

from transformers import Trainer, TrainingArguments


data_collator_gen = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False
)

training_args_gen = TrainingArguments(
    output_dir="./results_medquad_gpt2_generative",
    eval_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=6,
    per_device_eval_batch_size=6,
    gradient_accumulation_steps=4,
    num_train_epochs=1,
    weight_decay=0.01,
    save_strategy="epoch",
    load_best_model_at_end=True,
    fp16=torch.cuda.is_available(),
    logging_steps=50,
    report_to="none",
)

trainer_gen = Trainer(
    model=model,
    args=training_args_gen,
    train_dataset=tokenized_datasets_gen["train"],
    eval_dataset=tokenized_datasets_gen["validation"],
    data_collator=data_collator_gen,
)

print("\nStarted training Generative QA...")
try:
    trainer_gen.train()
    print("Training completed.")
except Exception as e:
    print(f"Error: {e}")
    print("Try reduce batch size if OOM.")

# Save model and token
output_model_dir_gen = "./fine_tuned_medquad_gpt2_generative_final"
trainer_gen.save_model(output_model_dir_gen)
tokenizer.save_pretrained(output_model_dir_gen)
print(f"Model and tokenizer (Generative) have been saved: {output_model_dir_gen}")

In [None]:
# @title 8. Model Evaluation (ROUGE)
import evaluate
import numpy as np
from tqdm import tqdm

rouge_metric = evaluate.load("rouge")

eval_model = trainer_gen.model
eval_tokenizer = tokenizer

validation_dataset = dataset_dict_gen["validation"]
# small_validation_dataset = validation_dataset.select(range(100))

predictions = []
references = []

print(f"Started generating answers base on validation ({len(validation_dataset)} ) for evaluate...")
eval_model.eval()
eval_model.to(device)

for example in tqdm(validation_dataset):
    question = example["question"]
    reference_answer = example["answer"]

    prompt = f"Question: {question} Answer:"
    inputs = eval_tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True).to(device)

    try:
        with torch.no_grad():

            output_sequences = eval_model.generate(
                input_ids=inputs.input_ids,
                attention_mask=inputs.attention_mask,
                max_length=len(inputs.input_ids[0]) + 150,
                num_beams=3,
                no_repeat_ngram_size=2,
                early_stopping=True,
                eos_token_id=eval_tokenizer.eos_token_id,
                pad_token_id=eval_tokenizer.pad_token_id
            )


        generated_text_full = eval_tokenizer.decode(output_sequences[0], skip_special_tokens=False)
        answer_prefix = "Answer:"
        start_of_answer_marker = generated_text_full.rfind(answer_prefix)

        if start_of_answer_marker != -1:
            start_of_answer_idx = start_of_answer_marker + len(answer_prefix)
            predicted_answer = generated_text_full[start_of_answer_idx:].strip()
            if predicted_answer.endswith(eval_tokenizer.eos_token):
                predicted_answer = predicted_answer[:-len(eval_tokenizer.eos_token)].strip()
        else:
            if generated_text_full.startswith(prompt):
                 predicted_answer = generated_text_full[len(prompt):].strip()
                 if predicted_answer.endswith(eval_tokenizer.eos_token):
                    predicted_answer = predicted_answer[:-len(eval_tokenizer.eos_token)].strip()
            else:
                predicted_answer = eval_tokenizer.decode(output_sequences[0], skip_special_tokens=True)
                if predicted_answer.startswith(question):
                    predicted_answer = predicted_answer[len(question):].strip().lstrip(':').strip()

        predictions.append(predicted_answer if predicted_answer else " ")
        references.append(reference_answer if reference_answer else " ")

    except Exception as e:
        print(f"Exception: '{question}'. ERROR: {e}")
        predictions.append(" ")
        references.append(reference_answer if reference_answer else " ")


if predictions and references:
    print("\nCalculating ROUGE scores...")
    rouge_results = rouge_metric.compute(predictions=predictions, references=references)
    print("\nROUGE reusults:")
    for key, value in rouge_results.items():
        print(f"{key}: {value*100:.2f}")
else:
    print("There is no prediction or reference to calculate ROUGE.")



In [None]:
# @title 9. Gradio interface (Generative QA use Sentence Transformers)

import gradio as gr
import torch
import pandas as pd
from sentence_transformers import SentenceTransformer, util

model_st_name = 'all-mpnet-base-v2'
print(f"Downloading  Sentence Transformer: {model_st_name}...")
try:
    st_model = SentenceTransformer(model_st_name)
    print(" Sentence Transformer download completed.")
except Exception as e:
    print(f"Exception: {e}. .")
    st_model = None


if 'processed_data_generative' not in globals() or not processed_data_generative:
    print("Warning: processed_data_generative null.")
    knowledge_base_df_gen = pd.DataFrame([
        {"id": "syn_kb_0", "question": "What are the symptoms of flu?", "answer": "Flu is a common respiratory illness caused by influenza viruses."},
        {"id": "syn_kb_1", "question": "How can one treat a common cold?", "answer": "For a common cold, it's advisable to get plenty of rest, drink fluids, and use over-the-counter medications for symptoms."},
        {"id": "syn_kb_2", "question": "Tell me about MRI scans.", "answer": "MRI, or Magnetic Resonance Imaging, is a medical imaging technique used to form pictures of the anatomy and physiological processes of the body."}
    ])
else:
    knowledge_base_df_gen = pd.DataFrame(processed_data_generative)

corpus_questions = []
corpus_answers = []
corpus_embeddings = None

if st_model and not knowledge_base_df_gen.empty:
    corpus_questions = knowledge_base_df_gen['question'].astype(str).tolist()
    corpus_answers = knowledge_base_df_gen['answer'].astype(str).tolist()
    if corpus_questions:
        print(f"Calculating embeddings for {len(corpus_questions)} questions in corpus...")
        try:
            if torch.cuda.is_available():
                st_model.to(device)
                print(f"Sentence Transformer model moved to {device}.")

            corpus_embeddings = st_model.encode(corpus_questions, convert_to_tensor=True, show_progress_bar=True)
            print("Done! Embeddings for corpus.")
        except Exception as e:
            print(f"Exception calculate embeddings cho corpus: {e}")
            corpus_embeddings = None # Đặt là None nếu có lỗi
    else:
        print("No questions in corpus for calculate embeddings.")
else:
    if not st_model:
        print("Can not calculate embeddings cause Sentence Transformer has not download yet.")
    if knowledge_base_df_gen.empty:
        print("Knowledge base null, Fail embeddings.")

if 'trainer_gen' not in globals() or not hasattr(trainer_gen, 'model'):
    print("CẢNH BÁO: trainer_gen or trainer_gen.model not found..")

    try:
        output_model_dir_gen = "./fine_tuned_medquad_gpt2_generative_final"
        gen_qa_model = GPT2LMHeadModel.from_pretrained(output_model_dir_gen)
        gen_qa_tokenizer = GPT2TokenizerFast.from_pretrained(output_model_dir_gen)
        gen_qa_model.to(device)
        print(f"Redownload model {output_model_dir_gen}")
    except Exception as e:
        print(f"Exception model generative QA: {e}.")
        gen_qa_model = None
        gen_qa_tokenizer = None
else:
    gen_qa_model = trainer_gen.model
    gen_qa_tokenizer = tokenizer

if gen_qa_model:
    gen_qa_model.eval()




def answer_question_gradio_generative(user_question):
    print(f"\nReceived question from Gradio (Generative): {user_question}")
    if not user_question.strip():
        return "Please enter a question."

    # Step 1: Find a similar Q&A in the "database" using Sentence Transformers
    similar_q, similar_a, similarity_score = find_most_similar_qa_st(user_question)
    similarity_info = ""
    if similar_q:
        similarity_info = (
            f"\n\n--- Reference Information from Data (Semantic Similarity: {similarity_score:.2f}) ---\n"
            f"Most similar question:\n{similar_q}\n"
            f"Corresponding answer:\n{similar_a}\n"
            f"-------------------------------------------------------------------"
        )
        print(f"Found similar Q&A (semantically) with similarity score {similarity_score:.2f}")

    # Step 2: Generate an answer using the generative QA model
    if not gen_qa_model or not gen_qa_tokenizer:
        return "Error: Generative QA model is not loaded."

    prompt = f"Question: {user_question} Answer:"
    inputs = gen_qa_tokenizer(prompt, return_tensors="pt").to(device if gen_qa_model.device.type == 'cuda' else 'cpu')  # Ensure input is on the same device as the model

    generated_answer_text = "Unable to generate an answer."
    try:
        with torch.no_grad():
            outputs = gen_qa_model.generate(
                inputs.input_ids,
                max_length=200,  # Slightly increased to allow longer responses
                num_beams=1,
                no_repeat_ngram_size=2,  # Helps avoid phrase repetition
                early_stopping=True,
                eos_token_id=gen_qa_tokenizer.eos_token_id,
                pad_token_id=gen_qa_tokenizer.pad_token_id,
                temperature=0.3,  # Controls randomness
                do_sample=True,  # MUST HAVE
                # top_k=50  # Recommended to add top_k or top_p to prevent strange token generation
                top_p=0.95
            )

        generated_text_full = gen_qa_tokenizer.decode(outputs[0], skip_special_tokens=False)
        answer_prefix = "Answer:"
        # Locate the last occurrence of "Answer:" to ensure proper extraction if prompt contains "Answer:"
        start_of_answer_marker = generated_text_full.rfind(answer_prefix)

        if start_of_answer_marker != -1:
            start_of_answer_idx = start_of_answer_marker + len(answer_prefix)
            generated_answer_text = generated_text_full[start_of_answer_idx:].strip()
            # Remove eos_token if present at the end
            if generated_answer_text.endswith(gen_qa_tokenizer.eos_token):
                generated_answer_text = generated_answer_text[:-len(gen_qa_tokenizer.eos_token)].strip()
        else:  # Fallback if "Answer:" is not found
            if generated_text_full.startswith(prompt):  # If output starts with prompt
                generated_answer_text = generated_text_full[len(prompt):].strip()
                if generated_answer_text.endswith(gen_qa_tokenizer.eos_token):
                    generated_answer_text = generated_answer_text[:-len(gen_qa_tokenizer.eos_token)].strip()
            else:  # Otherwise, extract the full output excluding special tokens
                generated_answer_text = gen_qa_tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
                # Sometimes the prompt might still be present, remove it if user_question is included
                if generated_answer_text.startswith(user_question):
                    generated_answer_text = generated_answer_text[len(user_question):].strip().lstrip(':').strip()

        print(f"Generated answer: {generated_answer_text}")

        final_response = f"Medical Chatbot (GPT-2 Generative) Answer:\n{generated_answer_text}"
        final_response += similarity_info

        return final_response

    except Exception as e:
        print(f"Error generating answer: {e}")
        import traceback
        traceback.print_exc()
        error_response = f"An error occurred while generating the answer. {e}"
        error_response += similarity_info
        return error_response

# Create Gradio Interface
iface_gen = gr.Interface(
    fn=answer_question_gradio_generative,
    inputs=gr.Textbox(lines=3, placeholder="Enter your medical question here... Example: 'What are common treatments for high blood pressure?'"),
    outputs=gr.Text(label="Answer and Reference Information"),
    title="Medical Q&A Chatbot (GPT-2 Generative + Semantic Search)",
    description="Enter a medical question. The chatbot will generate an answer and display the most semantically similar Q&A pair from the learned dataset.",
    examples=[
        ["What are the symptoms of flu?"],
        ["How to treat a common cold?"],
        ["What is an MRI?"],
        ["Tell me about diabetes type 2."],
        ["What are common treatments for high blood pressure?"]
    ],
    allow_flagging='never' # Disable flagging if not needed
)


print("\nĐang khởi chạy giao diện Gradio cho Generative QA (có tìm kiếm ngữ nghĩa)...")
iface_gen.launch(debug=True, share=True) # Chạy với debug=True nếu cần xem log chi tiết
# iface_gen.launch(share=True)