# Experiment: Nature Evidence 05 - Posterior Ablation and Compute-Quality Frontier

Objective:
- Quantify the compute-vs-quality tradeoff of posterior inference settings.
- Select a production recommendation using reproducible Pareto analysis rather than intuition.
- Evaluate measurement protocol design choices that improve reliability under fixed compute.

Success criteria:
- A ranked set of inference configurations with explicit runtime and quality metrics.
- A recommended "balanced" config justified by measured tradeoffs.
- Protocol-level evidence showing that better observation design improves inference quality.


In [None]:
# Setup: imports and reproducibility
from __future__ import annotations

import json
import os
import time
from pathlib import Path

import equinox as eqx
import jax
import matplotlib.pyplot as plt
import numpy as np

from ecsfm.fm.eval_classical import _resolve_model_geometry
from ecsfm.fm.model import VectorFieldNet
from ecsfm.fm.posterior import CEMPosteriorConfig, PosteriorInferenceConfig, infer_parameter_posterior
from ecsfm.fm.train import (
    MODEL_META_FILENAME,
    NORMALIZERS_FILENAME,
    load_model_metadata,
    load_saved_normalizers,
)

np.random.seed(2026)

ARTIFACT_DIR = Path('/tmp/ecsfm/notebook_nature_05')
ARTIFACT_DIR.mkdir(parents=True, exist_ok=True)
print(f'Artifacts: {ARTIFACT_DIR}')


## Plan

- Load one trained model and one held-out chunk.
- Build a fixed evaluation set of partial/noisy observations.
- Run configuration ablations and compute a Pareto frontier.
- Select a recommendation and test multiple observation-mask protocols.


In [None]:
# User-facing configuration and data/model loading
def _first_existing(candidates: list[Path]) -> Path | None:
    for candidate in candidates:
        if candidate.exists():
            return candidate
    return None


def _discover_latest(pattern: str, roots: list[Path]) -> Path | None:
    found: list[Path] = []
    for root in roots:
        if root.exists():
            found.extend(root.rglob(pattern))
    if not found:
        return None
    return max(found, key=lambda p: p.stat().st_mtime)


checkpoint_override = os.getenv('ECSFM_CHECKPOINT')
dataset_override = os.getenv('ECSFM_DATASET_CHUNK')

checkpoint_candidates = [
    Path('/tmp/ecsfm/fullscale_balanced_modal/surrogate_model.eqx'),
    Path('/vol/artifacts/fullscale_balanced_modal/surrogate_model.eqx'),
    Path('/tmp/ecsfm/surrogate_model.eqx'),
]
dataset_candidates = [
    Path('/tmp/ecsfm/dataset_balanced_742k/chunk_0.npz'),
    Path('/vol/datasets/dataset_balanced_742k/chunk_0.npz'),
    Path('/tmp/ecsfm/dataset_massive/chunk_0.npz'),
]

if checkpoint_override:
    CHECKPOINT = Path(checkpoint_override)
else:
    CHECKPOINT = _first_existing(checkpoint_candidates)
    if CHECKPOINT is None:
        CHECKPOINT = _discover_latest('surrogate_model.eqx', [Path('/tmp/ecsfm'), Path('/vol/artifacts')])

if dataset_override:
    DATASET_CHUNK = Path(dataset_override)
else:
    DATASET_CHUNK = _first_existing(dataset_candidates)
    if DATASET_CHUNK is None:
        DATASET_CHUNK = _discover_latest('chunk_0.npz', [Path('/tmp/ecsfm'), Path('/vol/datasets')])
    if DATASET_CHUNK is None:
        DATASET_CHUNK = _discover_latest('chunk_*.npz', [Path('/tmp/ecsfm'), Path('/vol/datasets')])

if CHECKPOINT is None or not CHECKPOINT.exists():
    raise FileNotFoundError('Checkpoint not found. Set ECSFM_CHECKPOINT.')
if DATASET_CHUNK is None or not DATASET_CHUNK.exists():
    raise FileNotFoundError('Dataset chunk not found. Set ECSFM_DATASET_CHUNK.')

N_CASES = 5
EVAL_NOISE_STD = 0.08

