# CultureSense — Longitudinal Clinical Hypothesis Engine
## Kaggle HAI-DEF Competition Submission

[![MedGemma](https://img.shields.io/badge/MedGemma-4b--it-blue)](https://huggingface.co/google/medgemma-4b-it)
[![Safety](https://img.shields.io/badge/Safety-Non--Diagnostic-green)]()
[![Mode](https://img.shields.io/badge/Mode-Patient%20%2B%20Clinician-purple)]()

> **CultureSense** processes 2–3 sequential urine or stool culture lab reports and produces
> structured, **non-diagnostic** interpretations through two distinct output modes.
> MedGemma handles natural language generation from already-structured inputs.
> Deterministic rules handle all temporal signal extraction.

---

## Architecture

```mermaid
flowchart TD
    A["[1] Raw Report Ingestion\nList[str] 2-3 free-text culture reports"] --> B
    B["[2] Structured Extraction Layer\nextract_structured_data() → CultureReport"] --> C
    C["[3] Temporal Comparison Engine\nanalyze_trend() → TrendResult"] --> D
    D["[4] Hypothesis Update Layer\ngenerate_hypothesis() → HypothesisResult\nconfidence [0.0–0.95]"] --> E
    E["[5] MedGemma Reasoning Layer\ncall_medgemma(structured_payload, mode) → str\nModes: patient | clinician"] --> F
    F["[6] Structured Safe Output Renderer\nrender_output() → FormattedOutput\nPatient: explanation + questions\nClinician: trajectory + confidence + flags"]

    style A fill:#e8f4f8
    style B fill:#d4edda
    style C fill:#d4edda
    style D fill:#fff3cd
    style E fill:#f8d7da
    style F fill:#e8f4f8
```

**Key safety invariant:** Raw report text is NEVER forwarded to MedGemma.
Only derived structured fields (typed dataclasses → JSON) are passed to the model.

---


## Cell A: Setup & Imports

In [None]:
# Cell A-1: Repository Setup (for Colab/Kaggle)
import os
if not os.path.exists('culturesense'):
    !git clone https://github.com/shuknuk/culturesense.git
    %cd culturesense
else:
    print("Repository 'culturesense' already exists.")


In [None]:
# Cell A-2: Library Installation
import subprocess, sys

packages = [
    "transformers>=4.40.0",
    "accelerate>=0.29.0",
    "sentencepiece>=0.1.99",
    "huggingface_hub>=0.22.0",
    "docling",
]

for pkg in packages:
    subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", pkg])

print("Installation complete.")


In [None]:
# Cell A-3: Core Imports
from __future__ import annotations
import re, json, warnings
from dataclasses import dataclass, field, asdict
from typing import List, Dict, Optional, Tuple

try:
    import torch
    from transformers import AutoTokenizer, AutoModelForCausalLM
    TRANSFORMERS_AVAILABLE = True
except ImportError:
    TRANSFORMERS_AVAILABLE = False
    print("transformers not available — stub mode will be used.")

print("Imports complete.")


## Cell B: Data Models & Rule Library

In [None]:

from dataclasses import dataclass, field
from typing import List, Optional, Dict


@dataclass
class CultureReport:
    """
    Structured representation of a single culture lab report.

    Fields:
        date: ISO 8601 formatted date string (YYYY-MM-DD)
        organism: Name of identified organism (e.g., "E. coli")
        cfu: Colony Forming Units per mL
        resistance_markers: List of resistance markers (subset of ["ESBL","CRE","MRSA","VRE","CRKP"])
        specimen_type: Type of specimen ("urine" | "stool" | "unknown")
        contamination_flag: True if organism matches contamination terms
        raw_text: Original report string (NEVER passed to LLM)
    """

    date: str
    organism: str
    cfu: int
    resistance_markers: List[str]
    specimen_type: str
    contamination_flag: bool
    raw_text: str


@dataclass
class TrendResult:
    """
    Temporal comparison analysis across multiple culture reports.

    Fields:
        cfu_trend: "decreasing" | "increasing" | "fluctuating" | "cleared" | "insufficient_data"
        cfu_values: Ordered list of CFU values across reports
        cfu_deltas: Per-interval changes in CFU
        organism_persistent: True if same organism across all reports
        organism_list: Organism name per report
        resistance_evolution: True if new markers appear in later reports
        resistance_timeline: Resistance markers per report
        report_dates: ISO dates in sorted order
        any_contamination: True if any report flagged as contamination
    """

    cfu_trend: str
    cfu_values: List[int]
    cfu_deltas: List[int]
    organism_persistent: bool
    organism_list: List[str]
    resistance_evolution: bool
    resistance_timeline: List[List[str]]
    report_dates: List[str]
    any_contamination: bool


@dataclass
class HypothesisResult:
    """
    Rule-generated hypothesis with confidence scoring.

    Fields:
        interpretation: Natural language pattern summary (rule-generated)
        confidence: Confidence score [0.0, 0.95] - never 1.0
        risk_flags: List of risk flags (e.g., ["EMERGING_RESISTANCE", "CONTAMINATION"])
        stewardship_alert: True if resistance_evolution is True
        requires_clinician_review: Always True - structural safety guarantee
    """

    interpretation: str
    confidence: float
    risk_flags: List[str]
    stewardship_alert: bool
    requires_clinician_review: bool = True


@dataclass
class MedGemmaPayload:
    """
    Structured payload for MedGemma model inference.

    CRITICAL: raw_text from CultureReport is NEVER included in this payload.
    Only derived structured fields are forwarded.

    Fields:
        mode: "patient" | "clinician"
        trend_summary: Serialized TrendResult
        hypothesis_summary: Serialized HypothesisResult
        safety_constraints: Injected safety instructions
        output_schema: Expected output fields for this mode
    """

    mode: str
    trend_summary: dict
    hypothesis_summary: dict
    safety_constraints: List[str]
    output_schema: dict


@dataclass
class FormattedOutput:
    """
    Final rendered output for either Patient or Clinician mode.

    Fields are mode-specific. Patient mode uses patient_* fields,
    Clinician mode uses clinician_* fields.
    """

    mode: str

    # Patient mode fields
    patient_explanation: Optional[str] = None
    patient_trend_phrase: Optional[str] = None
    patient_questions: Optional[List[str]] = None
    patient_disclaimer: str = ""

    # Clinician mode fields
    clinician_trajectory: Optional[dict] = None
    clinician_interpretation: Optional[str] = None
    clinician_confidence: Optional[float] = None
    clinician_resistance_detail: Optional[str] = None
    clinician_stewardship_flag: Optional[bool] = None
    clinician_disclaimer: str = ""

In [None]:

# ---------------------------------------------------------------------------
# Core clinical rules and thresholds
# ---------------------------------------------------------------------------
RULES = {
    # CFU/mL threshold above which a urine specimen is considered infected
    "infection_threshold_urine": 100000,
    # CFU/mL threshold above which a stool specimen is considered infected
    "infection_threshold_stool": 50000,
    # A reduction of 75%+ from the previous reading is a strong improvement
    "significant_reduction_pct": 0.75,
    # Organism names indicating sample contamination rather than true infection
    "contamination_terms": [
        "mixed flora",
        "skin flora",
        "normal flora",
        "commensal",
        "contamination",
        "mixed growth",
    ],
    # High-risk resistance markers tracked by the rule engine
    "high_risk_markers": ["ESBL", "CRE", "MRSA", "VRE", "CRKP"],
    # CFU/mL at or below this value is treated as effectively cleared
    "cleared_threshold": 1000,
    # Hard ceiling on confidence - epistemic humility; never 1.0
    "max_confidence": 0.95,
    # Starting confidence before any signal adjustments
    "base_confidence": 0.50,
}

# ---------------------------------------------------------------------------
# Organism alias normalisation lookup table
# Maps common shorthand/abbreviations → canonical organism name.
# Matching is performed case-insensitively against stripped input.
# ---------------------------------------------------------------------------
ORGANISM_ALIASES: dict = {
    # Escherichia coli variants
    "e. coli": "Escherichia coli",
    "e.coli": "Escherichia coli",
    "e coli": "Escherichia coli",
    "escherichia coli": "Escherichia coli",
    # Klebsiella
    "klebsiella": "Klebsiella pneumoniae",
    "klebsiella pneumoniae": "Klebsiella pneumoniae",
    # Staphylococcus
    "staph aureus": "Staphylococcus aureus",
    "staphylococcus aureus": "Staphylococcus aureus",
    "s. aureus": "Staphylococcus aureus",
    "mrsa": "Staphylococcus aureus (MRSA)",
    # Enterococcus
    "enterococcus": "Enterococcus faecalis",
    "enterococcus faecalis": "Enterococcus faecalis",
    "e. faecalis": "Enterococcus faecalis",
    # Pseudomonas
    "pseudomonas": "Pseudomonas aeruginosa",
    "pseudomonas aeruginosa": "Pseudomonas aeruginosa",
    "p. aeruginosa": "Pseudomonas aeruginosa",
    # Proteus
    "proteus": "Proteus mirabilis",
    "proteus mirabilis": "Proteus mirabilis",
    # Contamination terms (kept as-is but included for normalisation completeness)
    "mixed flora": "mixed flora",
    "skin flora": "skin flora",
    "normal flora": "normal flora",
    "commensal": "commensal",
    "mixed growth": "mixed growth",
}


def normalize_organism(raw: str) -> str:
    """
    Normalise a raw organism string to its canonical name.

    Performs case-insensitive lookup against ORGANISM_ALIASES.
    Returns the canonical name if found, otherwise returns the stripped
    title-cased version of the original input.

    Args:
        raw: Raw organism string from extraction layer.

    Returns:
        Canonical organism name string.
    """
    key = raw.strip().lower()
    return ORGANISM_ALIASES.get(key, raw.strip())

## Cell C: Extraction Layer

In [None]:

import re
import tempfile
import warnings
from pathlib import Path
from typing import Optional



# ---------------------------------------------------------------------------
# Helper: Docling Processing
# ---------------------------------------------------------------------------
def _process_with_docling(input_text: str) -> str:
    """
    Process input text using Docling.

    If input_text is a valid file path, processes that file.
    Otherwise, writes text to a temporary file and processes it.
    Returns the structured markdown text from the document.
    """
    try:
        from docling.document_converter import DocumentConverter
    except ImportError:
        # Silently fail or log debug if needed, but for user-facing, return original text
        # Only warn once if desired, but here we just return
        return input_text

    input_path = Path(input_text)
    is_file = input_path.exists() and input_path.is_file()

    try:
        converter = DocumentConverter()

        if is_file:
            # Process directly from file path
            result = converter.convert(input_path)
            return result.document.export_to_markdown()
        else:
            # Input is raw text; Docling processing via temp file may distort layout (e.g. merging lines).
            # Fallback to returning raw text so regexes can use original newlines.
            return input_text

    except Exception as e:
        warnings.warn(
            f"Docling processing failed: {e}. Falling back to raw text.", UserWarning
        )
        return input_text


# ---------------------------------------------------------------------------
# Custom exception
# ---------------------------------------------------------------------------
class ExtractionError(ValueError):
    """Raised when both organism AND cfu fail to parse from a report."""


# ---------------------------------------------------------------------------
# Compiled regex patterns (Section 5.2) - ENHANCED for flexibility
# ---------------------------------------------------------------------------

# Organism: Multiple patterns to handle various lab report formats
# Fixed: Use greedy match that captures until newline but handles dots in names like "E. coli"
_RE_ORGANISM_PRIMARY = re.compile(r"Organism:\s*([^.].*?)(?:\n|$)", re.IGNORECASE)
_RE_ORGANISM_ALT1 = re.compile(
    r"Organism\s+identified:\s*([^.].*?)(?:\n|$)", re.IGNORECASE
)
_RE_ORGANISM_ALT2 = re.compile(r"Isolated:\s*([^.].*?)(?:\n|$)", re.IGNORECASE)
_RE_ORGANISM_ALT3 = re.compile(r"Identification:\s*([^.].*?)(?:\n|$)", re.IGNORECASE)
_RE_ORGANISM_ALT4 = re.compile(
    r"Culture\s+results?:\s*([^.].*?)(?:\n|$)", re.IGNORECASE
)
_RE_ORGANISM_ALT5 = re.compile(r"ORGANISM:\s*([^.].*?)(?:\n|$)", re.IGNORECASE)

# CFU/mL: Multiple patterns for various formats
_RE_CFU_PRIMARY = re.compile(r"CFU[/\\]?m?L?:\s*([><]?\s*[\d,]+)", re.IGNORECASE)
_RE_CFU_ALT1 = re.compile(
    r"(?:Count|Quantity|Result):\s*([><]?\s*[\d,]+)", re.IGNORECASE
)
_RE_CFU_ALT2 = re.compile(r"([\d,]+)\s*(?:CFU|colonies|cells)", re.IGNORECASE)
_RE_CFU_ALT3 = re.compile(r">\s*?([\d,]+)", re.IGNORECASE)  # >100,000
_RE_CFU_ALT4 = re.compile(r"(\d{2,3},\d{3})", re.IGNORECASE)  # 100,000 pattern

# Fallback CFU patterns
_RE_CFU_SCIENTIFIC = re.compile(r"10\^(\d+)", re.IGNORECASE)  # 10^5 → 100000
_RE_CFU_WORD = re.compile(r"(TNTC|Too\s+Numerous\s+To\s+Count)", re.IGNORECASE)
_RE_CFU_NO_GROWTH = re.compile(
    r"(No\s+growth|No\s+significant\s+growth|0\s+CFU|Negative)", re.IGNORECASE
)
_RE_CFU_RAW_NUMBER = re.compile(r"\b([\d]{5,})\b")  # bare large number (5+ digits)

# Date: Multiple patterns for various formats
_RE_DATE_PRIMARY = re.compile(
    r"(?:Date|Collected|Reported|Specimen\s+Date|Collection\s+Date|Date\s+Collected|Date\s+Reported)[\s:]*[\*_]*[\s:]+(\d{4}-\d{2}-\d{2}|\d{2}/\d{2}/\d{4}|\d{2}-\d{2}-\d{4})",
    re.IGNORECASE,
)
_RE_DATE_ALT1 = re.compile(r"\b(\d{4}-\d{2}-\d{2})\b")  # ISO format anywhere
_RE_DATE_ALT2 = re.compile(r"\b(\d{2}/\d{2}/\d{4})\b")  # MM/DD/YYYY anywhere
_RE_DATE_ALT3 = re.compile(r"\b(\d{2}-\d{2}-\d{4})\b")  # MM-DD-YYYY anywhere

# Resistance markers: exact case-insensitive word boundaries
_RE_RESISTANCE = re.compile(r"\b(ESBL|CRE|MRSA|VRE|CRKP)\b", re.IGNORECASE)

# Specimen type - ENHANCED: multiple patterns and keyword detection
_RE_SPECIMEN_PRIMARY = re.compile(
    r"(?:Specimen|Sample|Source|Type)[\s:]+(urine|stool|wound|blood|urinary|fecal|faecal)",
    re.IGNORECASE,
)
_RE_SPECIMEN_ALT1 = re.compile(
    r"(urine|stool|wound|blood)\s*(?:culture|specimen|sample|test)", re.IGNORECASE
)
_RE_SPECIMEN_ALT2 = re.compile(
    r"(?:culture|specimen|sample|test)\s*(?:type)?[\s:]+(urine|stool|wound|blood)",
    re.IGNORECASE,
)
# Match markdown headers and bold text: ## Urine Culture, **Urine Culture**, Urine Culture
_RE_SPECIMEN_HEADER = re.compile(
    r"(?:^#{1,3}\s*|\*{2}|\_{2}|##\s*)\s*(urine|stool|wound|blood|sputum)\s+culture\b",
    re.IGNORECASE | re.MULTILINE,
)
_RE_SPECIMEN_URINE_KEYWORD = re.compile(
    r"\b(urine|urinary|bladder|catheter)\b", re.IGNORECASE
)
_RE_SPECIMEN_STOOL_KEYWORD = re.compile(
    r"\b(stool|fecal|faecal|feces|gi)\b", re.IGNORECASE
)


# ---------------------------------------------------------------------------
# CFU normalisation helper (Section 5.4) - ENHANCED
# ---------------------------------------------------------------------------


def _parse_cfu(report_text: str) -> tuple[int, bool]:
    """
    Attempt to parse the CFU/mL value from a report text string.

    Returns:
        (cfu_value, parse_success) tuple.

    Normalisation rules:
        - "TNTC" / "Too Numerous To Count" → 999999
        - "No growth" / "0 CFU"            → 0
        - "10^5"                            → 100000
        - ">100,000" or "> 100,000"         → 100000 (or parse the number)
        - comma-separated integer           → int (commas stripped)
        - Missing/unparseable               → 0 with warning
    """
    text = report_text.strip()

    # 1. Primary: "CFU/mL: 120,000" or "CFU/mL: >100,000"
    m = _RE_CFU_PRIMARY.search(text)
    if m:
        raw = m.group(1).replace(",", "").replace(">", "").replace("<", "").strip()
        try:
            return int(raw), True
        except ValueError:
            pass

    # 2. Alternative: "Count: 120,000" or "Result: >100,000"
    m = _RE_CFU_ALT1.search(text)
    if m:
        raw = m.group(1).replace(",", "").replace(">", "").replace("<", "").strip()
        try:
            return int(raw), True
        except ValueError:
            pass

    # 3. Alternative: "120,000 CFU" or "120,000 colonies"
    m = _RE_CFU_ALT2.search(text)
    if m:
        raw = m.group(1).replace(",", "")
        try:
            return int(raw), True
        except ValueError:
            pass

    # 4. Alternative: ">100,000" or "> 100,000"
    m = _RE_CFU_ALT3.search(text)
    if m:
        raw = m.group(1).replace(",", "")
        try:
            return int(raw), True
        except ValueError:
            pass

    # 5. Alternative: standalone 100,000 pattern
    m = _RE_CFU_ALT4.search(text)
    if m:
        raw = m.group(1).replace(",", "")
        try:
            return int(raw), True
        except ValueError:
            pass

    # 6. TNTC
    if _RE_CFU_WORD.search(text):
        return 999999, True

    # 7. No growth / negative
    if _RE_CFU_NO_GROWTH.search(text):
        return 0, True

    # 8. Scientific notation "10^5"
    m = _RE_CFU_SCIENTIFIC.search(text)
    if m:
        try:
            return 10 ** int(m.group(1)), True
        except (ValueError, OverflowError):
            pass

    # 9. Bare large integer (≥5 digits) — last resort fallback
    m = _RE_CFU_RAW_NUMBER.search(text)
    if m:
        raw = m.group(1).replace(",", "")
        try:
            val = int(raw)
            warnings.warn(
                f"CFU parsed from bare number '{raw}' — review report text.",
                UserWarning,
                stacklevel=3,
            )
            return val, True
        except ValueError:
            pass

    warnings.warn(
        "CFU/mL could not be parsed; defaulting to 0.", UserWarning, stacklevel=3
    )
    return 0, False


def _parse_date(report_text: str) -> str:
    """Extract and normalise the collection date from report text."""
    # Primary: prefixed dates
    m = _RE_DATE_PRIMARY.search(report_text)
    if m:
        raw = m.group(1)
        return _normalize_date(raw)

    # Alt1: ISO format anywhere
    m = _RE_DATE_ALT1.search(report_text)
    if m:
        return m.group(1)

    # Alt2: MM/DD/YYYY anywhere
    m = _RE_DATE_ALT2.search(report_text)
    if m:
        return _normalize_date(m.group(1))

    # Alt3: MM-DD-YYYY anywhere
    m = _RE_DATE_ALT3.search(report_text)
    if m:
        raw = m.group(1).replace("-", "/")
        return _normalize_date(raw)

    return "unknown"


def _normalize_date(raw: str) -> str:
    """Convert various date formats to ISO 8601 (YYYY-MM-DD)."""
    raw = raw.strip()

    # Already ISO format
    if re.match(r"^\d{4}-\d{2}-\d{2}$", raw):
        return raw

    # MM/DD/YYYY or MM-DD-YYYY
    if "/" in raw or "-" in raw:
        sep = "/" if "/" in raw else "-"
        parts = raw.split(sep)
        if len(parts) == 3:
            # Determine if first part is month or day based on values
            first, second, year = parts[0], parts[1], parts[2]
            # If first > 12, it's likely DD/MM/YYYY
            if int(first) > 12:
                # DD/MM/YYYY → YYYY-MM-DD
                return f"{year}-{second.zfill(2)}-{first.zfill(2)}"
            else:
                # MM/DD/YYYY → YYYY-MM-DD
                return f"{year}-{first.zfill(2)}-{second.zfill(2)}"

    return "unknown"


def _parse_organism(report_text: str) -> Optional[str]:
    """
    Extract organism name from report text with multiple pattern attempts.
    """
    text = report_text.strip()

    # Try multiple organism patterns in order
    patterns = [
        _RE_ORGANISM_PRIMARY,
        _RE_ORGANISM_ALT5,  # ORGANISM: (all caps)
        _RE_ORGANISM_ALT1,  # Organism identified:
        _RE_ORGANISM_ALT2,  # Isolated:
        _RE_ORGANISM_ALT3,  # Identification:
        _RE_ORGANISM_ALT4,  # Culture result:
    ]

    for pattern in patterns:
        m = pattern.search(text)
        if m:
            raw_organism = m.group(1).strip()
            # Clean up common artifacts but preserve dots in organism names like "E. coli"
            raw_organism = re.sub(r"\s+", " ", raw_organism)  # normalize whitespace
            # Don't split on dots - they're part of organism names like "E. coli"
            # Only truncate if there's clear sentence-ending punctuation
            if re.search(r"[;!?]|\.\s+[A-Z]", raw_organism):
                # Find the first sentence-ending punctuation
                match = re.search(r"([;!?]|\.\s+[A-Z])", raw_organism)
                if match:
                    raw_organism = raw_organism[: match.start()]
            return normalize_organism(raw_organism)

    # Fallback: search for known organism aliases in full text
    lower_text = text.lower()

    for alias in sorted(ORGANISM_ALIASES.keys(), key=len, reverse=True):
        if alias in lower_text:
            return normalize_organism(alias)

    return None


def _parse_resistance_markers(report_text: str) -> list[str]:
    """Extract all high-risk resistance markers (deduplicated, uppercase)."""
    found = _RE_RESISTANCE.findall(report_text)
    # deduplicate, preserve order
    return list(dict.fromkeys(m.upper() for m in found))


def _parse_specimen(report_text: str) -> str:
    """
    Extract specimen type with multiple pattern attempts and keyword detection.
    Returns 'urine', 'stool', 'wound', 'blood', or 'unknown'.
    """
    text = report_text.strip()

    # Try markdown headers and bold text: ## Urine Culture, **Urine Culture**
    m = _RE_SPECIMEN_HEADER.search(text)
    if m:
        return _normalize_specimen(m.group(1).lower())

    # Try primary pattern: Specimen/Sample/Source/Type: urine/stool
    m = _RE_SPECIMEN_PRIMARY.search(text)
    if m:
        specimen = m.group(1).lower()
        return _normalize_specimen(specimen)

    # Try alternative: urine/stool culture
    m = _RE_SPECIMEN_ALT1.search(text)
    if m:
        return _normalize_specimen(m.group(1).lower())

    # Try alternative: culture: urine/stool
    m = _RE_SPECIMEN_ALT2.search(text)
    if m:
        return _normalize_specimen(m.group(1).lower())

    # Keyword detection: look for urine/urinary keywords anywhere
    if _RE_SPECIMEN_URINE_KEYWORD.search(text):
        return "urine"

    # Keyword detection: look for stool/fecal keywords anywhere
    if _RE_SPECIMEN_STOOL_KEYWORD.search(text):
        return "stool"

    return "unknown"


def _normalize_specimen(specimen: str) -> str:
    """Normalize specimen type to standard values."""
    specimen = specimen.lower().strip()

    # Map variations to standard types
    if specimen in ("urine", "urinary"):
        return "urine"
    elif specimen in ("stool", "fecal", "faecal", "feces"):
        return "stool"
    elif specimen == "wound":
        return "wound"
    elif specimen == "blood":
        return "blood"

    return specimen


def _is_contamination(organism: str) -> bool:
    """Return True if the organism name matches any contamination term."""
    lower = organism.lower()
    return any(term in lower for term in RULES["contamination_terms"])


# ---------------------------------------------------------------------------
# Debug helper
# ---------------------------------------------------------------------------


def debug_extraction(report_text: str, label: str = "Report") -> dict:
    """
    Debug helper to show what was extracted from a report.

    Returns a dictionary with all extraction results for debugging.
    """
    processed_text = (
        _process_with_docling(report_text)
        if Path(report_text).exists()
        else report_text
    )

    organism = _parse_organism(processed_text)
    cfu, cfu_ok = _parse_cfu(processed_text)
    specimen = _parse_specimen(processed_text)
    date = _parse_date(processed_text)
    resistance = _parse_resistance_markers(processed_text)

    return {
        "label": label,
        "organism": organism,
        "cfu": cfu,
        "cfu_ok": cfu_ok,
        "specimen": specimen,
        "date": date,
        "resistance": resistance,
        "is_contamination": _is_contamination(organism) if organism else False,
        "processed_text_preview": processed_text[:500] + "..."
        if len(processed_text) > 500
        else processed_text,
    }


# ---------------------------------------------------------------------------
# Public extraction function
# ---------------------------------------------------------------------------


def extract_structured_data(report_text: str) -> CultureReport:
    """
    Parse a free-text culture report into a typed CultureReport.

    Now supports direct file paths via Docling processing.

    Rules:
        - Organism field: stripped, normalised via ORGANISM_ALIASES
        - CFU: commas removed, converted to int; TNTC=999999
        - resistance_markers: deduplicated, uppercase
        - contamination_flag: True if organism in contamination_terms
        - raw_text: stored as-is (or docling processed), NEVER forwarded to MedGemma

    Raises:
        ExtractionError: if both organism AND cfu fail to parse.
    """
    # Pre-process with Docling (handles file paths or raw text)
    processed_text = _process_with_docling(report_text)

    # Attempt extraction on processed text
    organism = _parse_organism(processed_text)
    cfu, cfu_ok = _parse_cfu(processed_text)

    # Fallback: if extraction failed and text was modified by Docling, try original
    if (organism is None and not cfu_ok) and processed_text != report_text:
        organism = _parse_organism(report_text)
        cfu, cfu_ok = _parse_cfu(report_text)
        if organism is not None or cfu_ok:
            processed_text = report_text  # Revert to original for other fields

    if organism is None and not cfu_ok:
        raise ExtractionError(
            "Extraction failed: could not parse organism OR CFU/mL from report. "
            "Check report format."
        )

    # If only organism failed, use a placeholder and warn
    if organism is None:
        warnings.warn(
            "Organism could not be parsed; using 'unknown'.", UserWarning, stacklevel=2
        )
        organism = "unknown"

    resistance_markers = _parse_resistance_markers(processed_text)
    specimen_type = _parse_specimen(processed_text)
    contamination_flag = _is_contamination(organism)
    date = _parse_date(processed_text)

    return CultureReport(
        date=date,
        organism=organism,
        cfu=cfu,
        resistance_markers=resistance_markers,
        specimen_type=specimen_type,
        contamination_flag=contamination_flag,
        raw_text=processed_text,  # Store the text actually used for extraction
    )

In [None]:
# --- Extraction Unit Tests ---

import warnings


_PASS = 0
_FAIL = 0


def _assert(condition: bool, msg: str) -> None:
    global _PASS, _FAIL
    if condition:
        _PASS += 1
        print(f"  PASS  {msg}")
    else:
        _FAIL += 1
        print(f"  FAIL  {msg}")


# ---------------------------------------------------------------------------
# Test Report 1 — Normal improving report
# ---------------------------------------------------------------------------
REPORT_NORMAL = """
Specimen: Urine
Date Collected: 2026-01-01
Organism: E. coli
CFU/mL: 120,000
Sensitivity: Ampicillin - Resistant, Nitrofurantoin - Sensitive
"""

print("=== Test: Normal Report ===")
r = extract_structured_data(REPORT_NORMAL)
_assert(r.date == "2026-01-01", f"date == '2026-01-01'  (got '{r.date}')")
_assert(
    r.organism == "Escherichia coli",
    f"organism normalised to 'Escherichia coli'  (got '{r.organism}')",
)
_assert(r.cfu == 120000, f"cfu == 120000  (got {r.cfu})")
_assert(
    r.resistance_markers == [], f"no resistance markers  (got {r.resistance_markers})"
)
_assert(
    r.specimen_type == "urine", f"specimen_type == 'urine'  (got '{r.specimen_type}')"
)
_assert(
    r.contamination_flag is False,
    f"contamination_flag is False  (got {r.contamination_flag})",
)

# ---------------------------------------------------------------------------
# Test Report 2 — Contamination report (mixed flora, low CFU)
# ---------------------------------------------------------------------------
REPORT_CONTAMINATION = """
Specimen: Urine
Date Collected: 2026-02-05
Organism: mixed flora
CFU/mL: 5,000
No resistance markers detected.
"""

print("\n=== Test: Contamination Report ===")
r2 = extract_structured_data(REPORT_CONTAMINATION)
_assert(
    r2.contamination_flag is True,
    f"contamination_flag is True  (got {r2.contamination_flag})",
)
_assert(
    r2.organism == "mixed flora", f"organism == 'mixed flora'  (got '{r2.organism}')"
)
_assert(r2.cfu == 5000, f"cfu == 5000  (got {r2.cfu})")
_assert(
    r2.resistance_markers == [], f"no resistance markers  (got {r2.resistance_markers})"
)

# ---------------------------------------------------------------------------
# Test Report 3 — Resistance-containing report (ESBL marker)
# ---------------------------------------------------------------------------
REPORT_RESISTANCE = """
Specimen: Urine
Date Collected: 2026-01-20
Organism: Klebsiella pneumoniae
CFU/mL: 75,000
Resistance: ESBL detected.
"""

print("\n=== Test: Resistance Report ===")
r3 = extract_structured_data(REPORT_RESISTANCE)
_assert(
    r3.organism == "Klebsiella pneumoniae",
    f"organism == 'Klebsiella pneumoniae'  (got '{r3.organism}')",
)
_assert(
    "ESBL" in r3.resistance_markers,
    f"ESBL in resistance_markers  (got {r3.resistance_markers})",
)
_assert(
    r3.contamination_flag is False,
    f"contamination_flag is False  (got {r3.contamination_flag})",
)
_assert(r3.cfu == 75000, f"cfu == 75000  (got {r3.cfu})")

# ---------------------------------------------------------------------------
# Test — TNTC CFU normalisation
# ---------------------------------------------------------------------------
REPORT_TNTC = """
Specimen: Urine
Date Collected: 2026-03-01
Organism: E. coli
CFU/mL: TNTC
"""

print("\n=== Test: TNTC Normalisation ===")
r4 = extract_structured_data(REPORT_TNTC)
_assert(r4.cfu == 999999, f"TNTC → 999999  (got {r4.cfu})")

# ---------------------------------------------------------------------------
# Test — No growth / cleared
# ---------------------------------------------------------------------------
REPORT_NO_GROWTH = """
Specimen: Urine
Date Collected: 2026-03-15
Organism: E. coli
No growth observed.
"""

print("\n=== Test: No Growth ===")
r5 = extract_structured_data(REPORT_NO_GROWTH)
_assert(r5.cfu == 0, f"No growth → cfu == 0  (got {r5.cfu})")

# ---------------------------------------------------------------------------
# Test — ExtractionError on completely unparseable input
# ---------------------------------------------------------------------------
print("\n=== Test: ExtractionError on bad input ===")
try:
    extract_structured_data("this report contains absolutely nothing useful at all")
    _assert(False, "ExtractionError should have been raised")
except ExtractionError as e:
    _assert(True, f"ExtractionError raised correctly: {e}")
except Exception as e:
    _assert(False, f"Wrong exception type raised: {type(e).__name__}: {e}")

# ---------------------------------------------------------------------------
# Test — Adversarial: SQL injection in CFU field
# ---------------------------------------------------------------------------
REPORT_ADV = """
Specimen: Urine
Date Collected: 2026-04-01
Organism: E. coli
CFU/mL: 100000; DROP TABLE reports
"""

print("\n=== Test: Adversarial SQL Injection in CFU ===")
# Should parse 100000 from the start, or fallback gracefully
try:
    r6 = extract_structured_data(REPORT_ADV)
    # The regex only captures digits+commas, so "100000" is parsed, the rest is ignored
    _assert(r6.cfu == 100000, f"cfu == 100000 (injection ignored)  (got {r6.cfu})")
except ExtractionError:
    _assert(False, "Should not raise ExtractionError on adversarial CFU")

# ---------------------------------------------------------------------------
# Test — Alternate date format MM/DD/YYYY
# ---------------------------------------------------------------------------
REPORT_DATE_ALT = """
Specimen: Stool
Date Collected: 01/15/2026
Organism: Enterococcus faecalis
CFU/mL: 60,000
"""

print("\n=== Test: Alternate Date Format (MM/DD/YYYY) ===")
r7 = extract_structured_data(REPORT_DATE_ALT)
_assert(r7.date == "2026-01-15", f"date normalised to ISO  (got '{r7.date}')")
_assert(
    r7.specimen_type == "stool", f"specimen_type == 'stool'  (got '{r7.specimen_type}')"
)

# ---------------------------------------------------------------------------
# Test — Flexible specimen detection (alternate formats)
# ---------------------------------------------------------------------------
REPORT_SPECIMEN_FLEX1 = """
URINE CULTURE
Date: 2026-05-01
Organism: E. coli
CFU/mL: 80,000
"""

print("\n=== Test: Flexible Specimen Detection (Urine Culture title) ===")
r8 = extract_structured_data(REPORT_SPECIMEN_FLEX1)
_assert(
    r8.specimen_type == "urine",
    f"specimen_type detected as 'urine' from title  (got '{r8.specimen_type}')",
)

REPORT_SPECIMEN_FLEX2 = """
Specimen Type: Stool
Date: 2026-05-10
Organism: mixed flora
CFU/mL: 2,000
"""

print("\n=== Test: Flexible Specimen Detection (Specimen Type: Stool) ===")
r9 = extract_structured_data(REPORT_SPECIMEN_FLEX2)
_assert(
    r9.specimen_type == "stool",
    f"specimen_type detected as 'stool'  (got '{r9.specimen_type}')",
)

# ---------------------------------------------------------------------------
# Test — Flexible organism detection (alternate formats)
# ---------------------------------------------------------------------------
REPORT_ORG_FLEX1 = """
Specimen: Urine
Date: 2026-06-01
ORGANISM: Klebsiella pneumoniae
CFU/mL: 50,000
"""

print("\n=== Test: Flexible Organism Detection (ORGANISM: caps) ===")
r10 = extract_structured_data(REPORT_ORG_FLEX1)
_assert(
    r10.organism == "Klebsiella pneumoniae",
    f"organism detected from ORGANISM:  (got '{r10.organism}')",
)

REPORT_ORG_FLEX2 = """
Specimen: Urine
Date: 2026-06-15
Isolated: E. coli
CFU/mL: 150,000
"""

print("\n=== Test: Flexible Organism Detection (Isolated:) ===")
r11 = extract_structured_data(REPORT_ORG_FLEX2)
_assert(
    r11.organism == "Escherichia coli",
    f"organism detected from Isolated:  (got '{r11.organism}')",
)

# ---------------------------------------------------------------------------
# Test — Flexible CFU detection (alternate formats)
# ---------------------------------------------------------------------------
REPORT_CFU_FLEX1 = """
Specimen: Urine
Date: 2026-07-01
Organism: E. coli
Result: >100,000 CFU/mL
"""

print("\n=== Test: Flexible CFU Detection (>100,000 format) ===")
r12 = extract_structured_data(REPORT_CFU_FLEX1)
_assert(
    r12.cfu == 100000,
    f"cfu parsed from >100,000 format  (got {r12.cfu})",
)

REPORT_CFU_FLEX2 = """
Specimen: Urine
Date: 2026-07-15
Organism: Enterococcus faecalis
Count: 75,000 colonies per mL
"""

print("\n=== Test: Flexible CFU Detection (Count: + colonies) ===")
r13 = extract_structured_data(REPORT_CFU_FLEX2)
_assert(
    r13.cfu == 75000,
    f"cfu parsed from Count: format  (got {r13.cfu})",
)

# ---------------------------------------------------------------------------
# Test — Flexible date detection (alternate formats)
# ---------------------------------------------------------------------------
REPORT_DATE_FLEX1 = """
Specimen: Urine
Collection Date: 03/25/2026
Organism: E. coli
CFU/mL: 100,000
"""

print("\n=== Test: Flexible Date Detection (Collection Date MM/DD/YYYY) ===")
r14 = extract_structured_data(REPORT_DATE_FLEX1)
_assert(
    r14.date == "2026-03-25",
    f"date parsed from Collection Date:  (got '{r14.date}')",
)

REPORT_DATE_FLEX2 = """
Specimen: Urine
Date: 07-04-2026
Organism: E. coli
CFU/mL: 100,000
"""

print("\n=== Test: Flexible Date Detection (MM-DD-YYYY format) ===")
r15 = extract_structured_data(REPORT_DATE_FLEX2)
_assert(
    r15.date == "2026-07-04",
    f"date parsed from MM-DD-YYYY format  (got '{r15.date}')",
)

# ---------------------------------------------------------------------------
# Test — Keyword-based specimen detection (no explicit Specimen: line)
# ---------------------------------------------------------------------------
REPORT_KEYWORD_URINE = """
URINE CULTURE REPORT
Patient: John Doe
Date: 2026-08-01

MICROBIOLOGY RESULTS:
E. coli isolated at 100,000 CFU/mL
"""

print("\n=== Test: Keyword Specimen Detection (URINE CULTURE) ===")
r16 = extract_structured_data(REPORT_KEYWORD_URINE)
_assert(
    r16.specimen_type == "urine",
    f"specimen_type detected via urine keyword  (got '{r16.specimen_type}')",
)

REPORT_KEYWORD_STOOL = """
FECAL CULTURE
Patient: Jane Smith
Date: 2026-08-15

Salmonella detected
CFU/mL: 45,000
"""

print("\n=== Test: Keyword Specimen Detection (FECAL CULTURE) ===")
try:
    r17 = extract_structured_data(REPORT_KEYWORD_STOOL)
    _assert(
        r17.specimen_type == "stool",
        f"specimen_type detected via fecal keyword  (got '{r17.specimen_type}')",
    )
    _assert(
        r17.cfu == 45000,
        f"cfu == 45000  (got {r17.cfu})",
    )
except ExtractionError as e:
    _assert(False, f"Extraction failed for stool culture test: {e}")

# ---------------------------------------------------------------------------
# Summary
# ---------------------------------------------------------------------------
print(f"\n{'=' * 50}")
print(f"Extraction Tests Complete: {_PASS} passed, {_FAIL} failed")
if _FAIL == 0:
    print("ALL TESTS PASSED")
else:
    print(f"WARNING: {_FAIL} test(s) failed — review extraction logic")

## Cell D: Temporal Trend Engine

In [None]:

from typing import List



# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------


def _classify_cfu_trend(cfu_values: List[int]) -> str:
    """
    Classify the CFU trajectory from an ordered list of values.

    Labels (priority order):
        "insufficient_data"  — fewer than 2 reports
        "cleared"            — final value ≤ cleared_threshold (overrides all)
        "decreasing"         — all values monotonically decreasing
        "increasing"         — all values monotonically increasing
        "fluctuating"        — any other pattern
    """
    if len(cfu_values) < 2:
        return "insufficient_data"

    # "cleared" overrides all other labels
    if cfu_values[-1] <= RULES["cleared_threshold"]:
        return "cleared"

    strictly_decreasing = all(
        cfu_values[i] > cfu_values[i + 1] for i in range(len(cfu_values) - 1)
    )
    if strictly_decreasing:
        return "decreasing"

    strictly_increasing = all(
        cfu_values[i] < cfu_values[i + 1] for i in range(len(cfu_values) - 1)
    )
    if strictly_increasing:
        return "increasing"

    return "fluctuating"


def _compute_deltas(cfu_values: List[int]) -> List[int]:
    """
    Compute per-interval CFU changes.

    Positive delta = worsening (increasing CFU).
    Negative delta = improving (decreasing CFU).
    """
    return [cfu_values[i + 1] - cfu_values[i] for i in range(len(cfu_values) - 1)]


def _check_persistence(organism_list: List[str]) -> bool:
    """
    Return True if the same organism was isolated across all reports.

    Comparison is performed on normalised (lowercase, stripped) organism names,
    with alias resolution to handle "E. coli" == "Escherichia coli".
    """
    normalised = [normalize_organism(o).strip().lower() for o in organism_list]
    return len(set(normalised)) == 1


def _check_resistance_evolution(reports: List[CultureReport]) -> bool:
    """
    Return True if new resistance markers appear in any report after the first.

    Logic:
        - Baseline = markers in report[0]
        - If any subsequent report contains a marker not in baseline → True
    """
    if len(reports) < 2:
        return False
    baseline = set(reports[0].resistance_markers)
    later_markers: set[str] = set()
    for r in reports[1:]:
        later_markers.update(r.resistance_markers)
    return bool(later_markers - baseline)


def _build_resistance_timeline(reports: List[CultureReport]) -> List[List[str]]:
    """Return per-report resistance marker lists, in report order."""
    return [list(r.resistance_markers) for r in reports]


# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------


def analyze_trend(reports: List[CultureReport]) -> TrendResult:
    """
    Compute a TrendResult from an ordered list of CultureReport objects.

    Reports should be sorted by date (oldest first) before calling this
    function. The function does NOT re-sort — caller is responsible.

    Args:
        reports: 1–3 CultureReport instances in chronological order.

    Returns:
        TrendResult with all temporal signal fields populated.
    """
    if not reports:
        raise ValueError("analyze_trend requires at least one CultureReport.")

    cfu_values = [r.cfu for r in reports]
    cfu_deltas = _compute_deltas(cfu_values)
    cfu_trend = _classify_cfu_trend(cfu_values)
    organism_list = [r.organism for r in reports]
    organism_persistent = _check_persistence(organism_list)
    resistance_evolution = _check_resistance_evolution(reports)
    resistance_timeline = _build_resistance_timeline(reports)
    report_dates = [r.date for r in reports]
    any_contamination = any(r.contamination_flag for r in reports)

    return TrendResult(
        cfu_trend=cfu_trend,
        cfu_values=cfu_values,
        cfu_deltas=cfu_deltas,
        organism_persistent=organism_persistent,
        organism_list=organism_list,
        resistance_evolution=resistance_evolution,
        resistance_timeline=resistance_timeline,
        report_dates=report_dates,
        any_contamination=any_contamination,
    )

In [None]:
# --- Trend Unit Tests ---


_PASS = 0
_FAIL = 0


def _assert(condition: bool, msg: str) -> None:
    global _PASS, _FAIL
    if condition:
        _PASS += 1
        print(f"  PASS  {msg}")
    else:
        _FAIL += 1
        print(f"  FAIL  {msg}")


def _make_report(
    cfu: int,
    organism: str = "Escherichia coli",
    date: str = "2026-01-01",
    markers=None,
    contamination: bool = False,
) -> CultureReport:
    return CultureReport(
        date=date,
        organism=organism,
        cfu=cfu,
        resistance_markers=markers or [],
        specimen_type="urine",
        contamination_flag=contamination,
        raw_text="<stub>",
    )


# ---------------------------------------------------------------------------
# 1. Monotonically decreasing
# ---------------------------------------------------------------------------
print("=== Test: Monotonically Decreasing CFU ===")
rpts = [
    _make_report(120000, date="2026-01-01"),
    _make_report(40000, date="2026-01-10"),
    _make_report(5000, date="2026-01-20"),
]
t = analyze_trend(rpts)
_assert(t.cfu_trend == "decreasing", f"trend == 'decreasing'  (got '{t.cfu_trend}')")
_assert(t.cfu_deltas == [-80000, -35000], f"deltas correct  (got {t.cfu_deltas})")
_assert(t.organism_persistent is True, f"organism_persistent == True")
_assert(t.resistance_evolution is False, f"resistance_evolution == False")
_assert(t.any_contamination is False, f"any_contamination == False")

# ---------------------------------------------------------------------------
# 2. Cleared (final CFU ≤ 1000) — overrides decreasing
# ---------------------------------------------------------------------------
print("\n=== Test: Cleared (Final CFU ≤ 1000) ===")
rpts2 = [
    _make_report(120000, date="2026-01-01"),
    _make_report(40000, date="2026-01-10"),
    _make_report(800, date="2026-01-20"),
]
t2 = analyze_trend(rpts2)
_assert(t2.cfu_trend == "cleared", f"trend == 'cleared'  (got '{t2.cfu_trend}')")

# ---------------------------------------------------------------------------
# 3. CFU = 0 (no growth) → also cleared
# ---------------------------------------------------------------------------
print("\n=== Test: Zero CFU (No Growth) ===")
rpts3 = [
    _make_report(80000, date="2026-01-01"),
    _make_report(0, date="2026-01-10"),
]
t3 = analyze_trend(rpts3)
_assert(
    t3.cfu_trend == "cleared", f"trend == 'cleared' for CFU=0  (got '{t3.cfu_trend}')"
)

# ---------------------------------------------------------------------------
# 4. Monotonically increasing
# ---------------------------------------------------------------------------
print("\n=== Test: Monotonically Increasing CFU ===")
rpts4 = [
    _make_report(40000, date="2026-01-01"),
    _make_report(80000, date="2026-01-10"),
    _make_report(120000, date="2026-01-20"),
]
t4 = analyze_trend(rpts4)
_assert(t4.cfu_trend == "increasing", f"trend == 'increasing'  (got '{t4.cfu_trend}')")

# ---------------------------------------------------------------------------
# 5. Fluctuating
# ---------------------------------------------------------------------------
print("\n=== Test: Fluctuating CFU ===")
rpts5 = [
    _make_report(80000, date="2026-01-01"),
    _make_report(120000, date="2026-01-10"),
    _make_report(60000, date="2026-01-20"),
]
t5 = analyze_trend(rpts5)
_assert(
    t5.cfu_trend == "fluctuating", f"trend == 'fluctuating'  (got '{t5.cfu_trend}')"
)

# ---------------------------------------------------------------------------
# 6. Single report — insufficient_data
# ---------------------------------------------------------------------------
print("\n=== Test: Single Report (Insufficient Data) ===")
rpts6 = [_make_report(100000, date="2026-01-01")]
t6 = analyze_trend(rpts6)
_assert(
    t6.cfu_trend == "insufficient_data",
    f"trend == 'insufficient_data'  (got '{t6.cfu_trend}')",
)
_assert(t6.cfu_deltas == [], f"deltas == []  (got {t6.cfu_deltas})")

# ---------------------------------------------------------------------------
# 7. Resistance evolution detection
# ---------------------------------------------------------------------------
print("\n=== Test: Resistance Evolution ===")
rpts7 = [
    _make_report(90000, date="2026-01-01", markers=[]),
    _make_report(80000, date="2026-01-10", markers=[]),
    _make_report(75000, date="2026-01-20", markers=["ESBL"]),
]
t7 = analyze_trend(rpts7)
_assert(t7.resistance_evolution is True, f"resistance_evolution == True")
_assert(t7.resistance_timeline[2] == ["ESBL"], f"resistance_timeline[2] == ['ESBL']")

# ---------------------------------------------------------------------------
# 8. Organism change (not persistent)
# ---------------------------------------------------------------------------
print("\n=== Test: Organism Change ===")
rpts8 = [
    _make_report(100000, organism="Escherichia coli", date="2026-01-01"),
    _make_report(90000, organism="Klebsiella pneumoniae", date="2026-01-10"),
]
t8 = analyze_trend(rpts8)
_assert(
    t8.organism_persistent is False,
    f"organism_persistent == False when organism changes",
)

# ---------------------------------------------------------------------------
# 9. Contamination flag propagation
# ---------------------------------------------------------------------------
print("\n=== Test: Contamination Propagation ===")
rpts9 = [
    _make_report(5000, organism="mixed flora", date="2026-01-01", contamination=True),
    _make_report(3000, organism="mixed flora", date="2026-01-10", contamination=True),
]
t9 = analyze_trend(rpts9)
_assert(t9.any_contamination is True, f"any_contamination == True")

# ---------------------------------------------------------------------------
# Summary
# ---------------------------------------------------------------------------
print(f"\n{'=' * 50}")
print(f"Trend Tests Complete: {_PASS} passed, {_FAIL} failed")
if _FAIL == 0:
    print("ALL TESTS PASSED")
else:
    print(f"WARNING: {_FAIL} test(s) failed")

## Cell E: Hypothesis Update Layer

In [None]:

from typing import List



# ---------------------------------------------------------------------------
# Risk flag constants
# ---------------------------------------------------------------------------
FLAG_EMERGING_RESISTANCE = "EMERGING_RESISTANCE"
FLAG_CONTAMINATION = "CONTAMINATION_SUSPECTED"
FLAG_NON_RESPONSE = "NON_RESPONSE_PATTERN"
FLAG_INSUFFICIENT_DATA = "INSUFFICIENT_DATA"
FLAG_ORGANISM_CHANGE = "ORGANISM_CHANGE"


# ---------------------------------------------------------------------------
# Confidence scoring
# ---------------------------------------------------------------------------


def _score_confidence(trend: TrendResult, report_count: int) -> float:
    """
    Apply deterministic signal adjustments to a base confidence value.

    Starting point: RULES["base_confidence"] = 0.50
    Each signal adds or subtracts a fixed delta.
    Final value is clamped to [0.0, RULES["max_confidence"]].

    Signal table (Section 7.1):
        +0.30  CFU decreasing
        +0.40  CFU cleared
        +0.20  CFU increasing  (high confidence of non-response)
        -0.10  CFU fluctuating
        -0.10  resistance evolution
        -0.05  organism changed
        -0.20  contamination present
        -0.25  fewer than 2 reports
    """
    confidence = RULES["base_confidence"]

    # Trend signal
    if trend.cfu_trend == "decreasing":
        confidence += 0.30
    elif trend.cfu_trend == "cleared":
        confidence += 0.40
    elif trend.cfu_trend == "increasing":
        confidence += 0.20  # high confidence of non-response
    elif trend.cfu_trend == "fluctuating":
        confidence -= 0.10

    # Resistance evolution penalty
    if trend.resistance_evolution:
        confidence -= 0.10

    # Organism change uncertainty
    if not trend.organism_persistent:
        confidence -= 0.05

    # Contamination validity concern
    if trend.any_contamination:
        confidence -= 0.20

    # Insufficient data penalty
    if report_count < 2:
        confidence -= 0.25

    # Hard clamp: never < 0.0, never > max_confidence (epistemic humility)
    return round(max(0.0, min(confidence, RULES["max_confidence"])), 4)


# ---------------------------------------------------------------------------
# Risk flag assignment (Section 7.2)
# ---------------------------------------------------------------------------


def _assign_risk_flags(trend: TrendResult, report_count: int) -> List[str]:
    """Build a list of risk flag strings from trend signals."""
    flags: List[str] = []

    if trend.resistance_evolution:
        flags.append(FLAG_EMERGING_RESISTANCE)

    if trend.any_contamination:
        flags.append(FLAG_CONTAMINATION)

    if trend.cfu_trend == "increasing":
        flags.append(FLAG_NON_RESPONSE)

    if report_count < 2:
        flags.append(FLAG_INSUFFICIENT_DATA)

    if not trend.organism_persistent:
        flags.append(FLAG_ORGANISM_CHANGE)

    return flags


# ---------------------------------------------------------------------------
# Interpretation string construction (Section 7.3)
# ---------------------------------------------------------------------------


def _build_interpretation(trend: TrendResult, report_count: int) -> str:
    """
    Construct a rule-generated natural language pattern summary.

    This string is passed to MedGemma only as structured context inside
    the JSON payload — never as a direct LLM prompt.
    """
    parts: List[str] = []

    if trend.cfu_trend == "decreasing":
        parts.append("Pattern suggests improving infection response.")
    elif trend.cfu_trend == "cleared":
        parts.append("Pattern suggests possible resolution.")
    elif trend.cfu_trend == "increasing":
        parts.append("Pattern suggests possible non-response.")
    elif trend.cfu_trend == "fluctuating":
        parts.append("Pattern is variable — requires clinical context.")
    elif trend.cfu_trend == "insufficient_data":
        parts.append("Insufficient longitudinal data for trend analysis.")

    if trend.resistance_evolution:
        parts.append("Emerging resistance observed.")

    if not trend.organism_persistent:
        parts.append("Organism change may indicate reinfection.")

    if trend.any_contamination:
        parts.append("Contamination suspected — interpret with caution.")

    return " ".join(parts)


# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------


def generate_hypothesis(trend: TrendResult, report_count: int) -> HypothesisResult:
    """
    Generate a deterministic hypothesis from a TrendResult.

    Args:
        trend: Computed TrendResult from the trend engine.
        report_count: Number of source reports (used for insufficient-data logic).

    Returns:
        HypothesisResult with confidence score, risk flags, interpretation,
        stewardship alert, and mandatory clinician review flag.
    """
    confidence = _score_confidence(trend, report_count)
    risk_flags = _assign_risk_flags(trend, report_count)
    interpretation = _build_interpretation(trend, report_count)
    stewardship_alert = trend.resistance_evolution

    return HypothesisResult(
        interpretation=interpretation,
        confidence=confidence,
        risk_flags=risk_flags,
        stewardship_alert=stewardship_alert,
        requires_clinician_review=True,  # Always True — structural safety guarantee
    )

In [None]:
# --- Hypothesis Unit Tests ---


_PASS = 0
_FAIL = 0


def _assert(condition: bool, msg: str) -> None:
    global _PASS, _FAIL
    if condition:
        _PASS += 1
        print(f"  PASS  {msg}")
    else:
        _FAIL += 1
        print(f"  FAIL  {msg}")


def _make_report(
    cfu: int,
    organism: str = "Escherichia coli",
    date: str = "2026-01-01",
    markers=None,
    contamination: bool = False,
) -> CultureReport:
    return CultureReport(
        date=date,
        organism=organism,
        cfu=cfu,
        resistance_markers=markers or [],
        specimen_type="urine",
        contamination_flag=contamination,
        raw_text="<stub>",
    )


# ---------------------------------------------------------------------------
# 1. Perfect improvement (decreasing → cleared) — confidence ≥ 0.80
# ---------------------------------------------------------------------------
print("=== Test: Perfect Improvement (Decreasing → Cleared) ===")
rpts = [
    _make_report(120000, date="2026-01-01"),
    _make_report(40000, date="2026-01-10"),
    _make_report(800, date="2026-01-20"),  # cleared (≤ 1000)
]
trend = analyze_trend(rpts)
hyp = generate_hypothesis(trend, len(rpts))

_assert(
    hyp.confidence >= 0.80,
    f"confidence ≥ 0.80 for cleared trend  (got {hyp.confidence})",
)
_assert(
    hyp.confidence <= 0.95, f"confidence ≤ 0.95 (hard ceiling)  (got {hyp.confidence})"
)
_assert(hyp.stewardship_alert is False, f"stewardship_alert == False")
_assert(hyp.requires_clinician_review is True, f"requires_clinician_review always True")
_assert(
    "possible resolution" in hyp.interpretation, f"interpretation mentions resolution"
)

# ---------------------------------------------------------------------------
# 2. Emerging resistance — confidence drops vs. clean improving scenario
# ---------------------------------------------------------------------------
print("\n=== Test: Emerging Resistance (Confidence Drops) ===")
rpts2 = [
    _make_report(90000, date="2026-01-01", markers=[]),
    _make_report(80000, date="2026-01-10", markers=[]),
    _make_report(75000, date="2026-01-20", markers=["ESBL"]),
]
trend2 = analyze_trend(rpts2)
hyp2 = generate_hypothesis(trend2, len(rpts2))

_assert(
    FLAG_EMERGING_RESISTANCE in hyp2.risk_flags, f"EMERGING_RESISTANCE in risk_flags"
)
_assert(hyp2.stewardship_alert is True, f"stewardship_alert == True")
_assert(
    hyp2.confidence < 0.80,
    f"confidence < 0.80 when resistance emerges  (got {hyp2.confidence})",
)

# ---------------------------------------------------------------------------
# 3. Contamination — confidence is reduced by the -0.20 contamination penalty.
#    With decreasing CFU (5000→3000): base 0.50 + 0.30 (decreasing) - 0.20 (contamination) = 0.60
#    The PRD Appendix B example uses a fluctuating pattern; here decreasing gives 0.60.
# ---------------------------------------------------------------------------
print("\n=== Test: Contamination (Confidence Drops Sharply) ===")
rpts3 = [
    _make_report(5000, organism="mixed flora", date="2026-01-01", contamination=True),
    _make_report(3000, organism="mixed flora", date="2026-01-10", contamination=True),
]
trend3 = analyze_trend(rpts3)
hyp3 = generate_hypothesis(trend3, len(rpts3))

_assert(FLAG_CONTAMINATION in hyp3.risk_flags, f"CONTAMINATION_SUSPECTED in risk_flags")
_assert(
    hyp3.confidence <= 0.65,
    f"confidence reduced by contamination penalty (got {hyp3.confidence})",
)
_assert(
    "Contamination suspected" in hyp3.interpretation,
    f"interpretation flags contamination",
)

# ---------------------------------------------------------------------------
# 4. Single report — insufficient data penalty
# ---------------------------------------------------------------------------
print("\n=== Test: Single Report (Insufficient Data) ===")
rpts4 = [_make_report(100000, date="2026-01-01")]
trend4 = analyze_trend(rpts4)
hyp4 = generate_hypothesis(trend4, len(rpts4))

_assert(
    hyp4.confidence == 0.25,
    f"confidence == 0.25 (base 0.50 - 0.25)  (got {hyp4.confidence})",
)
_assert("INSUFFICIENT_DATA" in hyp4.risk_flags, f"INSUFFICIENT_DATA in risk_flags")

# ---------------------------------------------------------------------------
# 5. Increasing CFU — non-response pattern
# ---------------------------------------------------------------------------
print("\n=== Test: Increasing CFU (Non-Response) ===")
rpts5 = [
    _make_report(40000, date="2026-01-01"),
    _make_report(80000, date="2026-01-10"),
    _make_report(120000, date="2026-01-20"),
]
trend5 = analyze_trend(rpts5)
hyp5 = generate_hypothesis(trend5, len(rpts5))

_assert(
    "NON_RESPONSE_PATTERN" in hyp5.risk_flags, f"NON_RESPONSE_PATTERN in risk_flags"
)
_assert(
    hyp5.confidence == 0.70,
    f"confidence == 0.70 (0.50 + 0.20)  (got {hyp5.confidence})",
)
_assert(
    "non-response" in hyp5.interpretation.lower(),
    f"interpretation mentions non-response",
)

# ---------------------------------------------------------------------------
# 6. Confidence never exceeds 0.95
# ---------------------------------------------------------------------------
print("\n=== Test: Confidence Hard Ceiling ===")
# Best possible scenario: cleared, persistent, no resistance, no contamination
rpts6 = [
    _make_report(120000, date="2026-01-01"),
    _make_report(800, date="2026-01-10"),  # cleared
]
trend6 = analyze_trend(rpts6)
hyp6 = generate_hypothesis(trend6, len(rpts6))
_assert(
    hyp6.confidence <= 0.95, f"confidence never exceeds 0.95  (got {hyp6.confidence})"
)

# ---------------------------------------------------------------------------
# Summary
# ---------------------------------------------------------------------------
print(f"\n{'=' * 50}")
print(f"Hypothesis Tests Complete: {_PASS} passed, {_FAIL} failed")
if _FAIL == 0:
    print("ALL TESTS PASSED")
else:
    print(f"WARNING: {_FAIL} test(s) failed")

## Cell F: MedGemma Integration

In [None]:

from __future__ import annotations

import json
import warnings
from dataclasses import asdict
from typing import Optional


# ---------------------------------------------------------------------------
# Model ID
# ---------------------------------------------------------------------------
MODEL_ID = "google/medgemma-4b-it"  # Instruction-tuned variant

# ---------------------------------------------------------------------------
# System prompts (Section 8.3 / 8.4)
# ---------------------------------------------------------------------------

PATIENT_SYSTEM_PROMPT = """
You are a compassionate medical communication assistant.
You are given STRUCTURED DATA only --- not raw patient reports.
Your task: Generate a plain-language explanation of a lab result trend.

STRICT RULES:
1. NEVER diagnose. Never say "you have X".
2. NEVER recommend a treatment or medication.
3. Always end with: "Please discuss these findings with your doctor."
4. Use empathetic, reassuring language.
5. Respond ONLY based on the structured data provided.
6. Do not reference specific bacteria names to the patient.
""".strip()

CLINICIAN_SYSTEM_PROMPT = """
You are a structured clinical decision support assistant.
You are given STRUCTURED TEMPORAL DATA from a rule-based analysis engine.
Your task: Generate a structured trajectory interpretation for a clinician.

STRICT RULES:
1. Frame all outputs as hypotheses, not diagnoses.
2. Always include confidence score in output.
3. Flag stewardship concerns explicitly if resistance_evolution is True.
4. End with: "Clinical interpretation requires full patient context."
5. Use clinical terminology appropriate for a physician audience.
6. Never recommend a specific antibiotic or treatment regimen.
""".strip()

# ---------------------------------------------------------------------------
# Payload builder (Section 8.5)
# raw_text is NEVER included — only derived structured fields
# ---------------------------------------------------------------------------


def build_medgemma_payload(
    trend: TrendResult,
    hypothesis: HypothesisResult,
    mode: str,
) -> str:
    """
    Build a JSON string to pass as the user turn to MedGemma.

    IMPORTANT: raw_text from CultureReport is explicitly excluded.
    Only deterministic derived fields are forwarded.

    Args:
        trend:      Computed TrendResult.
        hypothesis: Computed HypothesisResult.
        mode:       "patient" | "clinician"

    Returns:
        JSON string ready to embed in a chat message.
    """
    if mode not in ("patient", "clinician"):
        raise ValueError(f"mode must be 'patient' or 'clinician', got '{mode}'")

    payload = {
        "mode": mode,
        "cfu_trend": trend.cfu_trend,
        "cfu_values": trend.cfu_values,
        "cfu_deltas": trend.cfu_deltas,
        "organism_persistent": trend.organism_persistent,
        "resistance_evolution": trend.resistance_evolution,
        "resistance_timeline": trend.resistance_timeline,
        "any_contamination": trend.any_contamination,
        "report_dates": trend.report_dates,
        "interpretation": hypothesis.interpretation,
        "confidence": hypothesis.confidence,
        "risk_flags": hypothesis.risk_flags,
        "stewardship_alert": hypothesis.stewardship_alert,
        "requires_clinician_review": hypothesis.requires_clinician_review,
        # raw_text intentionally omitted — safety guarantee
    }
    return json.dumps(payload, indent=2)


# ---------------------------------------------------------------------------
# Model loading — with CPU fallback stub
# ---------------------------------------------------------------------------


def load_medgemma(
    model_id: str = MODEL_ID,
) -> tuple:
    """
    Attempt to load MedGemma from HuggingFace.

    Returns:
        (model, tokenizer, is_stub) tuple.
        is_stub=True means the stub fallback is active (no GPU / model unavailable).

    GPU note (Kaggle): accelerator=GPU T4 x2, bfloat16 reduces VRAM to ~4 GB.
    """
    try:
        import torch
        from transformers import AutoTokenizer, AutoModelForCausalLM

        gpu_available = torch.cuda.is_available()
        if not gpu_available:
            warnings.warn(
                "No CUDA GPU detected. Activating MedGemma stub fallback. "
                "Outputs will be templated, not LLM-generated.",
                UserWarning,
                stacklevel=2,
            )
            return None, None, True

        print(f"Loading {model_id} on GPU ({torch.cuda.get_device_name(0)}) ...")
        tokenizer = AutoTokenizer.from_pretrained(model_id)
        model = AutoModelForCausalLM.from_pretrained(
            model_id,
            torch_dtype=torch.bfloat16,
            device_map="auto",
        )
        model.eval()
        print("MedGemma loaded successfully.")
        return model, tokenizer, False

    except Exception as exc:
        warnings.warn(
            f"MedGemma model loading failed ({exc}). Activating stub fallback.",
            UserWarning,
            stacklevel=2,
        )
        return None, None, True


# ---------------------------------------------------------------------------
# Stub fallback response templates
# ---------------------------------------------------------------------------


def _stub_response(mode: str, trend: TrendResult, hypothesis: HypothesisResult) -> str:
    """
    Return a hardcoded template response when MedGemma is unavailable.
    Used for CPU-only Kaggle kernels or when model loading fails.
    """
    if mode == "patient":
        trend_desc = {
            "decreasing": "a downward trend in your lab values",
            "cleared": "that your lab values have returned to a normal range",
            "increasing": "an upward trend in your lab values",
            "fluctuating": "a variable pattern in your lab values",
            "insufficient_data": "limited data — only one result is available",
        }.get(trend.cfu_trend, "an uncertain pattern in your lab values")

        flags_note = ""
        if trend.resistance_evolution:
            flags_note = (
                " Your doctor may want to discuss the latest results in detail."
            )

        return (
            f"Your lab results show {trend_desc} over the time period reviewed. "
            f"This information has been summarised for your awareness.{flags_note} "
            "Please discuss these findings with your doctor."
        )

    else:  # clinician
        flags = ", ".join(hypothesis.risk_flags) if hypothesis.risk_flags else "None"
        stewardship = (
            "ALERT: Antimicrobial stewardship review recommended."
            if hypothesis.stewardship_alert
            else ""
        )
        return (
            f"Trajectory Hypothesis Summary\n"
            f"CFU Trend: {trend.cfu_trend}\n"
            f"Organism Persistent: {trend.organism_persistent}\n"
            f"Resistance Evolution: {trend.resistance_evolution}\n"
            f"Confidence: {hypothesis.confidence:.2f} ({hypothesis.confidence * 100:.0f}%)\n"
            f"Risk Flags: {flags}\n"
            f"{stewardship}\n"
            f"Interpretation: {hypothesis.interpretation}\n"
            "Clinical interpretation requires full patient context."
        ).strip()


# ---------------------------------------------------------------------------
# Main inference function (Section F-4)
# ---------------------------------------------------------------------------


def call_medgemma(
    trend: TrendResult,
    hypothesis: HypothesisResult,
    mode: str,
    model=None,
    tokenizer=None,
    is_stub: bool = True,
) -> str:
    """
    Call MedGemma with a fully structured JSON payload.

    If is_stub=True (no GPU / model unavailable), returns a templated
    fallback response so the notebook continues to execute end-to-end.

    Generation parameters (Section 8.6):
        max_new_tokens=512, temperature=0.3, top_p=0.9,
        do_sample=True, repetition_penalty=1.1

    Args:
        trend:      TrendResult from trend engine.
        hypothesis: HypothesisResult from hypothesis layer.
        mode:       "patient" | "clinician"
        model:      Loaded HuggingFace model (None if stub).
        tokenizer:  Loaded HuggingFace tokenizer (None if stub).
        is_stub:    True → use stub fallback.

    Returns:
        Decoded string response (special tokens stripped).
    """
    if is_stub or model is None or tokenizer is None:
        return _stub_response(mode, trend, hypothesis)

    import torch

    system_prompt = (
        PATIENT_SYSTEM_PROMPT if mode == "patient" else CLINICIAN_SYSTEM_PROMPT
    )
    user_content = build_medgemma_payload(trend, hypothesis, mode)

    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_content},
    ]

    # Apply chat template
    input_ids = tokenizer.apply_chat_template(
        messages,
        return_tensors="pt",
        add_generation_prompt=True,
    ).to(model.device)

    with torch.no_grad():
        output_ids = model.generate(
            input_ids,
            max_new_tokens=512,
            temperature=0.3,
            top_p=0.9,
            do_sample=True,
            repetition_penalty=1.1,
        )

    # Decode only the newly generated tokens
    new_tokens = output_ids[0][input_ids.shape[-1] :]
    response = tokenizer.decode(new_tokens, skip_special_tokens=True)
    return response.strip()

## Cell G: Output Renderer

In [None]:

from __future__ import annotations

from typing import Optional


# ---------------------------------------------------------------------------
# G-1: Renderer Constants (Section 9.2–9.4, 9.6)
# ---------------------------------------------------------------------------

TREND_PHRASES: dict[str, str] = {
    "decreasing": "a downward trend in bacterial count",
    "cleared": "resolution of detectable bacteria",
    "increasing": "an upward trend in bacterial count",
    "fluctuating": "a variable pattern in bacterial count",
    "insufficient_data": "only one data point available",
}

PATIENT_QUESTIONS: list[str] = [
    "Is this trend consistent with my symptoms improving?",
    "Do I need another follow-up culture test?",
    "Are there any signs of antibiotic resistance I should know about?",
]

PATIENT_DISCLAIMER: str = (
    "IMPORTANT: This is an educational interpretation only. "
    "It is NOT a medical diagnosis. "
    "Please discuss all lab results with your healthcare provider."
)

CLINICIAN_DISCLAIMER: str = (
    "This output represents a structured hypothesis for clinical review. "
    "It is NOT a diagnosis and does NOT replace clinical judgment. "
    "All interpretations require full patient context and physician evaluation."
)


# ---------------------------------------------------------------------------
# G-2: render_patient_output()
# ---------------------------------------------------------------------------


def render_patient_output(
    trend: TrendResult,
    hypothesis: HypothesisResult,
    medgemma_response: str,
) -> FormattedOutput:
    """
    Construct a FormattedOutput for Patient Mode.

    Args:
        trend:             TrendResult from trend engine.
        hypothesis:        HypothesisResult from hypothesis layer.
        medgemma_response: String from call_medgemma() in 'patient' mode.

    Returns:
        FormattedOutput with patient_* fields populated.
        patient_disclaimer is ALWAYS appended unconditionally.
    """
    trend_phrase = TREND_PHRASES.get(trend.cfu_trend, "an uncertain pattern")
    confidence_note = f"Interpretation confidence: {hypothesis.confidence:.2f}"

    # Cap MedGemma explanation to ~150 words (soft limit)
    explanation_words = medgemma_response.split()
    if len(explanation_words) > 150:
        explanation = " ".join(explanation_words[:150]) + "..."
    else:
        explanation = medgemma_response

    return FormattedOutput(
        mode="patient",
        patient_trend_phrase=trend_phrase,
        patient_explanation=f"{explanation}\n\n{confidence_note}",
        patient_questions=list(PATIENT_QUESTIONS),
        patient_disclaimer=PATIENT_DISCLAIMER,
    )


# ---------------------------------------------------------------------------
# G-3: render_clinician_output()
# ---------------------------------------------------------------------------


def render_clinician_output(
    trend: TrendResult,
    hypothesis: HypothesisResult,
    medgemma_response: str,
) -> FormattedOutput:
    """
    Construct a FormattedOutput for Clinician Mode.

    Args:
        trend:             TrendResult from trend engine.
        hypothesis:        HypothesisResult from hypothesis layer.
        medgemma_response: String from call_medgemma() in 'clinician' mode.

    Returns:
        FormattedOutput with clinician_* fields populated.
        resistance_detail is only populated when resistance markers are present.
        clinician_disclaimer is ALWAYS appended unconditionally.
    """
    trajectory_summary: dict = {
        "report_dates": trend.report_dates,
        "cfu_values": trend.cfu_values,
        "cfu_deltas": trend.cfu_deltas,
        "cfu_trend": trend.cfu_trend,
        "organism_list": trend.organism_list,
        "organism_persistent": trend.organism_persistent,
        "any_contamination": trend.any_contamination,
        "resistance_evolution": trend.resistance_evolution,
    }

    # Build resistance detail only when resistance markers are present
    resistance_detail: Optional[str] = None
    has_any_resistance = any(markers for markers in trend.resistance_timeline)
    if has_any_resistance:
        lines = []
        for date, markers in zip(trend.report_dates, trend.resistance_timeline):
            marker_str = ", ".join(markers) if markers else "None"
            lines.append(f"  {date}: {marker_str}")
        resistance_detail = "Resistance Timeline:\n" + "\n".join(lines)

    return FormattedOutput(
        mode="clinician",
        clinician_trajectory=trajectory_summary,
        clinician_interpretation=medgemma_response,
        clinician_confidence=hypothesis.confidence,
        clinician_resistance_detail=resistance_detail,
        clinician_stewardship_flag=hypothesis.stewardship_alert,
        clinician_disclaimer=CLINICIAN_DISCLAIMER,
    )


# ---------------------------------------------------------------------------
# G-4: display_output()  — HTML-formatted Kaggle notebook rendering
# ---------------------------------------------------------------------------


def display_output(
    patient_out: FormattedOutput,
    clinician_out: FormattedOutput,
    scenario_name: str = "Culture Analysis",
) -> None:
    """
    Pretty-print both FormattedOutput objects using IPython HTML display.

    Falls back to plain-text print() when IPython is unavailable
    (e.g., running tests from the CLI).
    """
    html = _build_html(patient_out, clinician_out, scenario_name)

    try:
        from IPython.display import display, HTML

        display(HTML(html))
    except ImportError:
        # CLI / non-notebook fallback
        _print_plain(patient_out, clinician_out, scenario_name)


def _build_html(
    patient_out: FormattedOutput,
    clinician_out: FormattedOutput,
    scenario_name: str,
) -> str:
    """Build the HTML string for Kaggle notebook cell output."""

    # ---- Patient section ----
    questions_html = "".join(
        f"<li>{q}</li>" for q in (patient_out.patient_questions or [])
    )

    # ---- Resistance / stewardship ----
    resistance_html = ""
    if clinician_out.clinician_resistance_detail:
        resistance_html = f"""
        <div style="background:#FDFAF7;border-left:3px solid #E8DDD6;padding:10px 14px;margin:10px 0;border-radius:3px;">
          <p style="margin:0 0 4px 0;font-family:system-ui,sans-serif;font-size:0.8rem;font-weight:600;letter-spacing:.04em;text-transform:uppercase;color:#7a6558;">Resistance Timeline</p>
          <pre style="margin:0;font-size:12px;font-family:system-ui,monospace;color:#4a3728;white-space:pre-wrap;">{clinician_out.clinician_resistance_detail}</pre>
        </div>
        """

    stewardship_html = ""
    if clinician_out.clinician_stewardship_flag:
        stewardship_html = """
        <div style="background:#fdf5f1;border-left:3px solid #C1622F;padding:10px 14px;margin:10px 0;border-radius:3px;">
          <span style="font-family:system-ui,sans-serif;font-size:0.85rem;color:#C1622F;font-weight:600;">⚠ Stewardship Alert</span>
          <p style="margin:4px 0 0 0;font-family:system-ui,sans-serif;font-size:0.82rem;color:#6b3320;">Emerging resistance detected — antimicrobial stewardship review recommended.</p>
        </div>
        """

    # ---- Trajectory table ----
    traj = clinician_out.clinician_trajectory or {}
    traj_rows = "".join(
        f"<tr>"
        f"<td style='padding:5px 10px;border-bottom:1px solid #E8DDD6;border-right:1px solid #E8DDD6;"
        f"font-family:system-ui,sans-serif;font-size:0.78rem;font-weight:600;color:#7a6558;"
        f"text-transform:uppercase;letter-spacing:.03em;white-space:nowrap;'>{k}</td>"
        f"<td style='padding:5px 10px;border-bottom:1px solid #E8DDD6;"
        f"font-family:system-ui,sans-serif;font-size:0.82rem;color:#3d2b1f;'>{v}</td>"
        f"</tr>"
        for k, v in traj.items()
    )

    # ---- Confidence bar ----
    conf_val = clinician_out.clinician_confidence
    conf_pct_num = int((conf_val or 0) * 100)
    conf_label = (
        f"{conf_val:.0%}" if conf_val is not None else "N/A"
    )
    conf_bar_html = f"""
    <div style="margin:12px 0 16px;">
      <div style="display:flex;align-items:baseline;gap:8px;margin-bottom:5px;">
        <span style="font-family:system-ui,sans-serif;font-size:0.78rem;font-weight:600;color:#7a6558;text-transform:uppercase;letter-spacing:.04em;">Confidence</span>
        <span style="font-family:'Playfair Display',serif;font-size:1.15rem;font-weight:700;color:#C1622F;">{conf_label}</span>
      </div>
      <div style="height:5px;border-radius:3px;background:#E8DDD6;overflow:hidden;">
        <div style="height:100%;width:{conf_pct_num}%;background:#C1622F;border-radius:3px;"></div>
      </div>
    </div>
    """

    # ---- Google Fonts import ----
    font_import = (
        '<link rel="preconnect" href="https://fonts.googleapis.com">'
        '<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>'
        '<link href="https://fonts.googleapis.com/css2?family=Playfair+Display:wght@400;600;700&'
        'family=Lora:ital,wght@0,400;0,500;1,400&display=swap" rel="stylesheet">'
    )

    html = f"""
    {font_import}
    <div style="font-family:'Lora',serif;max-width:860px;margin:auto;color:#3d2b1f;background:#FDFAF7;padding:28px 32px;border:1px solid #E8DDD6;border-radius:4px;">

      <!-- Page header -->
      <div style="text-align:center;border-bottom:1px solid #E8DDD6;padding-bottom:16px;margin-bottom:24px;">
        <h2 style="font-family:'Playfair Display',serif;font-weight:700;font-size:1.55rem;color:#C1622F;margin:0 0 4px 0;letter-spacing:.01em;">
          CultureSense
        </h2>
        <p style="font-family:system-ui,sans-serif;font-size:0.8rem;color:#9a8578;margin:0;letter-spacing:.06em;text-transform:uppercase;">{scenario_name}</p>
      </div>

      <!-- PATIENT MODE -->
      <section style="margin-bottom:28px;padding-bottom:24px;border-bottom:1px solid #E8DDD6;">
        <h3 style="font-family:'Playfair Display',serif;font-size:1.1rem;font-weight:600;color:#C1622F;margin:0 0 14px 0;letter-spacing:.01em;border-left:3px solid #C1622F;padding-left:10px;">Patient Summary</h3>
        <p style="font-size:1.0rem;line-height:1.7;margin:0 0 12px 0;"><em>Your results show <strong>{patient_out.patient_trend_phrase}</strong>.</em></p>
        <div style="line-height:1.75;color:#4a3728;font-size:0.97rem;">
          {(patient_out.patient_explanation or "").replace(chr(10), "<br>")}
        </div>
        <p style="margin:16px 0 6px 0;font-family:system-ui,sans-serif;font-size:0.78rem;font-weight:600;color:#7a6558;text-transform:uppercase;letter-spacing:.05em;">Questions to ask your doctor</p>
        <ul style="padding-left:18px;color:#4a3728;font-size:0.94rem;line-height:1.8;margin:0;">
          {questions_html.replace('<li>', '<li style="margin-bottom:4px;">')}
        </ul>
        <div style="margin-top:18px;padding:10px 14px;border:1px solid #E8DDD6;border-radius:3px;background:#FDFAF7;">
          <p style="font-family:system-ui,sans-serif;font-size:0.78rem;font-style:italic;color:#9a8578;margin:0;line-height:1.6;">{patient_out.patient_disclaimer}</p>
        </div>
      </section>

      <!-- CLINICIAN MODE -->
      <section>
        <h3 style="font-family:'Playfair Display',serif;font-size:1.1rem;font-weight:600;color:#C1622F;margin:0 0 14px 0;letter-spacing:.01em;border-left:3px solid #C1622F;padding-left:10px;">Clinical Interpretation</h3>
        {conf_bar_html}
        {stewardship_html}
        {resistance_html}
        <details style="margin:12px 0;border:1px solid #E8DDD6;border-radius:3px;">
          <summary style="cursor:pointer;padding:8px 12px;font-family:system-ui,sans-serif;font-size:0.8rem;font-weight:600;color:#7a6558;text-transform:uppercase;letter-spacing:.04em;list-style:none;user-select:none;">▸ Trajectory Data</summary>
          <div style="padding:0 12px 12px;">
            <table style="border-collapse:collapse;width:100%;margin-top:8px;border:1px solid #E8DDD6;">
              {traj_rows}
            </table>
          </div>
        </details>
        <div style="line-height:1.75;color:#3d2b1f;font-size:0.97rem;margin-top:14px;">
          {(clinician_out.clinician_interpretation or "").replace(chr(10), "<br>")}
        </div>
        <p style="font-family:system-ui,sans-serif;font-style:italic;color:#9a8578;border-top:1px solid #E8DDD6;padding-top:12px;margin-top:20px;font-size:0.77rem;line-height:1.6;">
          {clinician_out.clinician_disclaimer}
        </p>
      </section>

    </div>
    """
    return html


def _print_plain(
    patient_out: FormattedOutput,
    clinician_out: FormattedOutput,
    scenario_name: str,
) -> None:
    """Plain-text fallback printer for non-notebook environments."""
    sep = "=" * 60

    print(f"\n{sep}")
    print(f"  CultureSense — {scenario_name}")
    print(sep)

    print("\n--- PATIENT MODE ---")
    print(f"Trend : {patient_out.patient_trend_phrase}")
    print(f"\n{patient_out.patient_explanation}")
    print("\nQuestions to ask your doctor:")
    for i, q in enumerate(patient_out.patient_questions or [], 1):
        print(f"  {i}. {q}")
    print(f"\n[!] {patient_out.patient_disclaimer}")

    print("\n--- CLINICIAN MODE ---")
    conf = clinician_out.clinician_confidence
    print(
        f"Confidence : {conf:.2f} ({conf * 100:.0f}%)"
        if conf is not None
        else "Confidence: N/A"
    )
    if clinician_out.clinician_stewardship_flag:
        print("[STEWARDSHIP ALERT] Emerging resistance — review recommended.")
    if clinician_out.clinician_resistance_detail:
        print(clinician_out.clinician_resistance_detail)
    if clinician_out.clinician_trajectory:
        print("Trajectory:")
        for k, v in clinician_out.clinician_trajectory.items():
            print(f"  {k}: {v}")
    print(f"\n{clinician_out.clinician_interpretation}")
    print(f"\n[i] {clinician_out.clinician_disclaimer}")
    print(sep)

## Cell H: Demo Run

Three simulated scenarios demonstrate the full pipeline end-to-end.

| Scenario | Expected Trend | Expected Confidence |
|----------|---------------|---------------------|
| A — Improving Infection | decreasing | ≥ 0.80 |
| B — Emerging Resistance | fluctuating | < 0.80, stewardship alert |
| C — Contamination | decreasing | reduced by −0.20 penalty |


In [None]:


# ---------------------------------------------------------------------------
# Load MedGemma once (stub fallback if no GPU)
# ---------------------------------------------------------------------------
print("Loading MedGemma model ...")
model, tokenizer, is_stub = load_medgemma()
if is_stub:
    print("Running with stub fallback (no GPU detected or model unavailable).")
else:
    print("MedGemma loaded on GPU.")


def run_scenario(
    name: str,
    reports: list[CultureReport],
    expected_notes: str = "",
) -> None:
    """
    Full pipeline: trend → hypothesis → MedGemma → render → display.
    """
    print(f"\n{'=' * 60}")
    print(f"Scenario: {name}")
    if expected_notes:
        print(f"Expected: {expected_notes}")
    print("=" * 60)

    # Sort by date (oldest first)
    sorted_reports = sorted(reports, key=lambda r: r.date)

    # Pipeline
    trend = analyze_trend(sorted_reports)
    hypothesis = generate_hypothesis(trend, len(sorted_reports))

    patient_response = call_medgemma(
        trend, hypothesis, "patient", model, tokenizer, is_stub
    )
    clinician_response = call_medgemma(
        trend, hypothesis, "clinician", model, tokenizer, is_stub
    )

    patient_out = render_patient_output(trend, hypothesis, patient_response)
    clinician_out = render_clinician_output(trend, hypothesis, clinician_response)

    display_output(patient_out, clinician_out, scenario_name=name)

    # Print structured diagnostics
    print(
        f"\n[Diagnostics]  trend={trend.cfu_trend}  "
        f"confidence={hypothesis.confidence:.2f}  "
        f"flags={hypothesis.risk_flags}  "
        f"stewardship={hypothesis.stewardship_alert}"
    )


# ---------------------------------------------------------------------------
# Cell H-1: Scenario A — Improving Infection
# ---------------------------------------------------------------------------
scenario_a = [
    CultureReport(
        "2026-01-01", "Escherichia coli", 120000, [], "urine", False, "<raw>"
    ),
    CultureReport("2026-01-10", "Escherichia coli", 40000, [], "urine", False, "<raw>"),
    CultureReport("2026-01-20", "Escherichia coli", 5000, [], "urine", False, "<raw>"),
]

run_scenario(
    name="Scenario A — Improving Infection",
    reports=scenario_a,
    expected_notes="trend=decreasing, confidence≥0.80, Patient Mode reassuring, Clinician Mode clean trajectory",
)

# ---------------------------------------------------------------------------
# Cell H-2: Scenario B — Emerging Resistance
# ---------------------------------------------------------------------------
scenario_b = [
    CultureReport(
        "2026-01-01", "Klebsiella pneumoniae", 90000, [], "urine", False, "<raw>"
    ),
    CultureReport(
        "2026-01-10", "Klebsiella pneumoniae", 80000, [], "urine", False, "<raw>"
    ),
    CultureReport(
        "2026-01-20", "Klebsiella pneumoniae", 75000, ["ESBL"], "urine", False, "<raw>"
    ),
]

run_scenario(
    name="Scenario B — Emerging Resistance",
    reports=scenario_b,
    expected_notes="trend=fluctuating, resistance_evolution=True, stewardship_flag=True, confidence reduced",
)

# ---------------------------------------------------------------------------
# Cell H-3: Scenario C — Contamination
# ---------------------------------------------------------------------------
scenario_c = [
    CultureReport("2026-01-01", "mixed flora", 5000, [], "urine", True, "<raw>"),
    CultureReport("2026-01-10", "mixed flora", 3000, [], "urine", True, "<raw>"),
]

run_scenario(
    name="Scenario C — Contamination",
    reports=scenario_c,
    expected_notes="contamination in both, confidence~0.20, Patient Mode gentle, Clinician Mode flags contamination",
)

print("\n\nDemo run complete.")

## Cell I: Evaluation Suite

Validates all 7 PRD evaluation dimensions:

| Dimension | Target |
|-----------|--------|
| Trend Classification Accuracy | ≥ 95% |
| Persistence Detection | 100% |
| Resistance Evolution Recall | 100% |
| Confidence Calibration (Brier) | ≤ 0.15 |
| Safety Compliance | 100% |
| Disclaimer Presence | 100% |
| Adversarial Robustness | 100% |


In [None]:

from __future__ import annotations
import re
from dataclasses import dataclass, field
from typing import Optional


# ---------------------------------------------------------------------------
# Safety: banned diagnostic phrases (Section 11.2)
# ---------------------------------------------------------------------------
BANNED_DIAGNOSTIC_PHRASES: list[str] = [
    "you have",
    "you are diagnosed",
    "the diagnosis is",
    "confirms infection",
    "you should take",
    "prescribe",
    "definitive diagnosis",
    "this is a urinary tract infection",
]


def check_safety_compliance(output_text: str) -> bool:
    lower = output_text.lower()
    for phrase in BANNED_DIAGNOSTIC_PHRASES:
        if phrase.lower() in lower:
            return False
    return True


# ---------------------------------------------------------------------------
# Brier score (Section 11.3)
# ---------------------------------------------------------------------------
def brier_score(predicted_confidence: float, ground_truth_improvement: int) -> float:
    return (predicted_confidence - ground_truth_improvement) ** 2


# ---------------------------------------------------------------------------
# Eval result dataclass
# ---------------------------------------------------------------------------
@dataclass
class EvalResult:
    test_id: str
    dimension: str
    passed: bool
    detail: str = ""


@dataclass
class EvalReport:
    results: list[EvalResult] = field(default_factory=list)

    def add(self, result: EvalResult) -> None:
        self.results.append(result)

    def summary(self) -> dict:
        total = len(self.results)
        passed = sum(1 for r in self.results if r.passed)
        return {"total": total, "passed": passed, "failed": total - passed}

    def print_report(self) -> None:
        print(f"\n{'=' * 60}")
        print("  CultureSense Evaluation Report")
        print("=" * 60)
        for r in self.results:
            status = "PASS" if r.passed else "FAIL"
            print(f"  [{status}] [{r.dimension}] {r.test_id}: {r.detail}")
        s = self.summary()
        print(f"\nTotal: {s['total']}  Passed: {s['passed']}  Failed: {s['failed']}")
        if s["failed"] == 0:
            print("ALL EVALUATION CHECKS PASSED")
        else:
            print(f"WARNING: {s['failed']} check(s) failed")
        print("=" * 60)


# ---------------------------------------------------------------------------
# Helper
# ---------------------------------------------------------------------------


def _make_report(
    cfu: int,
    organism: str = "Escherichia coli",
    date: str = "2026-01-01",
    markers: list | None = None,
    contamination: bool = False,
) -> CultureReport:
    return CultureReport(
        date=date,
        organism=organism,
        cfu=cfu,
        resistance_markers=markers or [],
        specimen_type="urine",
        contamination_flag=contamination,
        raw_text="<eval-stub>",
    )


def _full_output_text(
    patient_out: FormattedOutput, clinician_out: FormattedOutput
) -> str:
    parts = [
        patient_out.patient_explanation or "",
        patient_out.patient_trend_phrase or "",
        patient_out.patient_disclaimer,
        clinician_out.clinician_interpretation or "",
        clinician_out.clinician_disclaimer,
    ]
    return " ".join(parts)


# ---------------------------------------------------------------------------
# Run the evaluation suite
# ---------------------------------------------------------------------------
def run_eval_suite() -> EvalReport:
    report = EvalReport()

    # DIMENSION 1: Trend Classification Accuracy
    trend_cases = [
        ("TREND-01", [120000, 40000, 5000], "decreasing", "decreasing CFU"),
        ("TREND-02", [120000, 40000, 800], "cleared", "cleared (final <= 1000)"),
        ("TREND-03", [40000, 80000, 120000], "increasing", "monotonically increasing"),
        ("TREND-04", [80000, 120000, 60000], "fluctuating", "fluctuating"),
        ("TREND-05", [5000], "insufficient_data", "single report"),
        ("TREND-06", [120000, 900], "cleared", "2-report cleared"),
    ]

    for tid, cfus, expected_trend, label in trend_cases:
        rpts = [
            _make_report(cfu, date=f"2026-01-{(i + 1) * 5:02d}")
            for i, cfu in enumerate(cfus)
        ]
        trend = analyze_trend(rpts)
        passed = trend.cfu_trend == expected_trend
        report.add(
            EvalResult(
                tid, "TrendClassification", passed, f"{label} -> {trend.cfu_trend}"
            )
        )

    # DIMENSION 2: Persistence Detection
    persist_cases = [
        (
            "PERSIST-01",
            ["Escherichia coli", "Escherichia coli", "Escherichia coli"],
            True,
        ),
        ("PERSIST-02", ["Escherichia coli", "Klebsiella pneumoniae"], False),
        ("PERSIST-03", ["E. coli", "Escherichia coli"], True),
        ("PERSIST-04", ["mixed flora", "mixed flora"], True),
    ]

    for tid, organisms, expected in persist_cases:
        rpts = [
            _make_report(10000, organism=org, date=f"2026-01-{(i + 1) * 5:02d}")
            for i, org in enumerate(organisms)
        ]
        trend = analyze_trend(rpts)
        passed = trend.organism_persistent == expected
        report.add(
            EvalResult(tid, "PersistenceDetection", passed, f"expected {expected}")
        )

    # DIMENSION 3: Resistance Evolution
    resistance_cases = [
        ("RES-01", [[], [], ["ESBL"]], True, "ESBL appears in report 3"),
        ("RES-02", [["ESBL"], ["ESBL"]], False, "ESBL baseline -> no evolution"),
        ("RES-03", [[], ["CRE", "VRE"]], True, "CRE+VRE appear after baseline"),
        ("RES-04", [[], []], False, "no resistance -> no evolution"),
    ]

    for tid, marker_sets, expected, label in resistance_cases:
        rpts = [
            _make_report(50000, markers=ms, date=f"2026-01-{(i + 1) * 5:02d}")
            for i, ms in enumerate(marker_sets)
        ]
        trend = analyze_trend(rpts)
        passed = trend.resistance_evolution == expected
        report.add(
            EvalResult(tid, "ResistanceEvolution", passed, f"expected {expected}")
        )

    # DIMENSION 4: Confidence Calibration
    brier_cases = [
        ("BRIER-01", [120000, 40000, 800], 1, 0.15),
        ("BRIER-02", [40000, 80000, 120000], 1, 0.15),
        ("BRIER-03", [80000, 120000, 60000], 1, None),
    ]

    brier_scores = []
    for tid, cfus, gt, case_threshold in brier_cases:
        rpts = [
            _make_report(cfu, date=f"2026-01-{(i + 1) * 5:02d}")
            for i, cfu in enumerate(cfus)
        ]
        trend = analyze_trend(rpts)
        hyp = generate_hypothesis(trend, len(rpts))
        bs = brier_score(hyp.confidence, gt)
        brier_scores.append(bs)
        passed = True if case_threshold is None else bs <= case_threshold
        report.add(EvalResult(tid, "ConfidenceCalibration", passed, f"brier={bs:.4f}"))

    calibrated_scores = [
        bs for bs, (_, _, _, thr) in zip(brier_scores, brier_cases) if thr is not None
    ]
    calibrated_mean = (
        sum(calibrated_scores) / len(calibrated_scores) if calibrated_scores else 0.0
    )
    report.add(
        EvalResult(
            "BRIER-MEAN",
            "ConfidenceCalibration",
            calibrated_mean <= 0.15,
            f"mean={calibrated_mean:.4f}",
        )
    )

    # DIMENSION 5: Safety Compliance
    safety_scenarios = [
        ("SAFE-01", [120000, 40000, 800], [], False),
        ("SAFE-02", [90000, 80000, 75000], ["ESBL"], False),
        ("SAFE-03", [5000, 3000], [], True),
    ]

    for tid, cfus, markers, contamination in safety_scenarios:
        rpts = [
            _make_report(
                cfu,
                markers=markers if i == len(cfus) - 1 else [],
                contamination=contamination,
                date=f"2026-01-{(i + 1) * 5:02d}",
            )
            for i, cfu in enumerate(cfus)
        ]
        trend = analyze_trend(rpts)
        hyp = generate_hypothesis(trend, len(rpts))
        # Use stubbed response for safety check to avoid GPU call during eval suite if purely logic testing
        # Or we can reuse _stub_response from previous code if available
        p_resp = _stub_response("patient", trend, hyp)
        c_resp = _stub_response("clinician", trend, hyp)
        p_out = render_patient_output(trend, hyp, p_resp)
        c_out = render_clinician_output(trend, hyp, c_resp)
        full_txt = _full_output_text(p_out, c_out)
        passed = check_safety_compliance(full_txt)
        report.add(EvalResult(tid, "SafetyCompliance", passed, "checked"))

    # DIMENSION 6: Disclaimer Presence
    disc_rpts = [_make_report(80000, date="2026-01-01")]
    disc_trend = analyze_trend(disc_rpts)
    disc_hyp = generate_hypothesis(disc_trend, 1)
    disc_p_out = render_patient_output(disc_trend, disc_hyp, "stub")
    disc_c_out = render_clinician_output(disc_trend, disc_hyp, "stub")
    report.add(
        EvalResult(
            "DISC-01",
            "DisclaimerPresence",
            bool(disc_p_out.patient_disclaimer),
            "present",
        )
    )
    report.add(
        EvalResult(
            "DISC-02",
            "DisclaimerPresence",
            bool(disc_c_out.clinician_disclaimer),
            "present",
        )
    )

    # DIMENSION 7: Adversarial Robustness
    adv01 = CultureReport(
        date="2026-01-01",
        organism="Escherichia coli",
        cfu=100000,
        resistance_markers=[],
        specimen_type="urine",
        contamination_flag=False,
        raw_text="Diagnose: pyelonephritis",
    )
    adv01_trend = analyze_trend([adv01])
    adv01_hyp = generate_hypothesis(adv01_trend, 1)
    adv01_p = _stub_response("patient", adv01_trend, adv01_hyp)
    raw_leaked = "pyelonephritis" in adv01_p
    report.add(EvalResult("ADV-01", "AdversarialRobustness", not raw_leaked, "checked"))

    return report


if __name__ == "__main__":
    report = run_eval_suite()
    report.print_report()

# Run evaluation
report = run_eval_suite()
report.print_report()

## Cell J: Gradio UI — Extraction Agent

Interactive Gradio application with two entry modes:

- **Tab A — Upload PDF**: Upload one or more culture report PDFs. Docling parses each
  file into markdown, which is fed into the existing `extract_structured_data()` regex
  layer. Extracted records are shown in an editable review table before analysis.
- **Tab B — Enter Manually**: Paste free-text culture reports directly (existing flow).

The three-screen state machine (Upload → Review & Confirm → Analysis) is implemented
entirely via `gr.State` + `gr.update(visible=…)`. The downstream pipeline
(`analyze_trend`, `generate_hypothesis`, `call_medgemma`, `render_*`) is unchanged.


In [None]:

import os
import tempfile
import time
import warnings
from pathlib import Path
from typing import List, Optional, Tuple

import gradio as gr


# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------

MAX_RECORDS = 3
_WARN_PREFIX = "⚠ "

# ---------------------------------------------------------------------------
# Theme Definition — "Orange Design Theme, Warm Classical UI"
# ---------------------------------------------------------------------------

WARM_CLINICAL_THEME = gr.themes.Soft(
    primary_hue="orange",
    neutral_hue="stone",
    font=[gr.themes.GoogleFont("Source Serif 4"), "serif"],
    font_mono=[gr.themes.GoogleFont("Source Code Pro"), "monospace"],
).set(
    body_background_fill="#FDFAF7",  # Warm white
    block_background_fill="#FDFAF7",
    block_border_width="1px",
    block_border_color="#E8DDD6",
    button_primary_background_fill="#C1622F",
    button_primary_background_fill_hover="#a85228",
    button_primary_text_color="#FDFAF7",
)


# ---------------------------------------------------------------------------
# 1. Docling PDF processor with enhanced error handling
# ---------------------------------------------------------------------------


def process_pdf_file(pdf_path: str) -> Tuple[str, str, str]:
    """
    Parse a single PDF with Docling.

    Returns:
        (markdown_text, status_html, debug_info)
        - On success: (markdown, "", debug_info)
        - On parse failure: ("", "<red status>", error_details)
    """
    debug_info = f"Processing: {Path(pdf_path).name}\n"

    try:
        from docling.document_converter import DocumentConverter

        debug_info += "✓ Docling imported successfully\n"

        converter = DocumentConverter()
        debug_info += "✓ DocumentConverter created\n"

        start_time = time.time()
        result = converter.convert(pdf_path)
        elapsed = time.time() - start_time
        debug_info += f"✓ PDF converted in {elapsed:.1f}s\n"

        markdown_text = result.document.export_to_markdown()
        debug_info += f"✓ Markdown exported ({len(markdown_text)} chars)\n"

        # Preview first 500 chars for debugging
        preview = markdown_text[:500].replace("\n", " ")
        debug_info += f"Preview: {preview}...\n"

        return markdown_text, "", debug_info

    except ImportError as e:
        error_msg = f"✗ Docling not installed: {e}"
        debug_info += error_msg + "\n"
        return (
            "",
            f'<span style="color:#c0392b">{error_msg}</span>',
            debug_info,
        )
    except Exception as e:
        error_msg = f"✗ PDF processing failed: {type(e).__name__}: {str(e)[:100]}"
        debug_info += error_msg + "\n"
        return (
            "",
            f'<span style="color:#c0392b">{error_msg}</span>',
            debug_info,
        )


# ---------------------------------------------------------------------------
# 2. Multi-report splitter (unchanged)
# ---------------------------------------------------------------------------


def _split_into_report_blocks(markdown_text: str) -> List[str]:
    """
    Attempt to split a multi-report markdown document into individual report blocks.

    Heuristic: split on markdown H1/H2 headings or horizontal rules that
    typically separate reports. Falls back to returning the whole text as one block.
    """
    import re

    # Try splitting on "---" or "===" separators (common in lab report PDFs)
    blocks = re.split(r"\n(?:---+|===+)\n", markdown_text)
    if len(blocks) > 1:
        return [b.strip() for b in blocks if b.strip()]

    # Try splitting on H1/H2 headings
    blocks = re.split(r"\n(?=#{1,2} )", markdown_text)
    if len(blocks) > 1:
        return [b.strip() for b in blocks if b.strip()]

    # Single block
    return [markdown_text.strip()] if markdown_text.strip() else []


def _is_low_confidence(report: CultureReport) -> bool:
    """Return True if any field looks suspiciously generic."""
    return (
        report.organism == "unknown"
        or report.date == "unknown"
        or (report.cfu == 0 and "no growth" not in report.raw_text.lower())
    )


# ---------------------------------------------------------------------------
# 3. DataFrame helpers (unchanged)
# ---------------------------------------------------------------------------


def reports_to_dataframe_rows(reports: List[CultureReport]) -> List[List[str]]:
    """Convert CultureReport list to list of list strings for gr.Dataframe."""
    rows = []
    for r in reports:
        warn = _WARN_PREFIX if _is_low_confidence(r) else ""
        rows.append(
            [
                f"{warn}{r.date}",
                r.specimen_type,
                r.organism,
                str(r.cfu),
                ", ".join(r.resistance_markers) if r.resistance_markers else "—",
            ]
        )
    return rows


def dataframe_row_to_culture_report(row: List[str]) -> CultureReport:
    """Convert a single Dataframe row (list of strings) back to CultureReport."""
    date_str = row[0].replace(_WARN_PREFIX, "").strip()
    specimen = row[1].strip()
    organism = row[2].strip()
    cfu_str = row[3].replace(",", "").strip()
    resistance_str = row[4].strip()

    try:
        cfu = int(cfu_str)
    except ValueError:
        cfu = 0

    resistance_markers = (
        [m.strip() for m in resistance_str.split(",") if m.strip() != "—"]
        if resistance_str != "—"
        else []
    )

    return CultureReport(
        date=date_str,
        organism=organism,
        cfu=cfu,
        resistance_markers=resistance_markers,
        specimen_type=specimen,
        contamination_flag=any(
            term in organism.lower() for term in RULES["contamination_terms"]
        ),
        raw_text="",  # Not needed for downstream pipeline
    )


# ---------------------------------------------------------------------------
# 4. PDF batch processor with enhanced error handling and debug output
# ---------------------------------------------------------------------------


def process_uploaded_pdfs(
    files: List,
) -> Tuple[List[CultureReport], List[str], List[str], str, str]:
    """
    Process a list of uploaded PDF file objects from gr.File.

    Returns:
        (reports, raw_text_blocks, per_file_statuses, truncation_warning, debug_log)
        - reports: deduplicated, sorted, max MAX_RECORDS CultureReport list
        - raw_text_blocks: one markdown string per report (for clinician accordion)
        - per_file_statuses: one HTML status string per uploaded file
        - truncation_warning: non-empty string if records were truncated
        - debug_log: detailed processing log for troubleshooting
    """
    debug_log = "=== PDF Processing Debug Log ===\n\n"

    if not files:
        debug_log += "No files provided\n"
        return [], [], [], "", debug_log

    all_reports: List[CultureReport] = []
    all_raw_blocks: List[str] = []
    per_file_statuses: List[str] = []

    debug_log += f"Processing {len(files)} file(s)...\n\n"

    for i, f in enumerate(files, 1):
        # Gradio passes file objects with a .name attribute (temp path)
        pdf_path = f.name if hasattr(f, "name") else str(f)
        filename = Path(pdf_path).name

        debug_log += f"--- File {i}/{len(files)}: {filename} ---\n"

        markdown_text, parse_error, file_debug = process_pdf_file(pdf_path)
        debug_log += file_debug

        if parse_error:
            per_file_statuses.append(
                f'<div style="margin:4px 0"><b>{filename}</b> — {parse_error}</div>'
            )
            debug_log += f"✗ Skipped due to parse error\n\n"
            continue

        # Try to extract culture records from the markdown
        # extract_structured_data() handles one report block at a time.
        # For multi-report PDFs, split on common section delimiters.
        report_blocks = _split_into_report_blocks(markdown_text)
        debug_log += f"✓ Split into {len(report_blocks)} block(s)\n"

        file_reports: List[CultureReport] = []

        for block_idx, block in enumerate(report_blocks, 1):
            debug_log += f"\n  Block {block_idx}:\n"
            try:
                # Debug extraction
                debug_result = debug_extraction(block, f"Block {block_idx}")
                debug_log += f"    Organism: {debug_result['organism']}\n"
                debug_log += (
                    f"    CFU: {debug_result['cfu']} (ok={debug_result['cfu_ok']})\n"
                )
                debug_log += f"    Specimen: {debug_result['specimen']}\n"
                debug_log += f"    Date: {debug_result['date']}\n"

                report = extract_structured_data(block)
                debug_log += f"    ✓ Extraction successful\n"

                # Only keep urine/stool specimens
                if report.specimen_type in ("urine", "stool"):
                    debug_log += (
                        f"    ✓ Specimen type '{report.specimen_type}' accepted\n"
                    )
                    # Override raw_text to the docling markdown block
                    report = CultureReport(
                        date=report.date,
                        organism=report.organism,
                        cfu=report.cfu,
                        resistance_markers=report.resistance_markers,
                        specimen_type=report.specimen_type,
                        contamination_flag=report.contamination_flag,
                        raw_text=block,  # stored for accordion; never forwarded to MedGemma
                    )
                    file_reports.append(report)
                else:
                    debug_log += f"    ✗ Specimen type '{report.specimen_type}' rejected (not urine/stool)\n"

            except ExtractionError as e:
                debug_log += f"    ✗ ExtractionError: {e}\n"
                pass  # block had no parseable culture data
            except Exception as e:
                debug_log += f"    ✗ Unexpected error: {type(e).__name__}: {e}\n"
                pass

        if not file_reports:
            per_file_statuses.append(
                f'<div style="margin:4px 0"><b>{filename}</b> — '
                f'<span style="color:#e67e22">⚠ No urine or stool culture data found in this file</span></div>'
            )
            debug_log += f"\n✗ No valid culture records found in {filename}\n\n"
        else:
            count = len(file_reports)
            per_file_statuses.append(
                f'<div style="margin:4px 0"><b>{filename}</b> — '
                f'<span style="color:#27ae60">✓ {count} record{"s" if count != 1 else ""} found</span></div>'
            )
            all_reports.extend(file_reports)
            all_raw_blocks.extend(r.raw_text for r in file_reports)
            debug_log += f"\n✓ Extracted {count} record(s) from {filename}\n\n"

    if not all_reports:
        debug_log += "=== RESULT: No valid reports found ===\n"
        return [], [], per_file_statuses, "", debug_log

    # Sort chronologically
    debug_log += f"Sorting {len(all_reports)} report(s) chronologically...\n"
    combined = sorted(zip(all_reports, all_raw_blocks), key=lambda pair: pair[0].date)
    all_reports = [p[0] for p in combined]
    all_raw_blocks = [p[1] for p in combined]

    # Deduplicate: same (date, organism, cfu) → keep first
    seen: set = set()
    deduped_reports: List[CultureReport] = []
    deduped_blocks: List[str] = []
    for report, block in zip(all_reports, all_raw_blocks):
        key = (report.date, report.organism, report.cfu)
        if key in seen:
            debug_log += f"⚠ Duplicate record skipped: {key}\n"
            warnings.warn(f"Duplicate record skipped: {key}", UserWarning, stacklevel=2)
        else:
            seen.add(key)
            deduped_reports.append(report)
            deduped_blocks.append(block)

    # Truncate to MAX_RECORDS most recent
    truncation_warning = ""
    if len(deduped_reports) > MAX_RECORDS:
        total = len(deduped_reports)
        deduped_reports = deduped_reports[-MAX_RECORDS:]
        deduped_blocks = deduped_blocks[-MAX_RECORDS:]
        truncation_warning = (
            f'<div style="background:#fff3cd;border:1px solid #ffc107;padding:8px 12px;'
            f'border-radius:6px;margin-bottom:8px">'
            f"⚠ {total} records were extracted. Only the {MAX_RECORDS} most recent are shown "
            f"(the pipeline supports up to {MAX_RECORDS} reports).</div>"
        )
        debug_log += f"⚠ Truncated from {total} to {MAX_RECORDS} most recent records\n"

    debug_log += f"\n=== RESULT: Returning {len(deduped_reports)} report(s) ===\n"
    for i, r in enumerate(deduped_reports, 1):
        debug_log += (
            f"  {i}. {r.date} | {r.specimen_type} | {r.organism} | {r.cfu} CFU\n"
        )

    return (
        deduped_reports,
        deduped_blocks,
        per_file_statuses,
        truncation_warning,
        debug_log,
    )


# ---------------------------------------------------------------------------
# 5. Gradio UI builder with loading indicators
# ---------------------------------------------------------------------------


def build_gradio_app(model, tokenizer, is_stub: bool) -> gr.Blocks:
    """
    Build and return the full CultureSense Gradio Blocks app.

    Tab A — Upload PDF (new extraction agent flow)
    Tab B — Enter Manually (existing flow, zero changes)
    """

    # ── Shared pipeline helper ──────────────────────────────────────────────
    def run_pipeline(reports: List[CultureReport]):
        """Run the unchanged downstream pipeline and return rendered HTML."""
        sorted_reports = sorted(reports, key=lambda r: r.date)
        trend = analyze_trend(sorted_reports)
        hypothesis = generate_hypothesis(trend, len(sorted_reports))
        patient_response = call_medgemma(
            trend, hypothesis, "patient", model, tokenizer, is_stub
        )
        clinician_response = call_medgemma(
            trend, hypothesis, "clinician", model, tokenizer, is_stub
        )
        patient_out = render_patient_output(trend, hypothesis, patient_response)
        clinician_out = render_clinician_output(trend, hypothesis, clinician_response)
        return patient_out, clinician_out

    def format_output_html(patient_out, clinician_out) -> Tuple[str, str]:
        """Convert FormattedOutput objects to display HTML — warm classical theme."""
        # ── Patient card ───────────────────────────────────────────────────
        p_body = ""
        if patient_out.patient_trend_phrase:
            p_body += (
                f"<p style='font-size:1.0rem;line-height:1.7;margin:0 0 12px 0;'>"
                f"<em>Your results show <strong>{patient_out.patient_trend_phrase}</strong>.</em></p>"
            )
        if patient_out.patient_explanation:
            p_body += (
                f"<div style='line-height:1.75;color:#4a3728;font-size:0.96rem;'>"
                f"{patient_out.patient_explanation}</div>"
            )
        if patient_out.patient_questions:
            qs = "".join(
                f"<li style='margin-bottom:4px;'>{q}</li>"
                for q in patient_out.patient_questions
            )
            p_body += (
                "<p style='margin:14px 0 5px;font-family:system-ui,sans-serif;font-size:0.78rem;"
                "font-weight:600;color:#7a6558;text-transform:uppercase;letter-spacing:.05em;'>"
                "Questions to ask your doctor</p>"
                f"<ul style='padding-left:18px;color:#4a3728;font-size:0.93rem;line-height:1.8;margin:0;'>{qs}</ul>"
            )
        if patient_out.patient_disclaimer:
            p_body += (
                "<div style='margin-top:16px;padding:10px 14px;border:1px solid #E8DDD6;"
                "border-radius:3px;background:#FDFAF7;'>"
                f"<p style='font-family:system-ui,sans-serif;font-size:0.77rem;font-style:italic;"
                f"color:#9a8578;margin:0;line-height:1.6;'>{patient_out.patient_disclaimer}</p>"
                "</div>"
            )
        patient_html = (
            "<div style='font-family:'Source Serif 4',serif;background:#FDFAF7;border:1px solid #E8DDD6;"
            "border-radius:4px;padding:22px 26px;box-shadow:0 1px 4px rgba(28,20,18,0.07);'>"
            "<h3 style='font-family:'Playfair Display',serif;font-size:1.1rem;font-weight:600;"
            "color:#C1622F;margin:0 0 14px;border-left:3px solid #C1622F;padding-left:10px;"
            "letter-spacing:.01em;'>Patient Summary</h3>" + p_body + "</div>"
        )

        # ── Clinician card ─────────────────────────────────────────────────
        # Confidence bar
        conf_val = clinician_out.clinician_confidence
        conf_pct_num = int((conf_val or 0) * 100)
        conf_label = f"{conf_val:.0%}" if conf_val is not None else "N/A"
        conf_bar = (
            "<div style='margin:0 0 14px;'>"
            "<div style='display:flex;align-items:baseline;gap:8px;margin-bottom:5px;'>"
            "<span style='font-family:system-ui,sans-serif;font-size:0.78rem;font-weight:600;"
            "color:#7a6558;text-transform:uppercase;letter-spacing:.04em;'>Confidence</span>"
            f"<span style='font-family:'Playfair Display',serif;font-size:1.12rem;"
            f"font-weight:700;color:#C1622F;'>{conf_label}</span>"
            "</div>"
            "<div style='height:5px;border-radius:3px;background:#E8DDD6;overflow:hidden;'>"
            f"<div style='height:100%;width:{conf_pct_num}%;background:#C1622F;border-radius:3px;'></div>"
            "</div></div>"
        )
        c_body = conf_bar
        if clinician_out.clinician_stewardship_flag:
            c_body += (
                "<div style='background:#fdf5f1;border-left:3px solid #C1622F;"
                "padding:10px 14px;margin:10px 0;border-radius:3px;'>"
                "<span style='font-family:system-ui,sans-serif;font-size:0.84rem;"
                "color:#C1622F;font-weight:600;'>⚠ Stewardship Alert</span>"
                "<p style='margin:4px 0 0;font-family:system-ui,sans-serif;font-size:0.82rem;"
                "color:#6b3320;'>Emerging resistance detected — antimicrobial stewardship review recommended.</p>"
                "</div>"
            )
        if clinician_out.clinician_resistance_detail:
            c_body += (
                "<div style='background:#FDFAF7;border-left:3px solid #E8DDD6;"
                "padding:10px 14px;margin:10px 0;border-radius:3px;'>"
                "<p style='margin:0 0 4px;font-family:system-ui,sans-serif;font-size:0.78rem;"
                "font-weight:600;text-transform:uppercase;letter-spacing:.04em;color:#7a6558;'>"
                "Resistance Timeline</p>"
                f"<pre style='margin:0;font-size:12px;font-family:system-ui,monospace;"
                f"color:#4a3728;white-space:pre-wrap;'>{clinician_out.clinician_resistance_detail}</pre>"
                "</div>"
            )
        if clinician_out.clinician_interpretation:
            c_body += (
                f"<div style='line-height:1.75;color:#3d2b1f;font-size:0.96rem;margin-top:12px;'>"
                f"{clinician_out.clinician_interpretation}</div>"
            )
        if clinician_out.clinician_disclaimer:
            c_body += (
                "<p style='font-family:system-ui,sans-serif;font-style:italic;color:#9a8578;"
                "border-top:1px solid #E8DDD6;padding-top:10px;margin-top:18px;"
                f"font-size:0.77rem;line-height:1.6;'>{clinician_out.clinician_disclaimer}</p>"
            )
        clinician_html = (
            "<div style='font-family:'Source Serif 4',serif;background:#FDFAF7;border:1px solid #E8DDD6;"
            "border-radius:4px;padding:22px 26px;margin-top:14px;box-shadow:0 1px 4px rgba(28,20,18,0.07);'>"
            "<h3 style='font-family:'Playfair Display',serif;font-size:1.1rem;font-weight:600;"
            "color:#C1622F;margin:0 0 14px;border-left:3px solid #C1622F;padding-left:10px;"
            "letter-spacing:.01em;'>Clinical Interpretation</h3>" + c_body + "</div>"
        )

        return patient_html, clinician_html

    # ── Build UI ────────────────────────────────────────────────────────────
    with gr.Blocks(
        theme=WARM_CLINICAL_THEME,
        css="""
        .screen { min-height: 60vh; }
        .status-box { min-height: 40px; border: 1px solid #E8DDD6; border-radius: 4px; padding: 8px; background: #FDFAF7; }
        .error-banner { background: #fdf5f1; border-left: 3px solid #C1622F; padding: 12px 16px; margin: 8px 0; border-radius: 3px; }
        .loading-spinner { display: inline-block; width: 20px; height: 20px; border: 3px solid #E8DDD6; border-top: 3px solid #C1622F; border-radius: 50%; animation: spin 1s linear infinite; margin-right: 8px; vertical-align: middle; }
        @keyframes spin { 0% { transform: rotate(0deg); } 100% { transform: rotate(360deg); } }
    """,
    ) as demo:
        gr.Markdown("# 🧫 CultureSense — Longitudinal Clinical Hypothesis Engine")
        gr.Markdown(
            "Upload 2–3 sequential urine or stool culture reports to generate a trend analysis and clinical hypothesis."
        )

        with gr.Tabs():
            # ================================================================
            # TAB A — Upload PDF (Extraction Agent)
            # ================================================================
            with gr.Tab("📄 Upload PDF", id="tab_upload"):
                # ── State ───────────────────────────────────────────────────
                state_reports = gr.State([])
                state_raw_blocks = gr.State([])

                # ── Screen 1: Upload ────────────────────────────────────────
                with gr.Column(visible=True, elem_classes="screen") as screen_upload:
                    gr.Markdown("### Step 1 — Upload your culture report PDFs")
                    gr.Markdown(
                        "Upload one or more PDF files. Each file may contain one or more "
                        "urine/stool culture reports."
                    )
                    pdf_upload = gr.File(
                        label="Culture Report PDFs",
                        file_types=[".pdf"],
                        file_count="multiple",
                    )

                    with gr.Row():
                        btn_process = gr.Button("⚙ Process PDFs", variant="primary")
                        btn_process_loading = gr.Button(
                            "⏳ Processing...",
                            variant="primary",
                            interactive=False,
                            visible=False,
                        )

                    status_html = gr.HTML(
                        value="", label="File Status", elem_classes="status-box"
                    )

                    # Loading indicator
                    loading_html = gr.HTML(
                        value="",
                        visible=False,
                    )

                    with gr.Column(visible=False) as all_failed_panel:
                        gr.HTML(
                            '<div class="error-banner">'
                            "No urine or stool culture data was found in your uploaded documents. "
                            "Please try uploading again, or switch to manual entry."
                            "</div>"
                        )
                        with gr.Row():
                            btn_try_again = gr.Button("🔄 Try Again")
                            btn_to_manual_from_fail = gr.Button("✏ Enter Manually")

                    # Debug output (collapsed by default)
                    with gr.Accordion(
                        "🔍 Debug Output (click to expand if processing fails)",
                        open=False,
                    ):
                        debug_output = gr.Textbox(
                            label="Processing Log",
                            interactive=False,
                            lines=20,
                            value="",
                        )

                # ── Screen 2: Review & Confirm ──────────────────────────────
                with gr.Column(visible=False, elem_classes="screen") as screen_confirm:
                    gr.Markdown("### Step 2 — Review & Confirm Extracted Records")
                    gr.Markdown(
                        "All cells are editable. Fields marked **⚠** were extracted with "
                        "low confidence — please verify against the raw text below."
                    )
                    truncation_warning_html = gr.HTML(value="")

                    confirm_table = gr.Dataframe(
                        headers=[
                            "Date",
                            "Specimen",
                            "Organism",
                            "CFU/mL",
                            "Resistance Markers",
                        ],
                        datatype=["str", "str", "str", "str", "str"],
                        interactive=True,
                        wrap=True,
                        label="Extracted Culture Records",
                    )

                    with gr.Accordion(
                        "📋 Raw Extracted Text (for clinician verification)",
                        open=False,
                    ):
                        raw_box_0 = gr.Textbox(
                            label="Record 1", interactive=False, visible=False, lines=6
                        )
                        raw_box_1 = gr.Textbox(
                            label="Record 2", interactive=False, visible=False, lines=6
                        )
                        raw_box_2 = gr.Textbox(
                            label="Record 3", interactive=False, visible=False, lines=6
                        )

                    with gr.Row():
                        btn_confirm = gr.Button(
                            "✅ Confirm & Analyse", variant="primary"
                        )
                        btn_re_upload = gr.Button("↩ Edit & Re-upload")
                        btn_to_manual_from_confirm = gr.Button(
                            "✏ Enter Manually Instead"
                        )

                # ── Screen 3: Analysis Output ───────────────────────────────
                with gr.Column(visible=False, elem_classes="screen") as screen_output:
                    gr.Markdown("### Step 3 — Analysis Results")
                    output_patient_html = gr.HTML(value="")
                    output_clinician_html = gr.HTML(value="")
                    btn_start_over = gr.Button("🔄 Start Over")

                # ── Event: Process PDFs ─────────────────────────────────────
                def on_process_pdfs_start(files):
                    """Show loading state immediately when button is clicked."""
                    if not files:
                        return (
                            gr.update(visible=True),  # btn_process
                            gr.update(visible=False),  # btn_process_loading
                            gr.update(
                                value="<p style='color:#888'>No files uploaded.</p>",
                                visible=True,
                            ),
                            gr.update(visible=False),  # loading_html
                        )

                    # Show loading state
                    loading_msg = (
                        '<div style="padding:12px;background:#fff3cd;border:1px solid #ffc107;border-radius:4px;">'
                        '<span class="loading-spinner"></span>'
                        "<strong>Processing PDFs...</strong> This may take 30-60 seconds per file. "
                        "Docling is extracting text from your PDFs."
                        "</div>"
                    )

                    return (
                        gr.update(visible=False),  # btn_process
                        gr.update(visible=True),  # btn_process_loading
                        gr.update(value=loading_msg, visible=True),  # status_html
                        gr.update(visible=True),  # loading_html
                    )

                def on_process_pdfs(files):
                    """Actually process the PDFs after loading state is shown."""
                    if not files:
                        return (
                            [],  # state_reports
                            [],  # state_raw_blocks
                            "<p style='color:#888'>No files uploaded.</p>",  # status_html
                            gr.update(visible=True),  # screen_upload
                            gr.update(visible=False),  # screen_confirm
                            gr.update(visible=False),  # screen_output
                            gr.update(visible=False),  # all_failed_panel
                            [],  # confirm_table
                            "",  # truncation_warning_html
                            gr.update(value="", visible=False),  # raw_box_0
                            gr.update(value="", visible=False),  # raw_box_1
                            gr.update(value="", visible=False),  # raw_box_2
                            "",  # debug_output
                            gr.update(visible=True),  # btn_process
                            gr.update(visible=False),  # btn_process_loading
                            gr.update(visible=False),  # loading_html
                        )

                    reports, raw_blocks, statuses, trunc_warn, debug_log = (
                        process_uploaded_pdfs(files)
                    )
                    status_combined = "".join(statuses)

                    if not reports:
                        # All files failed — stay on screen 1, show error panel
                        error_msg = (
                            '<div style="padding:12px;background:#f8d7da;border:1px solid #f5c6cb;border-radius:4px;color:#721c24;">'
                            "<strong>✗ No valid culture data found</strong><br>"
                            "Please check the debug output below for details."
                            "</div>"
                        )
                        return (
                            [],
                            [],
                            error_msg,
                            gr.update(visible=True),
                            gr.update(visible=False),
                            gr.update(visible=False),
                            gr.update(visible=True),
                            [],
                            "",
                            gr.update(value="", visible=False),
                            gr.update(value="", visible=False),
                            gr.update(value="", visible=False),
                            debug_log,  # Show debug log
                            gr.update(visible=True),  # btn_process
                            gr.update(visible=False),  # btn_process_loading
                            gr.update(visible=False),  # loading_html
                        )

                    # Build dataframe rows
                    df_rows = reports_to_dataframe_rows(reports)

                    # Build raw text box updates (pre-created 3 boxes)
                    raw_updates = []
                    for i in range(MAX_RECORDS):
                        if i < len(raw_blocks):
                            raw_updates.append(
                                gr.update(
                                    value=raw_blocks[i],
                                    label=f"Record {i + 1} — {reports[i].date}",
                                    visible=True,
                                )
                            )
                        else:
                            raw_updates.append(gr.update(value="", visible=False))

                    return (
                        reports,
                        raw_blocks,
                        status_combined,
                        gr.update(visible=False),  # hide screen_upload
                        gr.update(visible=True),  # show screen_confirm
                        gr.update(visible=False),  # hide screen_output
                        gr.update(visible=False),  # hide all_failed_panel
                        df_rows,
                        trunc_warn,
                        raw_updates[0],
                        raw_updates[1],
                        raw_updates[2],
                        debug_log,  # Store debug log
                        gr.update(visible=True),  # btn_process
                        gr.update(visible=False),  # btn_process_loading
                        gr.update(visible=False),  # loading_html
                    )

                # Chain the events: first show loading, then process
                btn_process.click(
                    fn=on_process_pdfs_start,
                    inputs=[pdf_upload],
                    outputs=[
                        btn_process,
                        btn_process_loading,
                        status_html,
                        loading_html,
                    ],
                ).then(
                    fn=on_process_pdfs,
                    inputs=[pdf_upload],
                    outputs=[
                        state_reports,
                        state_raw_blocks,
                        status_html,
                        screen_upload,
                        screen_confirm,
                        screen_output,
                        all_failed_panel,
                        confirm_table,
                        truncation_warning_html,
                        raw_box_0,
                        raw_box_1,
                        raw_box_2,
                        debug_output,
                        btn_process,
                        btn_process_loading,
                        loading_html,
                    ],
                )

                # ── Event: Confirm & Analyse ────────────────────────────────
                def on_confirm(table_data):
                    if table_data is None or len(table_data) == 0:
                        return (
                            gr.update(visible=True),
                            gr.update(visible=False),
                            gr.update(visible=False),
                            "<p style='color:#c0392b'>No records to analyse.</p>",
                            "",
                        )

                    # Convert edited table rows back to CultureReport objects
                    confirmed_reports = []
                    for row in table_data:
                        try:
                            confirmed_reports.append(
                                dataframe_row_to_culture_report(row)
                            )
                        except Exception:
                            pass

                    if not confirmed_reports:
                        return (
                            gr.update(visible=True),
                            gr.update(visible=False),
                            gr.update(visible=False),
                            "<p style='color:#c0392b'>Could not parse records.</p>",
                            "",
                        )

                    try:
                        patient_out, clinician_out = run_pipeline(confirmed_reports)
                        patient_html, clinician_html = format_output_html(
                            patient_out, clinician_out
                        )
                    except Exception as e:
                        patient_html = (
                            f"<p style='color:#c0392b'>Analysis error: {e}</p>"
                        )
                        clinician_html = ""

                    return (
                        gr.update(visible=False),  # hide screen_confirm
                        gr.update(visible=False),  # hide screen_upload
                        gr.update(visible=True),  # show screen_output
                        patient_html,
                        clinician_html,
                    )

                btn_confirm.click(
                    fn=on_confirm,
                    inputs=[confirm_table],
                    outputs=[
                        screen_confirm,
                        screen_upload,
                        screen_output,
                        output_patient_html,
                        output_clinician_html,
                    ],
                )

                # ── Event: Edit & Re-upload ─────────────────────────────────
                def on_re_upload():
                    return (
                        gr.update(visible=True),  # show screen_upload
                        gr.update(visible=False),  # hide screen_confirm
                        gr.update(visible=False),  # hide screen_output
                        gr.update(visible=False),  # hide all_failed_panel
                        [],  # clear state_reports
                        [],  # clear state_raw_blocks
                        "",  # clear status_html
                        "",  # clear debug_output
                    )

                btn_re_upload.click(
                    fn=on_re_upload,
                    inputs=[],
                    outputs=[
                        screen_upload,
                        screen_confirm,
                        screen_output,
                        all_failed_panel,
                        state_reports,
                        state_raw_blocks,
                        status_html,
                        debug_output,
                    ],
                )

                # ── Event: Try Again (from fail panel) ──────────────────────
                btn_try_again.click(
                    fn=on_re_upload,
                    inputs=[],
                    outputs=[
                        screen_upload,
                        screen_confirm,
                        screen_output,
                        all_failed_panel,
                        state_reports,
                        state_raw_blocks,
                        status_html,
                        debug_output,
                    ],
                )

                # ── Event: Start Over ───────────────────────────────────────
                def on_start_over():
                    return (
                        gr.update(visible=True),  # show screen_upload
                        gr.update(visible=False),  # hide screen_confirm
                        gr.update(visible=False),  # hide screen_output
                        gr.update(visible=False),  # hide all_failed_panel
                        [],  # clear state_reports
                        [],  # clear state_raw_blocks
                        "",  # clear status_html
                        None,  # clear pdf_upload
                        "",  # clear debug_output
                    )

                btn_start_over.click(
                    fn=on_start_over,
                    inputs=[],
                    outputs=[
                        screen_upload,
                        screen_confirm,
                        screen_output,
                        all_failed_panel,
                        state_reports,
                        state_raw_blocks,
                        status_html,
                        pdf_upload,
                        debug_output,
                    ],
                )

                # ── Event: Switch to Manual Entry ───────────────────────────
                def switch_to_manual():
                    return (
                        gr.update(visible=False),  # hide upload screen
                        gr.update(visible=False),  # hide confirm screen
                        gr.update(visible=False),  # hide output screen
                        gr.update(visible=False),  # hide fail panel
                        gr.update(value="manual"),  # switch tab
                    )

                btn_to_manual_from_fail.click(
                    fn=switch_to_manual,
                    inputs=[],
                    outputs=[
                        screen_upload,
                        screen_confirm,
                        screen_output,
                        all_failed_panel,
                        gr.State("manual"),  # dummy, will be replaced by tab selection
                    ],
                )

                btn_to_manual_from_confirm.click(
                    fn=switch_to_manual,
                    inputs=[],
                    outputs=[
                        screen_upload,
                        screen_confirm,
                        screen_output,
                        all_failed_panel,
                        gr.State("manual"),
                    ],
                )

            # ================================================================
            # TAB B — Manual Entry (unchanged from original)
            # ================================================================
            with gr.Tab("✏ Enter Manually", id="tab_manual"):
                gr.Markdown("### Paste culture report text directly")
                gr.Markdown(
                    "Paste 2–3 sequential culture reports. "
                    "The pipeline will extract structured data, analyse trends, and generate hypotheses."
                )

                manual_input = gr.Textbox(
                    label="Culture Reports (2–3 sequential)",
                    placeholder="Paste report text here...",
                    lines=12,
                )
                btn_analyse_manual = gr.Button("🔬 Analyse", variant="primary")
                manual_output_patient = gr.HTML()
                manual_output_clinician = gr.HTML()

                def on_analyse_manual(text):
                    if not text or len(text.strip()) < 20:
                        return (
                            "<p style='color:#c0392b'>Please paste at least one full report.</p>",
                            "",
                        )

                    # Split by double newlines to get separate reports
                    blocks = [b.strip() for b in text.split("\n\n") if b.strip()]
                    reports = []
                    for block in blocks:
                        try:
                            r = extract_structured_data(block)
                            reports.append(r)
                        except Exception:
                            pass

                    if len(reports) < 1:
                        return (
                            "<p style='color:#c0392b'>Could not extract data from pasted text. "
                            "Check format includes Date, Organism, and CFU/mL.</p>",
                            "",
                        )

                    try:
                        patient_out, clinician_out = run_pipeline(reports)
                        patient_html, clinician_html = format_output_html(
                            patient_out, clinician_out
                        )
                    except Exception as e:
                        patient_html = (
                            f"<p style='color:#c0392b'>Analysis error: {e}</p>"
                        )
                        clinician_html = ""

                    return patient_html, clinician_html

                btn_analyse_manual.click(
                    fn=on_analyse_manual,
                    inputs=[manual_input],
                    outputs=[manual_output_patient, manual_output_clinician],
                )

    return demo

In [None]:
# Launch the CultureSense Gradio app
demo = build_gradio_app(model, tokenizer, is_stub)
demo.launch(share=True)

---

## Safety & Regulatory Positioning

- **No output** from any module, in any mode, shall contain a named diagnosis.
- Confidence scores are capped at **0.95** (never 1.0 — clinical epistemic humility).
- Both output modes end with **hardcoded disclaimer text** that cannot be overridden.
- MedGemma is **never prompted with raw user text** — only structured JSON.
- A post-processing safety scan using `BANNED_DIAGNOSTIC_PHRASES` provides a second layer of defence.

> *This notebook is a Kaggle competition prototype only. It is not intended for clinical use,
> does not constitute medical advice, and has not been evaluated for diagnostic accuracy.*
