# PrimaCare AI - Agentic Workflow Demo

**MedGemma Impact Challenge** - Agentic Workflow Prize Submission

This notebook demonstrates a **multi-agent diagnostic support system** for primary care.

## Architecture

```
┌─────────────┐     ┌─────────────┐     ┌─────────────┐
│   Intake    │ --> │   Imaging   │ --> │  Reasoning  │
│   Agent     │     │   Agent     │     │   Agent     │
└─────────────┘     └─────────────┘     └─────────────┘
     │                    │                   │
     v                    v                   v
Structured HPI      Image Analysis      Differential Dx
                                        + Recommendations
```

## Agents

1. **IntakeAgent**: Structures patient history into HPI format
2. **ImagingAgent**: Analyzes chest X-rays (MedGemma + MedSigLIP)
3. **ReasoningAgent**: Generates differential diagnosis and recommendations
4. **PrimaCareOrchestrator**: Coordinates all agents

**Run on:** Kaggle (GPU T4/P100) or Google Colab (GPU)

## 1. Setup

In [None]:
# Uncomment for Colab (pre-installed on Kaggle):
# !pip install -q -U transformers>=4.50.0 accelerate datasets pillow huggingface-hub

In [None]:
import torch
import warnings
warnings.filterwarnings('ignore')

print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
from huggingface_hub import login

try:
    from kaggle_secrets import UserSecretsClient
    login(token=UserSecretsClient().get_secret("HF_TOKEN"))
    print("✓ Logged in via Kaggle")
except:
    login()
    print("✓ Logged in interactively")

## 2. Load Models

In [None]:
from transformers import pipeline, AutoProcessor, AutoModel

# Load MedGemma
print("Loading MedGemma...")
medgemma = pipeline(
    "image-text-to-text",
    model="google/medgemma-1.5-4b-it",
    torch_dtype=torch.bfloat16,
    device="cuda",
)
print("✓ MedGemma loaded")

# Load MedSigLIP
print("Loading MedSigLIP...")
siglip_model = AutoModel.from_pretrained("google/medsiglip-448").to("cuda")
siglip_processor = AutoProcessor.from_pretrained("google/medsiglip-448")
print("✓ MedSigLIP loaded")

## 3. Define Agent Classes

In [None]:
from dataclasses import dataclass, field
from typing import Optional, List, Dict, Any
from PIL import Image
from enum import Enum

# ============================================================================
# Data Classes
# ============================================================================

class Urgency(Enum):
    ROUTINE = "routine"
    SOON = "soon"
    URGENT = "urgent"
    EMERGENT = "emergent"

@dataclass
class PatientContext:
    chief_complaint: str
    history: str = ""
    age: Optional[int] = None
    gender: Optional[str] = None
    structured_hpi: str = ""
    red_flags: List[str] = field(default_factory=list)
    urgency: Urgency = Urgency.ROUTINE

@dataclass  
class ImagingResult:
    findings: str = ""
    impression: str = ""
    classification: Dict[str, float] = field(default_factory=dict)
    urgent: bool = False

@dataclass
class ClinicalAssessment:
    differential: List[str] = field(default_factory=list)
    most_likely: str = ""
    workup: List[str] = field(default_factory=list)
    disposition: str = ""
    follow_up: str = ""
    patient_instructions: str = ""

print("✓ Data classes defined")

In [None]:
# ============================================================================
# AGENT 1: Intake Agent
# ============================================================================

class IntakeAgent:
    """Structures patient history into clinical format."""
    
    RED_FLAGS = [
        "chest pain", "shortness of breath", "hemoptysis", "syncope",
        "severe headache", "fever", "weight loss", "night sweats"
    ]
    
    def process(self, chief_complaint: str, history: str = "", 
                age: int = None, gender: str = None) -> PatientContext:
        """Structure patient information."""
        
        # Build input text
        input_text = f"Chief Complaint: {chief_complaint}"
        if history:
            input_text += f"\nHistory: {history}"
        if age:
            input_text += f"\nAge: {age}"
        if gender:
            input_text += f"\nGender: {gender}"
        
        # Structure HPI using MedGemma
        prompt = f"""Structure this patient information into a formal History of Present Illness (HPI).

{input_text}

Extract: Onset, Location, Duration, Character, Aggravating/Relieving factors, Timing, Severity, Associated symptoms.
Also identify any red flag symptoms."""
        
        # Note: image-text-to-text pipeline requires content as list of dicts
        messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
        output = medgemma(text=messages, max_new_tokens=800)
        structured_hpi = output[0]["generated_text"][-1]["content"]
        
        # Check for red flags
        combined_text = (chief_complaint + " " + history).lower()
        found_flags = [f for f in self.RED_FLAGS if f in combined_text]
        
        # Determine urgency
        urgency = Urgency.ROUTINE
        if found_flags:
            urgency = Urgency.SOON
        if any(f in combined_text for f in ["chest pain", "shortness of breath", "syncope"]):
            urgency = Urgency.URGENT
        
        return PatientContext(
            chief_complaint=chief_complaint,
            history=history,
            age=age,
            gender=gender,
            structured_hpi=structured_hpi,
            red_flags=found_flags,
            urgency=urgency
        )

