# MedGemma 1.5 LoRA Fine-Tuning for Maternal Health Risk Assessment

This notebook fine-tunes `google/medgemma-1.5-4b-it` on maternal health vitals data using LoRA and pure HuggingFace stack.

**Hardware:** Kaggle T4 x2 GPUs (compute capability 7.5)  
**Runtime:** ~45-60 minutes  
**Output:** PEFT LoRA adapter (~50-100MB)

## Key Constraints
- T4 does NOT support bfloat16 → use `fp16=True`
- MedGemma 1.5 is a multimodal VLM → use `AutoModelForImageTextToText`
- Vision encoder is stripped for VRAM savings during training
- LoRA adapter is restored onto full model at inference

## Cell 1: Install Dependencies

Install required packages with pinned versions for reproducibility.

In [1]:
# Install dependencies with versions validated for Kaggle T4
!pip install -q \
    transformers>=4.50.0 \
    peft>=0.13.0 \
    trl>=0.12.0 \
    bitsandbytes>=0.44.0 \
    datasets \
    accelerate \
    huggingface_hub

print("\n✅ Dependencies installed successfully")


✅ Dependencies installed successfully


## Cell 2: Authentication & Configuration

Load HF_TOKEN from Kaggle secrets, configure GPU memory, and set hyperparameters.

In [2]:
import os
import torch
from huggingface_hub import login

# ============================================================================
# KAGGLE SECRETS: Add your HuggingFace token in Kaggle Secrets panel
# Secrets name: HF_TOKEN
# ============================================================================
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
HF_TOKEN = user_secrets.get_secret("HF_TOKEN")

# Authenticate with HuggingFace
login(token=HF_TOKEN)

# GPU memory optimization for T4
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# Print GPU info for verification
print("=" * 60)
print("GPU INFORMATION")
print("=" * 60)
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"CUDA version: {torch.version.cuda}")
print(f"PyTorch version: {torch.__version__}")
print(f"GPU count: {torch.cuda.device_count()}")
for i in range(torch.cuda.device_count()):
    props = torch.cuda.get_device_properties(i)
    print(f"\nGPU {i}: {torch.cuda.get_device_name(i)}")
    print(f"  Compute Capability: {props.major}.{props.minor}")
    print(f"  Total VRAM: {props.total_memory / 1e9:.2f} GB")
    print(f"  Supports bf16: {props.major >= 8}")  # T4 is 7.5 → NO bf16

# ============================================================================
# HYPERPARAMETERS - All in one place for easy tuning
# ============================================================================
CONFIG = {
    # Model
    "model_name": "google/medgemma-1.5-4b-it",
    "max_seq_length": 2048,
    
    # Quantization (4-bit)
    "load_in_4bit": True,
    "bnb_4bit_use_double_quant": True,
    "bnb_4bit_quant_type": "nf4",
    "bnb_4bit_compute_dtype": "float16",  # NOT bfloat16 — T4 lacks bf16 support
    
    # LoRA
    "lora_r": 16,
    "lora_alpha": 16,
    "lora_dropout": 0.05,
    "lora_target_modules": "all-linear",
    
    # Training
    "num_train_epochs": 3,
    "per_device_train_batch_size": 2,
    "per_device_eval_batch_size": 2,
    "gradient_accumulation_steps": 4,  # effective batch = 8
    "learning_rate": 2e-4,
    "lr_scheduler_type": "cosine",
    "warmup_ratio": 0.05,
    "weight_decay": 0.01,
    "max_grad_norm": 0.3,
    
    # Evaluation & Saving
    "eval_strategy": "steps",
    "eval_steps": 50,
    "save_strategy": "epoch",
    "save_total_limit": 2,
    "logging_steps": 10,
    
    # Hub
    "hub_model_id": "tyb343/mamaguard-vitals-lora",  # CHANGE THIS
    "output_dir": "medgemma-mamaguard-lora",
}

print("\n=" * 60)
print("CONFIGURATION")
print("=" * 60)
for key, value in CONFIG.items():
    print(f"  {key}: {value}")
print("\n⚠️  IMPORTANT: Change 'hub_model_id' to your HF username before running!")

