# Module 2

In [1]:
import os
import wandb

os.environ['HF_HOME'] = '/media/shrish/Data/medgemma_finetune/hf_models/'
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [None]:
from huggingface_hub import login
login(token=os.getenv("HF_TOKEN"))

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
import torch #
from datasets import load_dataset
from transformers import (
    AutoModelForImageTextToText,
    AutoProcessor, 
    BitsAndBytesConfig,
    TrainingArguments,
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from trl import SFTTrainer


In [4]:
MODEL_ID = "google/medgemma-1.5-4b-it" # Replace with your specific MedGemma/Gemma 1.5 path
DATASET_PATH = "medgemma_risk_training_augmented.jsonl"
OUTPUT_DIR = os.path.join("models", "lora_risk_module2")

In [5]:
print(f"Loading dataset from {DATASET_PATH}...")
full_dataset = load_dataset("json", data_files=DATASET_PATH, split="train")

Loading dataset from medgemma_risk_training_augmented.jsonl...


In [6]:
print("Splitting dataset into Training (90%) and Validation (10%) sets...")
# Set a seed so your splits are reproducible across different runs
split_dataset = full_dataset.train_test_split(test_size=0.1, seed=42)


train_dataset = split_dataset["train"]
eval_dataset = split_dataset["test"]
print(f"Training samples: {len(train_dataset)} | Validation samples: {len(eval_dataset)}")


Splitting dataset into Training (90%) and Validation (10%) sets...
Training samples: 16056 | Validation samples: 1784


In [7]:
# --- 2. Processor & Chat Template Formatting ---
print("Initializing MedGemma 1.5 Processor...")
# MedGemma 1.5 uses a Processor instead of a standalone Tokenizer
processor = AutoProcessor.from_pretrained(MODEL_ID)

# We extract the underlying tokenizer to pass to the SFTTrainer for text processing
tokenizer = processor.tokenizer
tokenizer.padding_side = 'right' 
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

Initializing MedGemma 1.5 Processor...


Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [8]:
# Function to apply MedGemma's specific multimodal chat template to our text array
def format_chat_template(example):
    example["text"] = processor.apply_chat_template(
        example["messages"], 
        tokenize=False, 
        add_generation_prompt=False
    )
    return example

In [9]:
print("Formatting splits to MedGemma conversational structure...")
# Map the formatting function to both splits independently
train_dataset = train_dataset.map(format_chat_template)
eval_dataset = eval_dataset.map(format_chat_template)

Formatting splits to MedGemma conversational structure...


In [10]:
# --- 3. 4-Bit Quantization Setup (QLoRA) ---
print("Loading Base Multimodal Model in 4-bit...")
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

# CRITICAL: Load using the ImageTextToText class
model = AutoModelForImageTextToText.from_pretrained(
    MODEL_ID,
    quantization_config=bnb_config,
    device_map="auto",
    torch_dtype=torch.bfloat16
)

Loading Base Multimodal Model in 4-bit...


`torch_dtype` is deprecated! Use `dtype` instead!
Loading checkpoint shards: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2/2 [00:10<00:00,  5.39s/it]


In [11]:
# Prepare model for parameter-efficient fine-tuning
model = prepare_model_for_kbit_training(model)

In [12]:
# --- 4. LoRA Adapter Configuration ---
print("Injecting LoRA Adapters...")
peft_config = LoraConfig(
    r=16, # Rank: Controls the capacity of the adapter
    lora_alpha=32,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)
new_model = get_peft_model(model, peft_config)
new_model.print_trainable_parameters()

Injecting LoRA Adapters...
trainable params: 32,788,480 || all params: 4,332,867,952 || trainable%: 0.7567


In [13]:
# --- 5. Training Arguments ---
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    per_device_train_batch_size=2,      # Adjust based on your GPU VRAM
    per_device_eval_batch_size=2,   
    gradient_accumulation_steps=16,      # Simulates a batch size of 8
    optim="paged_adamw_32bit",
    learning_rate=2e-4,
    lr_scheduler_type="cosine",
    save_strategy="epoch",
    eval_strategy="steps",   
    logging_steps=50,
    eval_steps=50, 
    num_train_epochs=3,                 # 3 epochs is standard for medical instruction tuning
    max_steps=-1,
    fp16=False,
    bf16=True,                          # Use bf16 if you have an Ampere GPU (RTX 3000/4000/A100)
    group_by_length=True,
    report_to="wandb",                   
    run_name="medgemma_module2",
)

In [14]:

# --- 6. SFT Trainer Initialization ---
print("Initializing Supervised Fine-Tuning (SFT) Trainer...")
trainer = SFTTrainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset, 
    peft_config=peft_config,
    args=training_args,
)

