# Experiment: Nature Evidence 06 - Posterior Failure Modes and Protocol Design

Objective:
- Stress-test posterior inference under model mismatch (biofouled multiphysics traces outside base training assumptions).
- Demonstrate that reliability diagnostics can detect out-of-family measurements.
- Provide an operational guardrail for deployment: when to trust inference and when to escalate.

Success criteria:
- In-distribution and stress-test reliability distributions separate measurably.
- A threshold-based rule achieves strong stress-case recall with controlled false positives.
- Representative failure traces are visualized and documented.


In [None]:
# Setup: imports and reproducibility
from __future__ import annotations

import json
import os
from pathlib import Path

import equinox as eqx
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np

from ecsfm.fm.eval_classical import _resolve_model_geometry
from ecsfm.fm.model import VectorFieldNet
from ecsfm.fm.posterior import CEMPosteriorConfig, PosteriorInferenceConfig, infer_parameter_posterior
from ecsfm.fm.train import (
    MODEL_META_FILENAME,
    NORMALIZERS_FILENAME,
    load_model_metadata,
    load_saved_normalizers,
)
from ecsfm.sim.multiphysics import MultiPhysicsConfig, simulate_multiphysics_electrochem

np.random.seed(2026)

ARTIFACT_DIR = Path('/tmp/ecsfm/notebook_nature_06')
ARTIFACT_DIR.mkdir(parents=True, exist_ok=True)
print(f'Artifacts: {ARTIFACT_DIR}')


## Plan

- Build paired evaluation cases: in-distribution traces and biofouled multiphysics stress traces.
- Run posterior inference with identical settings on both sets.
- Compare reliability/calibration behavior and fit a simple acceptance rule.
- Visualize one representative stress-case mismatch.


In [None]:
# User configuration and model/data loading
def _first_existing(candidates: list[Path]) -> Path | None:
    for candidate in candidates:
        if candidate.exists():
            return candidate
    return None


def _discover_latest(pattern: str, roots: list[Path]) -> Path | None:
    found: list[Path] = []
    for root in roots:
        if root.exists():
            found.extend(root.rglob(pattern))
    if not found:
        return None
    return max(found, key=lambda p: p.stat().st_mtime)


checkpoint_override = os.getenv('ECSFM_CHECKPOINT')
dataset_override = os.getenv('ECSFM_DATASET_CHUNK')

checkpoint_candidates = [
    Path('/tmp/ecsfm/fullscale_balanced_modal/surrogate_model.eqx'),
    Path('/vol/artifacts/fullscale_balanced_modal/surrogate_model.eqx'),
    Path('/tmp/ecsfm/surrogate_model.eqx'),
]
dataset_candidates = [
    Path('/tmp/ecsfm/dataset_balanced_742k/chunk_0.npz'),
    Path('/vol/datasets/dataset_balanced_742k/chunk_0.npz'),
    Path('/tmp/ecsfm/dataset_massive/chunk_0.npz'),
]

if checkpoint_override:
    CHECKPOINT = Path(checkpoint_override)
else:
    CHECKPOINT = _first_existing(checkpoint_candidates)
    if CHECKPOINT is None:
        CHECKPOINT = _discover_latest('surrogate_model.eqx', [Path('/tmp/ecsfm'), Path('/vol/artifacts')])

if dataset_override:
    DATASET_CHUNK = Path(dataset_override)
else:
    DATASET_CHUNK = _first_existing(dataset_candidates)
    if DATASET_CHUNK is None:
        DATASET_CHUNK = _discover_latest('chunk_0.npz', [Path('/tmp/ecsfm'), Path('/vol/datasets')])
    if DATASET_CHUNK is None:
        DATASET_CHUNK = _discover_latest('chunk_*.npz', [Path('/tmp/ecsfm'), Path('/vol/datasets')])

if CHECKPOINT is None or not CHECKPOINT.exists():
    raise FileNotFoundError('Checkpoint not found. Set ECSFM_CHECKPOINT.')
