<a href="https://colab.research.google.com/github/tashidu/HealthGPT/blob/main/HealthGPT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install opendatasets



In [2]:
import opendatasets as od

In [3]:
od.download('https://www.kaggle.com/datasets/thedevastator/comprehensive-medical-q-a-dataset')

Skipping, found downloaded files in "./comprehensive-medical-q-a-dataset" (use force=True to force download)


In [4]:
import pandas  as pd

In [5]:
datasetpath = '/content/comprehensive-medical-q-a-dataset'

In [6]:
df = pd.read_csv(datasetpath+'/train.csv')

In [7]:
print(df.columns)
print(df.head())


Index(['qtype', 'Question', 'Answer'], dtype='object')
             qtype                                           Question  \
0   susceptibility  Who is at risk for Lymphocytic Choriomeningiti...   
1         symptoms  What are the symptoms of Lymphocytic Choriomen...   
2   susceptibility  Who is at risk for Lymphocytic Choriomeningiti...   
3  exams and tests  How to diagnose Lymphocytic Choriomeningitis (...   
4        treatment  What are the treatments for Lymphocytic Chorio...   

                                              Answer  
0  LCMV infections can occur after exposure to fr...  
1  LCMV is most commonly recognized as causing ne...  
2  Individuals of all ages who come into contact ...  
3  During the first phase of the disease, the mos...  
4  Aseptic meningitis, encephalitis, or meningoen...  


In [8]:
df.head()

Unnamed: 0,qtype,Question,Answer
0,susceptibility,Who is at risk for Lymphocytic Choriomeningiti...,LCMV infections can occur after exposure to fr...
1,symptoms,What are the symptoms of Lymphocytic Choriomen...,LCMV is most commonly recognized as causing ne...
2,susceptibility,Who is at risk for Lymphocytic Choriomeningiti...,Individuals of all ages who come into contact ...
3,exams and tests,How to diagnose Lymphocytic Choriomeningitis (...,"During the first phase of the disease, the mos..."
4,treatment,What are the treatments for Lymphocytic Chorio...,"Aseptic meningitis, encephalitis, or meningoen..."


In [9]:
df['input_text'] = df['qtype'] + ": " + df['Question']
df['target_text'] = df['Answer']


In [10]:
from sklearn.model_selection import train_test_split

train_df, val_df = train_test_split(df[['input_text', 'target_text']], test_size=0.1, random_state=42)


In [11]:
import torch
from transformers import T5Tokenizer

tokenizer = T5Tokenizer.from_pretrained('t5-base')

class MedicalQADataset(torch.utils.data.Dataset):
    def __init__(self, dataframe, tokenizer, source_max_len=512, target_max_len=512):
        self.data = dataframe
        self.tokenizer = tokenizer
        self.source_max_len = source_max_len
        self.target_max_len = target_max_len

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        source = self.data.iloc[index]['input_text']
        target = self.data.iloc[index]['target_text']

        source_encoding = self.tokenizer(
            source,
            max_length=self.source_max_len,
            padding='max_length',
            truncation=True,
            return_tensors="pt",
        )
        target_encoding = self.tokenizer(
            target,
            max_length=self.target_max_len,
            padding='max_length',
            truncation=True,
            return_tensors="pt",
        )

        labels = target_encoding.input_ids.squeeze()
        labels[labels == tokenizer.pad_token_id] = -100  # ignore padding in loss

        return {
            'input_ids': source_encoding.input_ids.squeeze(),
            'attention_mask': source_encoding.attention_mask.squeeze(),
            'labels': labels,
        }


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


In [12]:
!pip install --upgrade transformers




In [15]:
import torch
torch.cuda.empty_cache()


In [17]:
import gc
import torch

gc.collect()
torch.cuda.empty_cache()


In [18]:
from transformers import T5ForConditionalGeneration, Trainer, TrainingArguments

# Load the pretrained T5 model
model = T5ForConditionalGeneration.from_pretrained('t5-base')

# Create dataset objects for train and validation splits
train_dataset = MedicalQADataset(train_df, tokenizer)
val_dataset = MedicalQADataset(val_df, tokenizer)

# Define training arguments
training_args = TrainingArguments(
    output_dir='./healthgpt_checkpoints',
    num_train_epochs=3,
    per_device_train_batch_size=1,        # smallest batch size
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=4,        # simulate 4 as effective batch size
    fp16=True,
    eval_strategy='steps',
    eval_steps=500,
    save_steps=1000,
    logging_dir='./logs',
    logging_steps=100,
    save_total_limit=2,
    load_best_model_at_end=True,
    run_name='healthgpt_exp_tinybatch',
)


# Initialize the Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
)

# Start training
trainer.train()

Step,Training Loss,Validation Loss
500,2.4721,2.145517
1000,2.2005,2.005263
1500,2.1859,1.940332
2000,2.2841,1.963891
2500,2.1756,1.963879


Step,Training Loss,Validation Loss
500,2.4721,2.145517
1000,2.2005,2.005263
1500,2.1859,1.940332
2000,2.2841,1.963891
2500,2.1756,1.963879
3000,2.2276,1.963879
3500,2.3282,1.963879
4000,2.2276,1.963879
4500,2.3185,1.963879
5000,0.0,


KeyboardInterrupt: 