<a href="https://colab.research.google.com/github/sdeshmukh99/Generative-AI-Showcase/blob/main/Showcase_01/Fine_Tuning_GPT_2_on_MedQuAD_for_Medical_Question_Answering.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### 1. Installing Dependencies

In [None]:
%%capture
!pip -q uninstall pyarrow -y
!pip -q install pyarrow==15.0.2
!pip -q install datasets
!pip -q install accelerate
!pip -q install transformers

### 2. Import Required Packages

In [None]:
# Import required packages (run this again after runtime restarts)
import os
import re
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import torch
from datasets import load_dataset
from transformers import GPT2Tokenizer, GPT2LMHeadModel, DataCollatorForLanguageModeling
from transformers import Trainer, TrainingArguments, EarlyStoppingCallback

import warnings
warnings.filterwarnings('ignore')

# Define model output path
model_output_path = "/content/gpt_model"

### 3. Download the Dataset

In [None]:
# Download the dataset
!wget -q https://cdn.iisc.talentsprint.com/AIandMLOps/MiniProjects/Datasets/MedQuAD.csv
!ls | grep ".csv"


MedQuAD.csv
MedQuAD.csv.1


### 4. Data Preprocessing:

##### 4.1: Read the MedQuAD.csv dataset


In [None]:
# Read the dataset
data = pd.read_csv('MedQuAD.csv')


##### 4.2: Data Preprocessing


In [None]:
# Check for missing values
print("Missing values in each column:")
print(data.isnull().sum())

# Drop rows with missing 'Question' or 'Answer'
data = data.dropna(subset=['Question', 'Answer'])

# Remove duplicates based on 'Question' and 'Answer' columns
data = data.drop_duplicates(subset=['Question', 'Answer'])


Missing values in each column:
Focus             14
CUI              565
SemanticType     597
SemanticGroup    565
Question           0
Answer             5
dtype: int64


##### 4.3: Display Focus Categories

In [None]:
# Total categories in Focus column
total_categories = data['Focus'].nunique()
print(f"Total categories in 'Focus' column: {total_categories}")

# Displaying the distinct categories of Focus column and the number of records belonging to each category (Top 100 only)
focus_counts = data['Focus'].value_counts()
print("Top 100 categories in 'Focus' column and their counts:")
print(focus_counts.head(100))

# Top 100 Focus categories names
top_100_focus = focus_counts.head(100).index.tolist()
print("Top 100 Focus categories names:")
print(top_100_focus)


Total categories in 'Focus' column: 5125
Top 100 categories in 'Focus' column and their counts:
Focus
Breast Cancer                 53
Prostate Cancer               43
Stroke                        35
Skin Cancer                   34
Alzheimer's Disease           30
                              ..
MECP2 duplication syndrome    11
Holt-Oram syndrome            11
Ehlers-Danlos syndrome        11
Hearing Loss                  10
Wilson disease                10
Name: count, Length: 100, dtype: int64
Top 100 Focus categories names:
['Breast Cancer', 'Prostate Cancer', 'Stroke', 'Skin Cancer', "Alzheimer's Disease", 'Lung Cancer', 'Colorectal Cancer', 'High Blood Cholesterol', 'Heart Attack', 'Heart Failure', 'High Blood Pressure', "Parkinson's Disease", 'Leukemia', 'Osteoporosis', 'Shingles', 'Age-related Macular Degeneration', 'Diabetes', 'Hemochromatosis', 'Diabetic Retinopathy', 'Psoriasis', 'Gum (Periodontal) Disease', 'Kidney Disease', 'COPD', 'Cataract', 'Balance Problems', 'Dry Mo

### Exercise 5: Create Training and Validation set

In [None]:
# Create training and validation sets
train_data_list = []
val_data_list = []

for focus in top_100_focus:
    focus_data = data[data['Focus'] == focus]
    # Shuffle the data
    focus_data = focus_data.sample(frac=1, random_state=42).reset_index(drop=True)
    if len(focus_data) >= 3:
        train_samples = focus_data.iloc[:2]  # Reduced from 4 to 2 samples to reduce training time
        val_samples = focus_data.iloc[2:3]
    elif len(focus_data) >= 2:
        train_samples = focus_data.iloc[:-1]
        val_samples = focus_data.iloc[-1:]
    else:
        # If only one sample, use it for training
        train_samples = focus_data
        val_samples = pd.DataFrame(columns=focus_data.columns)
    train_data_list.append(train_samples)
    val_data_list.append(val_samples)

train_data = pd.concat(train_data_list, ignore_index=True)
val_data = pd.concat(val_data_list, ignore_index=True)

print(f"Training data samples: {len(train_data)}")
print(f"Validation data samples: {len(val_data)}")


Training data samples: 200
Validation data samples: 100


### Exercise 6: Pre-process Question and Answer text


In [None]:
# Combine Questions and Answers for train and val data
train_data['Combined'] = '<question> ' + train_data['Question'] + ' <answer> ' + train_data['Answer'] + ' <end>'
val_data['Combined'] = '<question> ' + val_data['Question'] + ' <answer> ' + val_data['Answer'] + ' <end>'

# Create training and validation text
train_text = '\n'.join(train_data['Combined'].tolist())
val_text = '\n'.join(val_data['Combined'].tolist())

# Save the training and validation data as text files
with open('train_data.txt', 'w', encoding='utf-8') as f:
    f.write(train_text)

with open('val_data.txt', 'w', encoding='utf-8') as f:
    f.write(val_text)


### Exercise 7: Load pre-trained GPT2Tokenizer and GPT2LMHeadModel


In [None]:
# Set up the tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

# Add special tokens to the tokenizer
special_tokens_dict = {'additional_special_tokens': ['<question>', '<answer>', '<end>']}
num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
print(f"We have added {num_added_toks} special tokens")

# Set the padding token
tokenizer.pad_token = tokenizer.eos_token

# Load the pre-trained model
model = GPT2LMHeadModel.from_pretrained('gpt2')

# Resize the model embeddings to match the new tokenizer
model.resize_token_embeddings(len(tokenizer))

# Enable gradient checkpointing to reduce memory usage
model.gradient_checkpointing_enable()

We have added 3 special tokens


### Exercise 8: Tokenize Train and Validation Data


In [None]:
# Load datasets
data_files = {'train': 'train_data.txt', 'validation': 'val_data.txt'}
datasets = load_dataset('text', data_files=data_files)

# Tokenize datasets
def tokenize_function(examples):
    return tokenizer(examples['text'], truncation=True, max_length=256)  # Reduced max_length to reduce training time

tokenized_datasets = datasets.map(tokenize_function, batched=True, num_proc=1, remove_columns=['text'])


Generating train split: 0 examples [00:00, ? examples/s]

Generating validation split: 0 examples [00:00, ? examples/s]

Map:   0%|          | 0/200 [00:00<?, ? examples/s]

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

### Exercise 9: Create a DataCollator object


In [None]:
# Create a Data collator object
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, mlm=False, return_tensors='pt'
)