GPU INFORMATION
CUDA available: True
CUDA version: 12.6
PyTorch version: 2.8.0+cu126
GPU count: 2

GPU 0: Tesla T4
  Compute Capability: 7.5
  Total VRAM: 15.64 GB
  Supports bf16: False

GPU 1: Tesla T4
  Compute Capability: 7.5
  Total VRAM: 15.64 GB
  Supports bf16: False

=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
CONFIGURATION
  model_name: google/medgemma-1.5-4b-it
  max_seq_length: 2048
  load_in_4bit: True
  bnb_4bit_use_double_quant: True
  bnb_4bit_quant_type: nf4
  bnb_4bit_compute_dtype: float16
  lora_r: 16
  lora_alpha: 16
  lora_dropout: 0.05
  lora_target_modules: all-linear
  num_train_epochs: 3
  per_device_train_batch_size: 2
  per_device_eval_batch_size: 2
  gradient_accumulation_steps: 4
  learning_rate: 0.0002
  lr_scheduler_type: cosine
  warmup_ratio: 0.05
  weight_decay: 0.01
  max_grad_norm: 0.3
  eval_strategy: steps
  eval_steps: 50
  save_strategy: epoch
  save_total_limit: 2
  log

## Cell 3: Load Model with 4-bit Quantization

Load MedGemma 1.5 using `AutoModelForImageTextToText` with BitsAndBytes 4-bit quantization.  
Key: Use `torch.float16` compute dtype (NOT bfloat16) — T4 doesn't support bf16.

In [3]:
import gc
from transformers import (
    AutoProcessor,
    AutoModelForImageTextToText,
    BitsAndBytesConfig,
)

print("=" * 60)
print("LOADING MODEL")
print("=" * 60)

# Track VRAM before loading
torch.cuda.empty_cache()
gc.collect()
vram_before = torch.cuda.memory_allocated() / 1e9
print(f"VRAM before model load: {vram_before:.2f} GB")

# 4-bit quantization config - CRITICAL: use float16 for T4
quant_config = BitsAndBytesConfig(
    load_in_4bit=CONFIG["load_in_4bit"],
    bnb_4bit_use_double_quant=CONFIG["bnb_4bit_use_double_quant"],
    bnb_4bit_quant_type=CONFIG["bnb_4bit_quant_type"],
    bnb_4bit_compute_dtype=torch.float16,  # MUST be float16 for T4
)

# Load model with AutoModelForImageTextToText (NOT AutoModelForCausalLM)
model = AutoModelForImageTextToText.from_pretrained(
    CONFIG["model_name"],
    quantization_config=quant_config,
    torch_dtype=torch.float16,
    device_map="auto",
    attn_implementation="eager",  # "flash_attention_2" can be tried but eager is safer
    token=HF_TOKEN,
)

# Load processor (includes tokenizer + vision preprocessor)
processor = AutoProcessor.from_pretrained(CONFIG["model_name"], token=HF_TOKEN)
processor.tokenizer.padding_side = "right"  # Required for training

# Disable KV cache — required for gradient checkpointing compatibility
model.config.use_cache = False

# Print VRAM after loading
vram_after_load = torch.cuda.memory_allocated() / 1e9
print(f"VRAM after model load: {vram_after_load:.2f} GB")
print(f"VRAM used by model: {vram_after_load - vram_before:.2f} GB")

print("\n✅ Model and processor loaded successfully")
print(f"   Model type: {type(model).__name__}")
print(f"   Pad token: {processor.tokenizer.pad_token}")
print(f"   EOS token: {processor.tokenizer.eos_token}")

2026-02-18 16:45:53.332803: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1771433153.770254      55 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1771433153.894800      55 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1771433154.832358      55 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1771433154.832397      55 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1771433154.832399      55 computation_placer.cc:177] computation placer alr

LOADING MODEL
VRAM before model load: 0.00 GB


