# PrimaCare AI - Prototype

**MedGemma Impact Challenge** - Working Prototype

Complete diagnostic support system combining MedGemma + MedSigLIP.

**Requirements:**
- Kaggle GPU (T4/P100) or Colab GPU
- HF_TOKEN secret
- Accept terms: https://huggingface.co/google/medgemma-1.5-4b-it

## 1. Setup

In [None]:
# Uncomment for Colab:
# !pip install -q -U transformers>=4.50.0 accelerate datasets pillow huggingface-hub

In [None]:
import torch
import warnings
warnings.filterwarnings('ignore')

print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
else:
    raise RuntimeError("GPU required! Enable in Settings.")

In [None]:
from huggingface_hub import login

try:
    from kaggle_secrets import UserSecretsClient
    login(token=UserSecretsClient().get_secret("HF_TOKEN"))
    print("✓ Logged in via Kaggle")
except:
    login()
    print("✓ Logged in interactively")

## 2. Load Models

In [None]:
from transformers import pipeline, AutoProcessor, AutoModel

print("Loading MedGemma...")
medgemma = pipeline(
    "image-text-to-text",
    model="google/medgemma-1.5-4b-it",
    torch_dtype=torch.bfloat16,
    device="cuda",
)
print("✓ MedGemma loaded")

print("Loading MedSigLIP...")
siglip_model = AutoModel.from_pretrained("google/medsiglip-448").to("cuda")
siglip_processor = AutoProcessor.from_pretrained("google/medsiglip-448")
print("✓ MedSigLIP loaded")

## 3. Core Functions

In [None]:
from PIL import Image

def analyze_image(image, prompt, max_tokens=2000):
    """Analyze medical image with MedGemma."""
    if image.mode != "RGB":
        image = image.convert("RGB")
    messages = [{
        "role": "user",
        "content": [
            {"type": "image", "image": image},
            {"type": "text", "text": prompt}
        ]
    }]
    output = medgemma(text=messages, max_new_tokens=max_tokens)
    return output[0]["generated_text"][-1]["content"]


def ask_question(question, max_tokens=1000):
    """Ask medical question without image."""
    # Note: image-text-to-text pipeline requires content as list of dicts
    messages = [{
        "role": "user",
        "content": [{"type": "text", "text": question}]
    }]
    output = medgemma(text=messages, max_new_tokens=max_tokens)
    return output[0]["generated_text"][-1]["content"]


def classify_image(image, labels):
    """Zero-shot classification with MedSigLIP."""
    if image.mode != "RGB":
        image = image.convert("RGB")
    inputs = siglip_processor(
        text=labels, images=[image],
        padding="max_length", return_tensors="pt"
    ).to("cuda")
    with torch.no_grad():
        outputs = siglip_model(**inputs)
        probs = torch.softmax(outputs.logits_per_image, dim=1)[0]
    return {label: prob.item() for label, prob in zip(labels, probs)}


print("✓ Functions ready")

## 4. PrimaCare AI Class

In [None]:
class PrimaCareAI:
    """Primary Care Diagnostic Support System."""
    
    CXR_LABELS = [
        "normal chest x-ray", "pneumonia", "pleural effusion", 
        "cardiomegaly", "pulmonary edema", "atelectasis",
        "pneumothorax", "consolidation", "mass or nodule"
    ]
    
    def analyze_xray(self, image, clinical_context=None):
        """Complete X-ray analysis pipeline."""
        results = {}
        
        # Step 1: Classification
        print("  [1/4] Classifying...")
        results['classification'] = classify_image(image, self.CXR_LABELS)
        
        # Step 2: Findings
        print("  [2/4] Extracting findings...")
        results['findings'] = analyze_image(image, 
            "List all findings in this chest X-ray systematically.")
        
        # Step 3: Differential
        print("  [3/4] Generating differential...")
        context = f"Clinical context: {clinical_context}\n" if clinical_context else ""
        results['differential'] = analyze_image(image,
            f"{context}Provide differential diagnosis for this chest X-ray.")
        
        # Step 4: Recommendations
        print("  [4/4] Generating recommendations...")
        results['recommendations'] = analyze_image(image,
            "What follow-up or workup is recommended based on this X-ray?")
        
        return results
    
    def generate_report(self, image, patient_info=None):
        """Generate structured radiology report."""
        context = ""
        if patient_info:
            parts = []
            if patient_info.get('age'): parts.append(f"Age: {patient_info['age']}")
            if patient_info.get('gender'): parts.append(f"Gender: {patient_info['gender']}")
            if patient_info.get('history'): parts.append(f"History: {patient_info['history']}")
            context = "\n".join(parts)
        
        prompt = f"""Generate a radiology report for this chest X-ray.
{f'Patient: {context}' if context else ''}

**TECHNIQUE:** [describe]
**FINDINGS:** [systematic findings]
**IMPRESSION:** [summary]
**RECOMMENDATIONS:** [if any]"""
        
        return analyze_image(image, prompt, max_tokens=2000)


