# Experiment: Nature Evidence 04 - Posterior Identifiability and Calibration

Objective:
- Validate latent-parameter posterior inference against held-out synthetic measurements with known ground truth.
- Quantify both parameter recovery and predictive uncertainty calibration under realistic missingness/noise.
- Produce reviewer-facing diagnostics that are falsifiable and reproducible.

Success criteria:
- Posterior predictive coverage tracks nominal Gaussian targets (1σ and 2σ) on dense observations.
- Parameter recovery degrades gracefully as observations become sparse/noisy.
- Failure cases are explicitly visualized and interpreted.


In [None]:
# Setup: imports and reproducibility
from __future__ import annotations

import json
import math
import os
import time
from pathlib import Path

import equinox as eqx
import jax
import matplotlib.pyplot as plt
import numpy as np

from ecsfm.fm.eval_classical import _resolve_model_geometry
from ecsfm.fm.model import VectorFieldNet
from ecsfm.fm.posterior import CEMPosteriorConfig, PosteriorInferenceConfig, infer_parameter_posterior
from ecsfm.fm.train import (
    MODEL_META_FILENAME,
    NORMALIZERS_FILENAME,
    load_model_metadata,
    load_saved_normalizers,
)

np.random.seed(2026)

ARTIFACT_DIR = Path('/tmp/ecsfm/notebook_nature_04')
ARTIFACT_DIR.mkdir(parents=True, exist_ok=True)
print(f'Artifacts: {ARTIFACT_DIR}')


## Plan

- Load a trained surrogate and a held-out chunk.
- Build pseudo-measurements from known trajectories under three observation regimes.
- Infer posterior over unknown base electrochemical parameters while fixing known task/stage labels.
- Quantify parameter error, posterior coverage, and predictive calibration.
- Inspect worst-case traces and posterior marginals.


In [None]:
# User-facing configuration
def _first_existing(candidates: list[Path]) -> Path | None:
    for candidate in candidates:
        if candidate.exists():
            return candidate
    return None


def _discover_latest(pattern: str, roots: list[Path]) -> Path | None:
    found: list[Path] = []
    for root in roots:
        if root.exists():
            found.extend(root.rglob(pattern))
    if not found:
        return None
    return max(found, key=lambda p: p.stat().st_mtime)


checkpoint_override = os.getenv('ECSFM_CHECKPOINT')
dataset_override = os.getenv('ECSFM_DATASET_CHUNK')

checkpoint_candidates = [
    Path('/tmp/ecsfm/fullscale_balanced_modal/surrogate_model.eqx'),
    Path('/vol/artifacts/fullscale_balanced_modal/surrogate_model.eqx'),
    Path('/tmp/ecsfm/surrogate_model.eqx'),
]
dataset_candidates = [
    Path('/tmp/ecsfm/dataset_balanced_742k/chunk_0.npz'),
    Path('/vol/datasets/dataset_balanced_742k/chunk_0.npz'),
    Path('/tmp/ecsfm/dataset_massive/chunk_0.npz'),
]

if checkpoint_override:
    CHECKPOINT = Path(checkpoint_override)
else:
    CHECKPOINT = _first_existing(checkpoint_candidates)
    if CHECKPOINT is None:
        CHECKPOINT = _discover_latest('surrogate_model.eqx', [Path('/tmp/ecsfm'), Path('/vol/artifacts')])

if dataset_override:
    DATASET_CHUNK = Path(dataset_override)
else:
    DATASET_CHUNK = _first_existing(dataset_candidates)
    if DATASET_CHUNK is None:
        DATASET_CHUNK = _discover_latest('chunk_0.npz', [Path('/tmp/ecsfm'), Path('/vol/datasets')])
    if DATASET_CHUNK is None:
        DATASET_CHUNK = _discover_latest('chunk_*.npz', [Path('/tmp/ecsfm'), Path('/vol/datasets')])

if CHECKPOINT is None or not CHECKPOINT.exists():
    raise FileNotFoundError(
        'No checkpoint found. Set ECSFM_CHECKPOINT or place a model at one of: '
        + ', '.join(str(p) for p in checkpoint_candidates)
    )
if DATASET_CHUNK is None or not DATASET_CHUNK.exists():
    raise FileNotFoundError(
        'No dataset chunk found. Set ECSFM_DATASET_CHUNK or place a chunk at one of: '
        + ', '.join(str(p) for p in dataset_candidates)
    )

