In [1]:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, Trainer, TrainingArguments, EvalPrediction,AutoModelForSequenceClassification,AutoModelForCausalLM
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments,DataCollatorForSeq2Seq
from peft import LoraConfig, get_peft_model,  TaskType
from datasets import  load_dataset ,Dataset,DatasetDict,load_metric
import pandas as pd
import torch
from sklearn.model_selection import train_test_split
import numpy as np
import sentencepiece
from transformers import T5Tokenizer, T5ForConditionalGeneration


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import wandb


# Initialize a new run
wandb.init(project='loramedical')

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33msaahilkatariads[0m ([33msaahilkatariads-MCKV Institute of Engineering[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [4]:
if torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")



In [5]:
dataset = pd.read_csv('train.csv')
dataset = dataset.drop('qtype', axis=1)
dataset = dataset.rename(columns={'Question': 'question', 'Answer': 'answer'})

In [6]:
df_full_train, df_test = train_test_split(dataset, test_size=0.2, random_state=56)
df_train, df_val = train_test_split(df_full_train, test_size=0.25, random_state=56)

In [7]:
df_train = df_train.reset_index(drop=True)
df_val = df_train.reset_index(drop=True)
df_test = df_train.reset_index(drop=True)
train_dataset = Dataset.from_pandas(df_train)
val_dataset = Dataset.from_pandas(df_val)
test_dataset = Dataset.from_pandas(df_test)

In [8]:
health_dataset_dict = DatasetDict({
    'train': train_dataset,
    'validation': val_dataset,
    'test': test_dataset
})

In [9]:
import sentencepiece
from transformers import T5Tokenizer, T5ForConditionalGeneration

tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-large")
model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-large", device_map=device)


You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


In [10]:
prefix = "Assuming you are working as Doctor. Please answer this question: "



def preprocess_function(examples):
   inputs = [prefix + doc for doc in examples["question"]]
   model_inputs = tokenizer(inputs, max_length=256, truncation=True)
   labels = tokenizer(text_target=examples["answer"], 
                      max_length=256,         
                      truncation=True)

   model_inputs["labels"] = labels["input_ids"]
   return model_inputs

In [11]:
tokenized_dataset = health_dataset_dict.map(preprocess_function, batched=True)

Map: 100%|██████████| 9843/9843 [00:05<00:00, 1835.76 examples/s]
Map: 100%|██████████| 9843/9843 [00:05<00:00, 1850.64 examples/s]
Map: 100%|██████████| 9843/9843 [00:05<00:00, 1870.54 examples/s]


In [12]:
peft_config = LoraConfig(
    r=32,  # Rank
    lora_alpha=16,  # Scaling
    lora_dropout=0.01,  # Dropout probability
    bias="none",  # Bias term setting
    target_modules=["q",'v'],
    task_type=TaskType.SEQ_2_SEQ_LM
        # Apply LoRA to the query and value projection layers
     
)

In [13]:
model = get_peft_model(model, peft_config)

In [15]:
training_args = Seq2SeqTrainingArguments(
    output_dir='./results',           # output directory
    evaluation_strategy='epoch',      # evaluation is done at the end of each epoch
    learning_rate=1e-4,               # learning rate
    per_device_train_batch_size=8,    # batch size for training
    per_device_eval_batch_size=16,    # batch size for evaluation
    weight_decay=0.01,                # strength of weight decay
    save_total_limit=3,               # limits the total amount of checkpoints. Deletes older checkpoints.
    num_train_epochs=15,               # number of training epochs
    predict_with_generate=True,       # generate predictions for evaluation
    report_to='wandb'
)



In [16]:
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["validation"],
    tokenizer=tokenizer,
    data_collator=DataCollatorForSeq2Seq(tokenizer, model=model)
    
)

In [17]:

trainer.train()

  3%|▎         | 500/18465 [2:04:01<120:43:12, 24.19s/it]  

{'loss': 2.273, 'grad_norm': 0.1970219910144806, 'learning_rate': 9.729217438396968e-05, 'epoch': 0.41}


  5%|▌         | 1000/18465 [4:41:40<8:08:24,  1.68s/it]   

{'loss': 2.024, 'grad_norm': 0.2271500676870346, 'learning_rate': 9.458434876793935e-05, 'epoch': 0.81}


  7%|▋         | 1231/18465 [5:56:39<6:13:50,  1.30s/it]    
  7%|▋         | 1231/18465 [8:10:47<6:13:50,  1.30s/it]

{'eval_loss': 1.7467290163040161, 'eval_runtime': 8048.7748, 'eval_samples_per_second': 1.223, 'eval_steps_per_second': 0.077, 'epoch': 1.0}


  8%|▊         | 1500/18465 [9:46:43<13:50:36,  2.94s/it]     

{'loss': 1.9672, 'grad_norm': 0.2017890363931656, 'learning_rate': 9.187652315190903e-05, 'epoch': 1.22}


 11%|█         | 2000/18465 [10:02:28<8:14:45,  1.80s/it]

{'loss': 1.9213, 'grad_norm': 0.3887452781200409, 'learning_rate': 8.916869753587869e-05, 'epoch': 1.62}


                                                          
 13%|█▎        | 2462/18465 [10:30:36<8:37:05,  1.94s/it]

{'eval_loss': 1.680251121520996, 'eval_runtime': 828.3144, 'eval_samples_per_second': 11.883, 'eval_steps_per_second': 0.744, 'epoch': 2.0}


 14%|█▎        | 2500/18465 [10:32:16<8:58:02,  2.02s/it]    

{'loss': 1.8741, 'grad_norm': 0.20996493101119995, 'learning_rate': 8.646087191984836e-05, 'epoch': 2.03}


 16%|█▌        | 3000/18465 [10:50:28<10:17:06,  2.39s/it]

{'loss': 1.8742, 'grad_norm': 0.21190248429775238, 'learning_rate': 8.375304630381803e-05, 'epoch': 2.44}


 19%|█▉        | 3500/18465 [11:11:50<8:08:44,  1.96s/it]  

{'loss': 1.849, 'grad_norm': 0.2802028954029083, 'learning_rate': 8.104522068778771e-05, 'epoch': 2.84}


                                                          
 20%|██        | 3693/18465 [11:34:56<10:37:27,  2.59s/it]

{'eval_loss': 1.6409293413162231, 'eval_runtime': 983.5409, 'eval_samples_per_second': 10.008, 'eval_steps_per_second': 0.626, 'epoch': 3.0}


 22%|██▏       | 4000/18465 [11:49:10<9:05:48,  2.26s/it]    

{'loss': 1.8315, 'grad_norm': 0.20034688711166382, 'learning_rate': 7.833739507175738e-05, 'epoch': 3.25}


 24%|██▍       | 4500/18465 [12:10:51<9:46:58,  2.52s/it] 

{'loss': 1.8037, 'grad_norm': 0.21320746839046478, 'learning_rate': 7.562956945572706e-05, 'epoch': 3.66}


                                                          
 27%|██▋       | 4924/18465 [12:48:52<8:19:55,  2.22s/it]

{'eval_loss': 1.613669514656067, 'eval_runtime': 1013.8823, 'eval_samples_per_second': 9.708, 'eval_steps_per_second': 0.608, 'epoch': 4.0}


 27%|██▋       | 5000/18465 [12:52:35<8:33:06,  2.29s/it]    

{'loss': 1.8274, 'grad_norm': 0.2463410645723343, 'learning_rate': 7.292174383969674e-05, 'epoch': 4.06}


 30%|██▉       | 5500/18465 [13:15:36<8:23:25,  2.33s/it] 

{'loss': 1.8058, 'grad_norm': 0.2571028470993042, 'learning_rate': 7.02139182236664e-05, 'epoch': 4.47}


 32%|███▏      | 6000/18465 [13:36:27<7:51:05,  2.27s/it] 

{'loss': 1.7926, 'grad_norm': 0.27406466007232666, 'learning_rate': 6.750609260763608e-05, 'epoch': 4.87}


                                                          
 33%|███▎      | 6155/18465 [14:02:02<12:37:09,  3.69s/it]

{'eval_loss': 1.5924218893051147, 'eval_runtime': 1062.7325, 'eval_samples_per_second': 9.262, 'eval_steps_per_second': 0.58, 'epoch': 5.0}


 35%|███▌      | 6500/18465 [14:15:33<6:21:10,  1.91s/it]    

{'loss': 1.7749, 'grad_norm': 0.2934473156929016, 'learning_rate': 6.479826699160574e-05, 'epoch': 5.28}


 38%|███▊      | 7000/18465 [14:33:20<6:54:02,  2.17s/it] 

{'loss': 1.7764, 'grad_norm': 0.2927219867706299, 'learning_rate': 6.209044137557541e-05, 'epoch': 5.69}


                                                          
 40%|████      | 7386/18465 [15:03:19<9:05:08,  2.95s/it]

{'eval_loss': 1.5752484798431396, 'eval_runtime': 924.3071, 'eval_samples_per_second': 10.649, 'eval_steps_per_second': 0.666, 'epoch': 6.0}


 41%|████      | 7500/18465 [15:07:33<6:22:12,  2.09s/it]   

{'loss': 1.7748, 'grad_norm': 0.2277224361896515, 'learning_rate': 5.938261575954509e-05, 'epoch': 6.09}


 43%|████▎     | 8000/18465 [15:28:42<6:00:43,  2.07s/it] 

{'loss': 1.7679, 'grad_norm': 0.19758470356464386, 'learning_rate': 5.667479014351476e-05, 'epoch': 6.5}


 46%|████▌     | 8500/18465 [15:44:02<4:42:56,  1.70s/it] 

{'loss': 1.7611, 'grad_norm': 0.2536078691482544, 'learning_rate': 5.396696452748443e-05, 'epoch': 6.9}


                                                         
 47%|████▋     | 8617/18465 [16:02:15<9:34:43,  3.50s/it]

{'eval_loss': 1.5626046657562256, 'eval_runtime': 852.0821, 'eval_samples_per_second': 11.552, 'eval_steps_per_second': 0.723, 'epoch': 7.0}


 49%|████▊     | 9000/18465 [16:15:50<5:20:44,  2.03s/it]   

{'loss': 1.7487, 'grad_norm': 0.229537233710289, 'learning_rate': 5.125913891145411e-05, 'epoch': 7.31}


 51%|█████▏    | 9500/18465 [16:34:39<5:47:51,  2.33s/it] 

{'loss': 1.7307, 'grad_norm': 0.23780874907970428, 'learning_rate': 4.855131329542378e-05, 'epoch': 7.72}


                                                          
 53%|█████▎    | 9848/18465 [17:09:12<7:12:31,  3.01s/it]

{'eval_loss': 1.5475730895996094, 'eval_runtime': 1279.0529, 'eval_samples_per_second': 7.696, 'eval_steps_per_second': 0.482, 'epoch': 8.0}


 54%|█████▍    | 10000/18465 [17:17:21<5:30:58,  2.35s/it]  

{'loss': 1.7584, 'grad_norm': 0.23942220211029053, 'learning_rate': 4.584348767939345e-05, 'epoch': 8.12}


 57%|█████▋    | 10500/18465 [17:36:15<4:18:14,  1.95s/it] 

{'loss': 1.7313, 'grad_norm': 0.2655499279499054, 'learning_rate': 4.3135662063363117e-05, 'epoch': 8.53}


 60%|█████▉    | 11000/18465 [17:54:07<4:31:35,  2.18s/it] 

{'loss': 1.7435, 'grad_norm': 0.24380697309970856, 'learning_rate': 4.0427836447332795e-05, 'epoch': 8.94}


                                                          
 60%|██████    | 11079/18465 [18:11:22<5:57:06,  2.90s/it]

{'eval_loss': 1.5386842489242554, 'eval_runtime': 864.887, 'eval_samples_per_second': 11.381, 'eval_steps_per_second': 0.712, 'epoch': 9.0}


 62%|██████▏   | 11500/18465 [18:26:32<3:53:21,  2.01s/it]   

{'loss': 1.7352, 'grad_norm': 0.2692797780036926, 'learning_rate': 3.772001083130247e-05, 'epoch': 9.34}


 65%|██████▍   | 12000/18465 [18:45:19<4:00:44,  2.23s/it]

{'loss': 1.7327, 'grad_norm': 0.23657335340976715, 'learning_rate': 3.501218521527214e-05, 'epoch': 9.75}


                                                          
 67%|██████▋   | 12310/18465 [19:09:33<5:33:12,  3.25s/it]

{'eval_loss': 1.529298186302185, 'eval_runtime': 791.4374, 'eval_samples_per_second': 12.437, 'eval_steps_per_second': 0.778, 'epoch': 10.0}


 68%|██████▊   | 12500/18465 [19:16:21<3:29:04,  2.10s/it]   

{'loss': 1.717, 'grad_norm': 0.2811238765716553, 'learning_rate': 3.230435959924181e-05, 'epoch': 10.15}


 70%|███████   | 13000/18465 [19:36:41<3:42:45,  2.45s/it] 

{'loss': 1.7167, 'grad_norm': 0.2645471692085266, 'learning_rate': 2.9596533983211482e-05, 'epoch': 10.56}


 73%|███████▎  | 13500/18465 [19:54:32<2:51:02,  2.07s/it]

{'loss': 1.7228, 'grad_norm': 0.31828320026397705, 'learning_rate': 2.6888708367181153e-05, 'epoch': 10.97}


                                                          
 73%|███████▎  | 13541/18465 [20:09:31<4:16:57,  3.13s/it]

{'eval_loss': 1.5237170457839966, 'eval_runtime': 810.1524, 'eval_samples_per_second': 12.15, 'eval_steps_per_second': 0.76, 'epoch': 11.0}


 76%|███████▌  | 14000/18465 [20:26:50<2:46:35,  2.24s/it]   

{'loss': 1.7121, 'grad_norm': 0.24876298010349274, 'learning_rate': 2.418088275115083e-05, 'epoch': 11.37}


 79%|███████▊  | 14500/18465 [20:45:17<2:23:21,  2.17s/it]

{'loss': 1.7076, 'grad_norm': 0.3201294243335724, 'learning_rate': 2.14730571351205e-05, 'epoch': 11.78}


                                                          
 80%|████████  | 14772/18465 [21:09:17<3:19:33,  3.24s/it]

{'eval_loss': 1.5178865194320679, 'eval_runtime': 835.7198, 'eval_samples_per_second': 11.778, 'eval_steps_per_second': 0.737, 'epoch': 12.0}


 81%|████████  | 15000/18465 [21:18:00<2:12:57,  2.30s/it]   

{'loss': 1.7034, 'grad_norm': 0.28505027294158936, 'learning_rate': 1.8765231519090172e-05, 'epoch': 12.19}


 84%|████████▍ | 15500/18465 [21:36:25<2:15:03,  2.73s/it]

{'loss': 1.722, 'grad_norm': 0.2720876932144165, 'learning_rate': 1.6057405903059843e-05, 'epoch': 12.59}


 87%|████████▋ | 16000/18465 [21:55:07<1:38:38,  2.40s/it]

{'loss': 1.711, 'grad_norm': 0.30954286456108093, 'learning_rate': 1.3349580287029517e-05, 'epoch': 13.0}


                                                          
 87%|████████▋ | 16003/18465 [22:10:05<2:30:42,  3.67s/it]

{'eval_loss': 1.5148471593856812, 'eval_runtime': 884.8254, 'eval_samples_per_second': 11.124, 'eval_steps_per_second': 0.696, 'epoch': 13.0}


 89%|████████▉ | 16500/18465 [22:30:34<1:17:53,  2.38s/it]   

{'loss': 1.7081, 'grad_norm': 0.2785217761993408, 'learning_rate': 1.0641754670999188e-05, 'epoch': 13.4}


 92%|█████████▏| 17000/18465 [22:51:40<57:47,  2.37s/it]  

{'loss': 1.6975, 'grad_norm': 0.326979398727417, 'learning_rate': 7.933929054968862e-06, 'epoch': 13.81}


                                                          
 93%|█████████▎| 17234/18465 [23:24:58<1:40:36,  4.90s/it]

{'eval_loss': 1.511889934539795, 'eval_runtime': 1197.5932, 'eval_samples_per_second': 8.219, 'eval_steps_per_second': 0.514, 'epoch': 14.0}


 95%|█████████▍| 17500/18465 [23:36:24<37:41,  2.34s/it]     

{'loss': 1.7044, 'grad_norm': 0.33532893657684326, 'learning_rate': 5.226103438938533e-06, 'epoch': 14.22}


 97%|█████████▋| 18000/18465 [23:56:41<30:42,  3.96s/it]

{'loss': 1.7034, 'grad_norm': 0.2590237259864807, 'learning_rate': 2.5182778229082047e-06, 'epoch': 14.62}


                                                        
100%|██████████| 18465/18465 [24:31:43<00:00,  4.78s/it]

{'eval_loss': 1.5114909410476685, 'eval_runtime': 925.7911, 'eval_samples_per_second': 10.632, 'eval_steps_per_second': 0.665, 'epoch': 15.0}
{'train_runtime': 88303.587, 'train_samples_per_second': 1.672, 'train_steps_per_second': 0.209, 'train_loss': 1.78713118106151, 'epoch': 15.0}





TrainOutput(global_step=18465, training_loss=1.78713118106151, metrics={'train_runtime': 88303.587, 'train_samples_per_second': 1.672, 'train_steps_per_second': 0.209, 'total_flos': 2.59039270075392e+16, 'train_loss': 1.78713118106151, 'epoch': 15.0})

In [None]:
final_results = trainer.evaluate()



In [21]:
print(final_results)

{'eval_runtime': 70.4313, 'eval_samples_per_second': 139.753, 'eval_steps_per_second': 8.746, 'epoch': 5.0}


In [18]:
def generate_answer(question, model, tokenizer, max_length=512):
    input_text = "Assuming you are working as Doctor. Please answer this question: " + question
    inputs = tokenizer(input_text, return_tensors="pt").to(device)
    outputs = model.generate(
        input_ids=inputs["input_ids"], 
        max_length=256, 
        num_beams=5, 
        early_stopping=True
    )
    answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return answer



In [33]:
# Example question 1
question = "What are the treatments for spinocerebellar ataxia type 36"

# Generate the answer
answer = generate_answer(question, model, tokenizer)
print("Question:", question)
print("Answer:", answer)

Question: What are the treatments for spinocerebellar ataxia type 36
Answer: These resources address the diagnosis or management of spinocerebellar ataxia type 36: - Gene Review: Gene Review: Spinocerebellar Ataxia Type 36 - Genetic Testing Registry: Spinocerebellar ataxia type 36 These resources from MedlinePlus offer information about the diagnosis and management of various health conditions: - Diagnostic Tests - Drug Therapy - Surgery and Rehabilitation - Genetic Counseling - Palliative Care


In [21]:
# Example question 1
question = "Is Fuchs endothelial dystrophy inherited ?"

# Generate the answer
answer = generate_answer(question, model, tokenizer)
print("Question:", question)
print("Answer:", answer)

Question: Is Fuchs endothelial dystrophy inherited ?
Answer: This condition is inherited in an autosomal recessive pattern, which means both copies of the gene in each cell have mutations. The parents of an individual with an autosomal recessive condition each carry one copy of the mutated gene, but they typically do not show signs and symptoms of the condition.


In [20]:
train_dataset[4]

{'question': 'Is Fuchs endothelial dystrophy inherited ?',
 'answer': 'In some cases, Fuchs endothelial dystrophy appears to be inherited in an autosomal dominant pattern, which means one copy of the altered gene in each cell is sufficient to cause the disorder. When this condition is caused by a mutation in the COL8A2 gene, it is inherited in an autosomal dominant pattern. In addition, an autosomal dominant inheritance pattern is apparent in some situations in which the condition is caused by alterations in an unknown gene.  In many families, the inheritance pattern is unknown.  Some cases result from new mutations in a gene and occur in people with no history of the disorder in their family.'}

In [22]:
model.save_pretrained('./biomrcmodel')

In [23]:
tokenizer.save_pretrained('./biotokenizer')

('./biotokenizer/tokenizer_config.json',
 './biotokenizer/special_tokens_map.json',
 './biotokenizer/spiece.model',
 './biotokenizer/added_tokens.json')