# MedGemma LoRA Fine-Tuning — Patient Visit Summaries

Fine-tune MedGemma 4B on patient-friendly visit summaries using Gemini Flash as teacher.

**Runtime:** T4 GPU on Google Colab (~2-3 hours for full training)

**What this does:**
1. Generate synthetic training data from MTSamples using Gemini Flash
2. QLoRA fine-tune MedGemma 4B to produce patient summaries with a short prompt
3. Export adapter weights (~28MB) to use locally with MLX

In [1]:
# Install dependencies
!pip install -q \
    transformers>=4.46.0 \
    peft>=0.13.0 \
    trl>=0.12.0 \
    bitsandbytes>=0.44.0 \
    datasets \
    accelerate \
    google-genai \
    huggingface_hub \
    nest_asyncio \
    flash-attn --no-build-isolation

In [2]:
# Configuration — auto-detects Kaggle vs Colab
import os
import torch
from pathlib import Path

ON_KAGGLE = os.path.exists('/kaggle')

# --- Load secrets ---
if ON_KAGGLE:
    from kaggle_secrets import UserSecretsClient
    secrets = UserSecretsClient()
    HF_TOKEN = secrets.get_secret('HF_TOKEN')
    GEMINI_API_KEY = secrets.get_secret('GEMINI_API_KEY')
    WORK_DIR = '/kaggle/working'
    print('Running on Kaggle')
else:
    try:
        from google.colab import userdata
        HF_TOKEN = userdata.get('HF_TOKEN')
        GEMINI_API_KEY = userdata.get('GEMINI_API_KEY')
    except Exception:
        HF_TOKEN = ''
        GEMINI_API_KEY = ''
    WORK_DIR = '/content'
    print('Running on Colab')

os.environ['HF_TOKEN'] = HF_TOKEN
os.environ['GEMINI_API_KEY'] = GEMINI_API_KEY
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

# --- Auto-detect GPU and set optimal config ---
gpu_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'cpu'
vram_gb = torch.cuda.get_device_properties(0).total_mem / 1e9 if torch.cuda.is_available() else 0

if 'A100' in gpu_name:
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    N_TRAIN_SAMPLES = 1500
    BATCH_SIZE = 4
    GRAD_ACCUM_STEPS = 2
    MAX_SEQ_LENGTH = 2048
    USE_FLASH_ATTN = True
    print(f'A100 detected ({vram_gb:.0f}GB) — full speed config')
elif 'L4' in gpu_name or 'A10' in gpu_name or 'P100' in gpu_name:
    N_TRAIN_SAMPLES = 1500
    BATCH_SIZE = 2
    GRAD_ACCUM_STEPS = 4
    MAX_SEQ_LENGTH = 2048
    USE_FLASH_ATTN = 'P100' not in gpu_name  # P100 doesn't support flash attn
    print(f'{gpu_name} detected ({vram_gb:.0f}GB) — mid-tier config')
else:
    N_TRAIN_SAMPLES = 200
    BATCH_SIZE = 1
    GRAD_ACCUM_STEPS = 8
    MAX_SEQ_LENGTH = 1536
    USE_FLASH_ATTN = False
    print(f'{gpu_name} detected ({vram_gb:.0f}GB) — conservative config')

# Shared config
LORA_RANK = 8
LORA_ALPHA = 16
NUM_TRAIN_EPOCHS = 3
LEARNING_RATE = 5e-5

MODEL_ID = 'google/medgemma-4b-it'
OUTPUT_DIR = f'{WORK_DIR}/medgemma-lora-patient-summary'
DATA_DIR = f'{WORK_DIR}/training_data'

print(f'batch={BATCH_SIZE} x grad_accum={GRAD_ACCUM_STEPS} = effective {BATCH_SIZE * GRAD_ACCUM_STEPS}, '
      f'max_len={MAX_SEQ_LENGTH}, samples={N_TRAIN_SAMPLES}, flash_attn={USE_FLASH_ATTN}')

