# PrimaCare AI - Prototype Development

**MedGemma Impact Challenge** - Prototype Notebook

This notebook builds and tests the PrimaCare AI prototype - a multimodal diagnostic support system for primary care.

**Features:**
- Chest X-ray analysis with structured reporting
- Differential diagnosis generation
- Zero-shot classification
- Integration with clinical context

**Run on:** Kaggle (GPU T4/P100) or Google Colab (GPU)

**Competition:** https://www.kaggle.com/competitions/med-gemma-impact-challenge

## 1. Setup & Dependencies

In [None]:
# Install dependencies
!pip install -q -U transformers>=4.50.0 accelerate datasets pillow huggingface-hub gradio

In [None]:
# Standard imports
import torch
from PIL import Image
from datasets import load_dataset
import warnings
warnings.filterwarnings('ignore')

# Check GPU
print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# Hugging Face login (required for MedGemma access)
from huggingface_hub import login

# Option 1: Kaggle secrets
# from kaggle_secrets import UserSecretsClient
# secrets = UserSecretsClient()
# login(token=secrets.get_secret("HF_TOKEN"))

# Option 2: Interactive
login()

## 2. Load Models

In [None]:
from transformers import pipeline, AutoProcessor, AutoModel

# Model IDs
MEDGEMMA_ID = "google/medgemma-1.5-4b-it"
MEDSIGLIP_ID = "google/medsiglip-448"

print(f"Loading MedGemma: {MEDGEMMA_ID}")
medgemma = pipeline(
    "image-text-to-text",
    model=MEDGEMMA_ID,
    torch_dtype=torch.bfloat16,
    device="cuda",
)
print("✓ MedGemma loaded")

print(f"\nLoading MedSigLIP: {MEDSIGLIP_ID}")
siglip_model = AutoModel.from_pretrained(MEDSIGLIP_ID).to("cuda")
siglip_processor = AutoProcessor.from_pretrained(MEDSIGLIP_ID)
print("✓ MedSigLIP loaded")

## 3. Core Functions

In [None]:
def analyze_image(image, prompt, max_tokens=2000):
    """Analyze medical image with MedGemma."""
    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image", "image": image},
                {"type": "text", "text": prompt}
            ]
        }
    ]
    output = medgemma(text=messages, max_new_tokens=max_tokens)
    return output[0]["generated_text"][-1]["content"]


def ask_question(question, max_tokens=1000):
    """Ask medical question without image."""
    messages = [{"role": "user", "content": question}]
    output = medgemma(text=messages, max_new_tokens=max_tokens)
    return output[0]["generated_text"][-1]["content"]


def classify_image(image, labels):
    """Zero-shot classification with MedSigLIP."""
    inputs = siglip_processor(
        text=labels,
        images=[image],
        padding="max_length",
        return_tensors="pt"
    ).to("cuda")
    
    with torch.no_grad():
        outputs = siglip_model(**inputs)
        probs = torch.softmax(outputs.logits_per_image, dim=1)[0]
    
    return {label: prob.item() for label, prob in zip(labels, probs)}


print("✓ Core functions defined")

## 4. PrimaCare AI - Main Pipeline