intake_agent = IntakeAgent()
print("✓ IntakeAgent ready")

In [None]:
# ============================================================================
# AGENT 2: Imaging Agent
# ============================================================================

class ImagingAgent:
    """Analyzes medical images using MedGemma + MedSigLIP."""
    
    CXR_LABELS = [
        "normal chest x-ray", "pneumonia", "pleural effusion",
        "cardiomegaly", "pulmonary edema", "atelectasis",
        "pneumothorax", "consolidation", "mass or nodule"
    ]
    
    def analyze(self, image: Image.Image, clinical_context: str = "") -> ImagingResult:
        """Analyze chest X-ray."""
        
        # IMPORTANT: Convert to RGB for MedSigLIP
        if image.mode != "RGB":
            image = image.convert("RGB")
        
        # Step 1: Zero-shot classification with MedSigLIP
        inputs = siglip_processor(
            text=self.CXR_LABELS,
            images=[image],
            padding="max_length",
            return_tensors="pt"
        ).to("cuda")
        
        with torch.no_grad():
            outputs = siglip_model(**inputs)
            probs = torch.softmax(outputs.logits_per_image, dim=1)[0]
        
        classification = {label: prob.item() for label, prob in zip(self.CXR_LABELS, probs)}
        
        # Step 2: Detailed analysis with MedGemma
        context = f"Clinical context: {clinical_context}\n" if clinical_context else ""
        prompt = f"""{context}Analyze this chest X-ray systematically.

Provide:
1. FINDINGS: Describe cardiac silhouette, lung fields, pleura, mediastinum, bones
2. IMPRESSION: Key findings summary
3. URGENT: Are there findings requiring immediate attention? (YES/NO)"""
        
        messages = [{
            "role": "user",
            "content": [
                {"type": "image", "image": image},
                {"type": "text", "text": prompt}
            ]
        }]
        
        output = medgemma(text=messages, max_new_tokens=1500)
        analysis = output[0]["generated_text"][-1]["content"]
        
        # Parse urgency
        urgent = "YES" in analysis.upper() and "URGENT" in analysis.upper()
        
        return ImagingResult(
            findings=analysis,
            impression="",  # Could parse this out
            classification=classification,
            urgent=urgent
        )

imaging_agent = ImagingAgent()
print("✓ ImagingAgent ready")

In [None]:
# ============================================================================
# AGENT 3: Reasoning Agent
# ============================================================================

class ReasoningAgent:
    """Generates clinical assessment from combined inputs."""
    
    def assess(self, patient: PatientContext, imaging: ImagingResult = None) -> ClinicalAssessment:
        """Generate clinical assessment."""
        
        # Build context
        clinical_info = f"""
## Patient Information
Chief Complaint: {patient.chief_complaint}
Age: {patient.age or 'Unknown'}
Gender: {patient.gender or 'Unknown'}

## Structured HPI
{patient.structured_hpi}

## Red Flags Identified
{', '.join(patient.red_flags) if patient.red_flags else 'None'}
"""
        
        imaging_info = ""
        if imaging:
            top_class = sorted(imaging.classification.items(), key=lambda x: x[1], reverse=True)[:3]
            class_str = ", ".join([f"{k} ({v*100:.0f}%)" for k, v in top_class])
            imaging_info = f"""
## Imaging Findings
{imaging.findings}

## Classification
{class_str}
"""
        
        prompt = f"""You are an experienced primary care physician. Based on the following information, provide a clinical assessment.

{clinical_info}
{imaging_info}

Provide:

**MOST LIKELY DIAGNOSIS:**
[Single most likely diagnosis]

**DIFFERENTIAL DIAGNOSIS:**
1. [Diagnosis 1]
2. [Diagnosis 2]
3. [Diagnosis 3]
4. [Diagnosis 4]
5. [Diagnosis 5]

**RECOMMENDED WORKUP:**
- [Test 1]
- [Test 2]

**DISPOSITION:**
[Outpatient / Urgent Care / ED / Admit]

**FOLLOW-UP:**
[Timing and specialist if needed]

**PATIENT INSTRUCTIONS:**
[Key instructions in plain language]"""
        
        # Note: image-text-to-text pipeline requires content as list of dicts
        messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
        output = medgemma(text=messages, max_new_tokens=2000)
        assessment_text = output[0]["generated_text"][-1]["content"]
        
        # Parse (simplified)
        return ClinicalAssessment(
            differential=[],  # Would parse from response
            most_likely="See full assessment",
            workup=[],
            disposition="See full assessment",
            follow_up="See full assessment",
            patient_instructions=assessment_text
        )

