In [28]:
import torch
import bitsandbytes as bnb
from pathlib import Path
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from datasets import Dataset

DATA_PATH = Path("../data")
OMI_PATH_processed = DATA_PATH / "processed" / "omi-health"
OMI_PATH_raw = DATA_PATH / "raw" / "omi-health"
MODEL_PATH =  Path("../models")

print(f"PyTorch CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device count: {torch.cuda.device_count()}")
    print(f"Current CUDA device name: {torch.cuda.get_device_name(0)}")

print(f"Bitsandbytes version: {bnb.__version__}")

PyTorch CUDA available: True
CUDA device count: 1
Current CUDA device name: NVIDIA RTX A4000
Bitsandbytes version: 0.45.5


In [3]:
model_id = "google/txgemma-2b-predict"

# Use 4-bit quantization to reduce memory usage
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=quantization_config,
    device_map={"":0},
    torch_dtype="auto",
    attn_implementation="eager",
)

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development
Fetching 3 files: 100%|██████████| 3/3 [01:28<00:00, 29.37s/it]
Loading checkpoint shards: 100%|██████████| 3/3 [00:05<00:00,  1.77s/it]


In [4]:
import pandas as pd

# Load the dataset
train_df = pd.read_csv(OMI_PATH_processed / "train_v1.csv")
train_df.head()

Unnamed: 0,dialogue,soap,prompt,messages,messages_nosystem,event_tags
0,"Doctor: Hello, how can I help you today?\nPati...",S: The patient's mother reports that her 13-ye...,Create a Medical SOAP note summary from the di...,"[{'role': 'system', 'content': 'You are an exp...","[{'role': 'user', 'content': ""You are an exper...",['(After the tests)']
1,"Doctor: Hello, what brings you in today?\nPati...","S: The patient, a 21-month-old male, presented...",Create a Medical SOAP note summary from the di...,"[{'role': 'system', 'content': 'You are an exp...","[{'role': 'user', 'content': ""You are an exper...","['[After the tests]', '[After 3 weeks of thera..."
2,"Doctor: Hello, how can I help you today?\nPati...","S: Patient reports experiencing fatigue, night...",Create a Medical SOAP note summary from the di...,"[{'role': 'system', 'content': 'You are an exp...","[{'role': 'user', 'content': ""You are an exper...",[]
3,"Doctor: Hello, Patient D. How are you feeling ...","S: Patient D, a 60-year-old African American m...",Create a medical SOAP summary of this dialogue.,"[{'role': 'system', 'content': 'You are an exp...","[{'role': 'user', 'content': ""You are an exper...",[]
4,"Doctor: Hello, I see that you have a history o...","S: The patient, a married woman with a 7-year ...",Create a Medical SOAP note summary from the di...,"[{'role': 'system', 'content': 'You are an exp...","[{'role': 'user', 'content': ""You are an exper...",[]


In [22]:
def format_dialogue_for_soap_synthesis_v2(data):
    dialogue = data["dialogue"]
    soap = data["completion"]
    return f"dialogue: {dialogue}<soap_start> {soap} <eos>"

In [23]:
dataset = Dataset.from_pandas(train_df)
dataset = dataset.rename_column('soap', 'completion')

dataset

Dataset({
    features: ['dialogue', 'completion', 'prompt', 'messages', 'messages_nosystem', 'event_tags'],
    num_rows: 9250
})

In [24]:
from peft import LoraConfig

lora_config = LoraConfig(
    r=8,
    task_type="CAUSAL_LM",
    target_modules=[
        "q_proj",
        "o_proj",
        "k_proj",
        "v_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
    ],
)

In [25]:
from peft import prepare_model_for_kbit_training, get_peft_model

# Preprocess quantized model for training
model = prepare_model_for_kbit_training(model)

# Create PeftModel from quantized model and configuration
model = get_peft_model(model, lora_config)



In [26]:
import transformers
from trl import SFTTrainer, SFTConfig

trainer = SFTTrainer(
    model=model,
    train_dataset=dataset,
    args=SFTConfig(
        per_device_train_batch_size=1,
        gradient_accumulation_steps=4,
        warmup_steps=20,
        max_steps=500,
        learning_rate=2e-4,
        fp16=True,
        logging_steps=5,
        max_seq_length=512,
        output_dir="/content/outputs",
        optim="paged_adamw_8bit",
        report_to="none",
    ),
    peft_config=lora_config,
    formatting_func=format_dialogue_for_soap_synthesis_v2,
)

Applying formatting function to train dataset: 100%|██████████| 9250/9250 [00:00<00:00, 12986.17 examples/s]
Converting train dataset to ChatML: 100%|██████████| 9250/9250 [00:00<00:00, 11153.36 examples/s]
Adding EOS to train dataset: 100%|██████████| 9250/9250 [00:00<00:00, 12064.56 examples/s]
Tokenizing train dataset: 100%|██████████| 9250/9250 [00:12<00:00, 736.50 examples/s]
Truncating train dataset: 100%|██████████| 9250/9250 [00:00<00:00, 117028.57 examples/s]
No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


In [27]:
trainer.train()

  return fn(*args, **kwargs)


Step,Training Loss
5,16.453
10,15.9474
15,13.2621
20,7.8662
25,5.9365
30,4.4854
35,3.885
40,3.3543
45,3.0452
50,3.0473


TrainOutput(global_step=500, training_loss=2.756281015396118, metrics={'train_runtime': 1308.8063, 'train_samples_per_second': 1.528, 'train_steps_per_second': 0.382, 'total_flos': 1.1355701297624064e+16, 'train_loss': 2.756281015396118})

In [20]:
def generate_soap_note(dialogue, model, tokenizer, device="cuda:0"):
    """Generates a SOAP note from a given dialogue."""

    # Format the input (you might need to adapt this based on your training format)
    input_text = f"{dialogue} []<eos>"  # Assuming empty event tags for simplicity

    # Tokenize the input
    inputs = tokenizer.encode_plus(
        input_text, return_tensors="pt"
    ).to(device)

    # Generate the SOAP note
    outputs = model.generate(
        inputs.input_ids,
        max_length=2048,  # Adjust as needed
        num_beams=4,
        temperature=0.7,
        early_stopping=True,
    )

    # Decode the output
    soap_note = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return soap_note

In [29]:
# Save the fine-tuned model
model.save_pretrained(MODEL_PATH)

# Save the tokenizer
tokenizer.save_pretrained(MODEL_PATH)

print(f"Model and tokenizer saved to {MODEL_PATH}")

Model and tokenizer saved to ..\models


In [30]:
sample_dialogue = train_df.iloc[0,0]
generated_note = generate_soap_note(sample_dialogue, model, tokenizer)
print("Generated SOAP Note:")
print(generated_note)

  return fn(*args, **kwargs)


Generated SOAP Note:
Doctor: Hello, how can I help you today?
Patient: My son has been having some issues with speech and development. He's 13 years old now.
Doctor: I see. Can you tell me more about his symptoms? Does he have any issues with muscle tone or hypotonia?
Patient: No, he doesn't have hypotonia. But he has mild to moderate speech and developmental delay, and he's been diagnosed with attention deficit disorder.
Doctor: Thank you for sharing that information. We'll run some tests, including an MRI, to get a better understanding of your son's condition. 
(After the tests)
Doctor: The MRI results are in, and I'm glad to say that there are no structural brain anomalies. However, I did notice some physical characteristics. Does your son have any facial features like retrognathia, mild hypertelorism, or a slightly elongated philtrum and thin upper lip?
Patient: Yes, he has all of those features. His hands are also broad and short. And his feet have mild syndactyly of the second an