# MedGemma Model Exploration

**MedGemma Impact Challenge** - Model Exploration Notebook

This notebook explores MedGemma 1.5 4B capabilities for the competition.

**Requirements:**
- Kaggle GPU (T4/P100) or Google Colab (GPU)
- HF_TOKEN secret with access to MedGemma
- Accept HAI-DEF terms at: https://huggingface.co/google/medgemma-1.5-4b-it

**Models explored:**
- MedGemma 1.5 4B (multimodal - images + text)
- MedSigLIP (zero-shot classification)

## 1. Setup

In [None]:
# Install dependencies (skip on Kaggle - already installed)
# Uncomment if running on Colab or locally:
# !pip install -q -U transformers>=4.50.0 accelerate datasets pillow huggingface-hub

In [None]:
import torch
from transformers import pipeline, AutoProcessor, AutoModel
from PIL import Image
from datasets import load_dataset
import requests
from io import BytesIO
import warnings
warnings.filterwarnings('ignore')

# Check GPU - REQUIRED for MedGemma
print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    print("ERROR: GPU required! Enable GPU in notebook settings.")

In [None]:
# Hugging Face Login
# You must accept HAI-DEF terms at: https://huggingface.co/google/medgemma-1.5-4b-it
from huggingface_hub import login

# Option 1: Kaggle secrets (recommended)
try:
    from kaggle_secrets import UserSecretsClient
    secrets = UserSecretsClient()
    hf_token = secrets.get_secret("HF_TOKEN")
    login(token=hf_token)
    print("Logged in via Kaggle secrets")
except:
    # Option 2: Interactive login (Colab/local)
    login()
    print("Logged in interactively")

## 2. Load MedGemma

In [None]:
MODEL_ID = "google/medgemma-1.5-4b-it"

print(f"Loading {MODEL_ID}...")
print("This takes 2-3 minutes on first run.")

pipe = pipeline(
    "image-text-to-text",
    model=MODEL_ID,
    torch_dtype=torch.bfloat16,
    device="cuda",
)

print("✓ MedGemma loaded!")

In [None]:
def analyze_image(image, prompt, max_tokens=2000):
    """Analyze a medical image with MedGemma."""
    # Ensure RGB format
    if image.mode != "RGB":
        image = image.convert("RGB")
    
    messages = [{
        "role": "user",
        "content": [
            {"type": "image", "image": image},
            {"type": "text", "text": prompt}
        ]
    }]
    output = pipe(text=messages, max_new_tokens=max_tokens)
    return output[0]["generated_text"][-1]["content"]


def ask_question(question, max_tokens=1000):
    """Ask a medical question without an image."""
    messages = [{"role": "user", "content": question}]
    output = pipe(text=messages, max_new_tokens=max_tokens)
    return output[0]["generated_text"][-1]["content"]


print("✓ Helper functions ready")

## 3. Test with Sample Chest X-ray

In [None]:
# Load sample chest X-ray (public domain)
image_url = "https://upload.wikimedia.org/wikipedia/commons/c/c8/Chest_Xray_PA_3-8-2010.png"
response = requests.get(image_url, headers={"User-Agent": "MedGemma-Demo"})
sample_image = Image.open(BytesIO(response.content)).convert("RGB")

print("Sample Chest X-ray:")
display(sample_image.resize((400, 400)))

In [None]:
# Basic analysis
print("Analyzing chest X-ray...\n")
result = analyze_image(sample_image, "Describe this chest X-ray in detail.")
print(result)

## 4. Explore Different Prompts

In [None]:
prompts = {
    "Findings": "List all findings in this chest X-ray in bullet points.",
    "Differential": "What is your differential diagnosis based on this chest X-ray?",
    "Report": """Generate a structured radiology report for this chest X-ray:
1. Technique
2. Findings
3. Impression""",
    "Primary Care": """As a primary care physician reviewing this X-ray:
1. Key findings
2. Differential diagnosis
3. Recommended next steps""",
}

for name, prompt in prompts.items():
    print(f"\n{'='*60}")
    print(f"PROMPT TYPE: {name}")
    print("="*60)
    result = analyze_image(sample_image, prompt)
    print(result)

## 5. Load Dataset

In [None]:
# Load Chest X-ray Pneumonia dataset (simpler, no script issues)
print("Loading chest X-ray dataset (streaming)...")
dataset = load_dataset(
    "hf-vision/chest-xray-pneumonia",
    split="train",
    streaming=True
)