BackendError: Unexpected response from the service. Response: {'errors': ['No user secrets exist for kernel id 109225408 and label HF_TOKEN.'], 'error': {'code': 5}, 'wasSuccessful': False}.

## Step 1: Generate Training Data with Gemini Flash

**Skip this section** if you already generated data in a previous session — run the cell below to load from Google Drive instead.

In [None]:
# Load existing training data (Google Drive on Colab, skip on Kaggle)
import shutil

if ON_KAGGLE:
    print('On Kaggle — skipping Drive load, will generate fresh data below')
else:
    from google.colab import drive
    DRIVE_DATA_DIR = '/content/drive/MyDrive/medasr-mlx/training_data'
    drive.mount('/content/drive', force_remount=False)

    if Path(f'{DRIVE_DATA_DIR}/train.jsonl').exists():
        Path(DATA_DIR).mkdir(parents=True, exist_ok=True)
        for split in ['train', 'valid', 'test']:
            src = f'{DRIVE_DATA_DIR}/{split}.jsonl'
            dst = f'{DATA_DIR}/{split}.jsonl'
            if Path(src).exists():
                shutil.copy2(src, dst)
        n_train = sum(1 for _ in open(f'{DATA_DIR}/train.jsonl'))
        n_valid = sum(1 for _ in open(f'{DATA_DIR}/valid.jsonl'))
        print(f'Loaded from Drive: {n_train} train, {n_valid} valid')
        print('>>> Skip the generation cells below and go straight to Step 2')
    else:
        print(f'No data found at {DRIVE_DATA_DIR} — run generation cells below')

In [None]:
import json
import random
import time
from pathlib import Path
from datasets import load_dataset
from google import genai

random.seed(42)

# Teacher prompt for Gemini Flash
TEACHER_PROMPT = """You are a helpful medical communication assistant. A patient has just \
recorded their doctor visit. You will receive the transcript and produce \
a clear, friendly summary that helps the patient understand what happened.

Rules:
- Write at a 6th-grade reading level
- Explain every medical term in parentheses the first time it appears
- Use "your doctor" instead of doctor names
- Organize into these exact sections:
  1. VISIT SUMMARY — What happened in 2-3 sentences
  2. KEY FINDINGS — What the doctor found, explained simply
  3. DIAGNOSIS — What the doctor thinks is going on
  4. ACTION ITEMS — What I need to do next (medications, follow-ups, tests, etc.)
  5. QUESTIONS TO ASK — Things I might want to clarify at my next visit
- Include ALL action items (medications, follow-ups, tests, lifestyle changes)
- If something is unclear from the transcript, say so honestly
- Do NOT invent information not in the transcript
- Be warm and reassuring but accurate

Here is the transcript:

---
{transcript}
---

Please write the patient-friendly summary now."""

SHORT_USER_PROMPT = "Summarize this doctor visit for me in plain language:\n\n{transcript}"

TARGET_SPECIALTIES = [
    ' Cardiovascular / Pulmonary', ' Orthopedic', ' Neurology',
    ' Gastroenterology', ' General Medicine', ' Radiology',
    ' Emergency Room Reports', ' Consult - History and Phy.',
]

# Load and filter MTSamples
print('Loading MTSamples...')
ds = load_dataset('harishnair04/mtsamples', split='train')
filtered = [
    row for row in ds
    if row['medical_specialty'] in TARGET_SPECIALTIES
    and row.get('transcription')
    and len(row['transcription'].strip()) > 200
]
samples = random.sample(filtered, min(N_TRAIN_SAMPLES, len(filtered)))
print(f'{len(samples)} samples selected from {len(filtered)} available')

In [None]:
# Generate summaries with Gemini Flash — 10 concurrent requests
import asyncio
import nest_asyncio
nest_asyncio.apply()  # needed for Kaggle (Colab has native await support)

client = genai.Client(api_key=GEMINI_API_KEY)
CONCURRENCY = 10

