# MedGemma Model Exploration

**MedGemma Impact Challenge** - Model Exploration Notebook

This notebook explores MedGemma 1.5 4B capabilities for the competition.

**Run on:** Kaggle (GPU T4/P100) or Google Colab (GPU)

**Models explored:**
- MedGemma 1.5 4B (multimodal - images + text)
- MedSigLIP (zero-shot classification)

**Competition:** https://www.kaggle.com/competitions/med-gemma-impact-challenge

## 1. Setup & Installation

In [None]:
# Install dependencies (run once)
!pip install -q -U transformers>=4.50.0 accelerate datasets pillow huggingface-hub gradio

In [None]:
# Imports
import torch
from transformers import pipeline, AutoProcessor, AutoModel
from PIL import Image
from datasets import load_dataset
import requests
from io import BytesIO
import warnings
warnings.filterwarnings('ignore')

# Check GPU
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# Login to Hugging Face (required for MedGemma access)
# You must accept HAI-DEF terms at: https://huggingface.co/google/medgemma-1.5-4b-it
from huggingface_hub import login

# Option 1: Use Kaggle secrets
# from kaggle_secrets import UserSecretsClient
# secrets = UserSecretsClient()
# hf_token = secrets.get_secret("HF_TOKEN")
# login(token=hf_token)

# Option 2: Interactive login
login()

## 2. Load MedGemma 1.5 4B

In [None]:
# Model configuration
MODEL_ID = "google/medgemma-1.5-4b-it"

print(f"Loading {MODEL_ID}...")
print("This may take a few minutes on first run.")

# Load using pipeline API (recommended)
pipe = pipeline(
    "image-text-to-text",
    model=MODEL_ID,
    torch_dtype=torch.bfloat16,
    device="cuda",
)

print("✓ Model loaded successfully!")

In [None]:
# Helper function for image analysis
def analyze_image(image, prompt, max_tokens=2000):
    """Analyze a medical image with MedGemma."""
    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image", "image": image},
                {"type": "text", "text": prompt}
            ]
        }
    ]
    output = pipe(text=messages, max_new_tokens=max_tokens)
    return output[0]["generated_text"][-1]["content"]

# Helper function for text-only questions
def ask_question(question, max_tokens=1000):
    """Ask a medical question without an image."""
    messages = [{"role": "user", "content": question}]
    output = pipe(text=messages, max_new_tokens=max_tokens)
    return output[0]["generated_text"][-1]["content"]

print("✓ Helper functions ready")

## 3. Test with Sample Chest X-ray

In [None]:
# Load a sample chest X-ray from Wikipedia (public domain)
image_url = "https://upload.wikimedia.org/wikipedia/commons/c/c8/Chest_Xray_PA_3-8-2010.png"
response = requests.get(image_url, headers={"User-Agent": "MedGemma-Demo"})
sample_image = Image.open(BytesIO(response.content))

# Display
print("Sample Chest X-ray:")
display(sample_image.resize((400, 400)))

In [None]:
# Basic analysis
print("Analyzing chest X-ray...\n")
result = analyze_image(sample_image, "Describe this chest X-ray in detail.")
print(result)

## 4. Explore Different Prompts

In [None]:
# Test different prompt styles
prompts = {
    "Findings": "List all findings in this chest X-ray in bullet points.",
    "Differential": "What is your differential diagnosis based on this chest X-ray?",
    "Report": """Generate a structured radiology report for this chest X-ray:
1. Technique
2. Findings
3. Impression""",
    "Primary Care": """As a primary care physician reviewing this X-ray:
1. Key findings
2. Differential diagnosis
3. Recommended next steps""",
}

for name, prompt in prompts.items():
    print(f"\n{'='*60}")
    print(f"PROMPT TYPE: {name}")
    print("="*60)
    result = analyze_image(sample_image, prompt)
    print(result)

## 5. Load NIH Chest X-ray Dataset

In [None]:
# Load dataset in streaming mode (no full download needed)
print("Loading NIH Chest X-ray dataset (streaming)...")
dataset = load_dataset(
    "alkzar90/NIH-Chest-X-ray-dataset",
    split="train",
    streaming=True
)

# Get first few samples
samples = list(dataset.take(5))
print(f"\n✓ Loaded {len(samples)} samples")
print(f"Sample keys: {list(samples[0].keys())}")

In [None]:
# Explore a sample
sample = samples[0]
print(f"Labels: {sample.get('labels', 'N/A')}")
print(f"Patient Age: {sample.get('Patient Age', 'N/A')}")
print(f"Patient Gender: {sample.get('Patient Gender', 'N/A')}")
display(sample['image'].resize((400, 400)))

