<a href="https://colab.research.google.com/github/tamara-kostova/MSc_Thesis_Neuroimaging/blob/master/08_test_medgemma1_5_loading.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [7]:
BASE_DIR = "/content/drive/MyDrive/MSc_Thesis_Neuroimaging"
RESULTS = f"{BASE_DIR}/results/medgemma"

In [2]:
"""
MedGemma 1.5 4B Neuroimaging Evaluation - Fast Start Script
"""

import json
import torch
import time
from pathlib import Path
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from PIL import Image
from tqdm import tqdm
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import warnings
warnings.filterwarnings("ignore")

# ============================================================================
# CONFIGURATION
# ============================================================================

class Config:
    MODEL_ID = "google/medgemma-1.5-4b-it"
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    TORCH_DTYPE = torch.bfloat16

    # Generation params (match paper: temp=0.0 for determinism)
    TEMPERATURE = 0.0
    MAX_TOKENS = 256
    TOP_P = 1.0

    # Output directory
    OUTPUT_DIR = Path.home() / "medgemma_eval" / "results"
    OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# ============================================================================
# MODEL LOADING
# ============================================================================

class MedGemmaEvaluator:
    def __init__(self):
        print("[SETUP] Loading MedGemma-1.5-4B...")
        self.tokenizer = AutoTokenizer.from_pretrained(
            Config.MODEL_ID,
            trust_remote_code=True
        )

        self.model = AutoModelForCausalLM.from_pretrained(
            Config.MODEL_ID,
            torch_dtype=Config.TORCH_DTYPE,
            device_map="auto",
            trust_remote_code=True,
            low_cpu_mem_usage=True
        )

        print(f"✓ Model loaded on {Config.DEVICE}")
        print(f"  Dtype: {self.model.dtype}")
        print(f"  Memory allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")

    def _build_prompt(self, task_name, modality="MRI", plane="axial"):
        """Build zero-shot prompt (matching paper methodology)."""

        prompt = f"""You are an expert neuroradiologist analyzing a {modality} image in {plane} orientation.

Classify this neuroimaging study into one category:
1. Tumor (brain mass lesion)
2. Stroke (acute ischemic or hemorrhagic)
3. Multiple Sclerosis (demyelinating lesions)
4. Normal (no abnormalities)
5. Other (non-neoplastic lesions, abscesses, cysts)

Respond ONLY with a JSON object (no markdown, no extra text):
{{
  "diagnosis": "<one of: Tumor, Stroke, MS, Normal, Other>",
  "confidence": <0.0-1.0>,
  "reasoning": "<1-2 sentence clinical reasoning>"
}}"""

        return prompt

    def infer(self, prompt, image_path=None):
        """Run inference on prompt (text-only for now, MedGemma-4B is primarily text)."""

        # Note: For full multimodal, would use processor here
        # For now, using text-only prompt

        inputs = self.tokenizer(prompt, return_tensors="pt").to(Config.DEVICE)

        with torch.no_grad():
            output_ids = self.model.generate(
                **inputs,
                max_new_tokens=Config.MAX_TOKENS,
                temperature=Config.TEMPERATURE,
                top_p=Config.TOP_P,
                do_sample=Config.TEMPERATURE > 0,
                eos_token_id=self.tokenizer.eos_token_id,
            )

        response = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)

        # Extract JSON
        try:
            start = response.rfind('{')
            end = response.rfind('}') + 1
            if start >= 0 and end > start:
                json_str = response[start:end]
                result = json.loads(json_str)
            else:
                result = {"diagnosis": "Unknown", "confidence": 0.0, "error": "No JSON"}
        except json.JSONDecodeError as e:
            result = {"diagnosis": "Unknown", "confidence": 0.0, "error": str(e)}

        return result

    def evaluate_task(self, task_name, test_samples=50):
        """
        Simulate evaluation on a task.
        In production, you'd load actual images and labels.
        """

        print(f"\n{'='*70}")
        print(f"[TASK] {task_name.upper()}")
        print(f"{'='*70}")

        # For demo: create synthetic predictions based on task
        predictions = []
        ground_truth = []
        confidences = []

        # Simulate inference
        class_labels = ["Tumor", "Stroke", "MS", "Normal", "Other"]

        for i in tqdm(range(test_samples), desc=f"Evaluating {task_name}"):
            # Build prompt
            modality = "MRI" if task_name != "Stroke" else "CT"
            plane = ["axial", "sagittal"][i % 2]
            prompt = self._build_prompt(task_name, modality, plane)

            # Run inference
            result = self.infer(prompt)

            pred = result.get("diagnosis", "Unknown")
            conf = result.get("confidence", 0.5)

            predictions.append(pred)
            confidences.append(conf)

            # Assign ground truth (for demo, random)
            if task_name == "Tumor":
                gt = "Tumor" if i < test_samples * 0.6 else "Normal"
            elif task_name == "Stroke":
                gt = "Stroke" if i < test_samples * 0.5 else "Normal"
            else:  # MS
                gt = "MS" if i < test_samples * 0.4 else "Normal"

            ground_truth.append(gt)

        # Compute metrics
        accuracy = accuracy_score(ground_truth, predictions)
        precision, recall, f1, _ = precision_recall_fscore_support(
            ground_truth, predictions, average="weighted", zero_division=0
        )

        metrics = {
            "task": task_name,
            "accuracy": round(accuracy, 4),
            "precision": round(precision, 4),
            "recall": round(recall, 4),
            "f1": round(f1, 4),
            "avg_confidence": round(sum(confidences) / len(confidences), 4),
            "samples": test_samples,
        }

        return metrics

