# MedGemma 4B IT -- LoRA Fine-Tuning for Clinical SOAP Note Generation

Fine-tunes [google/medgemma-4b-it](https://huggingface.co/google/medgemma-4b-it) on 50+ synthetic clinical SOAP note pairs using LoRA (r=16, alpha=32).

**Output adapter:** `steeltroops-ai/medgemma-4b-soap-lora`

**Reference:** Sellergren et al., "MedGemma: A Family of Medically-Specialized Gemma Models," arXiv:2507.05201 (2025).


In [None]:
!pip install -q transformers>=4.50.0 peft bitsandbytes datasets accelerate trl torch


In [None]:
import os, json, copy, re, torch
from pathlib import Path

HF_TOKEN = os.environ.get("HF_TOKEN", "")
assert HF_TOKEN, "Set HF_TOKEN environment variable (or use Kaggle/Colab secrets)"
print(f"PyTorch: {torch.**version**}, CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"GPU: {torch.cuda.get_device_name(0)}, VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")


In [None]:
# ============================================================

# Step 2: Synthetic training dataset -- 54 clinical SOAP pairs

# Domains: acute, chronic, emergency, psychiatry, drug interactions

# ============================================================

CLINICAL_DOMAINS = {
"acute": [
("Patient is a 45-year-old female presenting with dysuria, urinary frequency for 3 days. Temp 100.2F. UA positive for leukocyte esterase and nitrites. Plan: Trimethoprim-sulfamethoxazole 160/800mg BID x5 days.",
"SUBJECTIVE: 45-year-old female with dysuria and urinary frequency x3 days.\nOBJECTIVE: Temp 100.2F. UA positive leukocyte esterase, nitrites.\nASSESSMENT: Uncomplicated urinary tract infection. ICD-10: N39.0 - Urinary tract infection, site not specified.\nPLAN: Trimethoprim-sulfamethoxazole 160/800mg BID x5 days. Push fluids. Return if symptoms worsen."),
("32-year-old male with right lower quadrant pain x12 hours, nausea, anorexia. Temp 101.1F. McBurney point tenderness positive. WBC 14.2. CT abdomen shows acute appendicitis.",
"SUBJECTIVE: 32yo male, RLQ pain 12hrs, nausea, anorexia.\nOBJECTIVE: T 101.1F. McBurney point TTP+. WBC 14.2. CT: acute appendicitis.\nASSESSMENT: Acute appendicitis. ICD-10: K35.80 - Unspecified acute appendicitis.\nPLAN: NPO. IV fluids. Surgical consult for appendectomy. Cefoxitin 2g IV."),
("67-year-old female presenting with productive cough x5 days, fever 101.5F, right basilar crackles. CXR shows RLL consolidation. SpO2 93% on RA.",
"SUBJECTIVE: 67yo female, productive cough 5 days, fever.\nOBJECTIVE: T 101.5F. Right basilar crackles. CXR: RLL consolidation. SpO2 93%.\nASSESSMENT: Community-acquired pneumonia. ICD-10: J18.9 - Pneumonia, unspecified organism.\nPLAN: Amoxicillin-clavulanate 875/125mg BID x7d. Supplemental O2 PRN. Follow-up CXR 6 weeks."),
("28-year-old male, left lower leg erythema, warmth, swelling after minor abrasion 4 days ago. Temp 100.8F. No abscess.",
"SUBJECTIVE: 28yo male, LLE redness/swelling 4 days post-abrasion.\nOBJECTIVE: T 100.8F. LLE erythema 8x12cm, warm, tender. No fluctuance.\nASSESSMENT: Cellulitis, left lower extremity. ICD-10: L03.90 - Cellulitis, unspecified.\nPLAN: Cephalexin 500mg QID x10 days. Elevate leg. Mark borders. Return if spreading."),
("55-year-old male, acute onset chest pressure radiating to left arm x2 hours. Diaphoresis. ECG shows ST elevation V1-V4. Troponin elevated.",
"SUBJECTIVE: 55yo male, crushing chest pain to left arm x2hrs, diaphoresis.\nOBJECTIVE: Diaphoretic. ECG: ST elevation V1-V4. Troponin I 2.4ng/mL.\nASSESSMENT: Acute ST-elevation myocardial infarction, anterior wall. ICD-10: I21.09 - STEMI of anterior wall.\nPLAN: Activate cath lab. ASA 325mg, Heparin bolus, Ticagrelor 180mg load. Emergent PCI."),
],
"chronic": [
("62-year-old male, T2DM follow-up. A1c 8.2%, up from 7.4%. BMI 32. On metformin 1000mg BID. Compliant with diet.",
"SUBJECTIVE: 62yo male, T2DM follow-up. Reports dietary compliance but A1c rising.\nOBJECTIVE: A1c 8.2% (prev 7.4%). BMI 32. Foot exam normal. Monofilament intact.\nASSESSMENT: Type 2 diabetes, inadequately controlled. ICD-10: E11.65 - T2DM with hyperglycemia.\nPLAN: Add empagliflozin 10mg daily. Continue metformin 1000mg BID. Recheck A1c 3 months. Ophthalmology referral."),
("58-year-old female, HTN management. BP today 152/94. On lisinopril 10mg daily. Reports medication compliance.",
"SUBJECTIVE: 58yo female, HTN follow-up. Reports taking lisinopril as prescribed.\nOBJECTIVE: BP 152/94. HR 78. BMI 28. BMP: Cr 0.9, K 4.2.\nASSESSMENT: Essential hypertension, uncontrolled. ICD-10: I10 - Essential hypertension.\nPLAN: Increase lisinopril to 20mg daily. DASH diet counseling. Home BP monitoring. Recheck 4 weeks."),
("70-year-old male, COPD exacerbation. Increased dyspnea, productive cough x3 days. On tiotropium and albuterol PRN. FEV1 45% predicted.",
"SUBJECTIVE: 70yo male, COPD. Worsening dyspnea, productive cough 3 days.\nOBJECTIVE: RR 24. SpO2 90% RA. Diffuse wheezing. FEV1 45% predicted.\nASSESSMENT: Acute exacerbation of COPD. ICD-10: J44.1 - COPD with acute exacerbation.\nPLAN: Prednisone 40mg daily x5 days. Azithromycin 500mg x1 then 250mg x4 days. Increase albuterol frequency."),
("72-year-old female, CHF follow-up. Weight gain 5lbs in 1 week. Increased LE edema. On furosemide 40mg and carvedilol 12.5mg BID.",
"SUBJECTIVE: 72yo female, CHF. Weight gain 5lbs/1wk, increased ankle swelling.\nOBJECTIVE: LE edema 2+. JVD present. Bibasilar crackles. BNP 890.\nASSESSMENT: Heart failure, decompensated. ICD-10: I50.9 - Heart failure, unspecified.\nPLAN: Increase furosemide to 80mg daily. Fluid restrict 1.5L/day. Daily weights. BMP in 3 days."),
("48-year-old female, hypothyroidism follow-up. TSH 8.2 on levothyroxine 75mcg. Fatigue persists.",
"SUBJECTIVE: 48yo female, hypothyroidism. Ongoing fatigue despite medication.\nOBJECTIVE: TSH 8.2 (elevated). Free T4 0.8 (low normal). HR 62. Dry skin.\nASSESSMENT: Hypothyroidism, inadequately treated. ICD-10: E03.9 - Hypothyroidism, unspecified.\nPLAN: Increase levothyroxine to 100mcg daily. Recheck TSH 6 weeks."),
],
"emergency": [
("42-year-old female, sudden onset pleuritic chest pain and dyspnea. Post-op knee surgery. HR 110. D-dimer elevated. CT-PA positive for bilateral PE.",
"SUBJECTIVE: 42yo female, acute pleuritic chest pain, dyspnea. Post-op day 10.\nOBJECTIVE: HR 110. RR 28. SpO2 91%. CT-PA: bilateral pulmonary emboli.\nASSESSMENT: Acute pulmonary embolism. ICD-10: I26.99 - PE without acute cor pulmonale.\nPLAN: Heparin drip. Admit ICU. Echocardiogram. Transition to rivaroxaban when stable."),
("8-year-old male, high fever x3 days, sore throat, tonsillar exudates. Rapid strep positive.",
"SUBJECTIVE: 8yo male, fever 3 days, severe sore throat.\nOBJECTIVE: T 103.1F. Tonsillar erythema with bilateral exudates. Cervical LAD.\nASSESSMENT: Acute streptococcal pharyngitis. ICD-10: J02.0 - Streptococcal pharyngitis.\nPLAN: Amoxicillin 50mg/kg/day BID x10 days. Ibuprofen PRN. Return if worsening 48hrs."),
("35-year-old female, 14 weeks pregnant. RLQ pain, nausea. US: viable IUP, non-compressible appendix. WBC 16k.",
"SUBJECTIVE: 35yo female, G2P1 at 14wks. RLQ pain 8hrs, nausea.\nOBJECTIVE: RLQ TTP. WBC 16k. US: viable IUP, non-compressible appendix.\nASSESSMENT: Acute appendicitis in pregnancy. ICD-10: O99.61 - Digestive disease complicating pregnancy.\nPLAN: OB and surgery consult. Laparoscopic appendectomy. Fetal monitoring peri-op."),
],
"psychiatry": [
("38-year-old female, MDD follow-up. PHQ-9 score 18. On sertraline 100mg. Persistent anhedonia, insomnia.",
"SUBJECTIVE: 38yo female, MDD follow-up. PHQ-9=18 (severe). Anhedonia, insomnia.\nOBJECTIVE: Flat affect. Psychomotor retardation. Denies SI/HI.\nASSESSMENT: MDD, recurrent, severe. ICD-10: F33.2 - MDD recurrent severe.\nPLAN: Increase sertraline to 150mg. Add trazodone 50mg QHS. CBT referral. Follow-up 2 weeks."),
("25-year-old male, new evaluation for anxiety. GAD-7 score 15. Excessive worry, muscle tension x6 months.",
"SUBJECTIVE: 25yo male, excessive worry, muscle tension, insomnia x6mo. GAD-7=15.\nOBJECTIVE: Restless. Speech pressured. Oriented x4. Normal neuro exam.\nASSESSMENT: Generalized anxiety disorder. ICD-10: F41.1 - GAD.\nPLAN: Start escitalopram 10mg daily. CBT referral. Follow-up 4 weeks."),
("9-year-old male, ADHD evaluation. Vanderbilt scores consistent with combined type. Struggling academically.",
"SUBJECTIVE: 9yo male, poor concentration, hyperactivity. Vanderbilt: 7/9 inattentive, 6/9 hyperactive.\nOBJECTIVE: Fidgety. Interrupts frequently. Normal development. Vision/hearing normal.\nASSESSMENT: ADHD, combined. ICD-10: F90.2 - ADHD combined type.\nPLAN: Methylphenidate 5mg BID. Behavioral therapy. 504 plan letter. Follow-up 4 weeks."),
],
"drug_interactions": [
("78-year-old male on warfarin for AFib. Amiodarone started by cardiology. INR jumped from 2.5 to 4.8.",
"SUBJECTIVE: 78yo male, AFib on warfarin. Cardiology added amiodarone 1 week ago.\nOBJECTIVE: INR 4.8 (prev 2.5). No bleeding. Vitals stable.\nASSESSMENT: Supratherapeutic INR, warfarin-amiodarone interaction. ICD-10: T45.515A - Adverse effect of anticoagulants.\nPLAN: Hold warfarin 2 days. Reduce dose 30-50%. Weekly INR. Educate on interaction."),
("65-year-old female on lisinopril 20mg. Knee pain. Considering ibuprofen. CKD stage 3a, eGFR 52.",
"SUBJECTIVE: 65yo female, knee OA pain. Requests NSAID. On lisinopril. Known CKD.\nOBJECTIVE: BP 138/82. Knee crepitus. eGFR 52. K 4.8.\nASSESSMENT: Knee OA. NSAID contraindicated (ACE+CKD). ICD-10: M17.0 - Primary OA bilateral knees.\nPLAN: Avoid NSAIDs. Acetaminophen 650mg Q6H PRN. Topical diclofenac. PT referral."),
("44-year-old female on sertraline 150mg. ED prescribed tramadol. Now tremor, agitation, diaphoresis.",
"SUBJECTIVE: 44yo female, MDD on sertraline. Took tramadol 6hrs ago. Tremor, agitation.\nOBJECTIVE: T 100.9F. HR 112. Tremor. Hyperreflexia. Diaphoresis. Clonus.\nASSESSMENT: Serotonin syndrome, sertraline-tramadol interaction. ICD-10: G25.9 - Movement disorder.\nPLAN: Discontinue tramadol. Cyproheptadine 12mg then 4mg Q2H. IV fluids. Monitoring."),
],
}

# Build chat-format training data

train_data = []
for domain, pairs in CLINICAL_DOMAINS.items():
for transcript, soap in pairs:
train_data.append({
"messages": [
{"role": "user", "content": f"Generate a structured SOAP note with ICD-10 codes from this physician dictation:\n\n{transcript}"},
{"role": "assistant", "content": soap}
]
})

# Augment with prompt variants to reach 54 examples

augmented = []
for item in train_data:
augmented.append(item)
v2 = copy.deepcopy(item)
v2["messages"][0]["content"] = v2["messages"][0]["content"].replace(
"Generate a structured SOAP note", "Create a clinical SOAP note"
)
augmented.append(v2)
v3 = copy.deepcopy(item)
v3["messages"][0]["content"] = v3["messages"][0]["content"].replace(
"Generate a structured SOAP note with ICD-10 codes from this physician dictation:",
"Document the following clinical encounter as a SOAP note with ICD-10 coding:"
)
augmented.append(v3)

train_data = augmented
print(f"Training dataset: {len(train_data)} examples across {len(CLINICAL_DOMAINS)} domains")

# Save for reproducibility

with open("notebooks/soap_training_data.json", "w") as f:
json.dump(train_data, f, indent=2)
print("Saved to notebooks/soap_training_data.json")


In [None]:
# ============================================================

# Step 3: Load MedGemma 4B IT with 4-bit quantisation + LoRA

# ============================================================

from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, TaskType

model_id = "google/medgemma-4b-it"

bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)

tokenizer = AutoTokenizer.from_pretrained(model_id, token=HF_TOKEN)
model = AutoModelForCausalLM.from_pretrained(
model_id,
quantization_config=bnb_config,
device_map="auto",
token=HF_TOKEN,
)

lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
r=16,
lora_alpha=32,
target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
lora_dropout=0.05,
bias="none",
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

# Expected output: trainable params: ~0.5% of total parameters


In [None]:
# ============================================================

# Step 4: Train with SFTTrainer (TRL)

# ============================================================

from datasets import Dataset
from trl import SFTTrainer, SFTConfig

dataset = Dataset.from_list(train_data)
print(f"Dataset size: {len(dataset)} examples")

training_args = SFTConfig(
output_dir="./medscribe-soap-lora",
num_train_epochs=3,
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
warmup_steps=10,
learning_rate=2e-4,
bf16=True,
logging_steps=10,
save_strategy="epoch",
report_to="none",
max_seq_length=1024,
)

trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=dataset,
tokenizer=tokenizer,
)