async def generate_one(idx, row):
    """Generate a single summary. Returns (idx, pair_dict) or (idx, None) on failure."""
    transcript = row['transcription'].strip()
    if len(transcript) > 4000:
        transcript = transcript[:4000] + '...'
    specialty = row['medical_specialty'].strip()

    try:
        t0 = time.time()
        response = await client.aio.models.generate_content(
            model='gemini-2.0-flash',
            contents=TEACHER_PROMPT.format(transcript=transcript),
            config={'temperature': 0.2, 'max_output_tokens': 4096},
        )
        summary = response.text
        elapsed = time.time() - t0
    except Exception as e:
        print(f'  [{idx+1}] {specialty} ERROR: {e}')
        return idx, None

    if len(summary) < 100 or len(summary) > 5000:
        print(f'  [{idx+1}] {specialty} SKIP ({len(summary)} chars)')
        return idx, None

    print(f'  [{idx+1}] {specialty} {elapsed:.1f}s, {len(summary)} chars')
    return idx, {
        'messages': [
            {'role': 'user', 'content': SHORT_USER_PROMPT.format(transcript=transcript)},
            {'role': 'assistant', 'content': summary},
        ]
    }

async def generate_all(samples):
    sem = asyncio.Semaphore(CONCURRENCY)

    async def bounded(idx, row):
        async with sem:
            return await generate_one(idx, row)

    tasks = [bounded(i, row) for i, row in enumerate(samples)]
    return await asyncio.gather(*tasks)

t_start = time.time()
print(f'Generating {len(samples)} summaries ({CONCURRENCY} concurrent)...')
results = asyncio.get_event_loop().run_until_complete(generate_all(samples))

train_pairs = [pair for _, pair in results if pair is not None]
errors = sum(1 for _, pair in results if pair is None)
total_elapsed = time.time() - t_start

print(f'\nDone: {len(train_pairs)} pairs from {len(samples)} samples '
      f'({errors} errors) in {total_elapsed:.0f}s '
      f'({total_elapsed/len(samples):.1f}s/sample avg)')

In [None]:
# Save as train/valid/test JSONL splits (local + Google Drive backup on Colab)
Path(DATA_DIR).mkdir(parents=True, exist_ok=True)

random.shuffle(train_pairs)
n = len(train_pairs)

# Guarantee at least 1 sample per split
n_test = max(1, int(n * 0.1))
n_valid = max(1, int(n * 0.1))
n_train = n - n_valid - n_test

if n_train < 1:
    raise ValueError(f'Only {n} training pairs generated — need at least 3. Check Gemini errors above.')

splits = {
    'train': train_pairs[:n_train],
    'valid': train_pairs[n_train:n_train + n_valid],
    'test': train_pairs[n_train + n_valid:],
}

for name, data in splits.items():
    path = f'{DATA_DIR}/{name}.jsonl'
    with open(path, 'w') as f:
        for item in data:
            f.write(json.dumps(item) + '\n')
    print(f'{name}: {len(data)} samples -> {path}')

# Backup to Google Drive (Colab only)
if not ON_KAGGLE:
    DRIVE_DATA_DIR = '/content/drive/MyDrive/medasr-mlx/training_data'
    try:
        from google.colab import drive
        drive.mount('/content/drive', force_remount=False)
        Path(DRIVE_DATA_DIR).mkdir(parents=True, exist_ok=True)
        for name in splits:
            shutil.copy2(f'{DATA_DIR}/{name}.jsonl', f'{DRIVE_DATA_DIR}/{name}.jsonl')
        print(f'\nBacked up to Google Drive: {DRIVE_DATA_DIR}')
    except Exception as e:
        print(f'\nDrive backup skipped: {e}')
else:
    print(f'\nData saved to {DATA_DIR} (Kaggle working dir — persists in output)')

## Step 2: Load MedGemma with 4-bit Quantization

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type='nf4',
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