reasoning_agent = ReasoningAgent()
print("✓ ReasoningAgent ready")

In [None]:
# ============================================================================
# ORCHESTRATOR
# ============================================================================

class PrimaCareOrchestrator:
    """Coordinates all agents for end-to-end diagnostic support."""
    
    def __init__(self):
        self.intake = intake_agent
        self.imaging = imaging_agent
        self.reasoning = reasoning_agent
    
    def run(self, chief_complaint: str, history: str = "",
            image: Image.Image = None, age: int = None, gender: str = None):
        """
        Run complete diagnostic pipeline.
        
        Returns dict with all agent outputs.
        """
        results = {"steps": []}
        
        # Step 1: Intake
        print("\n[Step 1] IntakeAgent: Processing patient information...")
        patient = self.intake.process(chief_complaint, history, age, gender)
        results["patient"] = patient
        results["steps"].append("intake_complete")
        print(f"  - Urgency: {patient.urgency.value}")
        print(f"  - Red flags: {patient.red_flags or 'None'}")
        
        # Step 2: Imaging (if provided)
        imaging_result = None
        if image is not None:
            print("\n[Step 2] ImagingAgent: Analyzing chest X-ray...")
            clinical_ctx = f"{chief_complaint}. {history}"
            imaging_result = self.imaging.analyze(image, clinical_ctx)
            results["imaging"] = imaging_result
            results["steps"].append("imaging_complete")
            
            # Show top classifications
            top = sorted(imaging_result.classification.items(), key=lambda x: x[1], reverse=True)[:3]
            print("  Top classifications:")
            for label, prob in top:
                print(f"    - {label}: {prob*100:.1f}%")
        else:
            print("\n[Step 2] ImagingAgent: No image provided, skipping...")
            results["steps"].append("imaging_skipped")
        
        # Step 3: Reasoning
        print("\n[Step 3] ReasoningAgent: Generating clinical assessment...")
        assessment = self.reasoning.assess(patient, imaging_result)
        results["assessment"] = assessment
        results["steps"].append("reasoning_complete")
        
        print("\n✓ Pipeline complete!")
        return results

orchestrator = PrimaCareOrchestrator()
print("✓ PrimaCareOrchestrator ready")

## 4. Load Test Data

In [None]:
from datasets import load_dataset

print("Loading chest X-ray dataset...")
# Use streaming dataset that works reliably on Kaggle
dataset = load_dataset("hf-vision/chest-xray-pneumonia", split="train", streaming=True)
samples = list(dataset.take(3))
print(f"✓ Loaded {len(samples)} samples")
print(f"Labels: 0=Normal, 1=Pneumonia")

## 5. Demo: Complete Pipeline

In [None]:
# Test case: Simulated patient with chest X-ray
test_image = samples[0]["image"]

# Convert to RGB (required for MedSigLIP)
if test_image.mode != "RGB":
    test_image = test_image.convert("RGB")

label = "Pneumonia" if samples[0].get('label', 0) == 1 else "Normal"

print("TEST CASE")
print("="*60)
print(f"Ground truth: {label}")
display(test_image.resize((300, 300)))

In [None]:
# Run the full agentic pipeline
results = orchestrator.run(
    chief_complaint="Cough for 2 weeks with fever",
    history="65 year old male smoker. Started with dry cough, now productive with yellow sputum. Low grade fever. Night sweats. 10 pound weight loss over past month.",
    image=test_image,
    age=65,
    gender="male"
)

In [None]:
# Display structured HPI
print("\n" + "="*60)
print("STRUCTURED HPI (IntakeAgent)")
print("="*60)
print(results["patient"].structured_hpi)

