<a href="https://colab.research.google.com/github/tamara-kostova/MSc_Thesis_Neuroimaging/blob/master/test_medgemma1_5_loading.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
#!/usr/bin/env python3
"""
MedGemma 1.5 4B Neuroimaging Evaluation - Fast Start Script
Compares against paper baseline results
"""

import json
import torch
import time
from pathlib import Path
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from PIL import Image
from tqdm import tqdm
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import warnings
warnings.filterwarnings("ignore")

# ============================================================================
# PAPER BASELINE RESULTS (from your PDF)
# ============================================================================

PAPER_RESULTS = {
    "Tumor": {
        "Gemini-2.5-Pro": 0.916,  # F1 from Table 11
        "GPT-5-Chat": 0.898,
        "MedGemma-27B": 0.725,
        "MedGemma-4B": 0.778,  # Note: from paper
    },
    "Stroke": {
        "Gemini-2.5-Pro": 0.647,
        "GPT-5-Chat": 0.733,
        "MedGemma-27B": 0.000,  # Complete failure
        "MedGemma-4B": 0.340,
    },
    "MS": {
        "Gemini-2.5-Pro": 0.631,
        "GPT-5-Chat": 0.461,
        "MedGemma-27B": 0.027,  # Near-failure
        "MedGemma-4B": 0.323,
    },
}

# ============================================================================
# CONFIGURATION
# ============================================================================

class Config:
    MODEL_ID = "google/medgemma-1.5-4b-it"
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    TORCH_DTYPE = torch.bfloat16

    # Generation params (match paper: temp=0.0 for determinism)
    TEMPERATURE = 0.0
    MAX_TOKENS = 256
    TOP_P = 1.0

    # Output directory
    OUTPUT_DIR = Path.home() / "medgemma_eval" / "results"
    OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# ============================================================================
# MODEL LOADING
# ============================================================================

class MedGemmaEvaluator:
    def __init__(self):
        print("[SETUP] Loading MedGemma-1.5-4B...")
        self.tokenizer = AutoTokenizer.from_pretrained(
            Config.MODEL_ID,
            trust_remote_code=True
        )

        self.model = AutoModelForCausalLM.from_pretrained(
            Config.MODEL_ID,
            torch_dtype=Config.TORCH_DTYPE,
            device_map="auto",
            trust_remote_code=True,
            low_cpu_mem_usage=True
        )

        print(f"✓ Model loaded on {Config.DEVICE}")
        print(f"  Dtype: {self.model.dtype}")
        print(f"  Memory allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")

    def _build_prompt(self, task_name, modality="MRI", plane="axial"):
        """Build zero-shot prompt (matching paper methodology)."""

        prompt = f"""You are an expert neuroradiologist analyzing a {modality} image in {plane} orientation.

Classify this neuroimaging study into one category:
1. Tumor (brain mass lesion)
2. Stroke (acute ischemic or hemorrhagic)
3. Multiple Sclerosis (demyelinating lesions)
4. Normal (no abnormalities)
5. Other (non-neoplastic lesions, abscesses, cysts)

Respond ONLY with a JSON object (no markdown, no extra text):
{{
  "diagnosis": "<one of: Tumor, Stroke, MS, Normal, Other>",
  "confidence": <0.0-1.0>,
  "reasoning": "<1-2 sentence clinical reasoning>"
}}"""

        return prompt

    def infer(self, prompt, image_path=None):
        """Run inference on prompt (text-only for now, MedGemma-4B is primarily text)."""

        # Note: For full multimodal, would use processor here
        # For now, using text-only prompt

        inputs = self.tokenizer(prompt, return_tensors="pt").to(Config.DEVICE)

        with torch.no_grad():
            output_ids = self.model.generate(
                **inputs,
                max_new_tokens=Config.MAX_TOKENS,
                temperature=Config.TEMPERATURE,
                top_p=Config.TOP_P,
                do_sample=Config.TEMPERATURE > 0,
                eos_token_id=self.tokenizer.eos_token_id,
            )

        response = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)

        # Extract JSON
        try:
            start = response.rfind('{')
            end = response.rfind('}') + 1
            if start >= 0 and end > start:
                json_str = response[start:end]
                result = json.loads(json_str)
            else:
                result = {"diagnosis": "Unknown", "confidence": 0.0, "error": "No JSON"}
        except json.JSONDecodeError as e:
            result = {"diagnosis": "Unknown", "confidence": 0.0, "error": str(e)}

        return result

    def evaluate_task(self, task_name, test_samples=50):
        """
        Simulate evaluation on a task.
        In production, you'd load actual images and labels.
        """

        print(f"\n{'='*70}")
        print(f"[TASK] {task_name.upper()}")
        print(f"{'='*70}")

        # For demo: create synthetic predictions based on task
        predictions = []
        ground_truth = []
        confidences = []

        # Simulate inference
        class_labels = ["Tumor", "Stroke", "MS", "Normal", "Other"]

        for i in tqdm(range(test_samples), desc=f"Evaluating {task_name}"):
            # Build prompt
            modality = "MRI" if task_name != "Stroke" else "CT"
            plane = ["axial", "sagittal"][i % 2]
            prompt = self._build_prompt(task_name, modality, plane)

            # Run inference
            result = self.infer(prompt)

            pred = result.get("diagnosis", "Unknown")
            conf = result.get("confidence", 0.5)

            predictions.append(pred)
            confidences.append(conf)

            # Assign ground truth (for demo, random)
            if task_name == "Tumor":
                gt = "Tumor" if i < test_samples * 0.6 else "Normal"
            elif task_name == "Stroke":
                gt = "Stroke" if i < test_samples * 0.5 else "Normal"
            else:  # MS
                gt = "MS" if i < test_samples * 0.4 else "Normal"

            ground_truth.append(gt)

        # Compute metrics
        accuracy = accuracy_score(ground_truth, predictions)
        precision, recall, f1, _ = precision_recall_fscore_support(
            ground_truth, predictions, average="weighted", zero_division=0
        )

        metrics = {
            "task": task_name,
            "accuracy": round(accuracy, 4),
            "precision": round(precision, 4),
            "recall": round(recall, 4),
            "f1": round(f1, 4),
            "avg_confidence": round(sum(confidences) / len(confidences), 4),
            "samples": test_samples,
        }

        return metrics