CONFIGS = [
    {'name': 'fast', 'n_particles': 32, 'n_iters': 3, 'elite_frac': 0.25, 'n_mc': 1, 'n_steps': 40},
    {'name': 'balanced', 'n_particles': 64, 'n_iters': 4, 'elite_frac': 0.25, 'n_mc': 2, 'n_steps': 60},
    {'name': 'robust', 'n_particles': 96, 'n_iters': 6, 'elite_frac': 0.25, 'n_mc': 3, 'n_steps': 100},
]

normalizers = load_saved_normalizers(CHECKPOINT.parent / NORMALIZERS_FILENAME)
meta = load_model_metadata(CHECKPOINT.parent / MODEL_META_FILENAME)
geometry = _resolve_model_geometry(normalizers, meta)

key = jax.random.PRNGKey(np.uint32(2026))
_, model_key = jax.random.split(key)
model = VectorFieldNet(
    state_dim=int(geometry['state_dim']),
    hidden_size=int(meta.get('hidden_size', 128)),
    depth=int(meta.get('depth', 3)),
    cond_dim=int(meta.get('cond_dim', 32)),
    phys_dim=int(geometry['phys_dim']),
    signal_channels=int(geometry['signal_channels']),
    key=model_key,
)
model = eqx.tree_deserialise_leaves(CHECKPOINT, model)

with np.load(DATASET_CHUNK) as chunk:
    currents = np.asarray(chunk['i'], dtype=np.float32)
    signals = np.asarray(chunk['e'], dtype=np.float32)
    params_base = np.asarray(chunk['p'], dtype=np.float32)
    task_ids = np.asarray(chunk['task_id'], dtype=np.int32)
    stage_ids = np.asarray(chunk['stage_id'], dtype=np.int32)

rng = np.random.default_rng(2027)
case_indices = np.asarray(rng.choice(currents.shape[0], size=min(N_CASES, currents.shape[0]), replace=False), dtype=np.int32)

phys_dim_base = int(geometry['phys_dim_base'])
phys_dim_core = int(geometry['phys_dim_core'])
n_tasks = int(geometry['n_tasks'])
n_stages = int(geometry['n_stages'])


def compose_core_params(base_row: np.ndarray, task_idx: int, stage_idx: int) -> np.ndarray:
    out = np.zeros((phys_dim_core,), dtype=np.float32)
    out[:phys_dim_base] = base_row[:phys_dim_base]
    cursor = phys_dim_base
    if n_tasks > 0:
        onehot = np.zeros((n_tasks,), dtype=np.float32)
        onehot[int(np.clip(task_idx, 0, n_tasks - 1))] = 1.0
        out[cursor : cursor + n_tasks] = onehot
        cursor += n_tasks
    if n_stages > 0:
        onehot = np.zeros((n_stages,), dtype=np.float32)
        onehot[int(np.clip(stage_idx, 0, n_stages - 1))] = 1.0
        out[cursor : cursor + n_stages] = onehot
    return out


p_core_true = np.stack(
    [compose_core_params(params_base[i], int(task_ids[i]), int(stage_ids[i])) for i in case_indices],
    axis=0,
)

_, _, _, _, p_mean, p_std = normalizers
p_mean = np.asarray(p_mean, dtype=np.float32)
p_std = np.asarray(p_std, dtype=np.float32)

print('checkpoint:', CHECKPOINT)
print('dataset chunk:', DATASET_CHUNK)
print('cases:', case_indices.tolist())


In [None]:
# Measurement protocols: random sparse vs transition-focused vs dense
def mask_random(length: int, keep_prob: float, rng: np.random.Generator, min_points: int = 24) -> np.ndarray:
    mask = rng.random(length) < keep_prob
    if int(mask.sum()) < min_points:
        idx = rng.choice(length, size=min(min_points, length), replace=False)
        mask[idx] = True
    return mask.astype(np.float32)


def mask_transition_focused(signal: np.ndarray, keep_fraction: float, min_points: int = 24) -> np.ndarray:
    signal = np.asarray(signal, dtype=float)
    length = signal.shape[0]
    n_keep = max(min_points, int(round(keep_fraction * length)))
    grad = np.abs(np.gradient(signal))
    idx = np.argsort(grad)[-n_keep:]
    mask = np.zeros((length,), dtype=bool)
    mask[idx] = True
    anchors = np.linspace(0, length - 1, num=min_points, dtype=int)
    mask[anchors] = True
    return mask.astype(np.float32)