if DATASET_CHUNK is None or not DATASET_CHUNK.exists():
    raise FileNotFoundError('Dataset chunk not found. Set ECSFM_DATASET_CHUNK.')

N_CASES = 6
EVAL_NOISE_STD = 0.05

INFERENCE_CFG = {
    'n_particles': 64,
    'n_iters': 4,
    'elite_frac': 0.25,
    'n_mc': 2,
    'n_steps': 60,
}

normalizers = load_saved_normalizers(CHECKPOINT.parent / NORMALIZERS_FILENAME)
meta = load_model_metadata(CHECKPOINT.parent / MODEL_META_FILENAME)
geometry = _resolve_model_geometry(normalizers, meta)

key = jax.random.PRNGKey(np.uint32(2026))
_, model_key = jax.random.split(key)
model = VectorFieldNet(
    state_dim=int(geometry['state_dim']),
    hidden_size=int(meta.get('hidden_size', 128)),
    depth=int(meta.get('depth', 3)),
    cond_dim=int(meta.get('cond_dim', 32)),
    phys_dim=int(geometry['phys_dim']),
    signal_channels=int(geometry['signal_channels']),
    key=model_key,
)
model = eqx.tree_deserialise_leaves(CHECKPOINT, model)

with np.load(DATASET_CHUNK) as chunk:
    currents = np.asarray(chunk['i'], dtype=np.float32)
    signals = np.asarray(chunk['e'], dtype=np.float32)
    params_base = np.asarray(chunk['p'], dtype=np.float32)
    task_ids = np.asarray(chunk['task_id'], dtype=np.int32)
    stage_ids = np.asarray(chunk['stage_id'], dtype=np.int32)

rng = np.random.default_rng(2028)
case_indices = np.asarray(rng.choice(currents.shape[0], size=min(N_CASES, currents.shape[0]), replace=False), dtype=np.int32)

phys_dim_base = int(geometry['phys_dim_base'])
phys_dim_core = int(geometry['phys_dim_core'])
n_tasks = int(geometry['n_tasks'])
n_stages = int(geometry['n_stages'])
max_species = int(geometry['max_species'])


def compose_core_params(base_row: np.ndarray, task_idx: int, stage_idx: int) -> np.ndarray:
    out = np.zeros((phys_dim_core,), dtype=np.float32)
    out[:phys_dim_base] = base_row[:phys_dim_base]
    cursor = phys_dim_base
    if n_tasks > 0:
        onehot = np.zeros((n_tasks,), dtype=np.float32)
        onehot[int(np.clip(task_idx, 0, n_tasks - 1))] = 1.0
        out[cursor : cursor + n_tasks] = onehot
        cursor += n_tasks
    if n_stages > 0:
        onehot = np.zeros((n_stages,), dtype=np.float32)
        onehot[int(np.clip(stage_idx, 0, n_stages - 1))] = 1.0
        out[cursor : cursor + n_stages] = onehot
    return out


p_core_true = np.stack(
    [compose_core_params(params_base[i], int(task_ids[i]), int(stage_ids[i])) for i in case_indices],
    axis=0,
)

_, _, _, _, p_mean, p_std = normalizers
p_mean = np.asarray(p_mean, dtype=np.float32)
p_std = np.asarray(p_std, dtype=np.float32)

print('cases:', case_indices.tolist())


In [None]:
# Build stress-test observations from multiphysics biofouling dynamics
def decode_base_params(base_row: np.ndarray, m: int) -> dict[str, np.ndarray]:
    row = np.asarray(base_row, dtype=np.float32)
    return {
        'D_ox': np.exp(row[0:m]),
        'D_red': np.exp(row[m : 2 * m]),
        'C_ox': np.clip(row[2 * m : 3 * m], 1e-5, None),
        'C_red': np.clip(row[3 * m : 4 * m], 0.0, None),
        'E0': row[4 * m : 5 * m],
        'k0': np.exp(row[5 * m : 6 * m]),
        'alpha': np.clip(row[6 * m : 7 * m], 0.05, 0.95),
    }