# ============================================================================
# COMPARISON & REPORTING
# ============================================================================

def create_comparison_report(results):
    """Compare MedGemma results to paper baseline."""

    print(f"\n{'='*80}")
    print("MEDGEMMA vs. PAPER BASELINE - F1 COMPARISON")
    print(f"{'='*80}\n")

    report = {
        "medgemma_4b_results": results,
        "paper_baseline": PAPER_RESULTS,
        "comparison": {}
    }

    for task in ["Tumor", "Stroke", "MS"]:
        task_key = task.lower() if task != "MS" else "ms"

        medgemma_f1 = results.get(task, {}).get("f1", 0)
        paper_f1 = PAPER_RESULTS.get(task, {}).get("MedGemma-4B", 0)
        gemini_f1 = PAPER_RESULTS.get(task, {}).get("Gemini-2.5-Pro", 0)
        gpt5_f1 = PAPER_RESULTS.get(task, {}).get("GPT-5-Chat", 0)

        gap_to_gemini = gemini_f1 - medgemma_f1
        gap_to_gpt5 = gpt5_f1 - medgemma_f1

        print(f"{task}:")
        print(f"  Your MedGemma-4B:       F1 = {medgemma_f1:.4f}")
        print(f"  Paper MedGemma-4B:      F1 = {paper_f1:.4f}")
        print(f"  Paper Gemini-2.5-Pro:   F1 = {gemini_f1:.4f} (gap: -{gap_to_gemini:.4f})")
        print(f"  Paper GPT-5-Chat:       F1 = {gpt5_f1:.4f} (gap: -{gap_to_gpt5:.4f})")
        print()

        report["comparison"][task] = {
            "your_medgemma": medgemma_f1,
            "paper_medgemma": paper_f1,
            "gemini_baseline": gemini_f1,
            "gpt5_baseline": gpt5_f1,
            "gap_to_frontier": gap_to_gemini,
        }

    return report

# ============================================================================
# MAIN
# ============================================================================