protocol_measurements: dict[str, list[dict[str, np.ndarray]]] = {
    'dense': [],
    'random_sparse': [],
    'transition_focus': [],
}

for local_idx, row_idx in enumerate(case_indices):
    rng_case = np.random.default_rng(3000 + int(row_idx))
    signal = np.asarray(signals[row_idx], dtype=np.float32)
    truth = np.asarray(currents[row_idx], dtype=np.float32)
    noisy = truth + rng_case.normal(0.0, EVAL_NOISE_STD, size=truth.shape[0]).astype(np.float32)

    protocol_measurements['dense'].append(
        {'signal': signal, 'obs_current': noisy, 'mask': np.ones_like(noisy, dtype=np.float32)}
    )
    protocol_measurements['random_sparse'].append(
        {
            'signal': signal,
            'obs_current': noisy,
            'mask': mask_random(noisy.shape[0], keep_prob=0.35, rng=rng_case),
        }
    )
    protocol_measurements['transition_focus'].append(
        {
            'signal': signal,
            'obs_current': noisy,
            'mask': mask_transition_focused(signal, keep_fraction=0.35),
        }
    )

print('Prepared measurement protocols for', len(case_indices), 'cases.')


In [None]:
# Config ablation on the random_sparse protocol
known_mask = np.zeros((phys_dim_core,), dtype=bool)
known_mask[phys_dim_base:] = True

ablation_records: list[dict[str, float | str | int]] = []

for cfg in CONFIGS:
    for local_idx, row_idx in enumerate(case_indices):
        meas = protocol_measurements['random_sparse'][local_idx]
        posterior_cfg = PosteriorInferenceConfig(
            cem=CEMPosteriorConfig(
                n_particles=int(cfg['n_particles']),
                n_iterations=int(cfg['n_iters']),
                elite_fraction=float(cfg['elite_frac']),
            ),
            n_mc_per_particle=int(cfg['n_mc']),
            n_integration_steps=int(cfg['n_steps']),
            obs_noise_std=float(EVAL_NOISE_STD),
        )

        t0 = time.perf_counter()
        post = infer_parameter_posterior(
            model=model,
            normalizers=normalizers,
            geometry=geometry,
            observed_current=meas['obs_current'],
            applied_signal=meas['signal'],
            known_p_core=p_core_true[local_idx],
            known_p_mask=known_mask,
            obs_mask=meas['mask'],
            config=posterior_cfg,
            seed=9000 + local_idx + int(cfg['n_particles']),
        )
        elapsed = float(time.perf_counter() - t0)

        true_norm = ((p_core_true[local_idx] - p_mean) / p_std)[:phys_dim_base]
        post_mean_norm = np.asarray(post['posterior_mean_norm'], dtype=np.float32)[:phys_dim_base]
        param_nrmse = float(np.sqrt(np.mean((post_mean_norm - true_norm) ** 2)))
        rel = dict(post['reliability'])

        ablation_records.append(
            {
                'config': str(cfg['name']),
                'case_index': int(row_idx),
                'elapsed_s': elapsed,
                'param_nrmse': param_nrmse,
                'reliability_score': float(rel['reliability_score']),
                'pred_nrmse': float(rel['nrmse']),
                'pred_nll': float(rel['nll']),
                'pred_calibration_error': float(rel['calibration_error']),
                'n_particles': int(cfg['n_particles']),
                'n_iters': int(cfg['n_iters']),
                'n_mc': int(cfg['n_mc']),
                'n_steps': int(cfg['n_steps']),
            }
        )


def summarize_config(name: str) -> dict[str, float | str]:
    rows = [r for r in ablation_records if r['config'] == name]
    arr = lambda key: np.asarray([float(r[key]) for r in rows], dtype=float)
    return {
        'config': name,
        'elapsed_s_mean': float(np.mean(arr('elapsed_s'))),
        'param_nrmse_mean': float(np.mean(arr('param_nrmse'))),
        'reliability_score_mean': float(np.mean(arr('reliability_score'))),
        'pred_calibration_error_mean': float(np.mean(arr('pred_calibration_error'))),
        'pred_nll_mean': float(np.mean(arr('pred_nll'))),
    }