def simulate_biofouled_current(signal: np.ndarray, base_row: np.ndarray) -> np.ndarray:
    p = decode_base_params(base_row, max_species)
    cfg = MultiPhysicsConfig(
        initial_theta=0.78,
        k_ads=4200.0,
        k_des=8.0e-5,
        k_clean=0.02,
        k_reaction=0.004,
        Rfilm_theta_max_ohm=1800.0,
        cdl_theta_fraction=0.82,
        area_floor_fraction=0.09,
        k0_theta_coeff=2.8,
        electrode_area_cm2=0.01,
    )
    out = simulate_multiphysics_electrochem(
        E_array=jnp.asarray(signal, dtype=jnp.float32),
        t_max=8.0,
        nx=24,
        config=cfg,
        D_ox=jnp.asarray(p['D_ox'], dtype=jnp.float32),
        D_red=jnp.asarray(p['D_red'], dtype=jnp.float32),
        C_bulk_ox=jnp.asarray(p['C_ox'], dtype=jnp.float32),
        C_bulk_red=jnp.asarray(p['C_red'], dtype=jnp.float32),
        E0=jnp.asarray(p['E0'], dtype=jnp.float32),
        k0=jnp.asarray(p['k0'], dtype=jnp.float32),
        alpha=jnp.asarray(p['alpha'], dtype=jnp.float32),
        save_every=0,
    )
    return np.asarray(out[4], dtype=np.float32)


def random_mask(length: int, keep_prob: float, rng: np.random.Generator, min_points: int = 24) -> np.ndarray:
    mask = rng.random(length) < keep_prob
    if int(mask.sum()) < min_points:
        idx = rng.choice(length, size=min(min_points, length), replace=False)
        mask[idx] = True
    return mask.astype(np.float32)


known_mask = np.zeros((phys_dim_core,), dtype=bool)
known_mask[phys_dim_base:] = True

posterior_cfg = PosteriorInferenceConfig(
    cem=CEMPosteriorConfig(
        n_particles=int(INFERENCE_CFG['n_particles']),
        n_iterations=int(INFERENCE_CFG['n_iters']),
        elite_fraction=float(INFERENCE_CFG['elite_frac']),
    ),
    n_mc_per_particle=int(INFERENCE_CFG['n_mc']),
    n_integration_steps=int(INFERENCE_CFG['n_steps']),
    obs_noise_std=float(EVAL_NOISE_STD),
)

runs: list[dict[str, float | str | int]] = []
examples: dict[str, dict[str, np.ndarray | float | int]] = {}

