In [17]:
#! pip install --upgrade --quiet bitsandbytes datasets peft transformers trl
#! pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121

In [18]:
import os
import torch
import bitsandbytes as bnb
from pathlib import Path

DATA_PATH = Path("../data")
OMI_PATH_processed = DATA_PATH / "processed" / "omi-health"
OMI_PATH_raw = DATA_PATH / "raw" / "omi-health"

print(f"PyTorch CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device count: {torch.cuda.device_count()}")
    print(f"Current CUDA device name: {torch.cuda.get_device_name(0)}")

print(f"Bitsandbytes version: {bnb.__version__}")
# You could also try a simple operation with bnb if you have a model loaded

PyTorch CUDA available: True
CUDA device count: 1
Current CUDA device name: NVIDIA RTX A4000
Bitsandbytes version: 0.45.5


In [19]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

model_id = "google/txgemma-9b-chat"

# Use 4-bit quantization to reduce memory usage
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=quantization_config,
    device_map={"":0},
    torch_dtype="auto",
    attn_implementation="eager",
)

Loading checkpoint shards: 100%|██████████| 4/4 [00:09<00:00,  2.39s/it]


In [20]:
import pandas as pd

# Load the dataset
train_df = pd.read_csv(OMI_PATH_processed / "train_v1.csv")
train_df.head()

