<a href="https://colab.research.google.com/github/surabhi13gupta/CDS/blob/main/Module%206/MP%201/Surabhi_M6_NB_MiniProject_1_Medical_Q%26A_GPT2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Advanced Certification Program in Computational Data Science
## A programme by IISc and TalentSprint
### Mini-Project: Medical Q&A using GPT2

## Learning Objectives

At the end of the experiment, you will be able to:

* perform data preprocessing, EDA and feature extraction on the Medical Q&A dataset
* load a pre-trained tokenizer
* finetune a GPT-2 language model for medical question-answering

## Dataset Description

The dataset used in this project is the *Medical Question Answering Dataset* ([MedQuAD](https://github.com/abachaa/MedQuAD/tree/master)). It includes medical question-answer pairs along with additional information, such as the question type, the question *focus*, its UMLS(Unified Medical Language System) details like - Concept Unique Identifier(*CUI*) and Semantic *Type* and *Group*.

To know more about this data's collection, and construction method, refer to this [paper](https://bmcbioinformatics.biomedcentral.com/articles/10.1186/s12859-019-3119-4).

The data is extracted and is in CSV format with below features:

- **Focus**: the question focus
- **CUI**: concept unique identifier
- **SemanticType**
- **SemanticGroup**
- **Question**
- **Answer**

## Part-A: Grading = 10 Points

## Information

Healthcare professionals often have to refer to medical literature and documents while seeking answers to medical queries. Medical databases or search engines are powerful resources of upto date medical knowledge. However, the existing documentation is large and makes it difficult for professionals to retrieve answers quickly in a clinical setting. The problem with search engines and informative retrieval engines is that these systems return a list of documents rather than answers. Instead, healthcare professionals can use question answering systems to retrieve short sentences or paragraphs in response to medical queries. Such systems have the biggest advantage of generating answers and providing hints in a few seconds.

### Problem Statement

Fine-tune gpt2 model on medical-question-answering-dataset for performing response generation for medical queries.

Please refer to ***M6 Assignment-1 Fine-tune GPT2*** to get familiar with how to load pre-trained gpt2 tokenizer and model.

### Import required packages

In [1]:
!pip -q install -U accelerate
!pip -q install -U transformers
!pip -q install torch

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m3.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m28.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m29.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m35.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m4.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m12.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m127.9/127.9 MB[0m [31m7.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [2]:
import os
import re
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel, TextDataset, DataCollatorForLanguageModeling
from transformers import Trainer, TrainingArguments

import warnings
warnings.filterwarnings('ignore')

In [3]:
#@title Download the dataset
!wget -q https://cdn.iisc.talentsprint.com/AIandMLOps/MiniProjects/Datasets/MedQuAD.csv
!ls | grep ".csv"

MedQuAD.csv


**Exercise 1: Read the MedQuAD.csv dataset**

**Hint:** pd.read_csv()

In [4]:
df = pd.read_csv("MedQuAD.csv")
df.shape

(16412, 6)

In [5]:
df.head()

Unnamed: 0,Focus,CUI,SemanticType,SemanticGroup,Question,Answer
0,Adult Acute Lymphoblastic Leukemia,C0751606,T191,Disorders,What is (are) Adult Acute Lymphoblastic Leukem...,Key Points - Adult acute lymphoblastic leukemi...
1,Adult Acute Lymphoblastic Leukemia,C0751606,T191,Disorders,What are the symptoms of Adult Acute Lymphobla...,"Signs and symptoms of adult ALL include fever,..."
2,Adult Acute Lymphoblastic Leukemia,C0751606,T191,Disorders,How to diagnose Adult Acute Lymphoblastic Leuk...,Tests that examine the blood and bone marrow a...
3,Adult Acute Lymphoblastic Leukemia,C0751606,T191,Disorders,What is the outlook for Adult Acute Lymphoblas...,Certain factors affect prognosis (chance of re...
4,Adult Acute Lymphoblastic Leukemia,C0751606,T191,Disorders,Who is at risk for Adult Acute Lymphoblastic L...,Previous chemotherapy and exposure to radiatio...


### Pre-processing and EDA

**Exercise 2: Perform below operations on the dataset [0.5 Mark]**

- Handle missing values
- Remove duplicates from data considering `Question` and `Answer` columns

- **Handle missing values**

In [6]:
# YOUR CODE HERE
print("Check if any missing values")
print("=======================================================================")
print(df.isna().sum())

Check if any missing values
Focus             14
CUI              565
SemanticType     597
SemanticGroup    565
Question           0
Answer             5
dtype: int64


In [7]:
print("List down rows where Focus is null")
print("======================================================================")
df[df.Focus.isna()]

List down rows where Focus is null


Unnamed: 0,Focus,CUI,SemanticType,SemanticGroup,Question,Answer
16234,,,,,how vaccines prevent disease,Why Are Childhood Vaccines So Important? It is...
16235,,,,,Who is at risk for ? ?,Measles: Make Sure Your Child Is Protected wit...
16236,,,,,How to prevent ?,Vaccines and Preventable Diseases On this Page...
16237,,,,,what diseases are vaccine preventable,List of Vaccine-Preventable Diseases The follo...
16386,,,,,What is (are) ?,On this Page General Information about VISA/VR...
16387,,,,,what is staphylococcus aureus?,On this Page General Information about VISA/VR...
16388,,,,,how can the spread of visa and vrsa be prevented?,On this Page General Information about VISA/VR...
16389,,,,,what is cdc doing to address visa and vrsa?,On this Page General Information about VISA/VR...
16390,,,,,What is (are) ?,On this Page General Information What is vanco...
16391,,,,,what is vancomycin-resistant enterococci?,On this Page General Information What is vanco...


In [8]:
print("List down rows where Answer is null")
print("======================================================================")
df[df.Answer.isna()]

List down rows where Answer is null


Unnamed: 0,Focus,CUI,SemanticType,SemanticGroup,Question,Answer
2263,"Emery-Dreifuss muscular dystrophy, dominant type",C0410189,T047,Disorders,What is (are) Emery-Dreifuss muscular dystroph...,
2264,"Emery-Dreifuss muscular dystrophy, X-linked",C0410189,T047,Disorders,What is (are) Emery-Dreifuss muscular dystroph...,
2400,Familial HDL deficiency,C2931838,T047,Disorders,What is (are) Familial HDL deficiency ?,
2876,HELLP syndrome,C0162739,T047,Disorders,What is (are) HELLP syndrome ?,
6021,X-linked lymphoproliferative syndrome,C0549463,T191,Disorders,What is (are) X-linked lymphoproliferative syn...,


In [9]:
orig = df.shape[0]
print("Number of rows before dropping null values present in Focus and Answer columns : ", orig)
print("======================================================================")
df = df.dropna(subset=['Focus', 'Answer'])
print("Number of rows after dropping null values present in Focus and Answer columns : ", df.shape[0])
print("======================================================================")
print("Rows removed: ", df.shape[0]-orig)

Number of rows before dropping null values present in Focus and Answer columns :  16412
Number of rows after dropping null values present in Focus and Answer columns :  16393
Rows removed:  -19


In [10]:
print("Check if any missing values")
print("=======================================================================")
print(df.isna().sum())

Check if any missing values
Focus              0
CUI              551
SemanticType     583
SemanticGroup    551
Question           0
Answer             0
dtype: int64


In [11]:
print("List down rows where SemanticGroup is null")
print("======================================================================")
df[df.SemanticGroup.isna()]["Focus"].drop_duplicates()

List down rows where SemanticGroup is null


Unnamed: 0,Focus
11553,A1C
11559,Acupuncture
11565,Adoption
11568,Advance Directives
11569,African American Health
...,...
16385,Typhoid Fever
16396,Parasites - Trichuriasis (also known as Whipwo...
16401,Yellow Fever Vaccination
16402,Yersinia


In [12]:
df[df.Focus=="Parasites - Zoonotic Hookworm"]

Unnamed: 0,Focus,CUI,SemanticType,SemanticGroup,Question,Answer
16407,Parasites - Zoonotic Hookworm,,,,What is (are) Parasites - Zoonotic Hookworm ?,"There are many different species of hookworms,..."
16408,Parasites - Zoonotic Hookworm,,,,Who is at risk for Parasites - Zoonotic Hookwo...,Dog and cat hookworms are found throughout the...
16409,Parasites - Zoonotic Hookworm,,,,How to diagnose Parasites - Zoonotic Hookworm ?,Cutaneous larva migrans (CLM) is a clinical di...
16410,Parasites - Zoonotic Hookworm,,,,What are the treatments for Parasites - Zoonot...,The zoonotic hookworm larvae that cause cutane...
16411,Parasites - Zoonotic Hookworm,,,,How to prevent Parasites - Zoonotic Hookworm ?,Wearing shoes and taking other protective meas...


In [13]:
df[df.Focus=="Yellow Fever Vaccination"]

Unnamed: 0,Focus,CUI,SemanticType,SemanticGroup,Question,Answer
16401,Yellow Fever Vaccination,,,,What is (are) Yellow Fever Vaccination ?,If you continue to live or travel in yellow fe...


For any focus - SemanticGroup and CUI are missing for all rows and nothing can be filled from previous. Drop 551 rows

In [14]:
orig = df.shape[0]
print("Number of rows before dropping null values present in SemanticGroup and CUI columns : ", orig)
print("======================================================================")
df = df.dropna(subset=['SemanticGroup', 'CUI'])
print("Number of rows after dropping null values present in SemanticGroup and CUI columns : ", df.shape[0])
print("======================================================================")
print("Rows removed: ", orig-df.shape[0])

Number of rows before dropping null values present in SemanticGroup and CUI columns :  16393
Number of rows after dropping null values present in SemanticGroup and CUI columns :  15842
Rows removed:  551


In [15]:
print("Check if any missing values")
print("=======================================================================")
print(df.isna().sum())

Check if any missing values
Focus             0
CUI               0
SemanticType     32
SemanticGroup     0
Question          0
Answer            0
dtype: int64


In [16]:
print("List down rows where SemanticType is null")
print("======================================================================")
df[df.SemanticType.isna()]

List down rows where SemanticType is null


Unnamed: 0,Focus,CUI,SemanticType,SemanticGroup,Question,Answer
1124,Autosomal recessive hyper IgE syndrome,C0022398,,Disorders,What is (are) Autosomal recessive hyper IgE sy...,Autosomal recessive hyper IgE syndrome (AR-HIE...
1125,Autosomal recessive hyper IgE syndrome,C0022398,,Disorders,What are the symptoms of Autosomal recessive h...,What are the signs and symptoms of Autosomal r...
1154,Baraitser-Winter syndrome,C0796084,,Disorders,What are the symptoms of Baraitser-Winter synd...,What are the signs and symptoms of Baraitser-W...
4660,Periventricular heterotopia,C3714789,,Disorders,What is (are) Periventricular heterotopia ?,Periventricular heterotopia is a condition in ...
4661,Periventricular heterotopia,C3714789,,Disorders,What are the symptoms of Periventricular heter...,What are the signs and symptoms of periventric...
4662,Periventricular heterotopia,C3714789,,Disorders,How to diagnose Periventricular heterotopia ?,What are the recommended evaluations for patie...
4663,Periventricular heterotopia,C3714789,,Disorders,What are the treatments for Periventricular he...,How might periventricular nodular heterotopia ...
6568,autosomal dominant hyper-IgE syndrome,C0022398,,Disorders,What is (are) autosomal dominant hyper-IgE syn...,Autosomal dominant hyper-IgE syndrome (AD-HIES...
6569,autosomal dominant hyper-IgE syndrome,C0022398,,Disorders,How many people are affected by autosomal domi...,"This condition is rare, affecting fewer than 1..."
6570,autosomal dominant hyper-IgE syndrome,C0022398,,Disorders,What are the genetic changes related to autoso...,Mutations in the STAT3 gene cause most cases o...


In [17]:
df[df.Focus=="Periventricular heterotopia"]

Unnamed: 0,Focus,CUI,SemanticType,SemanticGroup,Question,Answer
4660,Periventricular heterotopia,C3714789,,Disorders,What is (are) Periventricular heterotopia ?,Periventricular heterotopia is a condition in ...
4661,Periventricular heterotopia,C3714789,,Disorders,What are the symptoms of Periventricular heter...,What are the signs and symptoms of periventric...
4662,Periventricular heterotopia,C3714789,,Disorders,How to diagnose Periventricular heterotopia ?,What are the recommended evaluations for patie...
4663,Periventricular heterotopia,C3714789,,Disorders,What are the treatments for Periventricular he...,How might periventricular nodular heterotopia ...


In [18]:
df[df.Focus=="periventricular heterotopia"]

Unnamed: 0,Focus,CUI,SemanticType,SemanticGroup,Question,Answer
9973,periventricular heterotopia,C3714789,,Disorders,What is (are) periventricular heterotopia ?,Periventricular heterotopia is a condition in ...
9974,periventricular heterotopia,C3714789,,Disorders,How many people are affected by periventricula...,Periventricular heterotopia is a rare conditio...
9975,periventricular heterotopia,C3714789,,Disorders,What are the genetic changes related to perive...,Periventricular heterotopia is related to chro...
9976,periventricular heterotopia,C3714789,,Disorders,Is periventricular heterotopia inherited ?,Periventricular heterotopia can have different...
9977,periventricular heterotopia,C3714789,,Disorders,What are the treatments for periventricular he...,These resources address the diagnosis or manag...


Can drp all 32 rows because nothing can be filled from previous

In [19]:
orig = df.shape[0]
print("Number of rows before dropping null values present in SemanticType columns : ", orig)
print("======================================================================")
df = df.dropna(subset=['SemanticType'])
print("Number of rows after dropping null values present in SemanticType columns : ", df.shape[0])
print("======================================================================")
print("Rows removed: ", orig-df.shape[0])

Number of rows before dropping null values present in SemanticType columns :  15842
Number of rows after dropping null values present in SemanticType columns :  15810
Rows removed:  32


In [20]:
print("Check if any missing values")
print("=======================================================================")
print(df.isna().sum())

Check if any missing values
Focus            0
CUI              0
SemanticType     0
SemanticGroup    0
Question         0
Answer           0
dtype: int64


- **Remove duplicates from data considering `Question` and `Answer` columns**

In [21]:
# Check duplicates
# YOUR CODE HERE
print("Checking for duplicate rows:")
print("=======================================================================")
dup = df[df[["Question", "Answer"]].duplicated()]
print("Rows are exactly duplicate: ", dup.shape[0])
print("=======================================================================")
print("%age of rows having duplicate : ", (dup.shape[0]/df.shape[0]) *100 )

Checking for duplicate rows:
Rows are exactly duplicate:  48
%age of rows having duplicate :  0.3036053130929791


In [22]:
# Drop duplicates
# YOUR CODE HERE
print("Shape before dropping duplicate rows: ", df.shape, sep="\n")
print("=======================================================================")
print("Drop duplicate rows", sep="\n")
print("=======================================================================")
df = df.drop_duplicates()
df = df.reset_index()
df = df.drop(['index'], axis=1)
print("Shape after dropping duplicate rows: ", df.shape, sep="\n")
print("=======================================================================")

Shape before dropping duplicate rows: 
(15810, 6)
Drop duplicate rows
Shape after dropping duplicate rows: 
(15762, 6)


In [23]:
# Check duplicates
# YOUR CODE HERE
print("Checking for duplicate rows:")
print("=======================================================================")
dup = df[df[["Question", "Answer"]].duplicated()]
print("Rows are exactly duplicate: ", dup.shape[0])
print("=======================================================================")
print("%age of rows having duplicate : ", (dup.shape[0]/df.shape[0]) *100 )

Checking for duplicate rows:
Rows are exactly duplicate:  0
%age of rows having duplicate :  0.0


**Exercise 3: Display the category name, and the number of records belonging to top 100 categories of `Focus` column [1 Mark]**

In [24]:
# YOUR CODE HERE
print("Unique Categories i.e. Focus area")
print("====================================================================")
df["Focus"].value_counts().reset_index()

Unique Categories i.e. Focus area


Unnamed: 0,Focus,count
0,Breast Cancer,53
1,Prostate Cancer,43
2,Stroke,35
3,Skin Cancer,34
4,Alzheimer's Disease,30
...,...,...
4765,Growth hormone deficiency,1
4766,Ghosal hematodiaphyseal dysplasia syndrome,1
4767,Giant platelet syndrome,1
4768,"Gingival fibromatosis, 1",1


In [25]:
# Top 100 Focus categories names
# YOUR CODE HERE
print("Top 100 Focus categories")
print("====================================================================")
df["Focus"].value_counts().reset_index().head(100)

Top 100 Focus categories


Unnamed: 0,Focus,count
0,Breast Cancer,53
1,Prostate Cancer,43
2,Stroke,35
3,Skin Cancer,34
4,Alzheimer's Disease,30
...,...,...
95,Alzheimer's Caregiving,11
96,Polycythemia Vera,11
97,"Diabetes, Heart Disease, and Stroke",11
98,Pelizaeus-Merzbacher disease,10


### Create Training and Validation set

**Exercise 4: Create training and validation set [2 Marks]**

- Consider 4 samples per `Focus` category, for each top 100 categories, from the dataset (It will give 400 samples for training)

- Consider 1 sample per `Focus` category (different from training set), for each top 100 categories, from the dataset (It will give 100 samples for validation)

In [26]:
top_100_categories = df["Focus"].value_counts().reset_index().head(100)['Focus'].to_list()

In [27]:
# YOUR CODE HERE
df_for_100_categories = df[df['Focus'].isin(top_100_categories)]
df_for_100_categories.shape

(1532, 6)

Train expectation: 400,6


Validation set: 100,6

In [29]:
shuffled_dataset = df_for_100_categories.sample(frac=1, replace=False, random_state=42)
train_df = shuffled_dataset.groupby('Focus').head(4).reset_index(drop=True)
val_df = shuffled_dataset.groupby('Focus').tail(1).reset_index(drop=True)

In [30]:
print("Shape of training dataset: ", train_df.shape)
print("Shape of validation dataset: ", val_df.shape)

Shape of training dataset:  (400, 6)
Shape of validation dataset:  (100, 6)


In [31]:
train_df.head()

Unnamed: 0,Focus,CUI,SemanticType,SemanticGroup,Question,Answer
0,Colorectal Cancer,C1527249,T191,Disorders,Who is at risk for Colorectal Cancer? ?,Researchers are working hard to understand and...
1,Pituitary Tumors,C0032019,T191,Disorders,What are the stages of Pituitary Tumors ?,Key Points - Once a pituitary tumor has been d...
2,Liddle syndrome,C0221043,T047,Disorders,How many people are affected by Liddle syndrome ?,"Liddle syndrome is a rare condition, although ..."
3,Urinary Retention,C0080274,T033,Disorders,What are the treatments for Urinary Retention ?,A health care provider treats urinary retentio...
4,Medullary Sponge Kidney,C0022681,T019,Disorders,What is (are) Medullary Sponge Kidney ?,"Medullary sponge kidney, also known as Cacchi-..."


In [32]:
val_df.head()

Unnamed: 0,Focus,CUI,SemanticType,SemanticGroup,Question,Answer
0,Langerhans cell histiocytosis,C0019621,T191,Disorders,Is Langerhans cell histiocytosis inherited ?,Is Langerhans cell histiocytosis inherited? Al...
1,Camurati-Engelmann disease,C0011989,T019,Disorders,What are the genetic changes related to Camura...,Mutations in the TGFB1 gene cause Camurati-Eng...
2,National Hormone and Pituitary Program (NHPP):...,C0032002,T047,Disorders,What are the treatments for National Hormone a...,Some parents did not tell their children about...
3,Depression,C0349217,T048,Disorders,What is (are) Depression ?,Depression is more than just feeling blue or s...
4,Danon disease,C0878677,T047,Disorders,Is Danon disease inherited ?,How is Danon disease inherited? Dannon disease...


### Pre-process `Question` and `Answer` text

**Exercise 5: Perform below tasks: [1.5 Marks]**

- Combine `Question` and `Answer` for train and validation data as shown below:
    - sequence = *'\<question\>' + question-text + '\<answer\>' + answer-text*

- Join the combined text using '\n' into a single string for training and validation separately

- Save the training and validation strings as separate text files

- **Combine Question and Answer for train and val data**

In [35]:
# YOUR CODE HERE
train_df['Question-Answer'] = train_df[['Question', 'Answer']].apply(lambda x: '<question>{}<answer>{}'.format(x[0], x[1]), axis=1)
val_df['Question-Answer'] = val_df[['Question', 'Answer']].apply(lambda x: '<question>{}<answer>{}'.format(x[0], x[1]), axis=1)

In [38]:
train_df.loc[0,"Question"], train_df.loc[0,"Answer"], train_df.loc[0,"Question-Answer"]

('Who is at risk for Colorectal Cancer? ?',
 'Researchers are working hard to understand and identify the genes involved in colorectal cancer. Hereditary nonpolyposis colorectal cancer, or HNPCC, is one condition that causes people to develop colorectal cancer at a young age. The discovery of four genes involved with this disease has provided crucial clues about the role of DNA repair in colorectal and other cancers.',
 '<question>Who is at risk for Colorectal Cancer? ?<answer>Researchers are working hard to understand and identify the genes involved in colorectal cancer. Hereditary nonpolyposis colorectal cancer, or HNPCC, is one condition that causes people to develop colorectal cancer at a young age. The discovery of four genes involved with this disease has provided crucial clues about the role of DNA repair in colorectal and other cancers.')

In [39]:
val_df.loc[0,"Question"], val_df.loc[0,"Answer"], val_df.loc[0,"Question-Answer"]

('Is Langerhans cell histiocytosis inherited ?',
 'Is Langerhans cell histiocytosis inherited? Although Langerhans cell histiocytosis is generally considered a sporadic, non-hereditary condition, it has reportedly affected more than one individual in a family in a very limited number of cases (particularly identical twins).',
 '<question>Is Langerhans cell histiocytosis inherited ?<answer>Is Langerhans cell histiocytosis inherited? Although Langerhans cell histiocytosis is generally considered a sporadic, non-hereditary condition, it has reportedly affected more than one individual in a family in a very limited number of cases (particularly identical twins).')

- **Join the combined text using '\n' into a single string for training and validation separately**

In [41]:
# YOUR CODE HERE
train_text = '\n'.join(train_df['Question-Answer'].to_list())
val_text = '\n'.join(val_df['Question-Answer'].to_list())

- **Save the training and validation strings as text files**

In [43]:
# YOUR CODE HERE
with open('train_text.txt', 'w') as f:
  f.write(train_text)

with open('val_text.txt', 'w') as f:
  f.write(val_text)

**Exercise 6: Load pre-trained GPT2Tokenizer [0.5 Mark]**

- Use checkpoint = "gpt2"

In [44]:
# YOUR CODE HERE
checkpoint = "gpt2"
tokenizer = GPT2Tokenizer.from_pretrained(checkpoint)

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

**Exercise 7: Tokenize train and validation data and form TextDataset objects [0.5 Mark]**

- Use the loaded pre-trained tokenizer
- Use training and validation data saved in text files

In [45]:
# YOUR CODE HERE
train_dataset = TextDataset(tokenizer=tokenizer, file_path="train_text.txt", block_size=128)
val_dataset = TextDataset(tokenizer=tokenizer, file_path="val_text.txt", block_size=128)

**Exercise 8: Create a DataCollator object [0.5 Mark]**

In [46]:
# YOUR CODE HERE
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False, return_tensors="pt")

**Exercise 9: Load pre-trained GPT2LMHeadModel [0.5 Mark]**

In [47]:
# YOUR CODE HERE
model = GPT2LMHeadModel.from_pretrained(checkpoint)

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

**Exercise 10: Fine-tune GPT2 Model [1 Mark]**

- Specify training arguments and create a TrainingArguments object (Use 30 epochs)

- Train a GPT-2 model using the provided training arguments

- Save the resulting trained model and tokenizer to a specified output directory

In [48]:
# Set up the training arguments

# YOUR CODE HERE
model_output_path = "./medical_qna_gpt2_model"

training_args = TrainingArguments(
    output_dir = model_output_path,
    overwrite_output_dir = True,
    per_device_train_batch_size = 4, # try with 2
    per_device_eval_batch_size = 4,  #  try with 2
    num_train_epochs = 30,
    save_steps = 500,
    save_total_limit = 2,
    logging_dir = './logs',
    )

In [51]:
# Train the model
# YOUR CODE HERE
trainer = Trainer(
    model = model,
    args = training_args,
    data_collator = data_collator,
    train_dataset = train_dataset,
    eval_dataset = val_dataset,
)
trainer.train()

# Save the model
# YOUR CODE HERE
trainer.save_model(model_output_path)

# Save the tokenizer
# YOUR CODE HERE
tokenizer.save_pretrained(model_output_path)

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33msurabhi13gupta[0m ([33msurabhi13gupta-baci[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


`loss_type=None` was set in the config but it is unrecognised.Using the default loss: `ForCausalLMLoss`.


Step,Training Loss
500,2.5247
1000,1.9658
1500,1.5994
2000,1.3079
2500,1.0748
3000,0.8902
3500,0.7437
4000,0.6327
4500,0.5444
5000,0.4737


('./medical_qna_gpt2_model/tokenizer_config.json',
 './medical_qna_gpt2_model/special_tokens_map.json',
 './medical_qna_gpt2_model/vocab.json',
 './medical_qna_gpt2_model/merges.txt',
 './medical_qna_gpt2_model/added_tokens.json')

**Exercise 11: Test Model with user input prompts [1 Mark]**

- Create `generate_response()` function that takes a trained *model*, *tokenizer*, and a *prompt* string as input and generates a response using the GPT-2 model

- Test it with some user input prompts

In [52]:
# YOUR CODE HERE
def generate_response(model, tokenizer, prompt, max_length=100):

    input_ids = tokenizer.encode(prompt, return_tensors="pt")      # 'pt' for returning pytorch tensor

    # Create the attention mask and pad token id
    attention_mask = torch.ones_like(input_ids)
    pad_token_id = tokenizer.eos_token_id

    output = model.generate(
        input_ids,
        max_length=max_length,
        num_return_sequences=1,
        attention_mask=attention_mask,
        pad_token_id=pad_token_id
    )

    return tokenizer.decode(output[0], skip_special_tokens=True)

In [53]:
# Load the fine-tuned model and tokenizer

# YOUR CODE HERE
my_model = GPT2LMHeadModel.from_pretrained(model_output_path)
my_tokenizer = GPT2Tokenizer.from_pretrained(model_output_path)

In [54]:
# Response from model

# YOUR CODE HERE
prompt = "What precautions to take for a healthy life?"
response = generate_response(my_model, my_tokenizer, prompt)
print(response)

What precautions to take for a healthy life? - If you smoke, quit. - Maintain a healthy weight. - Be as physically active as you can. - Follow a heart healthy diet. - If you smoke, quit. Maintain a healthy weight. Be as physically active as you can. Follow a heart healthy diet. If you smoke, quit. Smoking can raise your risk for coronary heart disease and heart attack. Talk with your doctor about programs and products that can help you quit. Also


In [55]:
# Testing with given prompt 1

# YOUR CODE HERE
prompt = "What to do after being diagnosed with cancer?"
response = generate_response(my_model, my_tokenizer, prompt)
print(response)

What to do after being diagnosed with cancer? - Talk with your doctor about what you can do to keep your disease from getting worse. - If you have high blood pressure, follow your doctor's advice about keeping it under control. If you have diabetes, sometimes called high blood sugar, try to control your blood sugar level through diet and physical activity (as your doctor recommends). If needed, take medicine as prescribed. If you have diabetes, sometimes called high blood sugar, try to control your blood sugar


In [57]:
# Testing with given prompt 2

# YOUR CODE HERE
prompt = "Symptoms of Depression?"
response = generate_response(my_model, my_tokenizer, prompt)
print(response)

Symptoms of Depression? ?<answer>Symptoms of depression vary depending upon the cause. Some people with depression have no symptoms at all. The more common cause of depression is over-production of alcohol. Overproduction of alcohol causes depression in men and women, and in older adults. Short periods of unemployment (i.e. when unemployment does not exist) and low levels of education (i.e. high school) can also be caused by depression. Other common symptoms include - feeling nervous


**Exercise 12: Compare the performance of a *GPT2 model* with the *GPT2 model fine-tuned* on MedQuAD data [1 Mark]**

- Load another pre-trained GPT2LMHeadModel and do not fine-tune it

- To generate response using the untuned model, pass it as a parameter to `generate_response()` function

- Test both models (fine-tuned and untuned) with below user input prompts:

    - "What precautions to take for a healthy life?"
    - "What to do after being diagnosed with cancer?"
    - "What to do when feeling sick?"

In [58]:
# Load a pre-trained GPT2 model, do not finetune it with MedQuAD data

# YOUR CODE HERE
pretrained_model = GPT2LMHeadModel.from_pretrained(checkpoint)

In [59]:
# Testing with finetuned model: prompt 1

# YOUR CODE HERE
prompt = "Symptoms of Depression?"
response = generate_response(pretrained_model, my_tokenizer, prompt)
print(response)

Symptoms of Depression?

The symptoms of depression are often similar to those of other mental illnesses.

Symptoms of depression are often similar to those of other mental illnesses. Depression is a mental illness that is caused by a lack of motivation or control.

A person with depression is often unable to control their emotions, and they often feel like they are being controlled.

A person with depression is often unable to control their emotions, and they often feel like they are being controlled


In [60]:
# Testing with untuned model: prompt 1

# YOUR CODE HERE
prompt = "Symptoms of Depression?"
response = generate_response(my_model, my_tokenizer, prompt)
print(response)

Symptoms of Depression? ?<answer>Symptoms of depression vary depending upon the cause. Some people with depression have no symptoms at all. The more common cause of depression is over-production of alcohol. Overproduction of alcohol causes depression in men and women, and in older adults. Short periods of unemployment (i.e. when unemployment does not exist) and low levels of education (i.e. high school) can also be caused by depression. Other common symptoms include - feeling nervous


In [61]:
# Testing with finetuned model: prompt 2

# YOUR CODE HERE
prompt = "What precautions to take for a healthy life?"
response = generate_response(pretrained_model, my_tokenizer, prompt)
print(response)

What precautions to take for a healthy life?

The following are some of the most common questions you'll hear from your doctor or nurse about your health.

What are the risks of taking a drug that can cause cancer?

The risks of taking a drug that can cause cancer are very high.

What are the risks of taking a drug that can cause cancer?

The risks of taking a drug that can cause cancer are very high.

What are the risks


In [62]:
# Testing with untuned model: prompt 2

# YOUR CODE HERE
prompt = "What precautions to take for a healthy life?"
response = generate_response(my_model, my_tokenizer, prompt)
print(response)

What precautions to take for a healthy life? - If you smoke, quit. - Maintain a healthy weight. - Be as physically active as you can. - Follow a heart healthy diet. - If you smoke, quit. Maintain a healthy weight. Be as physically active as you can. Follow a heart healthy diet. If you smoke, quit. Smoking can raise your risk for coronary heart disease and heart attack. Talk with your doctor about programs and products that can help you quit. Also


In [63]:
# Testing with finetuned model: prompt 3

# YOUR CODE HERE
prompt = "What to do after being diagnosed with cancer?"
response = generate_response(pretrained_model, my_tokenizer, prompt)
print(response)

What to do after being diagnosed with cancer?

The first step is to get your doctor's approval for a treatment.

If you have a cancer diagnosis, you may need to get a second opinion.

If you have a cancer diagnosis, you may need to get a second opinion. If you have a cancer diagnosis, you may need to get a third opinion.

If you have a cancer diagnosis, you may need to get a third opinion. If you have a cancer


In [64]:
# Testing with untuned model: prompt 3

# YOUR CODE HERE
prompt = "What to do after being diagnosed with cancer?"
response = generate_response(my_model, my_tokenizer, prompt)
print(response)

What to do after being diagnosed with cancer? - Talk with your doctor about what you can do to keep your disease from getting worse. - If you have high blood pressure, follow your doctor's advice about keeping it under control. If you have diabetes, sometimes called high blood sugar, try to control your blood sugar level through diet and physical activity (as your doctor recommends). If needed, take medicine as prescribed. If you have diabetes, sometimes called high blood sugar, try to control your blood sugar