In [None]:
class PrimaCareAI:
    """
    Primary Care Diagnostic Support System.
    
    Combines MedGemma for analysis and report generation
    with MedSigLIP for classification.
    """
    
    # Classification labels for chest X-rays
    CXR_LABELS = [
        "normal chest x-ray",
        "pneumonia",
        "pleural effusion", 
        "cardiomegaly",
        "pulmonary edema",
        "atelectasis",
        "pneumothorax",
        "consolidation",
        "mass or nodule",
        "interstitial lung disease",
    ]
    
    def __init__(self):
        """Initialize with loaded models."""
        print("PrimaCare AI initialized")
    
    def analyze_xray(self, image, clinical_context=None):
        """
        Complete chest X-ray analysis pipeline.
        
        Returns:
            dict with 'classification', 'findings', 'differential', 'recommendations'
        """
        results = {}
        
        # Step 1: Zero-shot classification
        print("Running classification...")
        probs = classify_image(image, self.CXR_LABELS)
        results['classification'] = probs
        
        # Get top findings for context
        top_findings = sorted(probs.items(), key=lambda x: x[1], reverse=True)[:3]
        classification_context = ", ".join([f"{k} ({v*100:.0f}%)" for k, v in top_findings])
        
        # Step 2: Detailed findings
        print("Generating findings...")
        findings_prompt = """Analyze this chest X-ray systematically:

1. **Technical Quality:** Image quality, positioning, inspiration
2. **Cardiac Silhouette:** Size, contour, position
3. **Lung Fields:** Opacity, lucency, distribution
4. **Pleura:** Effusions, thickening, pneumothorax
5. **Mediastinum:** Width, contour, lymphadenopathy
6. **Bones & Soft Tissues:** Fractures, lesions

Be specific and descriptive."""
        
        results['findings'] = analyze_image(image, findings_prompt)
        
        # Step 3: Differential diagnosis with context
        print("Generating differential diagnosis...")
        ddx_prompt = f"""Based on this chest X-ray, provide a differential diagnosis.

Initial classification suggests: {classification_context}
"""
        if clinical_context:
            ddx_prompt += f"\nClinical context: {clinical_context}\n"
            
        ddx_prompt += """
List the most likely diagnoses in order of probability.
For each diagnosis, explain the supporting findings from the image."""
        
        results['differential'] = analyze_image(image, ddx_prompt)
        
        # Step 4: Recommendations
        print("Generating recommendations...")
        rec_prompt = """Based on the findings in this chest X-ray, provide:

1. **Immediate Actions:** Any urgent findings requiring immediate attention?
2. **Additional Imaging:** CT, lateral view, ultrasound, etc.?
3. **Laboratory Tests:** What bloodwork would help?
4. **Specialist Referral:** Pulmonology, cardiology, oncology, etc.?
5. **Follow-up:** When should this be repeated?

Be practical and specific for a primary care setting."""
        
        results['recommendations'] = analyze_image(image, rec_prompt)
        
        return results
    
    def generate_report(self, image, patient_info=None):
        """
        Generate a complete radiology-style report.
        
        Args:
            image: PIL Image
            patient_info: dict with 'age', 'gender', 'history'
        """
        context = ""
        if patient_info:
            parts = []
            if patient_info.get('age'):
                parts.append(f"Age: {patient_info['age']}")
            if patient_info.get('gender'):
                parts.append(f"Gender: {patient_info['gender']}")
            if patient_info.get('history'):
                parts.append(f"Clinical History: {patient_info['history']}")
            context = "\n".join(parts)
        
        prompt = f"""Generate a comprehensive radiology report for this chest X-ray.

{'Patient Information:\n' + context if context else ''}

**TECHNIQUE:**
[Describe imaging technique and quality]

**COMPARISON:**
[Note if prior studies available]

**FINDINGS:**
Provide detailed, systematic findings:
- Lungs and Airways
- Cardiac Silhouette  
- Mediastinum
- Pleura
- Bones and Soft Tissues

**IMPRESSION:**
[Concise summary of key findings]

**DIFFERENTIAL DIAGNOSIS:**
[List possible diagnoses]

**RECOMMENDATIONS:**
[Clinical correlation and follow-up]"""
        
        return analyze_image(image, prompt, max_tokens=2500)
    
    def explain_to_patient(self, image):
        """
        Generate patient-friendly explanation of X-ray findings.
        """
        prompt = """I need to explain this chest X-ray to my patient in simple, non-medical terms.

Please provide:
1. **What the X-ray shows:** Simple description without jargon
2. **What this means:** Explain implications in everyday language
3. **What happens next:** Simple explanation of next steps
4. **Questions to expect:** Common patient questions and answers

Use empathetic, reassuring language appropriate for a patient consultation."""
        
        return analyze_image(image, prompt)


# Initialize
primacare = PrimaCareAI()

## 5. Load Test Data