model_kwargs = dict(
    quantization_config=bnb_config,
    device_map='auto',
    token=HF_TOKEN,
    torch_dtype=torch.bfloat16,
)
if USE_FLASH_ATTN:
    model_kwargs['attn_implementation'] = 'flash_attention_2'

print(f'Loading {MODEL_ID}' + (' with Flash Attention 2' if USE_FLASH_ATTN else '') + '...')
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN)
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **model_kwargs)
model.config.use_cache = False

vram_before = torch.cuda.memory_allocated() / 1e9
print(f'Full VLM loaded. GPU memory: {vram_before:.1f} GB')

# --- Strip vision encoder (SigLIP) and multi-modal projector ---
vision_attrs = ['vision_tower', 'multi_modal_projector']
stripped = []
for attr in vision_attrs:
    if hasattr(model, attr):
        delattr(model, attr)
        stripped.append(attr)
    if hasattr(model, 'model') and hasattr(model.model, attr):
        delattr(model.model, attr)
        stripped.append(f'model.{attr}')

import gc
gc.collect()
torch.cuda.empty_cache()

vram_after = torch.cuda.memory_allocated() / 1e9
print(f'Stripped {stripped} — freed {vram_before - vram_after:.1f} GB')
print(f'Text-only model. GPU memory: {vram_after:.1f} GB')

In [None]:
# Quick sanity check — generate with full prompt before training
test_input = tokenizer.apply_chat_template(
    [{'role': 'user', 'content': 'What is costochondritis? Explain in simple terms.'}],
    tokenize=False, add_generation_prompt=True,
)
inputs = tokenizer(test_input, return_tensors='pt').to(model.device)
with torch.no_grad():
    out = model.generate(**inputs, max_new_tokens=100, temperature=0.2, do_sample=True)