N_CASES = 8
OBSERVATION_REGIMES = {
    'dense_clean': {'keep_prob': 1.00, 'noise_std': 0.02},
    'sparse_clean': {'keep_prob': 0.35, 'noise_std': 0.02},
    'sparse_noisy': {'keep_prob': 0.35, 'noise_std': 0.08},
}
INFERENCE_BASE = {
    'n_particles': 64,
    'n_iters': 4,
    'elite_frac': 0.25,
    'n_mc': 2,
    'n_steps': 60,
}

print('checkpoint:', CHECKPOINT)
print('dataset chunk:', DATASET_CHUNK)


In [None]:
# Load model + dataset and assemble true conditioning vectors
normalizers = load_saved_normalizers(CHECKPOINT.parent / NORMALIZERS_FILENAME)
meta = load_model_metadata(CHECKPOINT.parent / MODEL_META_FILENAME)
geometry = _resolve_model_geometry(normalizers, meta)

key = jax.random.PRNGKey(np.uint32(2026))
_, model_key = jax.random.split(key)
model = VectorFieldNet(
    state_dim=int(geometry['state_dim']),
    hidden_size=int(meta.get('hidden_size', 128)),
    depth=int(meta.get('depth', 3)),
    cond_dim=int(meta.get('cond_dim', 32)),
    phys_dim=int(geometry['phys_dim']),
    signal_channels=int(geometry['signal_channels']),
    key=model_key,
)
model = eqx.tree_deserialise_leaves(CHECKPOINT, model)

with np.load(DATASET_CHUNK) as chunk:
    currents = np.asarray(chunk['i'], dtype=np.float32)
    signals = np.asarray(chunk['e'], dtype=np.float32)
    params_base = np.asarray(chunk['p'], dtype=np.float32)
    task_ids = np.asarray(chunk['task_id'], dtype=np.int32)
    stage_ids = np.asarray(chunk['stage_id'], dtype=np.int32)

n_rows = currents.shape[0]
rng = np.random.default_rng(2026)
case_indices = np.asarray(
    rng.choice(n_rows, size=min(N_CASES, n_rows), replace=False),
    dtype=np.int32,
)

phys_dim_base = int(geometry['phys_dim_base'])
phys_dim_core = int(geometry['phys_dim_core'])
n_tasks = int(geometry['n_tasks'])
n_stages = int(geometry['n_stages'])


def compose_core_params(base_row: np.ndarray, task_idx: int, stage_idx: int) -> np.ndarray:
    out = np.zeros((phys_dim_core,), dtype=np.float32)
    out[:phys_dim_base] = base_row[:phys_dim_base]
    cursor = phys_dim_base
    if n_tasks > 0:
        onehot = np.zeros((n_tasks,), dtype=np.float32)
        onehot[int(np.clip(task_idx, 0, n_tasks - 1))] = 1.0
        out[cursor : cursor + n_tasks] = onehot
        cursor += n_tasks
    if n_stages > 0:
        onehot = np.zeros((n_stages,), dtype=np.float32)
        onehot[int(np.clip(stage_idx, 0, n_stages - 1))] = 1.0
        out[cursor : cursor + n_stages] = onehot
    return out


p_core_true = np.stack(
    [compose_core_params(params_base[i], int(task_ids[i]), int(stage_ids[i])) for i in case_indices],
    axis=0,
)

x_mean, x_std, e_mean, e_std, p_mean, p_std = normalizers
p_mean = np.asarray(p_mean, dtype=np.float32)
p_std = np.asarray(p_std, dtype=np.float32)

print('selected cases:', case_indices.tolist())
print(
    f"geometry: max_species={geometry['max_species']} nx={geometry['nx']} target_len={geometry['target_len']} "
    f"phys_dim_base={phys_dim_base} phys_dim_core={phys_dim_core}"
)


In [None]:
# Posterior sweep across observation regimes
def build_mask(length: int, keep_prob: float, rng: np.random.Generator, min_points: int = 24) -> np.ndarray:
    if keep_prob >= 0.999:
        return np.ones((length,), dtype=np.float32)
    mask = rng.random(length) < keep_prob
    if int(mask.sum()) < min_points:
        idx = rng.choice(length, size=min(min_points, length), replace=False)
        mask[idx] = True
    return mask.astype(np.float32)


records: list[dict[str, float | int | str]] = []
zscores_by_regime: dict[str, np.ndarray] = {}
example_payloads: dict[str, dict[str, np.ndarray | float | int]] = {}

known_mask = np.zeros((phys_dim_core,), dtype=bool)
known_mask[phys_dim_base:] = True