samples = list(dataset.take(5))
print(f"✓ Loaded {len(samples)} samples")
print(f"Keys: {list(samples[0].keys())}")
print(f"Labels: 0=Normal, 1=Pneumonia")

In [None]:
# Explore a sample
sample = samples[0]
label_name = "Pneumonia" if sample.get('label', 0) == 1 else "Normal"
print(f"Ground Truth: {label_name}")
display(sample['image'].resize((400, 400)))

In [None]:
# Analyze sample from dataset
print(f"Ground truth: {label_name}")
print("\nMedGemma Analysis:")
result = analyze_image(sample['image'], "List all abnormalities visible in this chest X-ray.")
print(result)

## 6. MedSigLIP Classification

In [None]:
print("Loading MedSigLIP...")
siglip_model = AutoModel.from_pretrained("google/medsiglip-448").to("cuda")
siglip_processor = AutoProcessor.from_pretrained("google/medsiglip-448")
print("✓ MedSigLIP loaded")

In [None]:
def classify_image(image, labels):
    """Zero-shot classification with MedSigLIP."""
    # IMPORTANT: Convert to RGB
    if image.mode != "RGB":
        image = image.convert("RGB")
    
    inputs = siglip_processor(
        text=labels,
        images=[image],
        padding="max_length",
        return_tensors="pt"
    ).to("cuda")
    
    with torch.no_grad():
        outputs = siglip_model(**inputs)
        probs = torch.softmax(outputs.logits_per_image, dim=1)[0]
    
    return {label: prob.item() for label, prob in zip(labels, probs)}


# Classification labels
labels = [
    "normal chest x-ray",
    "pneumonia",
    "pleural effusion",
    "cardiomegaly",
    "pulmonary edema"
]

print("Zero-shot classification results:")
results = classify_image(sample_image, labels)
for label, prob in sorted(results.items(), key=lambda x: x[1], reverse=True):
    bar = "█" * int(prob * 30) + "░" * (30 - int(prob * 30))
    print(f"  {label:20s} {bar} {prob*100:5.1f}%")

## 7. Text-Only Medical QA

In [None]:
questions = [
    "What are the classic findings of pneumonia on a chest X-ray?",
    "A 65-year-old smoker presents with hemoptysis and weight loss. What should be considered?",
    "What is the difference between consolidation and ground-glass opacity?",
]

for q in questions:
    print(f"\n{'='*60}")
    print(f"Q: {q}")
    print("="*60)
    answer = ask_question(q)
    print(f"A: {answer}")

## 8. Batch Analysis

In [None]:
print("Analyzing multiple chest X-rays...\n")

for i, sample in enumerate(samples[:3]):
    label_name = "Pneumonia" if sample.get('label', 0) == 1 else "Normal"
    
    print(f"\n{'='*60}")
    print(f"Sample {i+1} | Ground Truth: {label_name}")
    print("="*60)
    
    display(sample['image'].resize((200, 200)))
    
    # Classification
    probs = classify_image(sample['image'], labels)
    top_3 = sorted(probs.items(), key=lambda x: x[1], reverse=True)[:3]
    print("Classification:")
    for label, prob in top_3:
        print(f"  - {label}: {prob*100:.1f}%")
    
    # MedGemma analysis
    result = analyze_image(
        sample['image'],
        "In one sentence, describe the key finding in this chest X-ray."
    )
    print(f"\nMedGemma: {result}")

## 9. Summary

In [None]:
import json
from datetime import datetime

summary = {
    "timestamp": datetime.now().isoformat(),
    "model": MODEL_ID,
    "gpu": torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU",
    "samples_analyzed": len(samples),
    "status": "Exploration complete"
}

print("\n" + "="*60)
print("EXPLORATION SUMMARY")
print("="*60)
print(json.dumps(summary, indent=2))
print("\nNext: Run 03_prototype.ipynb or 04_agentic_workflow.ipynb")

---

## Key Findings

**MedGemma 1.5 4B:**
- Detailed chest X-ray descriptions
- Structured report generation
- Good medical knowledge

**MedSigLIP:**
- Fast zero-shot classification
- Good for initial triage

**Resources:**
- [MedGemma Model](https://huggingface.co/google/medgemma-1.5-4b-it)
- [Competition Page](https://www.kaggle.com/competitions/med-gemma-impact-challenge)