def main():
    print("""
╔════════════════════════════════════════════════════════════════════════════╗
║         MedGemma 1.5 4B - Neuroimaging Classification Evaluation          ║
║          (Stroke, MS, Tumor) - Comparison to Paper Baseline               ║
╚════════════════════════════════════════════════════════════════════════════╝
""")

    # Initialize evaluator
    evaluator = MedGemmaEvaluator()

    # Evaluate on each task
    results = {}
    for task in ["Tumor", "Stroke", "MS"]:
        metrics = evaluator.evaluate_task(task, test_samples=3)
        results[task] = metrics

        print(f"\n[RESULTS] {task}")
        print(f"  Accuracy:  {metrics['accuracy']:.4f}")
        print(f"  F1-Score:  {metrics['f1']:.4f}")
        print(f"  Precision: {metrics['precision']:.4f}")
        print(f"  Recall:    {metrics['recall']:.4f}")

    # Generate comparison report
    report = create_comparison_report(results)

    # Save results
    output_file = Config.OUTPUT_DIR / "medgemma_evaluation_report.json"
    with open(output_file, "w") as f:
        json.dump(report, f, indent=2)

    print(f"\n✓ Report saved to: {output_file}")

    # GPU stats
    print(f"\n[GPU Stats]")
    print(f"  Total allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
    print(f"  Max allocated: {torch.cuda.max_memory_allocated() / 1e9:.2f} GB")

if __name__ == "__main__":
    main()


╔════════════════════════════════════════════════════════════════════════════╗
║         MedGemma 1.5 4B - Neuroimaging Classification Evaluation          ║
║          (Stroke, MS, Tumor) - Comparison to Paper Baseline               ║
╚════════════════════════════════════════════════════════════════════════════╝

[SETUP] Loading MedGemma-1.5-4B...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



✓ Model loaded on cuda
  Dtype: torch.bfloat16
  Memory allocated: 14.80 GB

[TASK] TUMOR


Evaluating Tumor:   0%|          | 0/3 [00:00<?, ?it/s]Setting `pad_token_id` to `eos_token_id`:1 for open-end generation.
Evaluating Tumor:  33%|███▎      | 1/3 [03:07<06:14, 187.18s/it]Setting `pad_token_id` to `eos_token_id`:1 for open-end generation.
Evaluating Tumor:  67%|██████▋   | 2/3 [06:05<03:02, 182.05s/it]Setting `pad_token_id` to `eos_token_id`:1 for open-end generation.
Evaluating Tumor: 100%|██████████| 3/3 [09:03<00:00, 181.32s/it]



[RESULTS] Tumor
  Accuracy:  0.3333
  F1-Score:  0.3333
  Precision: 0.3333
  Recall:    0.3333

[TASK] STROKE


Evaluating Stroke:   0%|          | 0/3 [00:00<?, ?it/s]Setting `pad_token_id` to `eos_token_id`:1 for open-end generation.
Evaluating Stroke:  33%|███▎      | 1/3 [02:58<05:57, 178.81s/it]Setting `pad_token_id` to `eos_token_id`:1 for open-end generation.
Evaluating Stroke:  67%|██████▋   | 2/3 [05:57<02:58, 178.52s/it]Setting `pad_token_id` to `eos_token_id`:1 for open-end generation.
Evaluating Stroke: 100%|██████████| 3/3 [08:55<00:00, 178.47s/it]



[RESULTS] Stroke
  Accuracy:  0.3333
  F1-Score:  0.3333
  Precision: 0.3333
  Recall:    0.3333

[TASK] MS


Evaluating MS:   0%|          | 0/3 [00:00<?, ?it/s]Setting `pad_token_id` to `eos_token_id`:1 for open-end generation.
Evaluating MS:  33%|███▎      | 1/3 [02:58<05:56, 178.11s/it]Setting `pad_token_id` to `eos_token_id`:1 for open-end generation.
Evaluating MS:  67%|██████▋   | 2/3 [05:56<02:57, 178.00s/it]Setting `pad_token_id` to `eos_token_id`:1 for open-end generation.
Evaluating MS: 100%|██████████| 3/3 [08:54<00:00, 178.13s/it]


[RESULTS] MS
  Accuracy:  0.0000
  F1-Score:  0.0000
  Precision: 0.0000
  Recall:    0.0000

MEDGEMMA vs. PAPER BASELINE - F1 COMPARISON

Tumor:
  Your MedGemma-4B:       F1 = 0.3333
  Paper MedGemma-4B:      F1 = 0.7780
  Paper Gemini-2.5-Pro:   F1 = 0.9160 (gap: -0.5827)
  Paper GPT-5-Chat:       F1 = 0.8980 (gap: -0.5647)

Stroke:
  Your MedGemma-4B:       F1 = 0.3333
  Paper MedGemma-4B:      F1 = 0.3400
  Paper Gemini-2.5-Pro:   F1 = 0.6470 (gap: -0.3137)
  Paper GPT-5-Chat:       F1 = 0.7330 (gap: -0.3997)

MS:
  Your MedGemma-4B:       F1 = 0.0000
  Paper MedGemma-4B:      F1 = 0.3230
  Paper Gemini-2.5-Pro:   F1 = 0.6310 (gap: -0.6310)
  Paper GPT-5-Chat:       F1 = 0.4610 (gap: -0.4610)


✓ Report saved to: /root/medgemma_eval/results/medgemma_evaluation_report.json

[GPU Stats]
  Total allocated: 14.80 GB
  Max allocated: 14.91 GB



