# A Dialogue-Based SNOMED Prediction Benchmark

# Prepare Data
For data, I decided to use [Synthea](https://synthea.mitre.org/downloads) 1K sample dataset. 

In [3]:
data_root = '../../data/synthea_sample_data_csv_nov2021/'

In [4]:
import pandas as pd
import os

# Load CSVs
patients_df = pd.read_csv(os.path.join(data_root, "patients.csv"))
conditions_df = pd.read_csv(os.path.join(data_root, "conditions.csv"))
observations_df = pd.read_csv(os.path.join(data_root, "observations.csv"))

# Number of patients to sample
SAMPLE_SIZE = 500

# Sample patients
sample_patient_ids = patients_df['Id'].sample(SAMPLE_SIZE, random_state=42).tolist()
sample_patients = patients_df[patients_df['Id'].isin(sample_patient_ids)]
sample_conditions = conditions_df[conditions_df['PATIENT'].isin(sample_patient_ids)]
sample_observations = observations_df[observations_df['PATIENT'].isin(sample_patient_ids)]

# Build structured context per patient
flat_records = []

for _, patient in sample_patients.iterrows():
    pid = patient['Id']
    age = 2025 - int(patient['BIRTHDATE'][:4])
    gender = patient['GENDER']
    city = patient['CITY']
    state = patient['STATE']

    # Conditions with filtering
    raw_conds = (
        sample_conditions[sample_conditions['PATIENT'] == pid]['DESCRIPTION']
        .dropna()
        .unique()
        .tolist()
    )

    # Filter valid condition rows (as DataFrame)
    irrelevant_keywords = ["employment", "education", "finding", "social", "victim", "student", "contact"]
    valid_cond_df = sample_conditions[
        (sample_conditions['PATIENT'] == pid) &
        (~sample_conditions['DESCRIPTION'].str.lower().str.contains('|'.join(irrelevant_keywords)))
    ][['DESCRIPTION', 'CODE']].dropna().drop_duplicates()

    # SNOMED gold codes
    snomed_gold = [
        f"{row['CODE']} - {row['DESCRIPTION']}" for _, row in valid_cond_df.iterrows()
    ]

    # Conditions (for oracle_context)
    filtered_conds = valid_cond_df['DESCRIPTION'].tolist()

    # Observations
    obs_df = sample_observations[sample_observations['PATIENT'] == pid]
    obs_pairs = []
    for _, row in obs_df.iterrows():
        label = row['DESCRIPTION']
        value = row['VALUE']
        if pd.notna(label) and pd.notna(value):
            obs_pairs.append(f"{label}: {value}")

    # Create flat context
    context = f"Age: {age}\nGender: {gender}\nLocation: {city}, {state}"
    if filtered_conds:
        context += "\nConditions: " + "; ".join(filtered_conds[:10])
    if obs_pairs:
        context += "\nObservations:\n- " + "\n- ".join(obs_pairs[:10])

    flat_records.append({
        "patient_id": pid,
        "oracle_context": context,
        "snomed_gold": snomed_gold[:10]
    })

# Optional: Save to JSON
import json
with open("synthea_oracle_context.json", "w") as f:
    json.dump(flat_records, f, indent=2)

# Print a few samples
for record in flat_records[:3]:
    print(json.dumps(record, indent=2))


{
  "patient_id": "339144f8-50e1-633e-a013-f361391c4cff",
  "oracle_context": "Age: 27\nGender: M\nLocation: Boston, Massachusetts\nConditions: Sprain of ankle; Hypertension; Acute viral pharyngitis (disorder); Viral sinusitis (disorder); Refugee (person)\nObservations:\n- Body Height: 166.8\n- Pain severity - 0-10 verbal numeric rating [Score] - Reported: 3.0\n- Body Weight: 70.9\n- Body Mass Index: 25.5\n- Body mass index (BMI) [Percentile] Per age and gender: 94.0\n- Diastolic Blood Pressure: 82.0\n- Systolic Blood Pressure: 129.0\n- Heart rate: 87.0\n- Respiratory rate: 15.0\n- Tobacco smoking status NHIS: Never smoker",
  "snomed_gold": [
    "44465007 - Sprain of ankle",
    "59621000 - Hypertension",
    "195662009 - Acute viral pharyngitis (disorder)",
    "444814009 - Viral sinusitis (disorder)",
    "446654005 - Refugee (person)"
  ]
}
{
  "patient_id": "d488232e-bf14-4bed-08c0-a82f34b6a197",
  "oracle_context": "Age: 22\nGender: F\nLocation: Hingham, Massachusetts\nCondition

In [None]:
import openai
import asyncio
import json
from tqdm import tqdm

client = openai.AsyncOpenAI(api_key="sk-proj-") # Replace with your actual OpenAI API key to run this script

# Load structured patients
with open("synthea_oracle_context.json", "r") as f:
    patients = json.load(f)

output_file = "oracle_qa_finetune_data.jsonl"

def build_chat_messages(context, snomed_list):
    gold = "\n".join(snomed_list)
    return [
        {
            "role": "system",
            "content": (
                "You are simulating a patient with the following medical conditions. "
                "Respond in plain language without naming diagnoses unless asked. Vary your answers. "
                "Focus on realistic symptoms, complaints, or feelings."
            )
        },
        {
            "role": "user",
            "content": f"""PATIENT INFO:
            {context}

            GOLD CONDITIONS (SNOMED):
            {gold}

            Generate 5 different question/answer pairs that reflect the patient's symptoms or experiences. 
            Do NOT use clinical language. Each answer should sound like something a regular person would say. 
            Only output a JSON array like this:
            [
                {{ "question": "...", "answer": "..." }},
            ...
            ]
            """
        }
    ]

def build_negative_qa_prompt(context, snomed_list):
    gold = "\n".join(snomed_list)
    return [
        {
            "role": "system",
            "content": "You are simulating a patient with limited medical knowledge."
        },
        {
            "role": "user",
            "content": f"""Patient Info:
            {context}

            GOLD CONDITIONS (SNOMED):
            {gold}

            TASK:
            - Imagine a doctor is asking the patient about symptoms NOT related to the conditions above.
            - Generate 2 question/answer pairs.
            - The **question** should come from the doctor, phrased in plain language.
            - The **answer** should sound like the patient is unsure or hasn’t experienced it.
            - The patient should **not explain or speculate** — just admit uncertainty or lack of awareness.

            Only output a JSON array like this:
            [
            {{ "question": "Have you noticed any changes in your memory?", "answer": "I’m not really sure. I haven’t paid much attention to that." }},
            ...
            ]
            """
        }
    ]


# Generate QA examples
# Async function to process one patient
async def process_patient(patient, sem, out_fp):
    async with sem:
        try:
            # GPT call: symptom-based QAs
            res1 = await client.chat.completions.create(
                model="gpt-3.5-turbo",
                messages=build_chat_messages(patient["oracle_context"], patient["snomed_gold"]),
                temperature=0.9,
                timeout=20,
            )
            qa_pairs = json.loads(res1.choices[0].message.content.strip())

            # GPT call: negative QAs
            res2 = await client.chat.completions.create(
                model="gpt-3.5-turbo",
                messages=build_negative_qa_prompt(patient["oracle_context"], patient["snomed_gold"]),
                temperature=0.8,
                timeout=20,
            )
            neg_pairs = json.loads(res2.choices[0].message.content.strip())
            qa_pairs.extend(neg_pairs)

            # Write to file (append mode)
            for qa in qa_pairs:
                record = {
                    "instruction": "Answer the user's question based on the patient context. Use lay language only.",
                    "snomed_gold": patient["snomed_gold"],
                    "input": f"Patient Info:\n{patient['oracle_context']}\nQuestion: {qa['question']}",
                    "output": qa["answer"]
                }
                out_fp.write(json.dumps(record) + "\n")

        except Exception as e:
            print(f"⚠️ Error for {patient['patient_id']}: {str(e)}")

# Main async runner
async def main():
    output_path = "oracle_qa_finetune_data.json"
    sem = asyncio.Semaphore(10)  # Max 10 concurrent GPT calls

    with open(output_path, "w") as out_fp:
        tasks = [process_patient(p, sem, out_fp) for p in patients]
        for f in tqdm(asyncio.as_completed(tasks), total=len(tasks)):
            await f
 
 # Run the event loop
import nest_asyncio
nest_asyncio.apply()
await main()


  2%|▏         | 11/500 [00:05<03:17,  2.48it/s]

⚠️ Error for 4c7bbf23-68f4-86e8-c61e-1a1b1e81c4ea: Expecting value: line 1 column 1 (char 0)


  7%|▋         | 34/500 [00:10<01:45,  4.43it/s]

⚠️ Error for eb2f8ba1-fd23-fdbb-dbe3-484a3d58244e: Expecting value: line 1 column 1 (char 0)


  8%|▊         | 42/500 [00:13<02:31,  3.02it/s]

⚠️ Error for 1a3c4dbc-9cf9-3967-ac48-389057e4f2aa: Expecting value: line 1 column 1 (char 0)


  9%|▉         | 44/500 [00:13<02:07,  3.58it/s]

⚠️ Error for 70c444f6-2f20-e9c3-0c1c-13fab72f9a4a: Expecting value: line 1 column 1 (char 0)


 10%|▉         | 48/500 [00:14<01:45,  4.27it/s]

⚠️ Error for 72b71d33-b242-8f86-7ce0-198b3dfd3bf1: Expecting value: line 1 column 1 (char 0)


 14%|█▍        | 69/500 [00:20<01:57,  3.68it/s]

⚠️ Error for 5b0609a7-49e8-65b4-dc32-9edc8301c646: Expecting ',' delimiter: line 6 column 89 (char 656)


 15%|█▌        | 75/500 [00:22<01:45,  4.05it/s]

⚠️ Error for 9bbc77f6-c1e1-0c48-c64e-eb5958f15faf: Expecting value: line 1 column 1 (char 0)


 16%|█▌        | 80/500 [00:23<01:44,  4.04it/s]

⚠️ Error for 93d970b1-94b9-8455-de35-0bcc1a45eb7b: Expecting value: line 1 column 1 (char 0)


 18%|█▊        | 91/500 [00:26<01:58,  3.45it/s]

⚠️ Error for 07a1f80c-f4e2-8e91-5913-270fc6afa25d: Expecting value: line 1 column 1 (char 0)


 18%|█▊        | 92/500 [00:27<02:10,  3.14it/s]

⚠️ Error for eebbb6c9-87b3-9978-5eab-84edacbccc7c: Expecting value: line 1 column 1 (char 0)


 24%|██▎       | 118/500 [00:35<01:39,  3.85it/s]

⚠️ Error for 520fdcd4-8d55-8a00-67b8-a0d61a50d9b5: Expecting value: line 1 column 1 (char 0)
⚠️ Error for c5340af3-dde7-d403-78dd-bf20a03f0238: Expecting value: line 1 column 1 (char 0)


 26%|██▌       | 131/500 [00:39<02:14,  2.74it/s]

⚠️ Error for 9724e8c9-3f30-702a-f75a-95082c024706: Expecting value: line 1 column 1 (char 0)


 28%|██▊       | 139/500 [00:41<01:42,  3.52it/s]

⚠️ Error for 1eb1b9e0-4cdf-7910-e76d-9f7534bcc7f0: Expecting value: line 1 column 1 (char 0)


 32%|███▏      | 160/500 [00:47<02:15,  2.51it/s]

⚠️ Error for 08c885d3-43c9-efd4-e8b3-c6140832b18f: Expecting value: line 1 column 1 (char 0)


 34%|███▍      | 170/500 [00:50<01:14,  4.43it/s]

⚠️ Error for b795bf03-28fb-ad1d-5a43-3e44422700a7: Expecting value: line 1 column 1 (char 0)


 37%|███▋      | 187/500 [00:55<01:59,  2.63it/s]

⚠️ Error for 3fda13ae-c868-c505-10ee-3f6d4b694b84: Expecting value: line 1 column 1 (char 0)


 40%|████      | 200/500 [01:00<02:03,  2.43it/s]

⚠️ Error for d4dd09e5-8eae-7136-407c-df2715f2643b: Expecting value: line 1 column 1 (char 0)


 42%|████▏     | 209/500 [01:03<01:35,  3.04it/s]

⚠️ Error for 9ac43dbb-44d6-871e-b411-a1c18c61b55e: Expecting value: line 1 column 1 (char 0)


 44%|████▍     | 220/500 [01:06<01:00,  4.65it/s]

⚠️ Error for db827d6a-d1c7-cd51-9b86-626ab88d58da: Expecting value: line 1 column 1 (char 0)


 47%|████▋     | 234/500 [01:10<00:56,  4.73it/s]

⚠️ Error for ef82862b-c0f8-7440-2dcb-4d7137349f0d: Expecting value: line 1 column 1 (char 0)


 48%|████▊     | 240/500 [01:12<01:31,  2.85it/s]

⚠️ Error for c987d011-9807-23da-8487-4d7ce3511a86: Expecting value: line 1 column 1 (char 0)


 50%|█████     | 252/500 [01:15<01:15,  3.27it/s]

⚠️ Error for c1c5bcd1-d232-8de3-7e0f-3ef13280455c: Expecting value: line 1 column 1 (char 0)


 51%|█████▏    | 257/500 [01:16<00:51,  4.74it/s]

⚠️ Error for 68f34b8e-efce-ce3e-836d-a815503c7b56: Expecting value: line 1 column 1 (char 0)


 52%|█████▏    | 261/500 [01:18<01:32,  2.58it/s]

⚠️ Error for 2376720a-56ca-5474-3737-32cf783a169c: Expecting value: line 1 column 1 (char 0)


 54%|█████▎    | 268/500 [01:19<01:01,  3.78it/s]

⚠️ Error for 9e84e569-7adc-ff42-ccdb-9fe9c23842a6: Expecting value: line 1 column 1 (char 0)


 56%|█████▌    | 281/500 [01:24<01:18,  2.77it/s]

⚠️ Error for 246fb368-8991-dc93-f6a5-eca807e7dbde: Expecting value: line 1 column 1 (char 0)


 61%|██████    | 304/500 [01:30<01:04,  3.02it/s]

⚠️ Error for 77e0ed77-2f51-7fe3-792d-ff1e0be7f3c2: Expecting value: line 1 column 1 (char 0)


 63%|██████▎   | 313/500 [01:32<01:04,  2.88it/s]

⚠️ Error for 91f3dfb6-9545-a0cf-7c12-64d89cab1db3: Expecting value: line 1 column 1 (char 0)


 64%|██████▍   | 321/500 [01:34<00:38,  4.61it/s]

⚠️ Error for d7bb0340-9894-8bd0-056a-29efc5444fa0: Expecting value: line 1 column 1 (char 0)


 64%|██████▍   | 322/500 [01:35<01:08,  2.58it/s]

⚠️ Error for ec56f618-56ed-cf49-78fd-4d561bf7cba3: Expecting value: line 1 column 1 (char 0)


 65%|██████▌   | 326/500 [01:36<00:59,  2.91it/s]

⚠️ Error for 70231807-0c38-db06-806e-a73e94bcd591: Expecting value: line 1 column 1 (char 0)


 67%|██████▋   | 336/500 [01:39<00:48,  3.41it/s]

⚠️ Error for d90705f0-2c2d-4126-6504-93d4da9c6f7a: Expecting value: line 1 column 1 (char 0)


 75%|███████▌  | 376/500 [01:51<00:24,  4.97it/s]

⚠️ Error for ef167059-cef0-12c4-49db-993ca3a20c01: Expecting value: line 1 column 1 (char 0)


 76%|███████▌  | 378/500 [01:52<00:50,  2.42it/s]

⚠️ Error for ebe6f5b1-c05d-6043-33d1-83ef3dc4da13: Expecting value: line 1 column 1 (char 0)


 79%|███████▊  | 393/500 [01:56<00:32,  3.28it/s]

⚠️ Error for d26086c9-4205-29b4-0e28-a77231a9724a: Expecting value: line 1 column 1 (char 0)


 79%|███████▉  | 396/500 [01:57<00:33,  3.13it/s]

⚠️ Error for faf1d0f2-c429-72a4-5be0-e542949829ab: Expecting value: line 1 column 1 (char 0)


 80%|███████▉  | 399/500 [01:58<00:36,  2.74it/s]

⚠️ Error for 2798ae24-ef3b-1906-6e41-a31e0fd833a0: Expecting value: line 1 column 1 (char 0)


 82%|████████▏ | 408/500 [02:01<00:25,  3.59it/s]

⚠️ Error for 23d4a494-443e-fe6b-d0ec-e21a762d2a90: Expecting value: line 1 column 1 (char 0)


 86%|████████▋ | 432/500 [02:08<00:18,  3.61it/s]

⚠️ Error for e61fe1f4-3daa-bbeb-649a-fdae0eeee227: Expecting value: line 1 column 1 (char 0)


 88%|████████▊ | 441/500 [02:11<00:22,  2.61it/s]

⚠️ Error for 05801c88-9f2e-536e-9068-4b7b4431d072: Expecting value: line 1 column 1 (char 0)


 93%|█████████▎| 464/500 [02:18<00:07,  4.70it/s]

⚠️ Error for 3d5f91c2-0539-37c9-0691-d2bbe28c78e3: Expecting value: line 1 column 1 (char 0)


 93%|█████████▎| 467/500 [02:18<00:05,  5.54it/s]

⚠️ Error for ce1b17d2-15ee-aa1c-bdd1-8be89ba22acf: Expecting value: line 1 column 1 (char 0)


 94%|█████████▍| 471/500 [02:20<00:09,  3.21it/s]

⚠️ Error for 123a2225-13bc-8bd7-2868-0245cb02856d: Expecting value: line 1 column 1 (char 0)


 97%|█████████▋| 485/500 [02:24<00:03,  4.17it/s]

⚠️ Error for 767e1d64-d100-a9ec-fe05-073c6857ffa8: Expecting value: line 1 column 1 (char 0)


 98%|█████████▊| 488/500 [02:25<00:03,  3.56it/s]

⚠️ Error for 16374251-beab-44f9-ac53-faece4e66bc2: Expecting value: line 1 column 1 (char 0)


100%|██████████| 500/500 [02:29<00:00,  3.34it/s]


Let's try without fine-tune.

In [7]:
from transformers import AutoTokenizer, AutoModelForCausalLM

model_id = "biomistral/biomistral-7b"
local_dir = "./biomistral-7b"

# Download and cache model/tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=local_dir)
model = AutoModelForCausalLM.from_pretrained(model_id, cache_dir=local_dir)


Test zero-shot!

In [None]:
import json
import random

# run 5 sample
for _ in range(5):
    # Load a QA sample from your dataset
    with open("oracle_qa_finetune_data.json", "r") as f:
        lines = f.readlines()
        sample = json.loads(random.choice(lines))

    # Unpack the fields
    instruction = sample["instruction"]
    input_block = sample["input"]
    true_output = sample["output"]


    # Format prompt using your real training example
    prompt = f"""### Instruction:
    You are a patient. Answer the doctor's question based on the patient context given. 
    Use lay language only. Please be concise and answer in one or two short sentences.

    ### Input:
    {input_block}

    ### Response:
    """

    # Tokenize and generate
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=50,
            do_sample=False,
            temperature=0.6
        )

    # Decode
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    model_response = response.split("### Response:")[-1].strip()

    # Show comparison
    print("\n🔍 Input Prompt:\n", prompt)
    print("\n🧠 Model Response:\n", model_response)
    print("\n✅ Expected (Fine-Tuned) Answer:\n", true_output)

Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
  attn_output = torch.nn.functional.scaled_dot_product_attention(
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.



🔍 Input Prompt:
 ### Instruction:
    You are a patient. Answer the doctor's question based on the patient context given. 
    Use lay language only. Please be concise and answer in one or two short sentences.

    ### Input:
    Patient Info:
Age: 114
Gender: F
Location: Westfield, Massachusetts
Conditions: Refugee (person); Prediabetes; Anemia (disorder); Miscarriage in first trimester; Tubal pregnancy; Rheumatoid arthritis; Chronic sinusitis (disorder); Polyp of colon; Osteoporosis (disorder); Chronic intractable migraine without aura
Observations:
- Body Height: 161.9
- Pain severity - 0-10 verbal numeric rating [Score] - Reported: 2.0
- Body Weight: 72.9
- Body Mass Index: 27.8
- Diastolic Blood Pressure: 80.0
- Systolic Blood Pressure: 112.0
- Heart rate: 75.0
- Respiratory rate: 16.0
- Glucose: 65.4
- Urea Nitrogen: 19.6
Question: How have your sinus issues been lately?

    ### Response:
    

🧠 Model Response:
 1. My sinus issues have been pretty bad lately. I've been having 

Setting `pad_token_id` to `eos_token_id`:None for open-end generation.



🔍 Input Prompt:
 ### Instruction:
    You are a patient. Answer the doctor's question based on the patient context given. 
    Use lay language only. Please be concise and answer in one or two short sentences.

    ### Input:
    Patient Info:
Age: 65
Gender: M
Location: North Eastham, Massachusetts
Conditions: Hypertension; Cardiac Arrest; History of cardiac arrest (situation); Viral sinusitis (disorder); Laceration of foot; Otitis media
Observations:
- Body Height: 186.8
- Pain severity - 0-10 verbal numeric rating [Score] - Reported: 2.0
- Body Weight: 85.2
- Body Mass Index: 24.4
- Diastolic Blood Pressure: 83.0
- Systolic Blood Pressure: 115.0
- Heart rate: 74.0
- Respiratory rate: 16.0
- Tobacco smoking status NHIS: Never smoker
- In the past year  have you been afraid of your partner or ex-partner?: No
Question: Do you have any trouble with blurry vision or seeing double?

    ### Response:
    

🧠 Model Response:
 1. No, I do not have any trouble with blurry vision or seeing d

Setting `pad_token_id` to `eos_token_id`:None for open-end generation.



🔍 Input Prompt:
 ### Instruction:
    You are a patient. Answer the doctor's question based on the patient context given. 
    Use lay language only. Please be concise and answer in one or two short sentences.

    ### Input:
    Patient Info:
Age: 13
Gender: F
Location: East Brookfield, Massachusetts
Conditions: Acute bronchitis (disorder); Streptococcal sore throat (disorder); Sinusitis (disorder); Chronic sinusitis (disorder)
Observations:
- Body Height: 51.7
- Pain severity - 0-10 verbal numeric rating [Score] - Reported: 2.0
- Body Weight: 3.7
- Weight-for-length Per age and sex: 24.0
- Head Occipital-frontal circumference: 34.8
- Diastolic Blood Pressure: 85.0
- Systolic Blood Pressure: 109.0
- Heart rate: 72.0
- Respiratory rate: 15.0
- Leukocytes [#/volume] in Blood by Automated count: 10.0
Question: How have you been feeling lately?

    ### Response:
    

🧠 Model Response:
 ive been feeling really sick and tired.

✅ Expected (Fine-Tuned) Answer:
 I've been feeling really co

Setting `pad_token_id` to `eos_token_id`:None for open-end generation.



🔍 Input Prompt:
 ### Instruction:
    You are a patient. Answer the doctor's question based on the patient context given. 
    Use lay language only. Please be concise and answer in one or two short sentences.

    ### Input:
    Patient Info:
Age: 7
Gender: F
Location: Somerville, Massachusetts
Observations:
- Body Height: 49.4
- Pain severity - 0-10 verbal numeric rating [Score] - Reported: 2.0
- Body Weight: 3.1
- Weight-for-length Per age and sex: 32.4
- Head Occipital-frontal circumference: 33.2
- Diastolic Blood Pressure: 82.0
- Systolic Blood Pressure: 117.0
- Heart rate: 67.0
- Respiratory rate: 14.0
- Leukocytes [#/volume] in Blood by Automated count: 8.9
Question: Are you feeling tired or energetic lately?

    ### Response:
    

🧠 Model Response:
 7-year-old girl feeling energetic lately.

✅ Expected (Fine-Tuned) Answer:
 I've been feeling pretty energetic lately!

🔍 Input Prompt:
 ### Instruction:
    You are a patient. Answer the doctor's question based on the patient co

Fine tune!

In [9]:
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
from datasets import load_dataset
from peft import get_peft_model, LoraConfig, TaskType, prepare_model_for_kbit_training
from transformers import Trainer, DataCollatorForLanguageModeling

model = AutoModelForCausalLM.from_pretrained(
    model_dir,
    torch_dtype=torch.float16,
    device_map=None  # ✅ Do NOT move to device yet
)

tokenizer = AutoTokenizer.from_pretrained(model_dir)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Apply LoRA
peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.CAUSAL_LM
)
model = get_peft_model(model, peft_config)
model = model.to("cuda")

# Load your dataset
data = load_dataset("json", data_files="oracle_qa_finetune_data.jsonl", split="train")

# Format prompt
def format_example(example):
    prompt = f"""### Instruction:
You are a patient. Answer the doctor's question based on the patient context given. 
Use lay language only. Please be concise and answer in one or two short sentences.

### Input:
{example['input']}

### Response:
{example['output']}"""

    return {"text": prompt}

data = data.map(format_example)
tokenized = data.map(lambda x: tokenizer(x["text"], truncation=True, padding="max_length", max_length=512), batched=True)

# Training config
training_args = TrainingArguments(
    output_dir="./biomistral_oracle_lora",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    num_train_epochs=3,
    learning_rate=2e-5,
    fp16=True,
    logging_steps=10,
    save_steps=200,
    save_total_limit=2,
    report_to="none"
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized,
    data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False)
)

trainer.train()

model.save_pretrained("biomistral_oracle_lora")
tokenizer.save_pretrained("biomistral_oracle_lora")

  0%|          | 0/2382 [00:00<?, ?it/s]

{'loss': 2.0434, 'grad_norm': 10.874444961547852, 'learning_rate': 1.9916036943744756e-05, 'epoch': 0.01}
{'loss': 1.7315, 'grad_norm': 12.161198616027832, 'learning_rate': 1.9832073887489506e-05, 'epoch': 0.03}
{'loss': 1.4156, 'grad_norm': 11.05747127532959, 'learning_rate': 1.974811083123426e-05, 'epoch': 0.04}
{'loss': 1.0928, 'grad_norm': 11.315916061401367, 'learning_rate': 1.966414777497901e-05, 'epoch': 0.05}
{'loss': 0.8984, 'grad_norm': 11.166594505310059, 'learning_rate': 1.9580184718723762e-05, 'epoch': 0.06}
{'loss': 0.7656, 'grad_norm': 8.333418846130371, 'learning_rate': 1.9496221662468516e-05, 'epoch': 0.08}
{'loss': 0.6807, 'grad_norm': 9.426526069641113, 'learning_rate': 1.9412258606213266e-05, 'epoch': 0.09}
{'loss': 0.6128, 'grad_norm': 8.430987358093262, 'learning_rate': 1.932829554995802e-05, 'epoch': 0.1}
{'loss': 0.6199, 'grad_norm': 11.042398452758789, 'learning_rate': 1.9244332493702774e-05, 'epoch': 0.11}
{'loss': 0.553, 'grad_norm': 13.021499633789062, 'lear

('biomistral_oracle_lora\\tokenizer_config.json',
 'biomistral_oracle_lora\\special_tokens_map.json',
 'biomistral_oracle_lora\\tokenizer.json')

In [8]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel

# Path to LoRA adapter and tokenizer
adapter_path = "biomistral_oracle_lora"

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(adapter_path)

# Load base model (BioMistral-7B) — assumes you have it locally or access to Hugging Face
base_model_dir = "./biomistral-7b/models--biomistral--biomistral-7b/snapshots/9a11e1ffa817c211cbb52ee1fb312dc6b61b40a5"
base_model = AutoModelForCausalLM.from_pretrained(
    base_model_dir,
    device_map="auto",
    torch_dtype="auto"
)

# Load LoRA adapter
model = PeftModel.from_pretrained(
    base_model,
    adapter_path,
    torch_dtype="auto"
)

model.eval()  # Optional: set to eval mode for inference

Some parameters are on the meta device because they were offloaded to the cpu.


PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): MistralForCausalLM(
      (model): MistralModel(
        (embed_tokens): Embedding(32000, 4096)
        (layers): ModuleList(
          (0-31): 32 x MistralDecoderLayer(
            (self_attn): MistralSdpaAttention(
              (q_proj): lora.Linear(
                (base_layer): Linear(in_features=4096, out_features=4096, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=16, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=16, out_features=4096, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (k_proj): 

In [9]:
model = model.to("cuda")
next(model.parameters()).device


device(type='cuda', index=0)

Let's evaluate with some random samples.

In [10]:
import json
import random

# run 5 sample
for _ in range(5):
    # Load a QA sample from your dataset
    with open("oracle_qa_finetune_data.json", "r") as f:
        lines = f.readlines()
        sample = json.loads(random.choice(lines))

    # Unpack the fields
    instruction = sample["instruction"]
    input_block = sample["input"]
    true_output = sample["output"]


    # Format prompt using your real training example
    prompt = f"""### Instruction:
    You are a patient. Answer the doctor's question based on the patient context given. 
    Use lay language only. Please be concise and answer in one or two short sentences.

    ### Input:
    {input_block}

    ### Response:
    """

    # Tokenize and generate
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=50,
            do_sample=False,
            temperature=0.6
        )

    # Decode
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    model_response = response.split("### Response:")[-1].strip()

    # Show comparison
    print("\n🔍 Input Prompt:\n", prompt)
    print("\n🧠 Model Response:\n", model_response)
    print("\n✅ Expected (Fine-Tuned) Answer:\n", true_output)

Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
  attn_output = torch.nn.functional.scaled_dot_product_attention(
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.



🔍 Input Prompt:
 ### Instruction:
    You are a patient. Answer the doctor's question based on the patient context given. 
    Use lay language only. Please be concise and answer in one or two short sentences.

    ### Input:
    Patient Info:
Age: 30
Gender: M
Location: Boston, Massachusetts
Conditions: Perennial allergic rhinitis with seasonal variation; Chronic sinusitis (disorder); Acute bronchitis (disorder); Viral sinusitis (disorder); Laceration of foot; Streptococcal sore throat (disorder); Suspected COVID-19; COVID-19
Observations:
- Peanut IgE Ab in Serum: 0.0
- Walnut IgE Ab in Serum: 0.3
- Codfish IgE Ab in Serum: 0.3
- Shrimp IgE Ab in Serum: 0.3
- Wheat IgE Ab in Serum: 0.0
- Egg white IgE Ab in Serum: 0.3
- Soybean IgE Ab in Serum: 0.2
- Cow milk IgE Ab in Serum: 0.3
- White oak IgE Ab in Serum: 0.1
- Common Ragweed IgE Ab in Serum: 0.1
Question: How is your foot feeling after the laceration?

    ### Response:
    

🧠 Model Response:
 It's been healing well, but it's s

Setting `pad_token_id` to `eos_token_id`:None for open-end generation.



🔍 Input Prompt:
 ### Instruction:
    You are a patient. Answer the doctor's question based on the patient context given. 
    Use lay language only. Please be concise and answer in one or two short sentences.

    ### Input:
    Patient Info:
Age: 73
Gender: F
Location: Northampton, Massachusetts
Conditions: Miscarriage in first trimester; Normal pregnancy; Viral sinusitis (disorder); Acute bronchitis (disorder); Primary fibromyalgia syndrome; Chronic obstructive bronchitis (disorder); Osteoarthritis of knee
Observations:
- Body Height: 164.3
- Pain severity - 0-10 verbal numeric rating [Score] - Reported: 2.0
- Body Weight: 75.4
- Body Mass Index: 27.9
- Diastolic Blood Pressure: 86.0
- Systolic Blood Pressure: 123.0
- Heart rate: 67.0
- Respiratory rate: 14.0
- Total Cholesterol: 185.3
- Triglycerides: 146.2
Question: Have you experienced any sudden vision changes?

    ### Response:
    

🧠 Model Response:
 I'm not sure. I haven't really paid attention to my vision.

✅ Expected (F

Setting `pad_token_id` to `eos_token_id`:None for open-end generation.



🔍 Input Prompt:
 ### Instruction:
    You are a patient. Answer the doctor's question based on the patient context given. 
    Use lay language only. Please be concise and answer in one or two short sentences.

    ### Input:
    Patient Info:
Age: 67
Gender: M
Location: Pepperell, Massachusetts
Conditions: Hypertension; Chronic sinusitis (disorder); Viral sinusitis (disorder); Coronary Heart Disease; Acute viral pharyngitis (disorder); Stroke
Observations:
- Body Height: 174.5
- Pain severity - 0-10 verbal numeric rating [Score] - Reported: 3.0
- Body Weight: 87.3
- Body Mass Index: 28.7
- Diastolic Blood Pressure: 108.0
- Systolic Blood Pressure: 170.0
- Heart rate: 98.0
- Respiratory rate: 12.0
- Tobacco smoking status NHIS: Never smoker
- In the past year  have you been afraid of your partner or ex-partner?: No
Question: Have you experienced any unusual skin rashes or itching lately?

    ### Response:
    

🧠 Model Response:
 I'm not sure. I haven't really noticed anything like t

Setting `pad_token_id` to `eos_token_id`:None for open-end generation.



🔍 Input Prompt:
 ### Instruction:
    You are a patient. Answer the doctor's question based on the patient context given. 
    Use lay language only. Please be concise and answer in one or two short sentences.

    ### Input:
    Patient Info:
Age: 58
Gender: F
Location: Holyoke, Massachusetts
Conditions: Hyperlipidemia; Viral sinusitis (disorder); Normal pregnancy; Miscarriage in first trimester; Blighted ovum; Localized  primary osteoarthritis of the hand; Fracture subluxation of wrist; Suspected lung cancer (situation); Non-small cell lung cancer (disorder); Anemia (disorder)
Observations:
- Glucose: 92.8
- Urea Nitrogen: 8.5
- Creatinine: 3.1
- Calcium: 9.3
- Sodium: 136.9
- Potassium: 4.3
- Chloride: 101.8
- Carbon Dioxide: 27.1
- Glomerular filtration rate/1.73 sq M.predicted: 78.3
- Protein [Mass/volume] in Serum or Plasma: 64.8
Question: Do you experience frequent headaches?

    ### Response:
    

🧠 Model Response:
 I'm not sure. I don't think I've had any headaches recently

It seems to behave much better.