for local_idx, row_idx in enumerate(case_indices):
    rng_case = np.random.default_rng(5000 + int(row_idx))
    signal = np.asarray(signals[row_idx], dtype=np.float32)
    base_row = np.asarray(params_base[row_idx], dtype=np.float32)

    id_current = np.asarray(currents[row_idx], dtype=np.float32)
    ood_current = simulate_biofouled_current(signal, base_row)

    id_obs = id_current + rng_case.normal(0.0, EVAL_NOISE_STD, size=id_current.shape[0]).astype(np.float32)
    ood_obs = ood_current + rng_case.normal(0.0, EVAL_NOISE_STD, size=ood_current.shape[0]).astype(np.float32)

    id_mask = random_mask(id_obs.shape[0], keep_prob=0.40, rng=rng_case)
    ood_mask = random_mask(ood_obs.shape[0], keep_prob=0.40, rng=rng_case)

    for split_name, obs_current, obs_mask in (
        ('in_distribution', id_obs, id_mask),
        ('stress_biofouled', ood_obs, ood_mask),
    ):
        post = infer_parameter_posterior(
            model=model,
            normalizers=normalizers,
            geometry=geometry,
            observed_current=obs_current,
            applied_signal=signal,
            known_p_core=p_core_true[local_idx],
            known_p_mask=known_mask,
            obs_mask=obs_mask,
            config=posterior_cfg,
            seed=15000 + local_idx * 31 + (0 if split_name == 'in_distribution' else 1),
        )

        true_norm = ((p_core_true[local_idx] - p_mean) / p_std)[:phys_dim_base]
        post_mean_norm = np.asarray(post['posterior_mean_norm'], dtype=np.float32)[:phys_dim_base]
        param_nrmse = float(np.sqrt(np.mean((post_mean_norm - true_norm) ** 2)))
        rel = dict(post['reliability'])

        row = {
            'split': split_name,
            'case_index': int(row_idx),
            'param_nrmse': param_nrmse,
            'reliability_score': float(rel['reliability_score']),
            'pred_nrmse': float(rel['nrmse']),
            'pred_nll': float(rel['nll']),
            'pred_calibration_error': float(rel['calibration_error']),
            'pred_sharpness': float(rel['sharpness']),
        }
        runs.append(row)

        key_name = f'{split_name}_{int(row_idx)}'
        examples[key_name] = {
            'case_index': int(row_idx),
            'split': split_name,
            'observed_current': np.asarray(post['observed_current'], dtype=np.float32),
            'observed_mask': np.asarray(post['observed_mask'], dtype=np.float32),
            'pred_mean': np.asarray(post['predictive_mean_current'], dtype=np.float32),
            'pred_std': np.asarray(post['predictive_std_current'], dtype=np.float32),
            'reliability_score': float(rel['reliability_score']),
        }


def summarize_split(name: str) -> dict[str, float | str]:
    rows = [r for r in runs if r['split'] == name]
    arr = lambda key: np.asarray([float(r[key]) for r in rows], dtype=float)
    return {
        'split': name,
        'n_cases': float(len(rows)),
        'reliability_score_mean': float(np.mean(arr('reliability_score'))),
        'param_nrmse_mean': float(np.mean(arr('param_nrmse'))),
        'pred_nrmse_mean': float(np.mean(arr('pred_nrmse'))),
        'pred_nll_mean': float(np.mean(arr('pred_nll'))),
        'pred_calibration_error_mean': float(np.mean(arr('pred_calibration_error'))),
    }


split_summary = [summarize_split('in_distribution'), summarize_split('stress_biofouled')]
print(split_summary)


In [None]:
# Reliability separation and threshold-based guardrail analysis
id_rows = [r for r in runs if r['split'] == 'in_distribution']
ood_rows = [r for r in runs if r['split'] == 'stress_biofouled']

id_rel = np.asarray([float(r['reliability_score']) for r in id_rows], dtype=float)
ood_rel = np.asarray([float(r['reliability_score']) for r in ood_rows], dtype=float)
id_cal = np.asarray([float(r['pred_calibration_error']) for r in id_rows], dtype=float)
ood_cal = np.asarray([float(r['pred_calibration_error']) for r in ood_rows], dtype=float)

fig, axes = plt.subplots(1, 2, figsize=(12, 4.5))
axes[0].boxplot([id_rel, ood_rel], labels=['in_distribution', 'stress_biofouled'])
axes[0].set_title('Reliability Score Separation')
axes[0].set_ylabel('Score (0-100)')
axes[0].grid(alpha=0.3)

axes[1].boxplot([id_cal, ood_cal], labels=['in_distribution', 'stress_biofouled'])
axes[1].set_title('Calibration Error Separation')
axes[1].set_ylabel('|cov1-0.6827| + |cov2-0.9545|')
axes[1].grid(alpha=0.3)

fig.tight_layout()
fig.savefig(ARTIFACT_DIR / 'posterior_failure_mode_separation.png', dpi=180)
plt.show()

y_true = np.asarray([0] * len(id_rows) + [1] * len(ood_rows), dtype=np.int32)
rel_all = np.asarray([float(r['reliability_score']) for r in id_rows + ood_rows], dtype=float)
cal_all = np.asarray([float(r['pred_calibration_error']) for r in id_rows + ood_rows], dtype=float)