primacare = PrimaCareAI()
print("✓ PrimaCareAI initialized")

## 5. Load Test Data

In [None]:
from datasets import load_dataset
import requests
from io import BytesIO

# Option 1: Load from dataset
print("Loading dataset...")
dataset = load_dataset("hf-vision/chest-xray-pneumonia", split="train", streaming=True)
samples = list(dataset.take(3))
print(f"✓ Loaded {len(samples)} samples")

# Option 2: Sample from web
url = "https://upload.wikimedia.org/wikipedia/commons/c/c8/Chest_Xray_PA_3-8-2010.png"
response = requests.get(url, headers={"User-Agent": "MedGemma"})
web_image = Image.open(BytesIO(response.content)).convert("RGB")
print("✓ Web sample loaded")

## 6. Test Pipeline

In [None]:
test_image = samples[0]['image']
label = "Pneumonia" if samples[0].get('label', 0) == 1 else "Normal"

print(f"Ground Truth: {label}")
display(test_image.resize((300, 300)))

In [None]:
print("Running PrimaCare analysis...\n")
results = primacare.analyze_xray(
    test_image, 
    clinical_context="Adult patient with cough and fever"
)
print("\n✓ Analysis complete!")

In [None]:
# Classification results
print("="*60)
print("CLASSIFICATION (MedSigLIP)")
print("="*60)
for label, prob in sorted(results['classification'].items(), key=lambda x: x[1], reverse=True):
    bar = "█" * int(prob * 25) + "░" * (25 - int(prob * 25))
    print(f"{label:25s} {bar} {prob*100:5.1f}%")

In [None]:
print("="*60)
print("FINDINGS")
print("="*60)
print(results['findings'])

In [None]:
print("="*60)
print("DIFFERENTIAL DIAGNOSIS")
print("="*60)
print(results['differential'])

In [None]:
print("="*60)
print("RECOMMENDATIONS")
print("="*60)
print(results['recommendations'])

## 7. Structured Report

In [None]:
print("Generating structured report...\n")
report = primacare.generate_report(
    test_image,
    patient_info={'age': '45', 'gender': 'Female', 'history': 'Cough x 1 week'}
)

print("="*60)
print("RADIOLOGY REPORT")
print("="*60)
print(report)

## 8. Batch Test

In [None]:
print("Testing on multiple samples...\n")

for i, sample in enumerate(samples):
    label = "Pneumonia" if sample.get('label', 0) == 1 else "Normal"
    
    print(f"\n{'='*50}")
    print(f"Sample {i+1} | Ground Truth: {label}")
    print("="*50)
    
    display(sample['image'].resize((150, 150)))
    
    # Quick classification
    probs = classify_image(sample['image'], primacare.CXR_LABELS)
    top = sorted(probs.items(), key=lambda x: x[1], reverse=True)[:3]
    print("Top 3:")
    for lbl, prob in top:
        print(f"  {lbl}: {prob*100:.1f}%")

## 9. Summary

In [None]:
print("""
========================================
PRIMACARE AI - PROTOTYPE SUMMARY
========================================

✓ MedGemma 1.5 4B - Image analysis
✓ MedSigLIP - Zero-shot classification
✓ Full analysis pipeline
✓ Structured report generation

Next: Run 04_agentic_workflow.ipynb
for the multi-agent demo.
""")