In [None]:
# Load NIH Chest X-ray dataset
print("Loading dataset (streaming)...")
dataset = load_dataset(
    "alkzar90/NIH-Chest-X-ray-dataset",
    split="train",
    streaming=True
)

# Get diverse samples
samples = []
for sample in dataset:
    samples.append(sample)
    if len(samples) >= 5:
        break

print(f"\n✓ Loaded {len(samples)} samples")
print(f"Sample labels: {[s.get('labels', []) for s in samples]}")

## 6. Test Complete Pipeline

In [None]:
# Test with first sample
test_sample = samples[0]
test_image = test_sample['image']
ground_truth = test_sample.get('labels', [])

print(f"Ground Truth Labels: {ground_truth}")
print(f"Patient Age: {test_sample.get('Patient Age', 'N/A')}")
print(f"Patient Gender: {test_sample.get('Patient Gender', 'N/A')}")

# Display image
display(test_image.resize((400, 400)))

In [None]:
# Run full analysis pipeline
print("Running PrimaCare AI analysis pipeline...\n")

clinical_context = f"Patient is {test_sample.get('Patient Age', 'unknown age')}, {test_sample.get('Patient Gender', 'unknown gender')}"

results = primacare.analyze_xray(test_image, clinical_context=clinical_context)

print("\n" + "="*60)
print("ANALYSIS COMPLETE")
print("="*60)

In [None]:
# Display classification results
print("\n" + "="*60)
print("CLASSIFICATION RESULTS (MedSigLIP)")
print("="*60)

sorted_probs = sorted(results['classification'].items(), key=lambda x: x[1], reverse=True)
for label, prob in sorted_probs:
    bar = "█" * int(prob * 30) + "░" * (30 - int(prob * 30))
    print(f"{label:25s} {bar} {prob*100:5.1f}%")

In [None]:
# Display findings
print("\n" + "="*60)
print("DETAILED FINDINGS (MedGemma)")
print("="*60)
print(results['findings'])

In [None]:
# Display differential diagnosis
print("\n" + "="*60)
print("DIFFERENTIAL DIAGNOSIS")
print("="*60)
print(results['differential'])

In [None]:
# Display recommendations
print("\n" + "="*60)
print("RECOMMENDATIONS")
print("="*60)
print(results['recommendations'])

## 7. Test Structured Report

In [None]:
# Generate full radiology-style report
patient_info = {
    'age': test_sample.get('Patient Age', 'Unknown'),
    'gender': test_sample.get('Patient Gender', 'Unknown'),
    'history': 'Routine screening' if not ground_truth or 'No Finding' in ground_truth else f'Evaluation for {ground_truth[0] if ground_truth else "symptoms"}'
}

print("Generating structured radiology report...\n")
report = primacare.generate_report(test_image, patient_info)

print("="*60)
print("RADIOLOGY REPORT")
print("="*60)
print(report)

## 8. Test Patient Communication

In [None]:
# Generate patient-friendly explanation
print("Generating patient-friendly explanation...\n")
patient_explanation = primacare.explain_to_patient(test_image)

print("="*60)
print("PATIENT COMMUNICATION")
print("="*60)
print(patient_explanation)

## 9. Batch Testing

In [None]:
# Test on multiple samples
print("Testing on multiple samples...\n")

for i, sample in enumerate(samples[:3]):
    print(f"\n{'='*60}")
    print(f"SAMPLE {i+1}")
    print(f"Ground Truth: {sample.get('labels', [])}")
    print("="*60)
    
    # Display thumbnail
    display(sample['image'].resize((200, 200)))
    
    # Quick classification
    probs = classify_image(sample['image'], primacare.CXR_LABELS)
    top_3 = sorted(probs.items(), key=lambda x: x[1], reverse=True)[:3]
    
    print("\nTop 3 Classifications:")
    for label, prob in top_3:
        print(f"  - {label}: {prob*100:.1f}%")

## 10. Interactive Demo (Gradio)

In [None]:
import gradio as gr