In [None]:
# Display imaging findings
if "imaging" in results:
    print("\n" + "="*60)
    print("IMAGING ANALYSIS (ImagingAgent)")
    print("="*60)
    print(results["imaging"].findings)
    
    print("\nClassification Results:")
    for label, prob in sorted(results["imaging"].classification.items(), key=lambda x: x[1], reverse=True):
        bar = "█" * int(prob * 30) + "░" * (30 - int(prob * 30))
        print(f"  {label:25s} {bar} {prob*100:5.1f}%")

In [None]:
# Display clinical assessment
print("\n" + "="*60)
print("CLINICAL ASSESSMENT (ReasoningAgent)")
print("="*60)
print(results["assessment"].patient_instructions)

## 6. Demo: Text-Only (No Imaging)

In [None]:
# Test without imaging
results_no_img = orchestrator.run(
    chief_complaint="Chest pain for 3 hours",
    history="45 year old female. Substernal chest pressure radiating to left arm. Started at rest. Diaphoresis. No prior cardiac history. Takes birth control pills.",
    image=None,  # No image
    age=45,
    gender="female"
)

In [None]:
print("\n" + "="*60)
print("CLINICAL ASSESSMENT (No Imaging)")
print("="*60)
print(f"Urgency: {results_no_img['patient'].urgency.value.upper()}")
print(f"Red Flags: {results_no_img['patient'].red_flags}")
print("\n" + results_no_img["assessment"].patient_instructions)

## 7. Interactive Demo

In [None]:
import gradio as gr

def run_primacare(chief_complaint, history, image, age, gender):
    """Gradio interface for PrimaCare AI."""
    if not chief_complaint:
        return "Please enter a chief complaint."
    
    age_int = int(age) if age else None
    
    # Convert image to RGB if provided
    if image is not None and image.mode != "RGB":
        image = image.convert("RGB")
    
    results = orchestrator.run(
        chief_complaint=chief_complaint,
        history=history or "",
        image=image,
        age=age_int,
        gender=gender or None
    )
    
    # Format output
    output = []
    output.append("=" * 50)
    output.append("PRIMACARE AI - CLINICAL ASSESSMENT")
    output.append("=" * 50)
    output.append(f"\nUrgency: {results['patient'].urgency.value.upper()}")
    
    if results['patient'].red_flags:
        output.append(f"Red Flags: {', '.join(results['patient'].red_flags)}")
    
    output.append("\n" + "-" * 50)
    output.append("STRUCTURED HPI")
    output.append("-" * 50)
    output.append(results['patient'].structured_hpi)
    
    if 'imaging' in results:
        output.append("\n" + "-" * 50)
        output.append("IMAGING ANALYSIS")
        output.append("-" * 50)
        top = sorted(results['imaging'].classification.items(), key=lambda x: x[1], reverse=True)[:3]
        for label, prob in top:
            output.append(f"  {label}: {prob*100:.1f}%")
        output.append("\n" + results['imaging'].findings[:1000])
    
    output.append("\n" + "-" * 50)
    output.append("CLINICAL ASSESSMENT")
    output.append("-" * 50)
    output.append(results['assessment'].patient_instructions)
    
    return "\n".join(output)

# Create interface
demo = gr.Interface(
    fn=run_primacare,
    inputs=[
        gr.Textbox(label="Chief Complaint", placeholder="Cough for 2 weeks"),
        gr.Textbox(label="History", placeholder="65yo male smoker...", lines=3),
        gr.Image(type="pil", label="Chest X-ray (optional)"),
        gr.Textbox(label="Age", placeholder="65"),
        gr.Dropdown(["male", "female", "other"], label="Gender"),
    ],
    outputs=gr.Textbox(label="Assessment", lines=30),
    title="PrimaCare AI - Agentic Diagnostic Support",
    description="Multi-agent system for primary care clinical decision support."
)

demo.launch(share=True)

## 8. Summary

### Agentic Architecture

This notebook demonstrates a **multi-agent diagnostic support system** with:

1. **IntakeAgent**: Structures patient history, identifies red flags
2. **ImagingAgent**: Analyzes X-rays with MedGemma + MedSigLIP
3. **ReasoningAgent**: Generates differential diagnosis and recommendations
4. **PrimaCareOrchestrator**: Coordinates all agents

### Key Features

- **Modular Design**: Each agent has a specific role
- **Multimodal**: Combines text (history) and images (X-ray)
- **Urgency Tracking**: Identifies red flags and escalates appropriately
- **End-to-End**: From patient intake to clinical recommendation

### Competition Fit

This targets the **Agentic Workflow Prize** by demonstrating:
- Complex workflow reimagined with AI agents
- HAI-DEF models as intelligent tools
- Real clinical use case (primary care)

---

**MedGemma Impact Challenge 2026**