print(tokenizer.decode(out[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True))

## Step 3: Configure LoRA and Train

In [None]:
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

model = prepare_model_for_kbit_training(model)

lora_config = LoraConfig(
    r=LORA_RANK,
    lora_alpha=LORA_ALPHA,
    lora_dropout=0.05,
    target_modules=['q_proj', 'v_proj', 'k_proj', 'o_proj',
                    'gate_proj', 'up_proj', 'down_proj'],
    bias='none',
    task_type='CAUSAL_LM',
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

In [None]:
import datasets
from trl import SFTConfig, SFTTrainer
import torch
import functools
import json

# --- Diagnostics: check what's actually in the files ---
for split_name in ['train', 'valid', 'test']:
    fpath = f'{DATA_DIR}/{split_name}.jsonl'
    try:
        with open(fpath) as f:
            lines = [l for l in f if l.strip()]
        print(f'{split_name}.jsonl: {len(lines)} lines')
    except FileNotFoundError:
        print(f'{split_name}.jsonl: FILE NOT FOUND')

# --- Clear datasets cache to avoid stale results ---
import shutil
cache_dir = os.path.expanduser('~/.cache/huggingface/datasets')
json_caches = [d for d in os.listdir(cache_dir) if d.startswith('json')] if os.path.exists(cache_dir) else []
for d in json_caches:
    shutil.rmtree(os.path.join(cache_dir, d), ignore_errors=True)
if json_caches:
    print(f'Cleared {len(json_caches)} cached json datasets')

# --- Load train data ---
train_path = f'{DATA_DIR}/train.jsonl'
with open(train_path) as f:
    train_records = [json.loads(l) for l in f if l.strip()]

if len(train_records) == 0:
    raise ValueError(f'{train_path} is empty! Re-run data generation (cells 5-7) or check your Drive backup.')

train_ds = datasets.Dataset.from_list(train_records)

# --- Load or create valid data ---
valid_path = f'{DATA_DIR}/valid.jsonl'
try:
    with open(valid_path) as f:
        valid_records = [json.loads(l) for l in f if l.strip()]
except FileNotFoundError:
    valid_records = []

if len(valid_records) == 0:
    print('WARNING: valid set is empty — splitting 10% off train')
    split = train_ds.train_test_split(test_size=0.1, seed=42)
    train_ds = split['train']
    valid_ds = split['test']
else:
    valid_ds = datasets.Dataset.from_list(valid_records)

print(f'Train: {len(train_ds)}, Valid: {len(valid_ds)}')

# Patch model forward to inject token_type_ids (all 0 = text tokens)
# Gemma 3 VLM requires this during training, even with vision stripped.
_original_forward = model.forward.__wrapped__ if hasattr(model.forward, '__wrapped__') else model.forward

@functools.wraps(_original_forward)
def _forward_with_token_type_ids(*args, **kwargs):
    if 'token_type_ids' not in kwargs or kwargs['token_type_ids'] is None:
        input_ids = kwargs.get('input_ids')
        if input_ids is None and len(args) > 0:
            input_ids = args[0]
        if input_ids is not None:
            kwargs['token_type_ids'] = torch.zeros_like(input_ids)
    return _original_forward(*args, **kwargs)

model.forward = _forward_with_token_type_ids
print('Patched model.forward to inject token_type_ids')

# Training config
training_args = SFTConfig(
    output_dir=OUTPUT_DIR,
    num_train_epochs=NUM_TRAIN_EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRAD_ACCUM_STEPS,
    learning_rate=LEARNING_RATE,
    lr_scheduler_type='cosine',
    warmup_ratio=0.05,
    weight_decay=0.01,
    logging_steps=10,
    eval_strategy='steps',
    eval_steps=50,
    save_strategy='steps',
    save_steps=100,
    save_total_limit=2,
    bf16=True,
    max_length=MAX_SEQ_LENGTH,
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={'use_reentrant': False},
    optim='adamw_torch_fused',
    report_to='none',
    seed=42,
)

trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=valid_ds,
    processing_class=tokenizer,
)

In [None]:
# Train!
print('Starting training...')
trainer.train()
print(f'\nTraining complete. GPU memory: {torch.cuda.memory_allocated() / 1e9:.1f} GB')

In [None]:
# Save the LoRA adapter
trainer.save_model(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)

# Check adapter size
import os
total = sum(os.path.getsize(os.path.join(OUTPUT_DIR, f)) for f in os.listdir(OUTPUT_DIR))
print(f'Adapter saved to {OUTPUT_DIR} ({total / 1e6:.1f} MB)')

## Step 4: Test the Fine-Tuned Model

In [None]:
# Compare base vs fine-tuned on a held-out transcript
test_transcript = """SUBJECTIVE: The patient is a 45-year-old male who presents with chest pain \
that started two days ago. He describes it as a sharp, intermittent pain in \
the left side of his chest. The pain worsens with deep breathing and movement. \
He denies shortness of breath, palpitations, or radiation to the arm or jaw. \
No fever, cough, or recent illness. He has a history of hypertension controlled \
with lisinopril 10mg daily. No family history of heart disease.

OBJECTIVE: Blood pressure 138/88, heart rate 76, respiratory rate 16, \
temperature 98.6. Chest wall is tender to palpation over the left costochondral \
junction. Heart sounds regular, no murmurs. Lungs clear bilaterally. \
ECG shows normal sinus rhythm, no ST changes.

ASSESSMENT: Costochondritis, likely musculoskeletal etiology. \
Hypertension, stable on current medication.

PLAN: Ibuprofen 400mg three times daily with food for 7 days. \
Apply ice to affected area. Avoid heavy lifting for one week. \
Follow up in 2 weeks if symptoms persist. Continue lisinopril."""

prompt = tokenizer.apply_chat_template(
    [{'role': 'user', 'content': f'Summarize this doctor visit for me in plain language:\n\n{test_transcript}'}],
    tokenize=False, add_generation_prompt=True,
)
inputs = tokenizer(prompt, return_tensors='pt').to(model.device)

# --- Base model (LoRA disabled) ---
with model.disable_adapter():
    with torch.no_grad():
        out_base = model.generate(**inputs, max_new_tokens=512, temperature=0.2, do_sample=True)
base_response = tokenizer.decode(out_base[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)

# --- Fine-tuned model (LoRA enabled) ---
with torch.no_grad():
    out_ft = model.generate(**inputs, max_new_tokens=512, temperature=0.2, do_sample=True)
ft_response = tokenizer.decode(out_ft[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)

# --- Print side by side ---
print('=' * 80)
print('BASE MODEL (no LoRA)')
print('=' * 80)
print(base_response)
print()
print('=' * 80)
print('FINE-TUNED MODEL (LoRA)')
print('=' * 80)
print(ft_response)

## Step 5: Export for Local MLX Use

Two options to use the adapter locally:

**Option A:** Convert PEFT adapter to MLX format (recommended)

**Option B:** Download PEFT adapter and convert locally

In [None]:
# Option A: Convert to MLX-compatible safetensors right here
# This creates adapter weights that mlx-lm can load

from safetensors.torch import load_file, save_file
import re

adapter_weights = load_file(f'{OUTPUT_DIR}/adapter_model.safetensors')

# Remap PEFT keys to mlx-lm format:
# base_model.model.model.layers.X.self_attn.q_proj.lora_A.weight -> model.layers.X.self_attn.q_proj.lora_a
mlx_weights = {}
for key, tensor in adapter_weights.items():
    # Strip PEFT prefix
    new_key = key.replace('base_model.model.', '')
    # Rename lora_A.weight -> lora_a, lora_B.weight -> lora_b
    new_key = new_key.replace('.lora_A.weight', '.lora_a')
    new_key = new_key.replace('.lora_B.weight', '.lora_b')
    # Transpose: PEFT stores (out, rank) for A and (rank, in) for B
    # MLX LoRA expects (in, rank) for A and (rank, out) for B
    mlx_weights[new_key] = tensor.T.contiguous()

mlx_adapter_dir = f'{OUTPUT_DIR}/mlx_adapter'
os.makedirs(mlx_adapter_dir, exist_ok=True)
save_file(mlx_weights, f'{mlx_adapter_dir}/adapters.safetensors')

# Write mlx-lm compatible adapter_config.json
import json
mlx_config = {
    'num_layers': 16,
    'lora_parameters': {
        'rank': LORA_RANK,
        'alpha': float(LORA_ALPHA),
        'dropout': 0.0,
        'scale': float(LORA_ALPHA) / LORA_RANK,
        'keys': ['self_attn.q_proj', 'self_attn.v_proj', 'self_attn.k_proj',
                 'self_attn.o_proj', 'mlp.gate_proj', 'mlp.up_proj', 'mlp.down_proj'],
    },
}
with open(f'{mlx_adapter_dir}/adapter_config.json', 'w') as f:
    json.dump(mlx_config, f, indent=2)

print(f'MLX adapter saved to {mlx_adapter_dir}')
print('Keys:', list(mlx_weights.keys())[:5], '...')

In [None]:
# Download the MLX adapter
mlx_zip = f'{WORK_DIR}/mlx_adapter.zip'
!zip -r {mlx_zip} {mlx_adapter_dir}

if ON_KAGGLE:
    print(f'\nMLX adapter zipped at {mlx_zip}')
    print('Find it in the Output tab → /kaggle/working/mlx_adapter.zip')
    print('Then on your Mac:')
    print('  unzip mlx_adapter.zip -d artifacts/medgemma-lora-colab/')
else:
    from google.colab import files
    files.download(mlx_zip)
    print('\nDownload the zip, then on your Mac:')
    print('  unzip mlx_adapter.zip -d artifacts/medgemma-lora-colab/')
    print('  python main.py  # will auto-detect adapter')

In [None]:
# Optional: push adapter to HuggingFace Hub
# from huggingface_hub import HfApi
# api = HfApi(token=HF_TOKEN)
# api.upload_folder(
#     folder_path=mlx_adapter_dir,
#     repo_id='YOUR_USERNAME/medgemma-patient-summary-lora',
#     repo_type='model',
# )