In [None]:
# Analyze sample from dataset
print(f"Ground truth labels: {sample.get('labels', [])}")
print("\nMedGemma Analysis:")
result = analyze_image(sample['image'], "List all abnormalities visible in this chest X-ray.")
print(result)

## 6. Test MedSigLIP (Zero-Shot Classification)

In [None]:
# Load MedSigLIP for classification
print("Loading MedSigLIP...")
siglip_model = AutoModel.from_pretrained("google/medsiglip-448").to("cuda")
siglip_processor = AutoProcessor.from_pretrained("google/medsiglip-448")
print("✓ MedSigLIP loaded")

In [None]:
# Zero-shot classification function
def classify_image(image, labels):
    """Zero-shot classification with MedSigLIP."""
    inputs = siglip_processor(
        text=labels,
        images=[image],
        padding="max_length",
        return_tensors="pt"
    ).to("cuda")
    
    with torch.no_grad():
        outputs = siglip_model(**inputs)
        probs = torch.softmax(outputs.logits_per_image, dim=1)[0]
    
    return {label: f"{prob.item()*100:.1f}%" for label, prob in zip(labels, probs)}

# Test classification
labels = [
    "normal chest x-ray",
    "pneumonia",
    "pleural effusion",
    "cardiomegaly",
    "pulmonary edema"
]

print("Zero-shot classification results:")
results = classify_image(sample_image, labels)
for label, prob in sorted(results.items(), key=lambda x: float(x[1][:-1]), reverse=True):
    print(f"  {label}: {prob}")

## 7. Text-Only Medical QA

In [None]:
# Test text-only medical knowledge
questions = [
    "What are the classic findings of pneumonia on a chest X-ray?",
    "A 65-year-old smoker presents with hemoptysis and weight loss. What should be considered?",
    "What is the difference between consolidation and ground-glass opacity?",
]

for q in questions:
    print(f"\n{'='*60}")
    print(f"Q: {q}")
    print("="*60)
    answer = ask_question(q)
    print(f"A: {answer}")

## 8. Batch Analysis

In [None]:
# Analyze multiple samples
print("Analyzing multiple chest X-rays...\n")

for i, sample in enumerate(samples[:3]):
    print(f"\n{'='*60}")
    print(f"Sample {i+1}")
    print(f"Ground Truth: {sample.get('labels', [])}")
    print("="*60)
    
    # Display thumbnail
    display(sample['image'].resize((200, 200)))
    
    # Get MedGemma analysis
    result = analyze_image(
        sample['image'],
        "In one paragraph, describe the key findings in this chest X-ray."
    )
    print(f"\nMedGemma: {result}")

## 9. Observations & Next Steps

In [None]:
# Document your observations
observations = """
## Exploration Observations

### Model Strengths:
- [ ] Detailed image descriptions
- [ ] Structured report generation
- [ ] Medical knowledge accuracy
- [ ] Multiple imaging modalities

### Model Limitations:
- [ ] Sensitivity to prompt wording
- [ ] May miss subtle findings
- [ ] Confidence calibration

### Best Prompts Found:
- 

### Project Ideas:
1. 
2. 
3. 

### Next Steps:
- [ ] Test on pathology-specific cases
- [ ] Evaluate on held-out data
- [ ] Build prototype demo
"""

print(observations)

## 10. Save Results (Optional)

In [None]:
# Save exploration results
import json
from datetime import datetime

results_summary = {
    "timestamp": datetime.now().isoformat(),
    "model": MODEL_ID,
    "gpu": torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU",
    "samples_analyzed": len(samples),
    "notes": "Initial exploration complete"
}

print(json.dumps(results_summary, indent=2))

# Uncomment to save:
# with open('exploration_results.json', 'w') as f:
#     json.dump(results_summary, f, indent=2)

---

## Summary

This notebook demonstrated:

1. **MedGemma 1.5 4B** - Multimodal medical AI for image analysis and text generation
2. **MedSigLIP** - Zero-shot medical image classification
3. **NIH Chest X-ray Dataset** - 112K chest X-rays for training/testing

**Next:** Build a prototype application in `03_prototype.ipynb`

**Resources:**
- [MedGemma Model](https://huggingface.co/google/medgemma-1.5-4b-it)
- [MedGemma GitHub](https://github.com/Google-Health/medgemma)
- [Competition Page](https://www.kaggle.com/competitions/med-gemma-impact-challenge)