for regime_idx, (regime_name, regime_cfg) in enumerate(OBSERVATION_REGIMES.items()):
    regime_rows = []
    regime_z = []
    regime_payload = []

    for local_case_idx, row_idx in enumerate(case_indices):
        case_rng = np.random.default_rng(2026 + regime_idx * 10000 + int(row_idx))
        signal = np.asarray(signals[row_idx], dtype=np.float32)
        truth_current = np.asarray(currents[row_idx], dtype=np.float32)
        noisy_current = truth_current + case_rng.normal(
            0.0,
            float(regime_cfg['noise_std']),
            size=truth_current.shape[0],
        ).astype(np.float32)
        obs_mask = build_mask(
            truth_current.shape[0],
            keep_prob=float(regime_cfg['keep_prob']),
            rng=case_rng,
        )

        posterior_cfg = PosteriorInferenceConfig(
            cem=CEMPosteriorConfig(
                n_particles=int(INFERENCE_BASE['n_particles']),
                n_iterations=int(INFERENCE_BASE['n_iters']),
                elite_fraction=float(INFERENCE_BASE['elite_frac']),
            ),
            n_mc_per_particle=int(INFERENCE_BASE['n_mc']),
            n_integration_steps=int(INFERENCE_BASE['n_steps']),
            obs_noise_std=float(regime_cfg['noise_std']),
        )

        t0 = time.perf_counter()
        post = infer_parameter_posterior(
            model=model,
            normalizers=normalizers,
            geometry=geometry,
            observed_current=noisy_current,
            applied_signal=signal,
            known_p_core=p_core_true[local_case_idx],
            known_p_mask=known_mask,
            obs_mask=obs_mask,
            config=posterior_cfg,
            seed=2026 + regime_idx * 1000 + local_case_idx,
        )
        elapsed_s = float(time.perf_counter() - t0)

        true_norm = ((p_core_true[local_case_idx] - p_mean) / p_std)[:phys_dim_base]
        post_mean_norm = np.asarray(post['posterior_mean_norm'], dtype=np.float32)[:phys_dim_base]
        post_std_norm = np.asarray(post['posterior_std_norm'], dtype=np.float32)[:phys_dim_base]
        post_std_norm = np.maximum(post_std_norm, 1e-6)

        param_nrmse = float(np.sqrt(np.mean((post_mean_norm - true_norm) ** 2)))
        param_mae = float(np.mean(np.abs(post_mean_norm - true_norm)))
        param_cov1 = float(np.mean(np.abs(post_mean_norm - true_norm) <= post_std_norm))
        param_cov2 = float(np.mean(np.abs(post_mean_norm - true_norm) <= 2.0 * post_std_norm))

        rel = dict(post['reliability'])
        pred_mean = np.asarray(post['predictive_mean_current'], dtype=np.float32)
        pred_std = np.maximum(np.asarray(post['predictive_std_current'], dtype=np.float32), 1e-6)
        obs_current = np.asarray(post['observed_current'], dtype=np.float32)
        obs_mask_rs = np.asarray(post['observed_mask'], dtype=np.float32) >= 0.5
        z = (obs_current[obs_mask_rs] - pred_mean[obs_mask_rs]) / pred_std[obs_mask_rs]

        row = {
            'regime': regime_name,
            'case_index': int(row_idx),
            'elapsed_s': elapsed_s,
            'obs_fraction': float(np.mean(obs_mask_rs)),
            'param_nrmse': param_nrmse,
            'param_mae': param_mae,
            'param_cov_1sigma': param_cov1,
            'param_cov_2sigma': param_cov2,
            'reliability_score': float(rel['reliability_score']),
            'pred_nrmse': float(rel['nrmse']),
            'pred_nll': float(rel['nll']),
            'pred_cov_1sigma': float(rel['coverage_1sigma']),
            'pred_cov_2sigma': float(rel['coverage_2sigma']),
            'pred_calibration_error': float(rel['calibration_error']),
            'pred_sharpness': float(rel['sharpness']),
        }
        regime_rows.append(row)
        regime_z.append(z)
        regime_payload.append(
            {
                'case_index': int(row_idx),
                'reliability_score': float(rel['reliability_score']),
                'observed_current': obs_current,
                'observed_mask': obs_mask_rs.astype(np.float32),
                'pred_mean': pred_mean,
                'pred_std': pred_std,
                'samples_raw': np.asarray(post['posterior_samples_raw'], dtype=np.float32),
                'weights': np.asarray(post['posterior_weights'], dtype=np.float32),
                'post_mean_raw': np.asarray(post['posterior_mean_raw'], dtype=np.float32),
            }
        )

    records.extend(regime_rows)
    zscores_by_regime[regime_name] = np.concatenate(regime_z, axis=0)
    worst_idx = int(np.argmin([row['reliability_score'] for row in regime_rows]))
    example_payloads[regime_name] = regime_payload[worst_idx]