Initializing Supervised Fine-Tuning (SFT) Trainer...


In [15]:
# --- 7. Execute Training ---
print("Beginning training... This will take some time.")
trainer.train()

The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'eos_token_id': 1, 'bos_token_id': 2, 'pad_token_id': 0}.


Beginning training... This will take some time.


[34m[1mwandb[0m: [wandb.login()] Loaded credentials for https://api.wandb.ai from /home/shrish/.netrc.
[34m[1mwandb[0m: Currently logged in as: [33mda24s004[0m ([33mda24s004-iitm[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


Step,Training Loss,Validation Loss
50,0.5184,0.396847
100,0.313,0.354317
150,0.3088,0.347138
200,0.3078,0.336562
250,0.2957,0.330142
300,0.299,0.316288
350,0.2957,0.304725
400,0.2878,0.305818
450,0.2874,0.296518
500,0.2868,0.29649


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


TrainOutput(global_step=1506, training_loss=0.25365220677171885, metrics={'train_runtime': 103679.9449, 'train_samples_per_second': 0.465, 'train_steps_per_second': 0.015, 'total_flos': 1.9307331622164768e+17, 'train_loss': 0.25365220677171885})

In [16]:
# --- 8. Save the Final LoRA Adapter ---
print(f"Training complete. Saving adapter to {OUTPUT_DIR}...")
trainer.model.save_pretrained(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)
print("‚úÖ Module 2 Adapter successfully compiled and saved.")

Training complete. Saving adapter to models/lora_risk_module2...
‚úÖ Module 2 Adapter successfully compiled and saved.


# Test

In [18]:
import torch
import random
from tqdm import tqdm
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import wandb

# --- 9. Post-Training Evaluation on Validation Set ---
print("\n" + "="*50)
print("üöÄ Starting Final Evaluation on Validation Set...")
print("="*50)

# 1. Put the newly trained model in evaluation mode
trainer.model.eval()


üöÄ Starting Final Evaluation on Validation Set...


PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): Gemma3ForConditionalGeneration(
      (model): Gemma3Model(
        (vision_tower): SiglipVisionModel(
          (vision_model): SiglipVisionTransformer(
            (embeddings): SiglipVisionEmbeddings(
              (patch_embedding): Conv2d(3, 1152, kernel_size=(14, 14), stride=(14, 14), padding=valid)
              (position_embedding): Embedding(4096, 1152)
            )
            (encoder): SiglipEncoder(
              (layers): ModuleList(
                (0-26): 27 x SiglipEncoderLayer(
                  (layer_norm1): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)
                  (self_attn): SiglipAttention(
                    (k_proj): lora.Linear4bit(
                      (base_layer): Linear4bit(in_features=1152, out_features=1152, bias=True)
                      (lora_dropout): ModuleDict(
                        (default): Dropout(p=0.05, inplace=False)
                      )
                  

In [19]:
results = []

# Helper function to parse the specific Risk Category from the text
def extract_risk_level(text):
    text_lower = text.lower()
    if "critical risk" in text_lower: 
        return "Critical Risk"
    elif "moderate/high risk" in text_lower or "moderate risk" in text_lower: 
        return "Moderate/High Risk"
    elif "high risk" in text_lower: 
        return "High Risk"
    elif "standard/low risk" in text_lower or "low risk" in text_lower: 
        return "Standard/Low Risk"
    return "Unknown/Invalid"

In [21]:
# 2. Iterate through the validation dataset
for example in tqdm(eval_dataset, desc="Evaluating Module 2"):
    # The dataset uses a "messages" schema
    user_msg = [msg for msg in example["messages"] if msg["role"] == "user"]
    true_target = [msg for msg in example["messages"] if msg["role"] == "assistant"][0]["content"]
    
    # Format the prompt using the chat template
    prompt = processor.apply_chat_template(
        user_msg, 
        tokenize=False, 
        add_generation_prompt=True
    )
    
    # Tokenize and move to GPU
    inputs = processor(text=prompt, return_tensors="pt", padding=True).to(trainer.model.device)
    
    inputs.pop("token_type_ids", None)
    inputs.pop("pixel_values", None)

    # Generate the response
    with torch.no_grad():
        outputs = trainer.model.generate(
            **inputs, 
            max_new_tokens=150, # Enough tokens to capture the risk score and the reasoning
            do_sample=False,    # Greedy decoding for consistent clinical evaluation
            pad_token_id=processor.tokenizer.pad_token_id
        )
        
    # Decode only the newly generated tokens
    input_length = inputs["input_ids"].shape[1]
    generated_tokens = outputs[0, input_length:]
    pred_text = processor.decode(generated_tokens, skip_special_tokens=True).strip()
    
    # 3. Parse the risk classification
    true_risk = extract_risk_level(true_target)
    pred_risk = extract_risk_level(pred_text)
    
    is_correct = (true_risk == pred_risk)
    
    results.append({
        "User_Prompt": user_msg[0]["content"],
        "Ground_Truth_Text": true_target,
        "Model_Prediction_Text": pred_text,
        "True_Label": true_risk,
        "Predicted_Label": pred_risk,
        "Correct": is_correct
    })

print("\nInference complete! Calculating metrics...")

Evaluating Module 2: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1784/1784 [2:25:57<00:00,  4.91s/it]  


Inference complete! Calculating metrics...





In [22]:
# --- 10. Calculate Metrics & Categorize Samples ---

y_true = [r["True_Label"] for r in results]
y_pred = [r["Predicted_Label"] for r in results]

# Calculate Multi-Class Metrics (Using 'weighted' average since class distribution may be imbalanced)
accuracy = accuracy_score(y_true, y_pred)
precision, recall, f1, _ = precision_recall_fscore_support(
    y_true, y_pred, average="weighted", zero_division=0
)

print("\n" + "="*40)
print("üìä MODULE 2 METRICS (Risk Stratification)")
print("="*40)
print(f"Accuracy:  {accuracy:.4f}")
print(f"Weighted Precision: {precision:.4f}")
print(f"Weighted Recall:    {recall:.4f}")
print(f"Weighted F1 Score:  {f1:.4f}")
print("="*40 + "\n")


üìä MODULE 2 METRICS (Risk Stratification)
Accuracy:  0.9989
Weighted Precision: 0.9989
Weighted Recall:    0.9989
Weighted F1 Score:  0.9989



In [26]:
# Select 5 random samples
random_5 = random.sample(results, min(5, len(results)))

print("--- 5 RANDOM VALIDATION SAMPLES ---")
for i, s in enumerate(random_5):
    status = "‚úÖ" if s['Correct'] else "‚ùå"
    print(f"\nSample #{i+1} {status}")
    print(f"  True Risk Level: {s['True_Label']} | Predicted: {s['Predicted_Label']}")
    print(f"User Prompt: {s['User_Prompt']}")
    print(f"  Ground Truth: {s['Ground_Truth_Text']}") # Truncated for terminal readability
    print(f"  Model Output: {s['Model_Prediction_Text']}")

--- 5 RANDOM VALIDATION SAMPLES ---

Sample #1 ‚úÖ
  True Risk Level: Standard/Low Risk | Predicted: Standard/Low Risk
User Prompt: Assess the Multiple Myeloma risk profile for this 64-year-old Male patient:
- Patient Reported Symptoms: Patient presented with no specific symptoms noted in the provided snippet.
- CRAB Panel -> Creatinine: 0.80 mg/dL, Calcium: 7.75 mg/dL, Hemoglobin: 11.70 g/dL
- Tumor/Staging Panel -> Albumin: Not tested, Beta-2 Microglobulin: Not tested, LDH: Not tested, M-Spike (SPEP): Not tested, FLC Ratio: Not tested
  Ground Truth: Standard/Low Risk based on currently available data. No overt CRAB criteria or tumor markers are met. Continue routine clinical monitoring based on presenting symptoms.
  Model Output: Standard/Low Risk based on currently available data. No overt CRAB criteria or tumor markers are met. Continue routine clinical monitoring based on presenting symptoms.

Sample #2 ‚úÖ
  True Risk Level: Standard/Low Risk | Predicted: Standard/Low Risk
User

In [29]:
# --- 11. Log Everything to Weights & Biases ---

# Ensure wandb is tracking (Trainer initializes it automatically, but we ensure it's active)
if wandb.run is None:
    wandb.init(project="huggingface", name="medgemma_module2_eval")

# Log numeric metrics
wandb.log({
    "eval/accuracy": accuracy,
    "eval/precision": precision,
    "eval/recall": recall,
    "eval/f1_score": f1
})



In [28]:



# Build a WandB Table to log the text samples cleanly
columns = ["True Risk Level", "Predicted Risk Level", "Correct?", "Ground Truth Reasoning", "Model Predicted Reasoning", "Input Clinical Profile"]
eval_table = wandb.Table(columns=columns)

# Add our random 5 samples to the table
for s in random_5:
    eval_table.add_data(
        s["True_Label"], 
        s["Predicted_Label"], 
        s["Correct"], 
        s["Ground_Truth_Text"], 
        s["Model_Prediction_Text"],
        s["User_Prompt"]
    )

# Log the table to your wandb dashboard
wandb.log({"Module_2_Validation_Samples": eval_table})

print("‚úÖ Successfully logged all metrics and triage samples to Weights & Biases!")

‚úÖ Successfully logged all metrics and triage samples to Weights & Biases!