config_summary = [summarize_config(cfg['name']) for cfg in CONFIGS]


def is_dominated(a: dict[str, float | str], b: dict[str, float | str]) -> bool:
    qa = float(a['reliability_score_mean'])
    ta = float(a['elapsed_s_mean'])
    qb = float(b['reliability_score_mean'])
    tb = float(b['elapsed_s_mean'])
    return (qb >= qa and tb <= ta) and (qb > qa or tb < ta)


pareto = []
for row in config_summary:
    dominated = any(is_dominated(row, other) for other in config_summary if other is not row)
    if not dominated:
        pareto.append(row)

baseline_time = min(float(r['elapsed_s_mean']) for r in config_summary)
utilities = []
for row in config_summary:
    score = float(row['reliability_score_mean'])
    time_penalty = float(row['elapsed_s_mean']) / max(baseline_time, 1e-9)
    utility = score - 6.0 * time_penalty
    utilities.append((utility, row['config']))
recommended_config_name = max(utilities)[1]

print('--- config summary ---')
for row in sorted(config_summary, key=lambda x: -float(x['reliability_score_mean'])):
    print(
        f"{row['config']:9s} "
        f"R={row['reliability_score_mean']:.2f} "
        f"param_nrmse={row['param_nrmse_mean']:.3f} "
        f"time={row['elapsed_s_mean']:.2f}s "
        f"cal={row['pred_calibration_error_mean']:.3f}"
    )
print('pareto:', [row['config'] for row in pareto])
print('recommended:', recommended_config_name)


In [None]:
# Visualize compute-quality frontier
fig, ax = plt.subplots(figsize=(7.5, 5.5))

for row in config_summary:
    x = float(row['elapsed_s_mean'])
    y = float(row['reliability_score_mean'])
    ax.scatter(x, y, s=90)
    ax.text(x * 1.02, y + 0.2, str(row['config']), fontsize=9)

if pareto:
    pareto_sorted = sorted(pareto, key=lambda r: float(r['elapsed_s_mean']))
    ax.plot(
        [float(r['elapsed_s_mean']) for r in pareto_sorted],
        [float(r['reliability_score_mean']) for r in pareto_sorted],
        'k--',
        alpha=0.7,
        label='Pareto frontier',
    )

ax.set_xlabel('Mean runtime per case (s)')
ax.set_ylabel('Mean reliability score')
ax.set_title('Posterior Inference Compute-Quality Frontier')
ax.grid(alpha=0.3)
ax.legend()
fig.tight_layout()
fig.savefig(ARTIFACT_DIR / 'posterior_ablation_frontier.png', dpi=180)
plt.show()


In [None]:
# Protocol design study using the recommended config
cfg_map = {cfg['name']: cfg for cfg in CONFIGS}
chosen_cfg = cfg_map[recommended_config_name]

protocol_records: list[dict[str, float | str | int]] = []
for protocol_name, measurements in protocol_measurements.items():
    for local_idx, row_idx in enumerate(case_indices):
        meas = measurements[local_idx]
        posterior_cfg = PosteriorInferenceConfig(
            cem=CEMPosteriorConfig(
                n_particles=int(chosen_cfg['n_particles']),
                n_iterations=int(chosen_cfg['n_iters']),
                elite_fraction=float(chosen_cfg['elite_frac']),
            ),
            n_mc_per_particle=int(chosen_cfg['n_mc']),
            n_integration_steps=int(chosen_cfg['n_steps']),
            obs_noise_std=float(EVAL_NOISE_STD),
        )

        post = infer_parameter_posterior(
            model=model,
            normalizers=normalizers,
            geometry=geometry,
            observed_current=meas['obs_current'],
            applied_signal=meas['signal'],
            known_p_core=p_core_true[local_idx],
            known_p_mask=known_mask,
            obs_mask=meas['mask'],
            config=posterior_cfg,
            seed=12000 + local_idx,
        )

        true_norm = ((p_core_true[local_idx] - p_mean) / p_std)[:phys_dim_base]
        post_mean_norm = np.asarray(post['posterior_mean_norm'], dtype=np.float32)[:phys_dim_base]
        param_nrmse = float(np.sqrt(np.mean((post_mean_norm - true_norm) ** 2)))
        rel = dict(post['reliability'])

        protocol_records.append(
            {
                'protocol': protocol_name,
                'case_index': int(row_idx),
                'obs_fraction': float(np.mean(np.asarray(meas['mask']) >= 0.5)),
                'param_nrmse': param_nrmse,
                'reliability_score': float(rel['reliability_score']),
                'pred_calibration_error': float(rel['calibration_error']),
            }
        )