best = None
for rel_thr in np.linspace(30.0, 85.0, 23):
    for cal_thr in np.linspace(0.08, 0.60, 27):
        flagged = (rel_all < rel_thr) | (cal_all > cal_thr)
        tp = float(np.sum((flagged == 1) & (y_true == 1)))
        fp = float(np.sum((flagged == 1) & (y_true == 0)))
        tn = float(np.sum((flagged == 0) & (y_true == 0)))
        fn = float(np.sum((flagged == 0) & (y_true == 1)))
        precision = tp / max(tp + fp, 1e-9)
        recall = tp / max(tp + fn, 1e-9)
        f1 = 2.0 * precision * recall / max(precision + recall, 1e-9)
        tpr = recall
        fpr = fp / max(fp + tn, 1e-9)
        candidate = {
            'rel_thr': float(rel_thr),
            'cal_thr': float(cal_thr),
            'precision': float(precision),
            'recall': float(recall),
            'f1': float(f1),
            'tpr': float(tpr),
            'fpr': float(fpr),
        }
        if best is None or candidate['f1'] > best['f1']:
            best = candidate

print('best guardrail:', best)


In [None]:
# Representative stress-case trace visualization
stress_cases = [r for r in runs if r['split'] == 'stress_biofouled']
worst = min(stress_cases, key=lambda r: float(r['reliability_score']))
key = f"stress_biofouled_{int(worst['case_index'])}"
ex = examples[key]

obs = np.asarray(ex['observed_current'], dtype=float)
mask = np.asarray(ex['observed_mask'], dtype=float) >= 0.5
pred = np.asarray(ex['pred_mean'], dtype=float)
std = np.maximum(np.asarray(ex['pred_std'], dtype=float), 1e-6)
t = np.arange(obs.shape[0], dtype=float)

plt.figure(figsize=(10, 4.5))
plt.fill_between(t, pred - 2.0 * std, pred + 2.0 * std, alpha=0.22, label='pred ±2σ')
plt.fill_between(t, pred - std, pred + std, alpha=0.35, label='pred ±1σ')
plt.plot(t, pred, lw=1.25, label='pred mean')
plt.plot(t[mask], obs[mask], 'k.', ms=3, label='observed')
plt.plot(t[~mask], obs[~mask], '.', color='gray', ms=2, alpha=0.25, label='missing')
plt.title(
    f"Worst stress case #{int(ex['case_index'])} | reliability={float(ex['reliability_score']):.2f}"
)
plt.xlabel('Trace index')
plt.ylabel('Current')
plt.grid(alpha=0.3)
plt.legend(fontsize=8)
plt.tight_layout()
plt.savefig(ARTIFACT_DIR / 'posterior_worst_stress_trace.png', dpi=180)
plt.show()


In [None]:
# Persist report payload
payload = {
    'config': {
        'checkpoint': str(CHECKPOINT),
        'dataset_chunk': str(DATASET_CHUNK),
        'case_indices': case_indices.tolist(),
        'eval_noise_std': EVAL_NOISE_STD,
        'inference_config': INFERENCE_CFG,
    },
    'split_summary': split_summary,
    'runs': runs,
    'best_guardrail': best,
}
with open(ARTIFACT_DIR / 'posterior_failure_mode_report.json', 'w', encoding='utf-8') as f:
    json.dump(payload, f, indent=2)

payload['split_summary']


## Results and reviewer-facing interpretation

- This notebook explicitly tests model mismatch, not only in-distribution success cases.
- Reliability and calibration diagnostics are used as decision variables for accept/reject behavior.
- The threshold rule provides an operational guardrail for production workflows when traces are likely outside training assumptions.

## Improvement actions from this notebook

- Gate downstream decisions on reliability diagnostics.
- Route low-reliability or high-calibration-error traces to higher-fidelity multiphysics solvers or additional measurements.
- Track guardrail metrics longitudinally as model/dataset revisions are introduced.
