# MedGemma Model Exploration

This notebook explores the capabilities of MedGemma 1.5 4B for the MedGemma Impact Challenge.

**Requirements:**
- GPU with CUDA support
- Hugging Face account with HAI-DEF terms accepted
- Run on Kaggle or Google Colab if no local GPU

## 1. Setup

In [None]:
# Install dependencies (run once)
# !pip install -U transformers accelerate torch datasets pillow huggingface-hub

In [None]:
import torch
from transformers import pipeline, AutoProcessor, AutoModelForImageTextToText
from PIL import Image
import requests
from datasets import load_dataset

# Check GPU availability
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

## 2. Load MedGemma Model

In [None]:
# Model ID
MODEL_ID = "google/medgemma-1.5-4b-it"

# Load using pipeline (recommended for quick start)
pipe = pipeline(
    "image-text-to-text",
    model=MODEL_ID,
    torch_dtype=torch.bfloat16,
    device="cuda",
)

print("Model loaded successfully!")

## 3. Test with Sample Chest X-ray

In [None]:
# Load a sample chest X-ray from Wikipedia (public domain)
image_url = "https://upload.wikimedia.org/wikipedia/commons/c/c8/Chest_Xray_PA_3-8-2010.png"
image = Image.open(requests.get(image_url, headers={"User-Agent": "example"}, stream=True).raw)

# Display the image
display(image)

In [None]:
# Analyze the X-ray
messages = [
    {
        "role": "user",
        "content": [
            {"type": "image", "image": image},
            {"type": "text", "text": "Describe this chest X-ray in detail. Include any findings and your assessment."}
        ]
    }
]

output = pipe(text=messages, max_new_tokens=2000)
print(output[0]["generated_text"][-1]["content"])

## 4. Explore Different Prompts

In [None]:
# Helper function for analysis
def analyze_image(image, prompt):
    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image", "image": image},
                {"type": "text", "text": prompt}
            ]
        }
    ]
    output = pipe(text=messages, max_new_tokens=2000)
    return output[0]["generated_text"][-1]["content"]

In [None]:
# Test different prompt styles
prompts = [
    "Is there any pathology visible in this chest X-ray?",
    "List all findings in this chest X-ray in bullet points.",
    "What is your differential diagnosis based on this chest X-ray?",
    "Describe the cardiac silhouette and lung fields.",
    "Rate the quality of this X-ray and identify any technical issues."
]

for prompt in prompts:
    print(f"\n{'='*60}")
    print(f"PROMPT: {prompt}")
    print(f"{'='*60}")
    response = analyze_image(image, prompt)
    print(response)

## 5. Load NIH Chest X-ray Dataset

In [None]:
# Load the NIH Chest X-ray dataset from Hugging Face
dataset = load_dataset("alkzar90/NIH-Chest-X-ray-dataset", split="train", streaming=True)

# Get first few samples
samples = list(dataset.take(5))
print(f"Sample keys: {samples[0].keys()}")

In [None]:
# Explore a sample
sample = samples[0]
print(f"Labels: {sample.get('labels', 'N/A')}")
display(sample['image'])

In [None]:
# Analyze a sample from the dataset
response = analyze_image(sample['image'], "Describe this chest X-ray and identify any abnormalities.")
print(response)

## 6. Text-Only Medical QA

In [None]:
# Test text-only medical knowledge
def ask_medical_question(question):
    messages = [
        {
            "role": "user",
            "content": question
        }
    ]
    output = pipe(text=messages, max_new_tokens=1000)
    return output[0]["generated_text"][-1]["content"]

# Test questions
questions = [
    "What are the common causes of pneumonia?",
    "A 55-year-old patient presents with shortness of breath and bilateral leg swelling. What is your differential diagnosis?",
    "What workup would you order for a patient with suspected pulmonary embolism?"
]

for q in questions:
    print(f"\n{'='*60}")
    print(f"Q: {q}")
    print(f"{'='*60}")
    print(ask_medical_question(q))

## 7. Next Steps

Based on this exploration:

1. **Identify best prompt strategies** for your use case
2. **Decide on project direction** (imaging focus, clinical decision support, etc.)
3. **Build prototype** in `03_prototype.ipynb`
4. **Document findings** for the competition writeup

In [None]:
# Save any useful observations here
notes = """
## Observations from Exploration

### Model Strengths:
- 

### Model Limitations:
- 

### Best Prompts:
- 

### Project Ideas:
- 
"""
print(notes)