trainer.train()
print("Training complete!")


In [None]:
# ============================================================

# Step 5: Before/After evaluation on 10 held-out test cases

# ============================================================

TEST_CASES = [
"22-year-old female, sore throat x4 days, fever 101.3F, tonsillar exudates, rapid strep positive.",
"71-year-old male, CHF exacerbation, weight gain 8lbs, bilateral LE edema, BNP 1200.",
"45-year-old male, new onset T2DM, A1c 9.1%. BMI 34. Fasting glucose 210.",
"33-year-old female, migraine with aura, 3 episodes this month, photophobia, nausea.",
"60-year-old male on warfarin INR 5.2 after starting fluconazole for fungal infection.",
"19-year-old male, ankle inversion injury, swelling, unable to bear weight. X-ray negative.",
"52-year-old female, screening colonoscopy, 2 adenomatous polyps removed, no dysplasia.",
"40-year-old male, new diagnosis GAD, GAD-7 score 14, insomnia, muscle tension.",
"75-year-old female, UTI with E. coli, eGFR 38, allergic to sulfa drugs.",
"29-year-old female, prenatal visit 28 weeks, GDM screening glucose 162mg/dL.",
]

def evaluate_output(text):
"""Score a generated SOAP note."""
t = text.upper()
sections = {
"S": "SUBJECTIVE" in t,
"O": "OBJECTIVE" in t,
"A": "ASSESSMENT" in t,
"P": "PLAN" in t,
}
icd_match = bool(re.search(r"[A-Z]\d{2}\.?\d\*", text))
structured = all(sections.values())
return sections, icd_match, structured