### Exercise 10: Fine-tune GPT2 Model


In [None]:
# Set up the training arguments (ensure these are correctly configured)
training_args = TrainingArguments(
    output_dir=model_output_path,
    overwrite_output_dir=True,
    num_train_epochs=20,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    evaluation_strategy='epoch',
    save_strategy='epoch',
    logging_strategy='steps',
    logging_steps=50,
    weight_decay=0.01,
    learning_rate=5e-5,
    warmup_steps=50,
    fp16=True,
    prediction_loss_only=True,
    load_best_model_at_end=True,
    metric_for_best_model='loss',
    greater_is_better=False,
    save_safetensors=False,  # Add this line to save in pytorch_model.bin format
)

# Train the model
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets['train'],
    eval_dataset=tokenized_datasets['validation'],
    data_collator=data_collator,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=2)],
)

trainer.train()

# Save the model
trainer.save_model(model_output_path)

# Save the tokenizer
tokenizer.save_pretrained(model_output_path)

# List the contents of the model directory to verify saved files
print("Contents of /content/gpt_model:")
!ls /content/gpt_model

# Expected Output
# added_tokens.json       config.json            special_tokens_map.json
# pytorch_model.bin       merges.txt             tokenizer_config.json
# training_args.bin       vocab.json



`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Epoch,Training Loss,Validation Loss
1,42.3439,3.979216
2,2.8878,2.556852
3,2.4403,2.466745
4,2.25,2.372114
5,2.096,2.358772
6,1.9813,2.357291
7,1.8623,2.367722
8,1.7672,2.385193


Contents of /content/gpt_model:
added_tokens.json  checkpoint-300  generation_config.json   tokenizer_config.json
checkpoint-100	   checkpoint-350  merges.txt		    training_args.bin
checkpoint-150	   checkpoint-400  pytorch_model.bin	    vocab.json
checkpoint-200	   checkpoint-50   runs
checkpoint-250	   config.json	   special_tokens_map.json


### Exercise 11: Test Model with user input prompts


In [None]:
# Load the fine-tuned model and tokenizer
fine_tuned_model = GPT2LMHeadModel.from_pretrained(model_output_path)
fine_tuned_tokenizer = GPT2Tokenizer.from_pretrained(model_output_path)

# Move the model to the appropriate device
fine_tuned_model.to('cuda' if torch.cuda.is_available() else 'cpu')
fine_tuned_model.eval()

# Update Tokenizer and Model

# Set pad_token if not already set
if fine_tuned_tokenizer.pad_token is None:
    fine_tuned_tokenizer.pad_token = fine_tuned_tokenizer.eos_token
    fine_tuned_tokenizer.pad_token_id = fine_tuned_tokenizer.eos_token_id

# Ensure the model uses the correct pad_token_id
fine_tuned_model.config.pad_token_id = fine_tuned_tokenizer.pad_token_id

# Function to generate response
def generate_response(model, tokenizer, prompt, max_length=200):
    # Encode the prompt with attention mask and padding
    inputs = tokenizer(
        prompt,
        return_tensors='pt',
        padding=True,
        truncation=True,
        max_length=512  # Adjust as needed
    )
    input_ids = inputs['input_ids'].to(model.device)
    attention_mask = inputs['attention_mask'].to(model.device)

    # Generate output sequences
    outputs = model.generate(
        input_ids=input_ids,
        attention_mask=attention_mask,
        max_length=max_length,
        num_beams=2,
        no_repeat_ngram_size=2,
        early_stopping=True,
        pad_token_id=tokenizer.eos_token_id  # Ensure pad_token_id is set
    )

    # Decode the outputs
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return response

# Testing with sample prompts

# Prompt 1
prompt = "<question> What are the symptoms of diabetes? <answer>"
response = generate_response(fine_tuned_model, fine_tuned_tokenizer, prompt)
print("Finetuned Model Response to Prompt 1:")
print(response)

# Prompt 2
prompt = "<question> How can I lower my blood pressure? <answer>"
response = generate_response(fine_tuned_model, fine_tuned_tokenizer, prompt)
print("\nFinetuned Model Response to Prompt 2:")
print(response)

# Prompt 3
prompt = "<question> What precautions to take for a healthy life? <answer>"
response = generate_response(fine_tuned_model, fine_tuned_tokenizer, prompt)
print("\nFinetuned Model Response to Prompt 2:")
print(response)

# Prompt 4
prompt = "<question> What to do when feeling sick? <answer>"
response = generate_response(fine_tuned_model, fine_tuned_tokenizer, prompt)
print("\nFinetuned Model Response to Prompt 2:")
print(response)


Finetuned Model Response to Prompt 1:
 What are the symptoms of diabetes?  Diabetes is a leading cause of death in the United States. People with diabetes have a variety of signs and symptoms. They may have trouble eating, feeling tired, or feeling hungry. Diabetes can also affect how your body absorbs nutrients from the food you eat. In some cases, your blood sugar levels can rise too fast. This can lead to problems such as high blood pressure, diabetes, and heart disease.  How might diabetes affect your health? Diabetes affects your heart, blood vessels, kidneys, liver, pancreas and other organs. Your body makes insulin, a hormone that helps the body digest fats and sugars. Insulin is found in many foods, including fruits, vegetables, whole grains, legumes, dairy products, meats, poultry, fish, nuts, pulses, fruits and vegetables. It is produced by the pancrotate, the part of the brain responsible for digesting fats. The body also produces insulin-like growth factors

Finetuned Model

### Exercise 12: Compare the performance of a GPT2 model with the GPT2 model fine-tuned on MedQuAD data


In [None]:
# Loading the Untuned Model and Tokenizer:
untuned_model = GPT2LMHeadModel.from_pretrained('gpt2')
untuned_model.to('cuda' if torch.cuda.is_available() else 'cpu')
untuned_model.eval()

untuned_tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

# Setting the pad_token and Updating pad_token_id:
if untuned_tokenizer.pad_token is None:
    untuned_tokenizer.pad_token = untuned_tokenizer.eos_token
    untuned_tokenizer.pad_token_id = untuned_tokenizer.eos_token_id

# Ensure the model uses the correct pad_token_id
untuned_model.config.pad_token_id = untuned_tokenizer.pad_token_id

# Test prompts
prompts = [
    "What are the symptoms of diabetes?",
    "How can I lower my blood pressure?",
    "What precautions to take for a healthy life?",
    "What to do when feeling sick?"
]

# Testing with untuned model
print("Untuned Model Responses:\n")
for i, prompt_text in enumerate(prompts, 1):
    prompt = prompt_text
    with torch.no_grad():  # Disable gradient calculation for inference
        response = generate_response(untuned_model, untuned_tokenizer, prompt)
    print(f"Prompt {i}: {prompt_text}")
    print(f"Response:\n{response}\n")

Untuned Model Responses:

Prompt 1: What are the symptoms of diabetes?
Response:
What are the symptoms of diabetes?

Diabetes is the most common cause of death in the United States. It is caused by a variety of causes, including: diabetes mellitus (DM), type 2 diabetes (T2D), high blood pressure (HBP), and high cholesterol (CHC).
 (Diagnosis) is an important part of the treatment of this disease.
. Diabetes is also the cause for many other health problems, such as heart disease, cancer, diabetes, and cancer. The most commonly diagnosed diabetes is Type 1 diabetes. This is a condition that affects the blood sugar levels of people with diabetes and can lead to heart attacks, strokes, heart failure, or death. People with Type 2 Diabetes are more likely to have a heart attack or stroke. If you have Type 3 Diabetes, your risk of having a stroke is higher than if you had a normal blood glucose level. Your risk for having Type 4 Diabetes can be even higher if your blood sugars are

Prompt 2: 