def summarize(rows: list[dict[str, float | int | str]]) -> dict[str, float]:
    arr = lambda key: np.asarray([float(r[key]) for r in rows], dtype=float)
    return {
        'n_cases': float(len(rows)),
        'obs_fraction_mean': float(np.mean(arr('obs_fraction'))),
        'elapsed_s_mean': float(np.mean(arr('elapsed_s'))),
        'param_nrmse_mean': float(np.mean(arr('param_nrmse'))),
        'param_cov_1sigma_mean': float(np.mean(arr('param_cov_1sigma'))),
        'param_cov_2sigma_mean': float(np.mean(arr('param_cov_2sigma'))),
        'pred_nrmse_mean': float(np.mean(arr('pred_nrmse'))),
        'pred_nll_mean': float(np.mean(arr('pred_nll'))),
        'pred_cov_1sigma_mean': float(np.mean(arr('pred_cov_1sigma'))),
        'pred_cov_2sigma_mean': float(np.mean(arr('pred_cov_2sigma'))),
        'pred_calibration_error_mean': float(np.mean(arr('pred_calibration_error'))),
        'reliability_score_mean': float(np.mean(arr('reliability_score'))),
    }


summary = {
    name: summarize([r for r in records if r['regime'] == name])
    for name in OBSERVATION_REGIMES
}

print('--- posterior validation summary ---')
for name, row in summary.items():
    print(
        f"{name:12s} "
        f"R={row['reliability_score_mean']:.2f} "
        f"param_nrmse={row['param_nrmse_mean']:.3f} "
        f"pred_nrmse={row['pred_nrmse_mean']:.3f} "
        f"cal={row['pred_calibration_error_mean']:.3f} "
        f"time={row['elapsed_s_mean']:.2f}s"
    )

payload = {
    'config': {
        'checkpoint': str(CHECKPOINT),
        'dataset_chunk': str(DATASET_CHUNK),
        'case_indices': case_indices.tolist(),
        'observation_regimes': OBSERVATION_REGIMES,
        'inference_base': INFERENCE_BASE,
    },
    'summary': summary,
    'records': records,
}
with open(ARTIFACT_DIR / 'posterior_identifiability_calibration.json', 'w', encoding='utf-8') as f:
    json.dump(payload, f, indent=2)

summary


In [None]:
# Calibration and recovery visualizations
regime_names = list(OBSERVATION_REGIMES.keys())

def _values(metric: str) -> list[np.ndarray]:
    return [
        np.asarray(
            [float(r[metric]) for r in records if r['regime'] == name],
            dtype=float,
        )
        for name in regime_names
    ]


fig, axes = plt.subplots(1, 3, figsize=(15, 4.5))

axes[0].boxplot(_values('reliability_score'), labels=regime_names)
axes[0].set_title('Posterior Reliability Score')
axes[0].set_ylabel('Score (0-100)')
axes[0].grid(alpha=0.3)

axes[1].boxplot(_values('param_nrmse'), labels=regime_names)
axes[1].set_title('Parameter Recovery Error')
axes[1].set_ylabel('NRMSE (normalized parameter space)')
axes[1].grid(alpha=0.3)

cal_err = [
    np.mean(np.asarray([float(r['pred_calibration_error']) for r in records if r['regime'] == name]))
    for name in regime_names
]
axes[2].bar(np.arange(len(regime_names)), cal_err)
axes[2].set_xticks(np.arange(len(regime_names)))
axes[2].set_xticklabels(regime_names)
axes[2].set_title('Predictive Calibration Error')
axes[2].set_ylabel('|cov1-0.6827| + |cov2-0.9545|')
axes[2].grid(axis='y', alpha=0.3)

fig.tight_layout()
fig.savefig(ARTIFACT_DIR / 'posterior_recovery_calibration_boxes.png', dpi=180)
plt.show()

thresholds = np.linspace(0.5, 2.5, 9)
expected = np.asarray([math.erf(t / math.sqrt(2.0)) for t in thresholds], dtype=float)

