<a href="https://colab.research.google.com/github/tamara-kostova/MSc_Thesis_Neuroimaging/blob/master/10_medgema1_5_evaluation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
!ls "/content/drive/MyDrive/MSc_Thesis_Neuroimaging/data/split/labels"

CT_stroke_binary_norm_test_labels.json
CT_stroke_binary_norm_train_labels.json
CT_stroke_binary_norm_val_labels.json
MRI_ms_norm_test_labels.json
MRI_ms_norm_train_labels.json
MRI_ms_norm_val_labels.json
MRI_tumor_binary_norm_test_labels.json
MRI_tumor_binary_norm_train_labels.json
MRI_tumor_binary_norm_val_labels.json
MRI_tumor_multiclass_norm_test_labels.json
MRI_tumor_multiclass_norm_train_labels.json
MRI_tumor_multiclass_norm_val_labels.json


In [None]:
"""
MedGemma-1.5-4B multimodal evaluation with vLLM

- One inference per image
- Resume-safe checkpoints
- Dataset-agnostic
- Thesis / benchmark ready
"""

import os
import json
import base64
import requests
from pathlib import Path
from io import BytesIO
from PIL import Image
from tqdm import tqdm

# ==============================
# CONFIG
# ==============================

VLLM_ENDPOINT = "http://<SERVER_IP>:8000/v1/chat/completions"
MODEL_NAME = "google/medgemma-1.5-4b"

BASE_DIR = Path("/content/drive/MyDrive/MSc_Thesis_Neuroimaging")
SPLIT_DIR = BASE_DIR / "data/split"
LABELS_DIR = SPLIT_DIR / "labels"

RESULTS_DIR = BASE_DIR / "results/medgemma"
RESULTS_DIR.mkdir(parents=True, exist_ok=True)

DATASETS = [
    "CT_stroke_binary_norm",
    "MRI_ms_norm",
    "MRI_tumor_binary_norm",
    "MRI_tumor_multiclass_norm",
]

SPLIT = "test"

MAX_TOKENS = 256
TEMPERATURE = 0.0
TIMEOUT = 90

# ==============================
# PROMPT
# ==============================

SYSTEM_PROMPT = """You are a medical imaging expert.

Analyze the provided neuroimaging scan and output a JSON object:

{
  "diagnosis": string or null,
  "subtype": string or null,
  "modality": "CT" | "MRI" | null,
  "sequence": string or null,
  "plane": "Axial" | "Sagittal" | "Coronal" | null
}

If uncertain, use null.
Output JSON only.
"""

# ==============================
# UTILS
# ==============================

def encode_image(image_path: Path) -> str:
    """Encode image as base64 PNG for vLLM."""
    img = Image.open(image_path).convert("RGB")
    buffer = BytesIO()
    img.save(buffer, format="PNG")
    return base64.b64encode(buffer.getvalue()).decode("utf-8")


def infer(image_path: Path) -> str:
    """Send one multimodal inference request to vLLM."""
    image_b64 = encode_image(image_path)

    payload = {
        "model": MODEL_NAME,
        "messages": [
            {"role": "system", "content": SYSTEM_PROMPT},
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": "Analyze this scan."},
                    {
                        "type": "image_url",
                        "image_url": {
                            "url": f"data:image/png;base64,{image_b64}"
                        },
                    },
                ],
            },
        ],
        "temperature": TEMPERATURE,
        "max_tokens": MAX_TOKENS,
    }

    r = requests.post(VLLM_ENDPOINT, json=payload, timeout=TIMEOUT)
    r.raise_for_status()
    return r.json()["choices"][0]["message"]["content"]


# ==============================
# LOAD LABELS
# ==============================

for dataset in DATASETS:
    print(f"\n{'='*70}")
    print(f"DATASET: {dataset} [{SPLIT}]")
    print(f"{'='*70}")

    labels_file = LABELS_DIR / f"{dataset}_{SPLIT}_labels.json"
    image_root = SPLIT_DIR / dataset / SPLIT

    if not labels_file.exists():
        print(f"⚠ Missing labels: {labels_file}")
        continue

    with open(labels_file) as f:
        labels = json.load(f)

    output_file = RESULTS_DIR / f"{dataset}_{SPLIT}_outputs.jsonl"

    processed = set()
    if output_file.exists():
        with open(output_file) as f:
            for line in f:
                try:
                    processed.add(json.loads(line)["image_id"])
                except Exception:
                    pass

    print(f"✔ Resuming at {len(processed)} samples")

    with open(output_file, "a") as out:
        for relpath, gt in tqdm(labels.items(), desc=dataset):
            image_id = relpath.replace("/", "_")

            if image_id in processed:
                continue

            img_path = image_root / relpath
            if not img_path.exists():
                print(f"⚠ Missing image: {img_path}")
                continue

            try:
                response = infer(img_path)

                record = {
                    "dataset": dataset,
                    "split": SPLIT,
                    "image_id": image_id,
                    "relpath": relpath,
                    "ground_truth": gt,
                    "raw_response": response,
                }

                out.write(json.dumps(record) + "\n")
                out.flush()

            except Exception as e:
                print(f"❌ Error {image_id}: {e}")

print("\n✅ ALL DATASETS COMPLETED")



DATASET: CT_stroke_binary_norm [test]
✔ Resuming at 0 samples


CT_stroke_binary_norm:   0%|          | 1/999 [01:34<26:15:34, 94.72s/it]

❌ Error _content_drive_MyDrive_MSc_Thesis_Neuroimaging_data_split_CT_stroke_binary_norm_test_normal_16347.png: HTTPConnectionPool(host='185.153.49.170', port=8000): Max retries exceeded with url: /v1/chat/completions (Caused by ConnectTimeoutError(<urllib3.connection.HTTPConnection object at 0x7f3ffcf35d90>, 'Connection to 185.153.49.170 timed out. (connect timeout=90)'))


CT_stroke_binary_norm:   0%|          | 2/999 [03:05<25:35:10, 92.39s/it]

❌ Error _content_drive_MyDrive_MSc_Thesis_Neuroimaging_data_split_CT_stroke_binary_norm_test_normal_12577.png: HTTPConnectionPool(host='185.153.49.170', port=8000): Max retries exceeded with url: /v1/chat/completions (Caused by ConnectTimeoutError(<urllib3.connection.HTTPConnection object at 0x7f3ffcd424e0>, 'Connection to 185.153.49.170 timed out. (connect timeout=90)'))
