In [10]:
from pathlib import Path
import pandas as pd

DATA_PATH = Path("../data")
OMI_PATH_processed = DATA_PATH / "processed" / "omi-health"
OMI_PATH_predicted = DATA_PATH / "predicted" / "omi-health"
OMI_PATH_raw = DATA_PATH / "raw" / "omi-health"
MODEL_PATH =  Path("../models")

# Load the test dataset

In [4]:
test_df = pd.read_csv(OMI_PATH_processed / "test_v1.csv")
test_df

Unnamed: 0,dialogue,soap,prompt,messages,messages_nosystem,event_tags
0,"Doctor: Hello, can you please tell me about yo...","S: The patient, a flooring installer with no s...",Create a medical SOAP summary of this dialogue.,"[{'role': 'system', 'content': 'You are an exp...","[{'role': 'user', 'content': ""You are an exper...",[]
1,"Doctor: Hello, I understand that you're a 7-ye...",S: The patient is a 7-year-old boy with congen...,Create a medical SOAP summary of this dialogue.,"[{'role': 'system', 'content': 'You are an exp...","[{'role': 'user', 'content': ""You are an exper...",[]
2,"Doctor: Hello, we've received your results fro...",S: The patient reported undergoing an ultrasou...,Create a medical SOAP summary of this dialogue.,"[{'role': 'system', 'content': 'You are an exp...","[{'role': 'user', 'content': ""You are an exper...",[]
3,"Doctor: Hello, can you tell me what brought yo...","S: The patient reports a progressive headache,...",Create a medical SOAP summary of this dialogue.,"[{'role': 'system', 'content': 'You are an exp...","[{'role': 'user', 'content': ""You are an exper...",[]
4,"Doctor: Hello, I understand that you have been...","S: The patient, a post-liver transplant recipi...",Create a medical SOAP summary of this dialogue.,"[{'role': 'system', 'content': 'You are an exp...","[{'role': 'user', 'content': ""You are an exper...",[]
...,...,...,...,...,...,...
245,"Doctor: Hello, how can I help you today?\nPati...","S: The patient reports experiencing ataxia, tr...",Create a medical SOAP summary of this dialogue.,"[{'role': 'system', 'content': 'You are an exp...","[{'role': 'user', 'content': ""You are an exper...",[]
246,"Doctor: Hello, I'm Dr. Smith. How can I help y...","S: Patient reports abdominal pain for 2 weeks,...",Create a medical SOAP summary of this dialogue.,"[{'role': 'system', 'content': 'You are an exp...","[{'role': 'user', 'content': ""You are an exper...",[]
247,"Doctor: Hi there, I see that you've presented ...","S: The patient, a 10-year post-diagnosis breas...",Create a medical SOAP summary of this dialogue.,"[{'role': 'system', 'content': 'You are an exp...","[{'role': 'user', 'content': ""You are an exper...",[]
248,"Doctor: Hello, I understand that you were diag...","S: The patient, previously diagnosed with infe...",Create a medical SOAP summary of this dialogue.,"[{'role': 'system', 'content': 'You are an exp...","[{'role': 'user', 'content': ""You are an exper...",[]


# Load fine-tuned model

In [5]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

# For 4-bit quantization
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

model_path = MODEL_PATH / "SOAPgemma_v1"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    quantization_config=quantization_config,
    device_map={"":0},
    torch_dtype="auto"
)

Loading checkpoint shards: 100%|██████████| 3/3 [00:19<00:00,  6.34s/it]


# Testing on 5 points

In [6]:
from src.inference import generate_soap_note
from tqdm import tqdm

device = "cuda:0"
model.to(device)

# Apply your function to the dataset
tqdm.pandas()  # enables progress bar with apply
# create a copy of 5 datapoints
test_df_5 = test_df.head(5).copy()
test_df_5["generated_soap_note"] = test_df_5["dialogue"].progress_apply(
    lambda x: generate_soap_note(x, model, tokenizer, device)
)


  0%|          | 0/5 [00:00<?, ?it/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
 40%|████      | 2/5 [00:46<01:09, 23.15s/it]

SOAP note generation complete.


 60%|██████    | 3/5 [01:12<00:49, 24.50s/it]

SOAP note generation complete.


 80%|████████  | 4/5 [02:04<00:34, 34.83s/it]

SOAP note generation complete.


100%|██████████| 5/5 [03:00<00:00, 42.19s/it]

SOAP note generation complete.


100%|██████████| 5/5 [03:48<00:00, 45.73s/it]

SOAP note generation complete.





In [7]:
test_df_5.to_csv("test_with_generated_soap_notes.csv", index=False)

In [8]:
from src.inference_batch import generate_batch

batch_size = 4
results = []

for i in tqdm(range(0, len(test_df), batch_size)):
    batch_dialogues = test_df['dialogue'][i:i+batch_size].tolist()
    batch_results = generate_batch(batch_dialogues, model, tokenizer)
    results.extend(batch_results)

test_df["generated_soap_note"] = results


100%|██████████| 63/63 [1:51:49<00:00, 106.50s/it]


In [12]:
test_df.to_csv(OMI_PATH_predicted / 'test_v1.csv', index=False)