<a href="https://colab.research.google.com/github/randomrahulm/clincalFindings_gemma_fine_tunned/blob/main/clinical_findings_gemma_finetuning_notebook.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<a href="https://colab.research.google.com/github/bhattbhavesh91/google-gemma-finetuning-n2sql/blob/main/n2sql-google-gemma-finetuning-notebook.ipynb" target="_blank"><img height="40" alt="Run your own notebook in Colab" src = "https://colab.research.google.com/assets/colab-badge.svg"></a>

In [None]:
!pip3 install -q -U bitsandbytes==0.42.0
!pip3 install -q -U peft==0.8.2
!pip3 install -q -U trl==0.7.10
!pip3 install -q -U accelerate==0.27.1
!pip3 install -q -U datasets==2.17.0
!pip3 install -q -U transformers==4.38.0

In [None]:
import os
import transformers
import torch
from google.colab import userdata
from datasets import load_dataset
from trl import SFTTrainer
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import BitsAndBytesConfig, GemmaTokenizer

In [None]:
os.environ["ht_token"] = userdata.get('ht_token')

In [None]:
model_id = "google/gemma-2b"
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.environ['ht_token'])
model = AutoModelForCausalLM.from_pretrained(model_id,
                                             quantization_config=bnb_config,
                                             device_map={"":0},
                                             token=os.environ['ht_token'])