def demo_analyze(image, analysis_type, clinical_history):
    """Demo analysis function."""
    if image is None:
        return "Please upload an image."
    
    if analysis_type == "Quick Classification":
        probs = classify_image(image, primacare.CXR_LABELS)
        sorted_probs = sorted(probs.items(), key=lambda x: x[1], reverse=True)
        output = "**Classification Results:**\n\n"
        for label, prob in sorted_probs:
            bar = "█" * int(prob * 20) + "░" * (20 - int(prob * 20))
            output += f"{label}: {bar} {prob*100:.1f}%\n"
        return output
    
    elif analysis_type == "Full Analysis":
        results = primacare.analyze_xray(image, clinical_history)
        output = "**Classification:**\n"
        top_3 = sorted(results['classification'].items(), key=lambda x: x[1], reverse=True)[:3]
        for label, prob in top_3:
            output += f"- {label}: {prob*100:.1f}%\n"
        output += f"\n**Findings:**\n{results['findings']}\n"
        output += f"\n**Differential:**\n{results['differential']}\n"
        output += f"\n**Recommendations:**\n{results['recommendations']}"
        return output
    
    elif analysis_type == "Structured Report":
        return primacare.generate_report(image)
    
    elif analysis_type == "Patient Explanation":
        return primacare.explain_to_patient(image)


# Create demo interface
demo = gr.Interface(
    fn=demo_analyze,
    inputs=[
        gr.Image(type="pil", label="Chest X-ray"),
        gr.Dropdown(
            choices=["Quick Classification", "Full Analysis", "Structured Report", "Patient Explanation"],
            value="Full Analysis",
            label="Analysis Type"
        ),
        gr.Textbox(label="Clinical History (optional)", placeholder="e.g., 65yo male, cough x 2 weeks")
    ],
    outputs=gr.Textbox(label="Results", lines=25),
    title="PrimaCare AI - Chest X-ray Analysis",
    description="Upload a chest X-ray for AI-powered diagnostic support."
)

# Launch (in notebook mode)
demo.launch(share=True, inline=True)

## 11. Performance Metrics

In [None]:
import time

# Benchmark inference times
print("Benchmarking inference times...\n")

test_image = samples[0]['image']

# Classification time
start = time.time()
_ = classify_image(test_image, primacare.CXR_LABELS)
classify_time = time.time() - start
print(f"Classification (MedSigLIP): {classify_time:.2f}s")

# Simple analysis time
start = time.time()
_ = analyze_image(test_image, "List the key findings.")
simple_time = time.time() - start
print(f"Simple analysis (MedGemma): {simple_time:.2f}s")

# Full report time
start = time.time()
_ = primacare.generate_report(test_image)
report_time = time.time() - start
print(f"Full report (MedGemma): {report_time:.2f}s")

print(f"\nGPU Memory Used: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
print(f"GPU Memory Cached: {torch.cuda.memory_reserved() / 1e9:.2f} GB")

## 12. Summary & Next Steps

In [None]:
summary = """
## PrimaCare AI - Prototype Summary

### Capabilities Demonstrated:
✓ Zero-shot classification with MedSigLIP
✓ Detailed findings analysis with MedGemma
✓ Differential diagnosis generation
✓ Structured radiology-style reports
✓ Patient-friendly explanations
✓ Interactive Gradio demo

### Technical Performance:
- Classification: ~{:.1f}s per image
- Full analysis: ~{:.1f}s per image
- GPU memory: ~{:.1f} GB

### Next Steps:
1. [ ] Fine-tune prompts for better outputs
2. [ ] Add more pathology-specific prompts
3. [ ] Implement multi-image comparison
4. [ ] Create final submission notebook
5. [ ] Record video demo
6. [ ] Write competition writeup
""".format(classify_time, report_time, torch.cuda.memory_allocated() / 1e9)

print(summary)

---

## Resources

- [MedGemma Model](https://huggingface.co/google/medgemma-1.5-4b-it)
- [MedSigLIP Model](https://huggingface.co/google/medsiglip-448)
- [MedGemma GitHub](https://github.com/Google-Health/medgemma)
- [Competition Page](https://www.kaggle.com/competitions/med-gemma-impact-challenge)