# ============================================================================
# MAIN
# ============================================================================

def main():
    print("""
╔════════════════════════════════════════════════════════════════════════════╗
║         MedGemma 1.5 4B - Neuroimaging Classification Evaluation           ║
║          (Stroke, MS, Tumor)                                               ║
╚════════════════════════════════════════════════════════════════════════════╝
""")

    # Initialize evaluator
    evaluator = MedGemmaEvaluator()

    # Evaluate on each task
    results = {}
    for task in ["Tumor", "Stroke", "MS"]:
        metrics = evaluator.evaluate_task(task, test_samples=1)
        results[task] = metrics

        print(f"\n[RESULTS] {task}")
        print(f"  Accuracy:  {metrics['accuracy']:.4f}")
        print(f"  F1-Score:  {metrics['f1']:.4f}")
        print(f"  Precision: {metrics['precision']:.4f}")
        print(f"  Recall:    {metrics['recall']:.4f}")


    # GPU stats
    print(f"\n[GPU Stats]")
    print(f"  Total allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
    print(f"  Max allocated: {torch.cuda.max_memory_allocated() / 1e9:.2f} GB")

if __name__ == "__main__":
    main()


╔════════════════════════════════════════════════════════════════════════════╗
║         MedGemma 1.5 4B - Neuroimaging Classification Evaluation           ║
║          (Stroke, MS, Tumor)                                               ║
╚════════════════════════════════════════════════════════════════════════════╝

[SETUP] Loading MedGemma-1.5-4B...


`torch_dtype` is deprecated! Use `dtype` instead!


Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.96G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/3.64G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/168 [00:00<?, ?B/s]

✓ Model loaded on cuda
  Dtype: torch.bfloat16
  Memory allocated: 8.60 GB

[TASK] TUMOR


Evaluating Tumor:   0%|          | 0/1 [00:00<?, ?it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Setting `pad_token_id` to `eos_token_id`:1 for open-end generation.
Evaluating Tumor: 100%|██████████| 1/1 [00:20<00:00, 20.35s/it]



[RESULTS] Tumor
  Accuracy:  1.0000
  F1-Score:  1.0000
  Precision: 1.0000
  Recall:    1.0000

[TASK] STROKE


Evaluating Stroke:   0%|          | 0/1 [00:00<?, ?it/s]Setting `pad_token_id` to `eos_token_id`:1 for open-end generation.
Evaluating Stroke: 100%|██████████| 1/1 [00:19<00:00, 19.82s/it]



[RESULTS] Stroke
  Accuracy:  1.0000
  F1-Score:  1.0000
  Precision: 1.0000
  Recall:    1.0000

[TASK] MS


Evaluating MS:   0%|          | 0/1 [00:00<?, ?it/s]Setting `pad_token_id` to `eos_token_id`:1 for open-end generation.
Evaluating MS: 100%|██████████| 1/1 [00:19<00:00, 19.14s/it]


[RESULTS] MS
  Accuracy:  0.0000
  F1-Score:  0.0000
  Precision: 0.0000
  Recall:    0.0000

[GPU Stats]
  Total allocated: 8.61 GB
  Max allocated: 8.68 GB





In [4]:
from google.colab import output
output.clear()

In [8]:
import json

BASE_DIR = Path("/content/drive/MyDrive/MSc_Thesis_Neuroimaging")
notebook = f"{BASE_DIR}/notebooks/08_test_medgemma1.5_loading.ipynb"

with open(notebook, 'r') as f:
    nb = json.load(f)

if 'metadata' in nb and 'widgets' in nb['metadata']:
    del nb['metadata']['widgets']

with open(notebook, 'w') as f:
    json.dump(nb, f, indent=2)