tokenizer_config.json:   0%|          | 0.00/1.11k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/555 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/627 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/13.5k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.95G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/67.1M [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/137 [00:00<?, ?B/s]

In [None]:
text = "Quote: Our doubts are traitors,"
device = "cuda:0"
inputs = tokenizer(text, return_tensors="pt").to(device)

outputs = model.generate(**inputs, max_new_tokens=20)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

Quote: Our doubts are traitors, and nothing but fools doth venture.

The above quote is from the play, <em>The Tempest


In [None]:
os.environ["WANDB_DISABLED"] = "false"

In [None]:
lora_config = LoraConfig(
    r = 8,
    target_modules = ["q_proj", "o_proj", "k_proj", "v_proj",
                      "gate_proj", "up_proj", "down_proj"],
    task_type = "CAUSAL_LM",
)

In [None]:
data = load_dataset("iamcreator/clinical_findings")

In [None]:
data

DatasetDict({
    train: Dataset({
        features: ['Input', 'Output', 'Context', '__index_level_0__'],
        num_rows: 3157
    })
    test: Dataset({
        features: ['Input', 'Output', 'Context', '__index_level_0__'],
        num_rows: 1353
    })
})

In [None]:
data['train'][0]

{'Input': 'ACUTE IWMI ( STK + 14/08/2022) NORMAL LV SYSTOLIC FUNCTION CAG-DOUBLE VESSEL DISEASE(15/08/22) RESCUE WITH STENTING TO DISTAL RCA & PTCA MID LAD(15/08/22)',
 'Output': 'Patient was admitted with the complaints of chest pain retrosternal radiates to left upper limb and lower jaw on set 2 AM intermittent continuous pain from 1 PM .cag was done which shows double vessel disease.Hence PTCA WITH STENTING TO DISTAL RCA & PTCA MID LAD. Post procedure period was uneventful. Other than that he was treated with Antiplatelets statin bronchodilators diuretics IV Fluids and antibiotic along with other supportive measures. Patient become symptomatically better and discharged in a stable condition on the following medications.',
 'Context': 'Patient was admitted with the complaints of chest pain retrosternal radiates to left upper limb and lower jaw on set 2 AM intermittent continuous pain from 1 PM .',
 '__index_level_0__': 3908}

In [None]:
data = data.map(lambda samples: tokenizer(samples["Input"],
                                          samples["Context"]), batched=True)

Map:   0%|          | 0/3157 [00:00<?, ? examples/s]

Map:   0%|          | 0/1353 [00:00<?, ? examples/s]

In [None]:
def formatting_func(example):
    text = f"Input: {example['Input'][0]}\nContext: {example['Context'][0]}\nOutput: {example['Output'][0]}"
    return [text]

In [None]:
trainer = SFTTrainer(
    model=model,
    train_dataset=data["train"],
    args=transformers.TrainingArguments(
        per_device_train_batch_size=2,
        gradient_accumulation_steps=2,
        warmup_steps=2,
        max_steps=50,
        learning_rate=2e-4,
        fp16=True,
        logging_steps=1,
        output_dir="outputs",
        optim="paged_adamw_8bit"
    ),
    peft_config=lora_config,
    formatting_func=formatting_func,
)

In [None]:
trainer.train()

Step,Training Loss
1,2.9349
2,2.9349
3,2.8789
4,2.7775
5,2.6515
6,2.5042
7,2.3688
8,2.252
9,2.1169
10,1.9759


TrainOutput(global_step=50, training_loss=0.8676691317558288, metrics={'train_runtime': 76.7105, 'train_samples_per_second': 2.607, 'train_steps_per_second': 0.652, 'total_flos': 610604427018240.0, 'train_loss': 0.8676691317558288, 'epoch': 50.0})

In [None]:
text = """Input: INFERIOR WALL MI (DELAYED PRESENTATION) ISCHEMIC HEART DISEASE MILD LV SYSTOLIC DYSFUNCTION CAG-DOUBLE VESSEL DISEASE (27/07/2022) PRIMARY PTCA  WITH STENTING TO  OSTIOPROXIMAL   OM (27/07/2022)
Context: Patient was admitted with the complaints of Chest Pain since 3 days associated with sweating and palpitation. He is a known case of hypertension and diabetes mellitus """
device = "cuda:0"
inputs = tokenizer(text, return_tensors="pt").to(device)

outputs = model.generate(**inputs, max_new_tokens=200)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

Input: INFERIOR WALL MI (DELAYED PRESENTATION) ISCHEMIC HEART DISEASE MILD LV SYSTOLIC DYSFUNCTION CAG-DOUBLE VESSEL DISEASE (27/07/2022) PRIMARY PTCA  WITH STENTING TO  OSTIOPROXIMAL   OM (27/07/2022)
Context: Patient was admitted with the complaints of Chest Pain since 3 days associated with sweating and palpitation. He is a known case of hypertension and diabetes mellitus 
Output: Patient was admitted with the complaints of Chest Pain since 3 days associated with sweating and palpitation. Blood investigation shows that he is a patient of diabetes mellitus. CAG was done which shows Double vessel disease. Hence primary PTCA with stenting to OSTIOPROXIMAL  OM. Post procedure period was uneventful. Other than that he was treated with Antiplatelets statin diuretics IV Fluids and antibiotic along with other supportive measures. Patient become symptomatically better and discharged in a stable condition on the following medications. Post PTCA period was uneventful. He was treated with Antip

In [None]:
trainer.save_model("./model")

In [None]:
!pip install datasets
!pip install huggingface_hub

Collecting datasets
  Downloading datasets-2.18.0-py3-none-any.whl (510 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m510.5/510.5 kB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m8.8 MB/s[0m eta [36m0:00:00[0m
Collecting xxhash (from datasets)
  Downloading xxhash-3.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (194 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m194.1/194.1 kB[0m [31m9.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting multiprocess (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl (134 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m16.7 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: xxhash, dill, multiprocess, datasets
Successfully installed datasets-

In [None]:
from huggingface_hub import login
from datasets import Dataset

In [None]:
login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [None]:
import pandas as pd
input=pd.read_csv("/content/LLM_1.csv")
input

Unnamed: 0,Input,Output,Context
0,ISCHEMIC HEART DISEASE AWMI(Thrombolysed on 15...,Patient was admitted with complaints of neck p...,Patient was admitted with complaints of neck p...
1,CORONARY ARTERY DISESASE SEVERE LV SYSTOLIC DY...,Admitted with history of\t non healing ulcer i...,Admitted with history of\t non healing ulcer i...
2,LRTI NORMAL LV SYSTOLIC FUNCTION DYSLIPIDEMIA ...,Patient was admitted with the complaints of fe...,Patient was admitted with the complaints of fe...
3,AF WITH FVR OLD MI ISCHEMIC HEART DISEASE SEVE...,Patient was admitted with the complaints of br...,Patient was admitted with the complaints of br...
4,ACUTE IWMI ISCHEMIC HEART DISEASE GOOD LV SYST...,Patient was admitted with the complaints of ch...,Patient was admitted with the complaints of ch...
...,...,...,...
4659,STABLE ISCHEMIC HEART DISEASE NORMAL LV SYSTO...,Patient was admitted with the complaints of re...,Patient was admitted with the complaints of re...
4660,RECENT INFERIOR WALL MI ISCHEMIC HEART DISEAS...,Patient was admitted for CAG. CAG was done whi...,Patient was admitted for CAG.
4661,LRTI DILATED CARDIOMYOPATHY ISCHEMIC HEART DIS...,Patient was admitted with the complaints of pe...,Patient was admitted with the complaints of pe...
4662,DILATED CARDIOMYOPATHY SEVERE LV SYSTOLIC DY...,Patient was admitted with the complaints of ch...,Patient was admitted with the complaints of ch...


In [None]:
dataset=Dataset.from_pandas(input)

In [None]:
dataset=dataset.train_test_split(0.3)

print(dataset)

DatasetDict({
    train: Dataset({
        features: ['Input', 'Output', 'Context'],
        num_rows: 3264
    })
    test: Dataset({
        features: ['Input', 'Output', 'Context'],
        num_rows: 1400
    })
})


In [None]:
dataset.push_to_hub("Rahuk/clinical_findings")

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/4 [00:00<?, ?ba/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/2 [00:00<?, ?ba/s]

README.md:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/datasets/Rahuk/clinical_findings/commit/f3a585c3ad58f0ab250c7711c7cd5a6f85f676bf', commit_message='Upload dataset', commit_description='', oid='f3a585c3ad58f0ab250c7711c7cd5a6f85f676bf', pr_url=None, pr_revision=None, pr_num=None)

In [None]:
data['test'][0]

{'Input': 'INFERIOR WALL MI (DELAYED PRESENTATION) ISCHEMIC HEART DISEASE MILD LV SYSTOLIC DYSFUNCTION CAG-DOUBLE VESSEL DISEASE (27/07/2022) PRIMARY PTCA  WITH STENTING TO  OSTIOPROXIMAL   OM (27/07/2022)',
 'Output': 'Patient was admitted with the complaints of Chest Pain since 3 days associated with sweating and palpitation. 2D ECHO shows  RWMA(Inferior and inferoposterior wall is hypokinetic) Mild LV systolic dysfunction trivial mitral regurgitation trivial  tricuspid regurgitation grade LVDD  IVC normal in size and collapsing. CAG was done which shows double vessel disease hence Primary PTCA stenting to Ostioproximal OM. Post procedure period was uneventful. Other than that he was treated with Antiplatelets statin IV Fluids and antibiotic along with other supportive measures. Patient become symptomatically better and discharged in a stable condition on the following medications.',
 'Context': 'Patient was admitted with the complaints of Chest Pain since 3 days associated with swea

In [None]:
import shutil
import os

# Define the folder you want to download
folder_path = "/content/model"

# Create a zip file containing the folder
shutil.make_archive("/content/my_folder", 'zip', folder_path)

# Download the zip file
from google.colab import files
files.download("/content/my_folder.zip")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>