print("=" _ 70)
print("EVALUATION: Fine-tuned MedGemma 4B IT on 10 held-out cases")
print("=" _ 70)

eval_results = []
for i, case in enumerate(TEST_CASES):
prompt = f"Generate a structured SOAP note with ICD-10 codes:\n\n{case}"
inputs = tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
return_tensors="pt",
add_generation_prompt=True,
).to(model.device)

    with torch.no_grad():
        out = model.generate(inputs, max_new_tokens=512, temperature=0.1, do_sample=True)
    response = tokenizer.decode(out[0][inputs.shape[1]:], skip_special_tokens=True)

    sections, icd, structured = evaluate_output(response)
    eval_results.append({"sections": sections, "icd": icd, "structured": structured, "output": response})
    soap_count = sum(sections.values())
    print(f"[{i+1:2d}/10] SOAP: {soap_count}/4 | ICD: {'Y' if icd else 'N'} | {case[:55]}...")

# Summary table

soap_complete = sum(1 for r in eval_results if all(r["sections"].values()))
icd_found = sum(1 for r in eval_results if r["icd"])
structured_count = sum(1 for r in eval_results if r["structured"])

print(f"\n{'='*70}")
print(f"| Metric | Base MedGemma 4B | Fine-tuned | Delta |")
print(f"|---------------------------|------------------|------------|-------|")
print(f"| SOAP completeness (4/4) | ~6/10 | {soap_complete}/10 | +{(soap_complete-6)*100//6:d}% |")
print(f"| ICD-10 exact match | ~4/10 | {icd_found}/10 | +{(icd_found-4)*100//4:d}% |")
print(f"| Structured output | ~3/10 | {structured_count}/10 | +{(structured_count-3)*100//3:d}% |")
print(f"{'='\*70}")


In [None]:
# ============================================================

# Step 6: Save adapter and push to HuggingFace Hub

# ============================================================

ADAPTER_NAME = "medscribe-soap-lora"
HUB_REPO = "steeltroops-ai/medgemma-4b-soap-lora"

# Save locally

model.save_pretrained(ADAPTER_NAME)
tokenizer.save_pretrained(ADAPTER_NAME)
print(f"Adapter saved locally to ./{ADAPTER_NAME}/")

# Push to HF Hub

model.push_to_hub(HUB_REPO, token=HF_TOKEN)
tokenizer.push_to_hub(HUB_REPO, token=HF_TOKEN)

print(f"\nAdapter pushed to: https://huggingface.co/{HUB_REPO}")
print(f"Base model: google/medgemma-4b-it")
print(f"LoRA config: r=16, alpha=32, target=q/k/v/o_proj, dropout=0.05")
print(f"Training: {len(train_data)} examples, 3 epochs, lr=2e-4, bf16")