def summarize_protocol(name: str) -> dict[str, float | str]:
    rows = [r for r in protocol_records if r['protocol'] == name]
    arr = lambda key: np.asarray([float(r[key]) for r in rows], dtype=float)
    return {
        'protocol': name,
        'obs_fraction_mean': float(np.mean(arr('obs_fraction'))),
        'param_nrmse_mean': float(np.mean(arr('param_nrmse'))),
        'reliability_score_mean': float(np.mean(arr('reliability_score'))),
        'pred_calibration_error_mean': float(np.mean(arr('pred_calibration_error'))),
    }


protocol_summary = [summarize_protocol(name) for name in protocol_measurements]
protocol_summary = sorted(protocol_summary, key=lambda r: -float(r['reliability_score_mean']))

print('--- protocol summary (using', recommended_config_name, ') ---')
for row in protocol_summary:
    print(
        f"{row['protocol']:16s} "
        f"R={row['reliability_score_mean']:.2f} "
        f"param_nrmse={row['param_nrmse_mean']:.3f} "
        f"obs_frac={row['obs_fraction_mean']:.3f} "
        f"cal={row['pred_calibration_error_mean']:.3f}"
    )

payload = {
    'config': {
        'checkpoint': str(CHECKPOINT),
        'dataset_chunk': str(DATASET_CHUNK),
        'case_indices': case_indices.tolist(),
        'eval_noise_std': EVAL_NOISE_STD,
        'configs': CONFIGS,
        'recommended_config': recommended_config_name,
    },
    'ablation_summary': config_summary,
    'pareto_configs': [row['config'] for row in pareto],
    'ablation_records': ablation_records,
    'protocol_summary': protocol_summary,
    'protocol_records': protocol_records,
}
with open(ARTIFACT_DIR / 'posterior_ablation_protocols.json', 'w', encoding='utf-8') as f:
    json.dump(payload, f, indent=2)

protocols = [row['protocol'] for row in protocol_summary]
rel = np.asarray([row['reliability_score_mean'] for row in protocol_summary], dtype=float)
prm = np.asarray([row['param_nrmse_mean'] for row in protocol_summary], dtype=float)
cal = np.asarray([row['pred_calibration_error_mean'] for row in protocol_summary], dtype=float)

x = np.arange(len(protocols))
width = 0.28
fig, ax = plt.subplots(figsize=(10, 5))
ax.bar(x - width, rel, width, label='reliability score')
ax.bar(x, prm, width, label='param nrmse')
ax.bar(x + width, cal, width, label='calibration error')
ax.set_xticks(x)
ax.set_xticklabels(protocols)
ax.set_title('Protocol Design Impact (recommended inference config)')
ax.grid(axis='y', alpha=0.3)
ax.legend(fontsize=8)
fig.tight_layout()
fig.savefig(ARTIFACT_DIR / 'posterior_protocol_design_comparison.png', dpi=180)
plt.show()

protocol_summary


## Results and reviewer-facing interpretation

- This notebook makes compute-quality tradeoffs explicit, with a Pareto frontier and a reproducible recommendation rule.
- The recommended configuration is not assumed; it is selected from measured reliability, calibration, and runtime.
- Protocol design (where/when observations are taken) can materially improve inverse inference quality under a fixed compute budget.

## Improvement actions from this notebook

- Use the recommended config as default for balanced production runs.
- For resource-constrained runs, retain fast mode but require stricter reliability gating.
- Prefer transition-focused measurement schedules when dense sampling is impractical.
