<a href="https://colab.research.google.com/github/shanvelc/genao/blob/main/Shan_M5_NB_MiniProject_1_Medical_Q%26A_GPT2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Advanced Certification Program in Computational Data Science
## A programme by IISc and TalentSprint
### Mini-Project: Medical Q&A using GPT2

## Learning Objectives

At the end of the experiment, you will be able to:

* perform data preprocessing, EDA and feature extraction on the Medical Q&A dataset
* load a pre-trained tokenizer
* finetune a GPT-2 language model for medical question-answering

## Dataset Description

The dataset used in this project is the *Medical Question Answering Dataset* ([MedQuAD](https://github.com/abachaa/MedQuAD/tree/master)). It includes medical question-answer pairs along with additional information, such as the question type, the question *focus*, its UMLS(Unified Medical Language System) details like - Concept Unique Identifier(*CUI*) and Semantic *Type* and *Group*.

To know more about this data's collection, and construction method, refer to this [paper](https://bmcbioinformatics.biomedcentral.com/articles/10.1186/s12859-019-3119-4).

The data is extracted and is in CSV format with below features:

- **Focus**: the question focus
- **CUI**: concept unique identifier
- **SemanticType**
- **SemanticGroup**
- **Question**
- **Answer**

## Part-A: Grading = 10 Points

## Information

Healthcare professionals often have to refer to medical literature and documents while seeking answers to medical queries. Medical databases or search engines are powerful resources of upto date medical knowledge. However, the existing documentation is large and makes it difficult for professionals to retrieve answers quickly in a clinical setting. The problem with search engines and informative retrieval engines is that these systems return a list of documents rather than answers. Instead, healthcare professionals can use question answering systems to retrieve short sentences or paragraphs in response to medical queries. Such systems have the biggest advantage of generating answers and providing hints in a few seconds.

### Problem Statement

Fine-tune gpt2 model on medical-question-answering-dataset for performing response generation for medical queries.

Please refer to ***M6 Assignment-1 Fine-tune GPT2*** to get familiar with how to load pre-trained gpt2 tokenizer and model.

### Import required packages

In [1]:
!pip -q install -U accelerate
!pip -q install -U transformers
!pip -q install torch

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.4/44.4 kB[0m [31m1.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.9/9.9 MB[0m [31m34.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.9/2.9 MB[0m [31m36.5 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
import os
import re
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel, TextDataset, DataCollatorForLanguageModeling
from transformers import Trainer, TrainingArguments

import warnings
warnings.filterwarnings('ignore')

In [3]:
#@title Download the dataset
!wget -q https://cdn.iisc.talentsprint.com/AIandMLOps/MiniProjects/Datasets/MedQuAD.csv
!ls | grep ".csv"

MedQuAD.csv


**Exercise 1: Read the MedQuAD.csv dataset**

**Hint:** pd.read_csv()

In [4]:
df = pd.read_csv("MedQuAD.csv")
df.shape

(16412, 6)

In [5]:
df.head()

Unnamed: 0,Focus,CUI,SemanticType,SemanticGroup,Question,Answer
0,Adult Acute Lymphoblastic Leukemia,C0751606,T191,Disorders,What is (are) Adult Acute Lymphoblastic Leukem...,Key Points - Adult acute lymphoblastic leukemi...
1,Adult Acute Lymphoblastic Leukemia,C0751606,T191,Disorders,What are the symptoms of Adult Acute Lymphobla...,"Signs and symptoms of adult ALL include fever,..."
2,Adult Acute Lymphoblastic Leukemia,C0751606,T191,Disorders,How to diagnose Adult Acute Lymphoblastic Leuk...,Tests that examine the blood and bone marrow a...
3,Adult Acute Lymphoblastic Leukemia,C0751606,T191,Disorders,What is the outlook for Adult Acute Lymphoblas...,Certain factors affect prognosis (chance of re...
4,Adult Acute Lymphoblastic Leukemia,C0751606,T191,Disorders,Who is at risk for Adult Acute Lymphoblastic L...,Previous chemotherapy and exposure to radiatio...


### Pre-processing and EDA

**Exercise 2: Perform below operations on the dataset [0.5 Mark]**

- Handle missing values
- Remove duplicates from data considering `Question` and `Answer` columns

- **Handle missing values**

In [17]:
# Calculate missing values per column
missing_values = df.isnull().sum()

# Calculate the total missing values across all columns
total_missing = missing_values.sum()

# Append the total to the Series
missing_values_with_total = pd.concat([missing_values, pd.Series(total_missing, index=['Total'])])

# Display the result
print(missing_values_with_total)

Focus              14
CUI               565
SemanticType      597
SemanticGroup     565
Question            0
Answer              5
Total            1746
dtype: int64


In [None]:
print(df.isnull().sum())

In [21]:
np_missing_values = df.isnull().sum().sum()

In [25]:
# Drop missing values
# Drop rows with any missing values and create a clean copy
df_cleaned = df.dropna().copy()

# Check if there are any remaining missing values
total_missing = df_cleaned.isnull().sum().sum()

# Display the number of missing values and the shape of the cleaned DataFrame
print(f"Total missing values after cleaning: {total_missing}")
print(f"Shape of the cleaned DataFrame: {df_cleaned.shape}")

Total missing values after cleaning: 0
Shape of the cleaned DataFrame: (15810, 6)


- **Remove duplicates from data considering `Question` and `Answer` columns**

In [31]:
# Check duplicates
# Find duplicates based on 'Question' and 'Answer', keeping all occurrences of duplicates
df_duplicates = df_cleaned[df_cleaned.duplicated(subset=['Question', 'Answer'], keep=False)].copy()

# Format 'Question' column for better readability by adding line breaks after question marks
df_duplicates['Question'] = df_duplicates['Question'].str.replace(r'\?+', '?\n', regex=True)

# Display the first few duplicate rows with relevant info
print("Duplicate Rows Based on 'Question' and 'Answer':")
print(df_duplicates[['Question', 'Answer']].head())

# Show the number of duplicate rows
print(f"\nTotal number of duplicate rows: {df_duplicates.shape[0]}")


Duplicate Rows Based on 'Question' and 'Answer':
                                         Question  \
12539  What are the treatments for Acromegaly ?\n   
12540  What are the treatments for Acromegaly ?\n   
12649        What is (are) Causes of Diabetes ?\n   
12650          What causes Causes of Diabetes ?\n   
12651          What causes Causes of Diabetes ?\n   

                                                  Answer  
12539  Currently, treatment options include surgical ...  
12540  Currently, treatment options include surgical ...  
12649  Diabetes is a complex group of diseases with a...  
12650  Type 1 diabetes is caused by a lack of insulin...  
12651  Type 2 diabetesthe most common form of diabete...  

Total number of duplicate rows: 80


In [32]:
# Drop duplicates
# Remove duplicates considering only the 'Question' and 'Answer' columns
df_cleaned_no_duplicates = df_cleaned.drop_duplicates(subset=['Question', 'Answer']).copy()

# Display the shape of the DataFrame after removing duplicates
print(f"Shape of the DataFrame after removing duplicates: {df_cleaned_no_duplicates.shape}")


Shape of the DataFrame after removing duplicates: (15762, 6)


In [33]:
# Check duplicates
# Check for any remaining duplicates in the 'Question' and 'Answer' columns
remaining_duplicates = df_cleaned_no_duplicates.duplicated(subset=['Question', 'Answer']).sum()

# Print the number of duplicates found
print(f"Number of remaining duplicates: {remaining_duplicates}")

Number of remaining duplicates: 0


**Exercise 3: Display the category name, and the number of records belonging to top 100 categories of `Focus` column [1 Mark]**

In [57]:
# Get the top 100 categories based on their frequency in the 'Focus' column
top_100_focus = df_cleaned_no_duplicates['Focus'].value_counts().head(100)

# Display the category name and the number of records for the top 100 categories
print("Top 100 Categories in 'Focus' Column and their Record Counts:")
print(top_100_focus)

Top 100 Categories in 'Focus' Column and their Record Counts:
Focus
Breast Cancer                          53
Prostate Cancer                        43
Stroke                                 35
Skin Cancer                            34
Alzheimer's Disease                    30
                                       ..
Alzheimer's Caregiving                 11
Polycythemia Vera                      11
Diabetes, Heart Disease, and Stroke    11
Pelizaeus-Merzbacher disease           10
Peutz-Jeghers syndrome                 10
Name: count, Length: 100, dtype: int64


In [62]:
# Top 100 Focus categories names
# Get the top 100 categories based on their frequency in the 'Focus' column
top_100_focus = df_cleaned_no_duplicates['Focus'].value_counts().head(100)
top_100_focus_cat_index = df_cleaned_no_duplicates['Focus'].value_counts().head(100).index
df_top_100_focus_cat = df_cleaned_no_duplicates[df_cleaned_no_duplicates['Focus'].isin(top_100_focus_cat_index)]
# Extract only the category names
top_100_focus_names = top_100_focus.index.tolist()

# Display the top 100 category names
print("Top 100 Focus Category Names:")
for index, name in enumerate(top_100_focus.index, start=1):
    print(f"{index}. {name}")


Top 100 Focus Category Names:
1. Breast Cancer
2. Prostate Cancer
3. Stroke
4. Skin Cancer
5. Alzheimer's Disease
6. Colorectal Cancer
7. Lung Cancer
8. Heart Failure
9. Heart Attack
10. High Blood Cholesterol
11. High Blood Pressure
12. Parkinson's Disease
13. Leukemia
14. Osteoporosis
15. Shingles
16. Hemochromatosis
17. Age-related Macular Degeneration
18. Diabetes
19. Gum (Periodontal) Disease
20. Diabetic Retinopathy
21. Psoriasis
22. Kidney Disease
23. Dry Mouth
24. COPD
25. Cataract
26. Balance Problems
27. Gout
28. Wilson Disease
29. Medicare and Continuing Care
30. Prescription and Illicit Drug Abuse
31. Glaucoma
32. Rheumatoid Arthritis
33. Neuroblastoma
34. Short Bowel Syndrome
35. Problems with Taste
36. Narcolepsy
37. Endometrial Cancer
38. Osteoarthritis
39. Kidney Dysplasia
40. Problems with Smell
41. Dry Eye
42. Pituitary Tumors
43. Anxiety Disorders
44. Urinary Tract Infections in Children
45. Peripheral Arterial Disease (P.A.D.)
46. Surviving Cancer
47. Amyloidosis an

### Create Training and Validation set

**Exercise 4: Create training and validation set [2 Marks]**

- Consider 4 samples per `Focus` category, for each top 100 categories, from the dataset (It will give 400 samples for training)

- Consider 1 sample per `Focus` category (different from training set), for each top 100 categories, from the dataset (It will give 100 samples for validation)

In [63]:
df_top_100_focus_cat.shape

(1532, 6)

In [67]:
print("Shape of top 100 focus categories:", df_top_100_focus_cat.shape)
# Create training samples: 4 samples per Focus category
def sample_training(x):
    # Sample 4 if available, otherwise take all available samples
    return x.sample(n=4, random_state=42) if len(x) >= 4 else x

# Apply the sampling function and reset index
df_samples_training = df_top_100_focus_cat.groupby('Focus').apply(sample_training)

# Get the indices of the sampled training data
sampled_index = df_samples_training.index.get_level_values(1)

# Display the first 20 training samples
#print("First 20 Training Samples:")
#print(df_samples_training.head(20))

print("Shape of Training Samples:", df_samples_training.shape)

# Create remaining DataFrame by dropping sampled training indices
df_remaining = df_top_100_focus_cat[~df_top_100_focus_cat.index.isin(sampled_index)]
print("Shape of DataFrame excluding Training Samples:", df_remaining.shape)

# Create validation samples: 1 sample per Focus category
def sample_validation(x):
    # Sample 1 sample per Focus category
    return x.sample(n=1, random_state=42)

# Apply the sampling function for validation samples
df_samples_validation = df_remaining.groupby('Focus').apply(sample_validation).reset_index(drop=True)
print("Shape of Validation Samples:", df_samples_validation.shape)

Shape of top 100 focus categories: (1532, 6)
Shape of Training Samples: (400, 6)
Shape of DataFrame excluding Training Samples: (1132, 6)
Shape of Validation Samples: (100, 6)


### Pre-process `Question` and `Answer` text

**Exercise 5: Perform below tasks: [1.5 Marks]**

- Combine `Question` and `Answer` for train and validation data as shown below:
    - sequence = *'\<question\>' + question-text + '\<answer\>' + answer-text*

- Join the combined text using '\n' into a single string for training and validation separately

- Save the training and validation strings as separate text files

- **Combine Question and Answer for train and val data**

In [68]:
question_answer_training = '<question>' + df_samples_training['Question']+ '<answer>'+df_samples_training['Answer']

question_answer_validation = '<question>' + df_samples_validation['Question']+ '<answer>'+df_samples_validation['Answer']
print(type(question_answer_validation))
print(question_answer_training.head(5))
print(question_answer_validation.head(5))

<class 'pandas.core.series.Series'>
Focus                           
21-hydroxylase deficiency  6079     <question>What are the treatments for 21-hydro...
                           6074     <question>What is (are) 21-hydroxylase deficie...
                           11451    <question>Is 21-hydroxylase deficiency inherit...
                           11452    <question>What are the treatments for 21-hydro...
Abdominal Adhesions        12810    <question>What to do for Abdominal Adhesions ?...
dtype: object
0    <question>What are the symptoms of 21-hydroxyl...
1    <question>What causes Abdominal Adhesions ?<an...
2    <question>What are the treatments for Adrenal ...
3    <question>What are the symptoms of Age-related...
4    <question>What is (are) Alagille Syndrome ?<an...
dtype: object


- **Join the combined text using '\n' into a single string for training and validation separately**

In [69]:
training_string = question_answer_training.str.cat(sep = '\n')
validation_string = question_answer_validation.str.cat(sep = '\n')
print("Training String\n\n", training_string)
print("\n\nValidation String\n\n",validation_string)

Training String

 <question>What are the treatments for 21-hydroxylase deficiency ?<answer>What is the goal for treating 21-hydroxylase-deficient congenital adrenal hyperplasia? The objectives for treating 21-hydroxylase deficiency differ with age. In childhood, the overall goal is to replace cortisol. Obtaining hormonal balance is important and patients growth velocity and bone age is monitored. Routine analysis of blood, urine, and/or saliva may also be necessary. Corrective surgery is frequently required for females born with abnormal genitalia. In late childhood and adolescence, maintaining hormonal balance is equally important. Overtreatment may result in obesity and delayed menarche/puberty, whereas under-replacement will result in sexual precocity. Also, it is important that teens and young adults with 21-hydroxylase deficiency be successfully transitioned to adult care facilities. Follow-up of adult patients should involve multidisciplinary clinics. Problems in adult women incl

- **Save the training and validation strings as text files**

In [70]:
with open('training_file.txt','w') as trained_file:
    trained_file.write(training_string)

with open('validation_file.txt','w') as validation_file:
    validation_file.write(validation_string)

**Exercise 6: Load pre-trained GPT2Tokenizer [0.5 Mark]**

- Use checkpoint = "gpt2"

In [None]:
# YOUR CODE HERE

**Exercise 7: Tokenize train and validation data and form TextDataset objects [0.5 Mark]**

- Use the loaded pre-trained tokenizer
- Use training and validation data saved in text files

In [None]:
# YOUR CODE HERE

**Exercise 8: Create a DataCollator object [0.5 Mark]**

In [None]:
# YOUR CODE HERE

**Exercise 9: Load pre-trained GPT2LMHeadModel [0.5 Mark]**

In [None]:
# YOUR CODE HERE

**Exercise 10: Fine-tune GPT2 Model [1 Mark]**

- Specify training arguments and create a TrainingArguments object (Use 30 epochs)

- Train a GPT-2 model using the provided training arguments

- Save the resulting trained model and tokenizer to a specified output directory

In [None]:
# Set up the training arguments

# YOUR CODE HERE

In [None]:
# Train the model
# YOUR CODE HERE

# Save the model
# YOUR CODE HERE

# Save the tokenizer
# YOUR CODE HERE

**Exercise 11: Test Model with user input prompts [1 Mark]**

- Create `generate_response()` function that takes a trained *model*, *tokenizer*, and a *prompt* string as input and generates a response using the GPT-2 model

- Test it with some user input prompts

In [None]:
# YOUR CODE HERE

In [None]:
# Load the fine-tuned model and tokenizer

# YOUR CODE HERE

In [None]:
# Response from model

# YOUR CODE HERE

In [None]:
# Testing with given prompt 1

# YOUR CODE HERE

In [None]:
# Testing with given prompt 2

# YOUR CODE HERE

**Exercise 12: Compare the performance of a *GPT2 model* with the *GPT2 model fine-tuned* on MedQuAD data [1 Mark]**

- Load another pre-trained GPT2LMHeadModel and do not fine-tune it

- To generate response using the untuned model, pass it as a parameter to `generate_response()` function

- Test both models (fine-tuned and untuned) with below user input prompts:

    - "What precautions to take for a healthy life?"
    - "What to do after being diagnosed with cancer?"
    - "What to do when feeling sick?"

In [None]:
# Load a pre-trained GPT2 model, do not finetune it with MedQuAD data

# YOUR CODE HERE

In [None]:
# Testing with finetuned model: prompt 1

# YOUR CODE HERE

In [None]:
# Testing with untuned model: prompt 1

# YOUR CODE HERE

In [None]:
# Testing with finetuned model: prompt 2

# YOUR CODE HERE

In [None]:
# Testing with untuned model: prompt 2

# YOUR CODE HERE

In [None]:
# Testing with finetuned model: prompt 3

# YOUR CODE HERE

In [None]:
# Testing with untuned model: prompt 3

# YOUR CODE HERE