Unnamed: 0,dialogue,soap,prompt,messages,messages_nosystem,event_tags
0,"Doctor: Hello, how can I help you today?\nPati...",S: The patient's mother reports that her 13-ye...,Create a Medical SOAP note summary from the di...,"[{'role': 'system', 'content': 'You are an exp...","[{'role': 'user', 'content': ""You are an exper...",['(After the tests)']
1,"Doctor: Hello, what brings you in today?\nPati...","S: The patient, a 21-month-old male, presented...",Create a Medical SOAP note summary from the di...,"[{'role': 'system', 'content': 'You are an exp...","[{'role': 'user', 'content': ""You are an exper...","['[After the tests]', '[After 3 weeks of thera..."
2,"Doctor: Hello, how can I help you today?\nPati...","S: Patient reports experiencing fatigue, night...",Create a Medical SOAP note summary from the di...,"[{'role': 'system', 'content': 'You are an exp...","[{'role': 'user', 'content': ""You are an exper...",[]
3,"Doctor: Hello, Patient D. How are you feeling ...","S: Patient D, a 60-year-old African American m...",Create a medical SOAP summary of this dialogue.,"[{'role': 'system', 'content': 'You are an exp...","[{'role': 'user', 'content': ""You are an exper...",[]
4,"Doctor: Hello, I see that you have a history o...","S: The patient, a married woman with a 7-year ...",Create a Medical SOAP note summary from the di...,"[{'role': 'system', 'content': 'You are an exp...","[{'role': 'user', 'content': ""You are an exper...",[]


In [21]:
train_df.iloc[0,0]

"Doctor: Hello, how can I help you today?\nPatient: My son has been having some issues with speech and development. He's 13 years old now.\nDoctor: I see. Can you tell me more about his symptoms? Does he have any issues with muscle tone or hypotonia?\nPatient: No, he doesn't have hypotonia. But he has mild to moderate speech and developmental delay, and he's been diagnosed with attention deficit disorder.\nDoctor: Thank you for sharing that information. We'll run some tests, including an MRI, to get a better understanding of your son's condition. \n(After the tests)\nDoctor: The MRI results are in, and I'm glad to say that there are no structural brain anomalies. However, I did notice some physical characteristics. Does your son have any facial features like retrognathia, mild hypertelorism, or a slightly elongated philtrum and thin upper lip?\nPatient: Yes, he has all of those features. His hands are also broad and short. And his feet have mild syndactyly of the second and third toe, 

Trying out the pretrained model

In [22]:
prompt = ("Create a Medical SOAP note summary from the dialogue, following these guidelines:\n"
          "S (Subjective): Summarize the patient's reported symptoms, including chief complaint and relevant history. Rely on the patient's statements as the primary source and ensure standardized terminology.\n    "
          "O (Objective): Highlight critical findings such as vital signs, lab results, and imaging, emphasizing important details like the side of the body affected and specific dosages. Include normal ranges where relevant.\n    "
          "A (Assessment): Offer a concise assessment combining subjective and objective data. State the primary diagnosis and any differential diagnoses, noting potential complications and the prognostic outlook.\n    "
          "P (Plan): Outline the management plan, covering medication, diet, consultations, and education. Ensure to mention necessary referrals to other specialties and address compliance challenges.\n    "

          "Considerations: Compile the report based solely on the transcript provided. Maintain confidentiality and document sensitively. Use concise medical jargon and abbreviations for effective doctor communication. Add explanations to medical terms if needed\n    Format the summary in a clean, simple list format without using markdown or bullet points. Use 'S:', 'O:', 'A:', 'P:' directly followed by the text. Avoid any styling or special characters.")

full_prompt = f"{prompt}\n\nDialogue to summarize:\n{train_df.iloc[0,0]}"
print(full_prompt)

Create a Medical SOAP note summary from the dialogue, following these guidelines:
S (Subjective): Summarize the patient's reported symptoms, including chief complaint and relevant history. Rely on the patient's statements as the primary source and ensure standardized terminology.
    O (Objective): Highlight critical findings such as vital signs, lab results, and imaging, emphasizing important details like the side of the body affected and specific dosages. Include normal ranges where relevant.
    A (Assessment): Offer a concise assessment combining subjective and objective data. State the primary diagnosis and any differential diagnoses, noting potential complications and the prognostic outlook.
    P (Plan): Outline the management plan, covering medication, diet, consultations, and education. Ensure to mention necessary referrals to other specialties and address compliance challenges.
    Considerations: Compile the report based solely on the transcript provided. Maintain confidenti

In [23]:
inputs = tokenizer(full_prompt, return_tensors="pt").to("cuda")

outputs = model.generate(**inputs, max_new_tokens=512)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

Create a Medical SOAP note summary from the dialogue, following these guidelines:
S (Subjective): Summarize the patient's reported symptoms, including chief complaint and relevant history. Rely on the patient's statements as the primary source and ensure standardized terminology.
    O (Objective): Highlight critical findings such as vital signs, lab results, and imaging, emphasizing important details like the side of the body affected and specific dosages. Include normal ranges where relevant.
    A (Assessment): Offer a concise assessment combining subjective and objective data. State the primary diagnosis and any differential diagnoses, noting potential complications and the prognostic outlook.
    P (Plan): Outline the management plan, covering medication, diet, consultations, and education. Ensure to mention necessary referrals to other specialties and address compliance challenges.
    Considerations: Compile the report based solely on the transcript provided. Maintain confidenti

In [24]:
from peft import LoraConfig
lora_config = LoraConfig(
    r=8,
    task_type="CAUSAL_LM",
    target_modules=[
        "q_proj",
        "o_proj",
        "k_proj",
        "v_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
    ],
)

In [25]:
from peft import prepare_model_for_kbit_training, get_peft_model

# Preprocess quantized model for training
model = prepare_model_for_kbit_training(model)

# Create PeftModel from quantized model and configuration
model = get_peft_model(model, lora_config)

In [26]:
INSTRUCTIONS_PROMPT = (
    "Create a Medical SOAP note summary from the dialogue, following these guidelines:\n"
    "S (Subjective): Summarize the patient's reported symptoms, including chief complaint and relevant history. Rely on the patient's statements as the primary source and ensure standardized terminology.\n    "
    "O (Objective): Highlight critical findings such as vital signs, lab results, and imaging, emphasizing important details like the side of the body affected and specific dosages. Include normal ranges where relevant.\n    "
    "A (Assessment): Offer a concise assessment combining subjective and objective data. State the primary diagnosis and any differential diagnoses, noting potential complications and the prognostic outlook.\n    "
    "P (Plan): Outline the management plan, covering medication, diet, consultations, and education. Ensure to mention necessary referrals to other specialties and address compliance challenges.\n    "
    "Considerations: Compile the report based solely on the transcript provided. Maintain confidentiality and document sensitively. Use concise medical jargon and abbreviations for effective doctor communication. Add explanations to medical terms. Pay attention to events happened using provided event tags.\n    Format the summary in a clean, simple list format without using markdown or bullet points. Use 'S:', 'O:', 'A:', 'P:' directly followed by the text. Avoid any styling or special characters."
)

def format_dialogue_for_soap_synthesis(data, dialogue_field_name="dialogue", soap_field_name="soap", event_tags_field_name="event_tags"):
    if dialogue_field_name not in data:
        raise KeyError(
            f"The field '{dialogue_field_name}' was not found in the input data. "
            f"Available keys in the data: {list(data.keys())}."
        )
    if soap_field_name not in data:
        raise KeyError(
            f"The field '{soap_field_name}' was not found in the input data. "
            f"Available keys in the data: {list(data.keys())}."
        )
    if event_tags_field_name not in data:
        raise KeyError(
            f"The field '{event_tags_field_name}' was not found in the input data. "
            f"Available keys in the data: {list(data.keys())}."
        )

    prompt = f"""Instruction: {INSTRUCTIONS_PROMPT}\ndialogue: {data[dialogue_field_name]}\nevent_tags: {data[event_tags_field_name]}\nSOAP Note:"""
    completion = f"{data[soap_field_name]}<eos>"

    return {"prompt": prompt, "completion": completion}

In [27]:
full_prompt = format_dialogue_for_soap_synthesis(train_df.iloc[0])
print(full_prompt)

{'prompt': "Instruction: Create a Medical SOAP note summary from the dialogue, following these guidelines:\nS (Subjective): Summarize the patient's reported symptoms, including chief complaint and relevant history. Rely on the patient's statements as the primary source and ensure standardized terminology.\n    O (Objective): Highlight critical findings such as vital signs, lab results, and imaging, emphasizing important details like the side of the body affected and specific dosages. Include normal ranges where relevant.\n    A (Assessment): Offer a concise assessment combining subjective and objective data. State the primary diagnosis and any differential diagnoses, noting potential complications and the prognostic outlook.\n    P (Plan): Outline the management plan, covering medication, diet, consultations, and education. Ensure to mention necessary referrals to other specialties and address compliance challenges.\n    Considerations: Compile the report based solely on the transcript

In [28]:
from peft import prepare_model_for_kbit_training, get_peft_model

# Preprocess quantized model for training
model = prepare_model_for_kbit_training(model)

# Create PeftModel from quantized model and configuration
model = get_peft_model(model, lora_config)



In [29]:
from datasets import Dataset

dataset = Dataset.from_pandas(train_df)
dataset = dataset.map(format_dialogue_for_soap_synthesis)
dataset

Map: 100%|██████████| 9250/9250 [00:00<00:00, 10104.45 examples/s]


Dataset({
    features: ['dialogue', 'soap', 'prompt', 'messages', 'messages_nosystem', 'event_tags', 'completion'],
    num_rows: 9250
})

In [30]:
def check_identical_prompt_completion(dataset, prompt_field="prompt", completion_field="completion", verbose=False):
    identical_indices = []
    for idx, example in enumerate(dataset):
        prompt = example[prompt_field]
        completion = example[completion_field]
        if prompt == completion:
            identical_indices.append(idx)
            if verbose:
                print(f"Example {idx} has identical prompt and completion.")
    return identical_indices

check_identical_prompt_completion(dataset)

[]

In [31]:
import transformers
from trl import SFTTrainer, SFTConfig

trainer = SFTTrainer(
    model=model,
    train_dataset=dataset,
    args=SFTConfig(
        per_device_train_batch_size=1,
        gradient_accumulation_steps=4,
        warmup_steps=2,
        max_steps=50,
        learning_rate=2e-4,
        fp16=True,
        logging_steps=5,
        max_seq_length=512,
        output_dir="/content/outputs",
        optim="paged_adamw_8bit",
        report_to="none"
    ),
    peft_config=lora_config,
    formatting_func=None,
)


Converting train dataset to ChatML: 100%|██████████| 9250/9250 [00:00<00:00, 10650.67 examples/s]
Adding EOS to train dataset: 100%|██████████| 9250/9250 [00:00<00:00, 15533.66 examples/s]
Tokenizing train dataset: 100%|██████████| 9250/9250 [00:32<00:00, 288.28 examples/s]
Truncating train dataset: 100%|██████████| 9250/9250 [00:00<00:00, 107632.78 examples/s]
No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


In [None]:
#trainer.train()

  return fn(*args, **kwargs)


Step,Training Loss
5,0.0
