In [2]:
import torch
import bitsandbytes as bnb
from pathlib import Path
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from datasets import Dataset

DATA_PATH = Path("../data")
OMI_PATH_processed = DATA_PATH / "processed" / "omi-health"
OMI_PATH_raw = DATA_PATH / "raw" / "omi-health"
MODEL_PATH =  Path("../models")

print(f"PyTorch CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device count: {torch.cuda.device_count()}")
    print(f"Current CUDA device name: {torch.cuda.get_device_name(0)}")

print(f"Bitsandbytes version: {bnb.__version__}")

  from .autonotebook import tqdm as notebook_tqdm


PyTorch CUDA available: True
CUDA device count: 1
Current CUDA device name: NVIDIA RTX A4000
Bitsandbytes version: 0.45.5


In [3]:
model_id = "google/txgemma-2b-predict"

# Use 4-bit quantization to reduce memory usage
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=quantization_config,
    device_map={"":0},
    torch_dtype="auto",
    attn_implementation="eager",
)

Loading checkpoint shards: 100%|██████████| 3/3 [00:11<00:00,  3.68s/it]


In [4]:
import pandas as pd

# Load the dataset
train_df = pd.read_csv(OMI_PATH_processed / "train_v1.csv")
train_df.head()

Unnamed: 0,dialogue,soap,prompt,messages,messages_nosystem,event_tags
0,"Doctor: Hello, how can I help you today?\nPati...",S: The patient's mother reports that her 13-ye...,Create a Medical SOAP note summary from the di...,"[{'role': 'system', 'content': 'You are an exp...","[{'role': 'user', 'content': ""You are an exper...",['(After the tests)']
1,"Doctor: Hello, what brings you in today?\nPati...","S: The patient, a 21-month-old male, presented...",Create a Medical SOAP note summary from the di...,"[{'role': 'system', 'content': 'You are an exp...","[{'role': 'user', 'content': ""You are an exper...","['[After the tests]', '[After 3 weeks of thera..."
2,"Doctor: Hello, how can I help you today?\nPati...","S: Patient reports experiencing fatigue, night...",Create a Medical SOAP note summary from the di...,"[{'role': 'system', 'content': 'You are an exp...","[{'role': 'user', 'content': ""You are an exper...",[]
3,"Doctor: Hello, Patient D. How are you feeling ...","S: Patient D, a 60-year-old African American m...",Create a medical SOAP summary of this dialogue.,"[{'role': 'system', 'content': 'You are an exp...","[{'role': 'user', 'content': ""You are an exper...",[]
4,"Doctor: Hello, I see that you have a history o...","S: The patient, a married woman with a 7-year ...",Create a Medical SOAP note summary from the di...,"[{'role': 'system', 'content': 'You are an exp...","[{'role': 'user', 'content': ""You are an exper...",[]


In [5]:
dataset = Dataset.from_pandas(train_df)
dataset = dataset.rename_column('soap', 'completion')
cols_to_drop = ['prompt', 'messages', 'messages_nosystem']
dataset = dataset.remove_columns(cols_to_drop)
dataset

Dataset({
    features: ['dialogue', 'completion', 'event_tags'],
    num_rows: 9250
})

In [6]:
from peft import LoraConfig

lora_config = LoraConfig(
    r=8,
    task_type="CAUSAL_LM",
    target_modules=[
        "q_proj",
        "o_proj",
        "k_proj",
        "v_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
    ],
)

In [7]:
from peft import prepare_model_for_kbit_training, get_peft_model

# Preprocess quantized model for training
model = prepare_model_for_kbit_training(model)

# Create PeftModel from quantized model and configuration
model = get_peft_model(model, lora_config)

In [8]:
def format_dialogue_for_soap_synthesis_v2(data):
    dialogue = data["dialogue"]
    soap = data["completion"]
    return f"dialogue: {dialogue}<soap_start> soap_note:{soap} <eos>"

In [9]:
import transformers
from trl import SFTTrainer, SFTConfig

trainer = SFTTrainer(
    model=model,
    train_dataset=dataset,
    args=SFTConfig(
        per_device_train_batch_size=1,
        gradient_accumulation_steps=4,
        warmup_steps=8,
        max_steps=100,
        learning_rate=5e-4,
        fp16=True,
        logging_steps=5,
        max_seq_length=2048,
        output_dir="/content/outputs",
        optim="paged_adamw_8bit",
        report_to="none",
    ),
    peft_config=lora_config,
    formatting_func=format_dialogue_for_soap_synthesis_v2,
)

Applying formatting function to train dataset: 100%|██████████| 9250/9250 [00:00<00:00, 14740.38 examples/s]
Converting train dataset to ChatML: 100%|██████████| 9250/9250 [00:00<00:00, 17824.02 examples/s]
Adding EOS to train dataset: 100%|██████████| 9250/9250 [00:00<00:00, 13149.23 examples/s]
Tokenizing train dataset: 100%|██████████| 9250/9250 [00:18<00:00, 493.59 examples/s]
Truncating train dataset: 100%|██████████| 9250/9250 [00:00<00:00, 92027.26 examples/s]
No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


In [10]:
trainer.train()

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
  return fn(*args, **kwargs)


Step,Training Loss
5,14.1888
10,6.749
15,3.6012
20,2.6644
25,2.3311
30,2.0706
35,1.8829
40,1.8499
45,1.8113
50,1.8314


TrainOutput(global_step=100, training_loss=2.78219181060791, metrics={'train_runtime': 549.8728, 'train_samples_per_second': 0.727, 'train_steps_per_second': 0.182, 'total_flos': 4708887651749376.0, 'train_loss': 2.78219181060791})

In [15]:
# Save the fine-tuned model
model.save_pretrained(MODEL_PATH)

# Save the tokenizer
tokenizer.save_pretrained(MODEL_PATH)

print(f"Model and tokenizer saved to {MODEL_PATH}")

Model and tokenizer saved to ..\models


In [32]:
def generate_soap_note(dialogue, model, tokenizer, device="cuda:0"):
    """Generates a SOAP note from a given dialogue."""

    # MODIFIED: Align with the training prompt structure to guide generation
    input_text = f"dialogue: {dialogue}<soap_start> soap_note:"

    # Tokenize the input
    inputs = tokenizer.encode_plus(
        input_text,
        return_tensors="pt",
        # Ensure tokenizer doesn't add EOS token here if the model adds it during generation
        add_special_tokens=True # Or False, depending on tokenizer and model behavior with this specific prompt
    ).to(device)

    # Generate the SOAP note
    outputs = model.generate(
        inputs.input_ids,
        # MODIFIED: Use max_new_tokens to control the length of the *generated* text
        max_new_tokens=512,  # Adjust as needed for typical SOAP note length
        num_beams=4,
        temperature=0.7,
        # IMPORTANT: Add pad_token_id if not set in model config, especially for open-ended generation
        pad_token_id=tokenizer.eos_token_id, # Or tokenizer.pad_token_id if different and model expects it
        # Consider adding an early stopping criterion or specific stop sequences if needed
        # eos_token_id=tokenizer.eos_token_id # Ensure model stops at <eos>
    )

    # Decode the output
    # MODIFIED: Slice the output to remove the input prompt text
    # The generated tokens start *after* the input_ids length
    generated_ids = outputs[0][inputs.input_ids.shape[1]:]
    soap_note = tokenizer.decode(generated_ids, skip_special_tokens=True)

    return soap_note

In [33]:
sample_dialogue = train_df.iloc[0,0]
generated_note = generate_soap_note(sample_dialogue, model, tokenizer)
print("Generated SOAP Note:")
print(generated_note)

The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
  return fn(*args, **kwargs)


Generated SOAP Note:
S: The patient reports that their son has been experiencing mild to moderate speech and developmental delay, diagnosed with attention deficit disorder at age 13. The patient also notes that the son has mild to moderate hypotonia.
O: MRI results indicate no structural brain anomalies. Physical examination revealed facial features like retrognathia, mild hypertelorism, a thin upper lip, and feet with mild syndactyly of the second and third toe with a sandal gap in both feet. Genetic analysis revealed a de novo frameshift variant in Chr1 (GRCH3:37]) located more than 400 codons upstream of the canonical termination codon, leading to a premature termination codon.
A: The primary diagnosis is a de novo frameshift variant in Chr1 (GRCH3:37]) located more than 400 codons upstream of the canonical termination codon, leading to a premature termination codon. This variant may contribute to the son's speech, developmental delay, and attention deficit disorder.
P: The manageme

In [35]:
train_df.iloc[0,1]

"S: The patient's mother reports that her 13-year-old son has mild to moderate speech and developmental delays and has been diagnosed with attention deficit disorder. She denies any issues with muscle tone or hypotonia. The patient also exhibits certain physical characteristics, including retrognathia, mild hypertelorism, an elongated philtrum, thin upper lip, broad and short hands, mild syndactyly of the second and third toes, and a sandal gap in both feet.\nO: An MRI of the brain showed no structural anomalies. Whole Exome Sequencing (WES) revealed a de novo frameshift variant Chr1(GRCh37):g.244217335del, NM_205768.2(ZBTB18):c.259del(p.(Leu87Cysfs*21)), indicating a premature termination codon located more than 400 codons upstream of the canonical termination codon.\nA: The primary diagnosis is a genetic disorder associated with the identified frameshift mutation, which likely contributes to the patient's speech and developmental delays and attention deficit disorder. The physical ch