plt.figure(figsize=(7.5, 5.5))
plt.plot(thresholds, expected, 'k--', lw=1.3, label='ideal Gaussian coverage')
for name in regime_names:
    z = np.asarray(zscores_by_regime[name], dtype=float)
    empirical = np.asarray([np.mean(np.abs(z) <= t) for t in thresholds], dtype=float)
    plt.plot(thresholds, empirical, 'o-', label=name)

plt.xlabel('z threshold')
plt.ylabel('Empirical coverage P(|z| <= threshold)')
plt.title('Posterior Predictive Calibration Curves')
plt.grid(alpha=0.3)
plt.legend()
plt.tight_layout()
plt.savefig(ARTIFACT_DIR / 'posterior_calibration_curves.png', dpi=180)
plt.show()


In [None]:
# Worst-case inspection for sparse_noisy regime
example = example_payloads['sparse_noisy']
obs = np.asarray(example['observed_current'], dtype=float)
mask = np.asarray(example['observed_mask'], dtype=float) >= 0.5
pred = np.asarray(example['pred_mean'], dtype=float)
std = np.asarray(example['pred_std'], dtype=float)
t = np.arange(obs.shape[0], dtype=float)

fig, axes = plt.subplots(1, 3, figsize=(16, 4.5))

ax = axes[0]
ax.fill_between(t, pred - 2.0 * std, pred + 2.0 * std, alpha=0.25, label='pred ±2σ')
ax.fill_between(t, pred - std, pred + std, alpha=0.35, label='pred ±1σ')
ax.plot(t, pred, lw=1.2, label='pred mean')
ax.plot(t[mask], obs[mask], 'k.', ms=3, label='observed')
ax.plot(t[~mask], obs[~mask], '.', color='gray', ms=2, alpha=0.25, label='missing')
ax.set_title(f"Worst sparse_noisy case #{example['case_index']}")
ax.set_xlabel('Trace index')
ax.set_ylabel('Current')
ax.grid(alpha=0.3)
ax.legend(fontsize=8)

samples = np.asarray(example['samples_raw'], dtype=float)
weights = np.asarray(example['weights'], dtype=float)
weights = weights / max(np.sum(weights), 1e-12)
post_mean_raw = np.asarray(example['post_mean_raw'], dtype=float)
true_idx = int(np.where(case_indices == int(example['case_index']))[0][0])
true_row = p_core_true[true_idx]

max_species = int(geometry['max_species'])
idx_e0 = min(4 * max_species, phys_dim_base - 1)
idx_logk0 = min(5 * max_species, phys_dim_base - 1)

axes[1].hist(samples[:, idx_e0], bins=24, weights=weights, density=True, alpha=0.7)
axes[1].axvline(true_row[idx_e0], color='k', linestyle='--', label='truth')
axes[1].axvline(post_mean_raw[idx_e0], color='tab:red', linestyle='-', label='posterior mean')
axes[1].set_title(f'Posterior marginal: E0[0] (idx={idx_e0})')
axes[1].set_xlabel('E0 value')
axes[1].grid(alpha=0.3)
axes[1].legend(fontsize=8)

axes[2].hist(samples[:, idx_logk0], bins=24, weights=weights, density=True, alpha=0.7)
axes[2].axvline(true_row[idx_logk0], color='k', linestyle='--', label='truth')
axes[2].axvline(post_mean_raw[idx_logk0], color='tab:red', linestyle='-', label='posterior mean')
axes[2].set_title(f'Posterior marginal: log(k0)[0] (idx={idx_logk0})')
axes[2].set_xlabel('log(k0) value')
axes[2].grid(alpha=0.3)
axes[2].legend(fontsize=8)

fig.tight_layout()
fig.savefig(ARTIFACT_DIR / 'posterior_sparse_noisy_worst_case.png', dpi=180)
plt.show()


## Results and reviewer-facing interpretation

- This notebook provides a direct inverse-problem validation loop: known truth -> partial observations -> posterior inference -> recovery/calibration metrics.
- Dense measurements should approach nominal calibration and strongest parameter recovery.
- Sparse/noisy conditions expose identifiability limits explicitly, including posterior broadening and degraded coverage.
- The worst-case panel is included intentionally to prevent cherry-picking and document where the model remains weak.

## Improvement actions from this notebook

- Use reliability-aware acceptance criteria (not point-estimate error alone).
- Prefer measurement designs that increase information density in dynamic waveform regions.
- Increase posterior compute budget only where reliability diagnostics indicate under-convergence.