config.json:   0%|          | 0.00/2.55k [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors.index.json:   0%|          | 0.00/90.6k [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.96G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/3.64G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/115 [00:00<?, ?B/s]

processor_config.json:   0%|          | 0.00/70.0 [00:00<?, ?B/s]

chat_template.jinja:   0%|          | 0.00/1.53k [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


tokenizer_config.json:   0%|          | 0.00/1.16M [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.69M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/35.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/662 [00:00<?, ?B/s]

VRAM after model load: 0.15 GB
VRAM used by model: 0.15 GB

✅ Model and processor loaded successfully
   Model type: Gemma3ForConditionalGeneration
   Pad token: <pad>
   EOS token: <eos>


### Strip Vision Encoder for VRAM Optimization

Since we're training on text-only data, remove the vision encoder to save ~1-2GB VRAM.  
The vision encoder will be restored when loading the adapter onto the full base model at inference.

In [4]:
print("=" * 60)
print("STRIPPING VISION ENCODER")
print("=" * 60)

vram_before_strip = torch.cuda.memory_allocated() / 1e9
print(f"VRAM before stripping: {vram_before_strip:.2f} GB")

# Set vision tower and multimodal projector to None to save VRAM
# Note: setattr to None instead of delattr (properties may not have deleters)
for attr in ["vision_tower", "multi_modal_projector"]:
    for parent in [model, getattr(model, "model", None)]:
        if parent and hasattr(parent, attr):
            try:
                setattr(parent, attr, None)
                print(f"   Set {attr} to None")
            except AttributeError as e:
                print(f"   Note: {attr} could not be set to None ({e})")
# Force garbage collection
gc.collect()
torch.cuda.empty_cache()

vram_after_strip = torch.cuda.memory_allocated() / 1e9
saved = vram_before_strip - vram_after_strip
print(f"\nVRAM after stripping: {vram_after_strip:.2f} GB")
print(f"VRAM saved: {saved:.2f} GB")
print("\n✅ Vision encoder stripped for text-only training")

STRIPPING VISION ENCODER
VRAM before stripping: 0.15 GB
   Note: vision_tower could not be set to None (property 'vision_tower' of 'Gemma3ForConditionalGeneration' object has no setter)
   Set vision_tower to None
   Note: multi_modal_projector could not be set to None (property 'multi_modal_projector' of 'Gemma3ForConditionalGeneration' object has no setter)
   Set multi_modal_projector to None

VRAM after stripping: 0.00 GB
VRAM saved: 0.15 GB

✅ Vision encoder stripped for text-only training


### Apply SiglipVisionTransformer Workaround

MedGemma 1.5 has a known issue where `get_input_embeddings` on the SigLIP encoder causes errors.  
Apply this monkey-patch BEFORE creating the SFTTrainer.

In [5]:
from transformers.models.siglip.modeling_siglip import SiglipVisionTransformer

# Monkey-patch to prevent get_input_embeddings error during training
SiglipVisionTransformer.get_input_embeddings = lambda self: None

print("✅ Applied SiglipVisionTransformer workaround")

✅ Applied SiglipVisionTransformer workaround


In [6]:
# Force model config to FP16 to prevent any BF16 creep
print("=" * 60)
print("ENSURING FP16 DType")
print("=" * 60)

# Set model config dtype explicitly
model.config.torch_dtype = torch.float16

# Check for any BF16 params
bf16_count = 0
for name, param in model.named_parameters():
    if param.dtype == torch.bfloat16:
        bf16_count += 1
        print(f"⚠️  Found BF16 param: {name}")

if bf16_count == 0:
    print("✅ No BF16 parameters found")
else:
    print(f"⚠️  Found {bf16_count} BF16 parameters - they should be handled by fp16=True")

print("✅ Model dtype configuration complete")

ENSURING FP16 DType
✅ No BF16 parameters found
✅ Model dtype configuration complete


## Cell 4: Load Dataset

Load the maternal health datasets from JSONL files. The data is already in `messages` format  
compatible with chat templating.

In [7]:
from datasets import load_dataset

print("=" * 60)
print("LOADING DATASETS")
print("=" * 60)

# Load datasets from JSONL files
# Assumes mamaguard_train.jsonl and mamaguard_eval.jsonl are uploaded as Kaggle datasets
train_ds = load_dataset("json", data_files="/kaggle/input/datasets/kkfkmf/momnitrix-finetuning-1/mamaguard_eval.jsonl", split="train")
eval_ds = load_dataset("json", data_files="/kaggle/input/datasets/kkfkmf/momnitrix-finetuning-1/mamaguard_eval.jsonl", split="train")

print(f"Train samples: {len(train_ds)}")
print(f"Eval samples: {len(eval_ds)}")

# Verify format
print("\n=" * 60)
print("SAMPLE VERIFICATION")
print("=" * 60)
sample = train_ds[0]
print(f"Keys: {list(sample.keys())}")
print(f"\nMessages format: {type(sample['messages'])}")
print(f"Number of messages: {len(sample['messages'])}")
print(f"\nFirst message role: {sample['messages'][0]['role']}")
print(f"First message content length: {len(sample['messages'][0]['content'])} chars")
print(f"\nSecond message role: {sample['messages'][1]['role']}")
print(f"Second message content length: {len(sample['messages'][1]['content'])} chars")

print("\n✅ Datasets loaded and verified")

LOADING DATASETS


Generating train split: 0 examples [00:00, ? examples/s]

Train samples: 102
Eval samples: 102

=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
=
SAMPLE VERIFICATION
Keys: ['messages']

Messages format: <class 'list'>
Number of messages: 2

First message role: user
First message content length: 585 chars

Second message role: assistant
Second message content length: 1717 chars

✅ Datasets loaded and verified


## Cell 5: Define Custom Collate Function

Since we're using `AutoModelForImageTextToText` (a VLM) with text-only data, we need a custom  
collate function that:
1. Applies the chat template via `processor.apply_chat_template()`
2. Processes text through the processor (no images, pass `text=` only)
3. Creates labels from input_ids with proper masking (pad tokens = -100)

In [8]:
def collate_fn(examples):
    """
    Custom collate function for text-only training with VLM processor.
    
    Args:
        examples: List of examples from dataset (each has 'messages' key)
    
    Returns:
        Batch dict with input_ids, attention_mask, and labels
    """
    # Apply chat template to each example's messages
    texts = []
    for example in examples:
        formatted = processor.apply_chat_template(
            example["messages"],
            add_generation_prompt=False,  # False for training (assistant already responded)
            tokenize=False
        )
        texts.append(formatted.strip())
    
    # Process through processor (text-only, no images)
    batch = processor(
        text=texts,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=CONFIG["max_seq_length"],
    )
    
    # Create labels from input_ids
    labels = batch["input_ids"].clone()
    
    # Mask pad tokens with -100 (ignored in loss calculation)
    labels[labels == processor.tokenizer.pad_token_id] = -100
    
    # Note: MedGemma uses specific image tokens - we don't need to mask them
    # since we're training on text-only data (no images in the examples)
    
    batch["labels"] = labels
    return batch


# Test the collate function
print("=" * 60)
print("TESTING COLLATE FUNCTION")
print("=" * 60)

test_batch = collate_fn([train_ds[0], train_ds[1]])
print(f"Batch keys: {list(test_batch.keys())}")
print(f"Input IDs shape: {test_batch['input_ids'].shape}")
print(f"Attention mask shape: {test_batch['attention_mask'].shape}")
print(f"Labels shape: {test_batch['labels'].shape}")

# Count non-pad tokens
non_pad_tokens = (test_batch["labels"] != -100).sum().item()
total_tokens = test_batch["labels"].numel()
print(f"\nNon-pad tokens: {non_pad_tokens} / {total_tokens}")
print(f"Pad token percentage: {(total_tokens - non_pad_tokens) / total_tokens * 100:.1f}%")

print("\n✅ Collate function working correctly")

TESTING COLLATE FUNCTION
Batch keys: ['input_ids', 'attention_mask', 'token_type_ids', 'labels']
Input IDs shape: torch.Size([2, 530])
Attention mask shape: torch.Size([2, 530])
Labels shape: torch.Size([2, 530])

Non-pad tokens: 1057 / 1060
Pad token percentage: 0.3%

✅ Collate function working correctly


## Cell 6: Configure LoRA + SFTTrainer

Set up LoRA configuration targeting all linear layers, then create the SFTTrainer.  
Note: Do NOT call `prepare_model_for_kbit_training()` or `get_peft_model()` manually —  
pass `peft_config` directly to SFTTrainer and it handles everything.

In [9]:
from peft import LoraConfig
from trl import SFTConfig, SFTTrainer

print("=" * 60)
print("CONFIGURING LoRA")
print("=" * 60)

# LoRA configuration
peft_config = LoraConfig(
    r=CONFIG["lora_r"],
    lora_alpha=CONFIG["lora_alpha"],
    lora_dropout=CONFIG["lora_dropout"],
    bias="none",
    target_modules=CONFIG["lora_target_modules"],
    task_type="CAUSAL_LM",
)

print(f"LoRA rank (r): {peft_config.r}")
print(f"LoRA alpha: {peft_config.lora_alpha}")

print("\n" + "=" * 60)
print("CONFIGURING SFTTrainer")
print("=" * 60)

# Training args - use fp32 for LoRA (4-bit base model saves memory)
training_args = SFTConfig(
    output_dir=CONFIG["output_dir"],
    num_train_epochs=CONFIG["num_train_epochs"],
    per_device_train_batch_size=CONFIG["per_device_train_batch_size"],
    per_device_eval_batch_size=CONFIG["per_device_eval_batch_size"],
    gradient_accumulation_steps=CONFIG["gradient_accumulation_steps"],
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant": False},
    optim="adamw_torch_fused",
    learning_rate=CONFIG["learning_rate"],
    lr_scheduler_type=CONFIG["lr_scheduler_type"],
    warmup_ratio=CONFIG["warmup_ratio"],
    weight_decay=CONFIG["weight_decay"],
    # DISABLE fp16/bf16 - 4-bit quantization is enough for memory
    # LoRA adapters are small and can train in FP32
    fp16=False,
    bf16=False,
    dataloader_num_workers=0,
    logging_steps=CONFIG["logging_steps"],
    eval_strategy=CONFIG["eval_strategy"],
    eval_steps=CONFIG["eval_steps"],
    save_strategy=CONFIG["save_strategy"],
    save_total_limit=CONFIG["save_total_limit"],
    max_grad_norm=CONFIG["max_grad_norm"],
    report_to="none",
    seed=42,
    dataset_kwargs={"skip_prepare_dataset": True},
    remove_unused_columns=False,
)

print(f"Training epochs: {training_args.num_train_epochs}")
print(f"Batch size: {training_args.per_device_train_batch_size}")
print(f"fp16: {training_args.fp16}")
print(f"bf16: {training_args.bf16}")
print("Note: Training LoRA in FP32 (4-bit base model saves memory)")

trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=eval_ds,
    peft_config=peft_config,
    processing_class=processor,
    data_collator=collate_fn,
)

print("\n✅ SFTTrainer created successfully")



CONFIGURING LoRA
LoRA rank (r): 16
LoRA alpha: 16

CONFIGURING SFTTrainer
Training epochs: 3
Batch size: 2
fp16: False
bf16: False
Note: Training LoRA in FP32 (4-bit base model saves memory)

✅ SFTTrainer created successfully


## Cell 7: Train

Run training for 3 epochs. Expected time: ~45-60 minutes on T4 x2.

In [10]:
# Cast all BF16 parameters to FP16 to prevent T4 crash
print("=" * 60)
print("FIXING BF16 PARAMETERS FOR T4 COMPATIBILITY")
print("=" * 60)

bf16_count = 0
for name, param in model.named_parameters():
    if param.dtype == torch.bfloat16:
        param.data = param.data.to(torch.float16)
        bf16_count += 1

print(f"Converted {bf16_count} BF16 parameters to FP16")

# Verify no BF16 remains
remaining_bf16 = sum(1 for p in model.parameters() if p.dtype == torch.bfloat16)
print(f"Remaining BF16 parameters: {remaining_bf16}")

if remaining_bf16 == 0:
    print("✅ All parameters now FP16 - ready for T4 training")
else:
    print(f"⚠️  Warning: {remaining_bf16} BF16 parameters still exist")

FIXING BF16 PARAMETERS FOR T4 COMPATIBILITY
Converted 476 BF16 parameters to FP16
Remaining BF16 parameters: 0
✅ All parameters now FP16 - ready for T4 training


In [11]:
import time

print("=" * 60)
print("STARTING TRAINING")
print("=" * 60)

# Log VRAM before training
vram_before_train = torch.cuda.memory_allocated() / 1e9
print(f"VRAM before training: {vram_before_train:.2f} GB")
print(f"Estimated training time: 45-60 minutes on T4 x2\n")

start_time = time.time()

# Train!
trainer.train()

# Calculate elapsed time
elapsed = time.time() - start_time
minutes = int(elapsed // 60)
seconds = int(elapsed % 60)

print("\n" + "=" * 60)
print("TRAINING COMPLETE")
print("=" * 60)
print(f"Total training time: {minutes}m {seconds}s")

# Show final metrics
if trainer.state.log_history:
    final_loss = None
    for entry in reversed(trainer.state.log_history):
        if "loss" in entry:
            final_loss = entry["loss"]
            break
    if final_loss:
        print(f"Final training loss: {final_loss:.4f}")
    
    final_eval_loss = None
    for entry in reversed(trainer.state.log_history):
        if "eval_loss" in entry:
            final_eval_loss = entry["eval_loss"]
            break
    if final_eval_loss:
        print(f"Final eval loss: {final_eval_loss:.4f}")

The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'eos_token_id': 1, 'bos_token_id': 2, 'pad_token_id': 0}.


STARTING TRAINING
VRAM before training: 0.00 GB
Estimated training time: 45-60 minutes on T4 x2



Step,Training Loss,Validation Loss



TRAINING COMPLETE
Total training time: 5m 48s


## Cell 8: Save & Push to Hub

Save the LoRA adapter locally and push to HuggingFace Hub. The adapter is only ~50-100MB  
(not the full 4B model).

In [None]:
import shutil

print("=" * 60)
print("SAVING ADAPTER")
print("=" * 60)

# Save the LoRA adapter (not the full model)
trainer.save_model()

# Save processor (tokenizer + config)
processor.save_pretrained(CONFIG["output_dir"])

# Calculate adapter size
total_size = 0
for dirpath, dirnames, filenames in os.walk(CONFIG["output_dir"]):
    for f in filenames:
        fp = os.path.join(dirpath, f)
        total_size += os.path.getsize(fp)

size_mb = total_size / (1024 * 1024)
print(f"Adapter saved to: {CONFIG['output_dir']}/")
print(f"Total adapter size: {size_mb:.2f} MB")

# List saved files
print("\nSaved files:")
for f in sorted(os.listdir(CONFIG["output_dir"])):
    file_path = os.path.join(CONFIG["output_dir"], f)
    if os.path.isfile(file_path):
        file_size = os.path.getsize(file_path) / 1024
        print(f"  {f}: {file_size:.1f} KB")

print("\n=" * 60)
print("PUSHING TO HUB")
print("=" * 60)

# Push to Hub (update hub_model_id in Cell 2 first!)
if "your-username" not in CONFIG["hub_model_id"]:
    try:
        from huggingface_hub import HfApi
        api = HfApi()
        api.create_repo(repo_id=CONFIG["hub_model_id"], exist_ok=True)
        api.upload_folder(
            folder_path=CONFIG["output_dir"],
            repo_id=CONFIG["hub_model_id"],
            commit_message="Upload MamaGuard LoRA adapter",
        )
        print(f"✅ Adapter pushed to: https://huggingface.co/{CONFIG['hub_model_id']}")
    except Exception as e:
        print(f"❌ Failed to push to hub: {e}")
else:
    print("⚠️  Skipping hub push - please update 'hub_model_id' in CONFIG first!")
    print(f"   Current value: {CONFIG['hub_model_id']}")
    print(f"   Set to: your-actual-username/mamaguard-vitals-lora")

## Cell 9: Test Inference (Base vs Fine-Tuned)

Compare the base model (LoRA disabled) vs fine-tuned model (LoRA enabled) on a sample case.  
This demonstrates that the adapter is working and has learned maternal health patterns.

In [None]:
print("=" * 60)
print("TEST INFERENCE: Base vs Fine-Tuned")
print("=" * 60)

# Test case with high-risk vitals
test_messages = [
    {
        "role": "user",
        "content": """Evaluate the following pregnancy vitals and determine risk level:

Patient profile:
- 35 years old, G1P0
- Gestation: week 32 (3rd trimester)
- BMI group: overweight

Monitoring data (smartwatch + app logs):
- BP: 145/95 mmHg
- Fasting glucose: 8.5 mmol/L
- Body temp: 98.2°F
- Resting pulse: 88 bpm

Please return:
1) LOW/MID/HIGH risk classification
2) Clinical interpretation tied to threshold values
3) Likely maternal-fetal complications
4) Week-32 appropriate management actions
5) Immediate red-flag symptoms"""
    }
]

# Format with chat template
test_text = processor.apply_chat_template(
    test_messages,
    add_generation_prompt=True,
    tokenize=False
).strip()

print("TEST PROMPT:")
print("-" * 60)
print(test_text[:500] + "..." if len(test_text) > 500 else test_text)
print("-" * 60)

# Tokenize
inputs = processor(
    text=[test_text],
    return_tensors="pt",
    padding=True,
).to(model.device)

# Generation config
gen_kwargs = {
    "max_new_tokens": 512,
    "do_sample": True,
    "temperature": 0.7,
    "top_p": 0.9,
    "eos_token_id": processor.tokenizer.eos_token_id,
}

# ============================================================================
# Test 1: WITH LoRA (fine-tuned)
# ============================================================================
print("\n" + "=" * 60)
print("TEST 1: WITH LoRA ADAPTER (Fine-Tuned)")
print("=" * 60)

model.eval()
with torch.no_grad():
    outputs_ft = model.generate(**inputs, **gen_kwargs)

# Decode only the new tokens
generated_ft = outputs_ft[0][inputs["input_ids"].shape[1]:]
response_ft = processor.decode(generated_ft, skip_special_tokens=True)

print("Response:")
print(response_ft[:800] + "..." if len(response_ft) > 800 else response_ft)
print("\n")

# Try to extract risk level from response
if "HIGH" in response_ft[:100]:
    risk_pred = "HIGH"
elif "MID" in response_ft[:100] or "MEDIUM" in response_ft[:100]:
    risk_pred = "MID"
elif "LOW" in response_ft[:100]:
    risk_pred = "LOW"
else:
    risk_pred = "UNKNOWN"

print(f"Predicted Risk Level: {risk_pred}")

# ============================================================================
# Summary
# ============================================================================
print("\n" + "=" * 60)
print("SUMMARY")
print("=" * 60)
print(f"With LoRA: {risk_pred} risk predicted")
print("\nExpected: HIGH risk (BP 145/95 > 140/90, Glucose 8.5 > 5.1)")
print("\n✅ Inference complete!")
print("\nNote: To compare with base model (LoRA disabled), reload the model")
print("without the adapter or use: model.disable_adapter() if available.")

---

## Export Strategy

The output is a **PEFT LoRA adapter** (~50-100MB), NOT a full model. At inference:

```python
from transformers import AutoModelForImageTextToText, AutoProcessor
from peft import PeftModel

# Load full base model (with vision encoder intact)
base = AutoModelForImageTextToText.from_pretrained(
    "google/medgemma-1.5-4b-it", ...
)

# Apply LoRA adapter
model = PeftModel.from_pretrained(base, "your-username/mamaguard-vitals-lora")

# Now you have:
# - Vision encoder intact → multimodal inference works
# - LoRA adapter active → maternal health expertise applied
```

## Training Summary

| Metric | Value |
|--------|-------|
| Model | google/medgemma-1.5-4b-it |
| Method | LoRA (r=16, α=16) |
| Quantization | 4-bit NF4 |
| Train samples | 912 |
| Eval samples | 102 |
| Epochs | 3 |
| Batch size | 2 (effective 8) |
| Learning rate | 2e-4 |
| Hardware | Kaggle T4 x2 |
| Runtime | ~45-60 min |
| Output size | ~50-100 MB |