# Background Model Explorer

Interactively explore the Guardian background channels before enabling any collision analysis. Adjust the controls, rerun the simulation, and export the resulting Guardian certificate once the background gate passes.


In [None]:
from dataclasses import asdict
import json
from pathlib import Path
from datetime import datetime

import ipywidgets as widgets
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import Markdown, display
from scipy.signal import welch

from simulation.background_effects_simulator import BackgroundConfig, simulate_background_timeseries
from simulation.guardian_validators.guardian_background_validator import guardian_check_backgrounds
from simulation.guardian_validators.signal_to_background_analyzer import estimate_snr
from simulation.guardian_validators.systematic_effect_mapper import map_systematics_to_mitigations

In [None]:
plt.style.use('seaborn-v0_8')

PRESETS = {
    'Default (lab nominal)': {
        'T_kelvin': 295.0,
        'mains_hz': 50.0,
        'rf_pickup_rms': 0.5,
        'patch_potential_rms_mV': 5.0,
        'patch_corr_length_um': 50.0,
        'photon_rate_bg_cps': 220.0,
        'readout_integration_ms': 1.0,
    },
    'Mains 60 Hz emphasis': {
        'T_kelvin': 295.0,
        'mains_hz': 60.0,
        'rf_pickup_rms': 1.2,
        'patch_potential_rms_mV': 4.0,
        'patch_corr_length_um': 40.0,
        'photon_rate_bg_cps': 210.0,
        'readout_integration_ms': 1.0,
    },
    'Strong patch drift': {
        'T_kelvin': 310.0,
        'mains_hz': 50.0,
        'rf_pickup_rms': 0.6,
        'patch_potential_rms_mV': 12.0,
        'patch_corr_length_um': 150.0,
        'photon_rate_bg_cps': 240.0,
        'readout_integration_ms': 2.0,
    },
}

LATEST_STATE = {"config": None, "data": None, "report": None, "snr": None}


def allan_like(series: np.ndarray, dt: float, points: int = 8):
    series = np.asarray(series, dtype=float)
    n = series.size
    if n < 4:
        return np.array([]), np.array([])
    max_m = max(1, n // 20)
    m_values = np.unique(np.clip(np.logspace(0, np.log10(max_m), points, dtype=int), 1, None))
    taus = m_values * dt
    deviations = []
    for m in m_values:
        diff = series[2 * m :] - 2 * series[m:-m] + series[: -2 * m]
        if diff.size == 0:
            deviations.append(np.nan)
            continue
        deviations.append(np.sqrt(np.mean(diff**2) / (2 * m**2)))
    return taus, np.array(deviations)

In [None]:
temperature = widgets.FloatSlider(value=PRESETS['Default (lab nominal)']['T_kelvin'], min=100.0, max=600.0, step=5.0, description='T [K]')
mains = widgets.FloatSlider(value=PRESETS['Default (lab nominal)']['mains_hz'], min=0.0, max=120.0, step=1.0, description='Mains [Hz]')
rf_pickup = widgets.FloatSlider(value=PRESETS['Default (lab nominal)']['rf_pickup_rms'], min=0.0, max=5.0, step=0.1, description='RF rms [mV]')
patch_rms = widgets.FloatSlider(value=PRESETS['Default (lab nominal)']['patch_potential_rms_mV'], min=0.0, max=30.0, step=0.5, description='Patch rms [mV]')
patch_corr = widgets.FloatSlider(value=PRESETS['Default (lab nominal)']['patch_corr_length_um'], min=5.0, max=500.0, step=5.0, description='Patch corr [μm]')
photon_rate = widgets.FloatSlider(value=PRESETS['Default (lab nominal)']['photon_rate_bg_cps'], min=10.0, max=1000.0, step=10.0, description='BG cps')
readout_ms = widgets.FloatSlider(value=PRESETS['Default (lab nominal)']['readout_integration_ms'], min=0.1, max=10.0, step=0.1, description='t_int [ms]')
samples = widgets.IntSlider(value=8000, min=500, max=20000, step=500, description='Samples')
dt = widgets.FloatLogSlider(value=1e-4, base=10, min=-5, max=-1, step=0.1, description='dt [s]')
seed = widgets.IntSlider(value=0, min=0, max=9999, step=1, description='Seed')

channel_checks = {
    'position': widgets.Checkbox(value=True, description='Position'),
    'em_pickup': widgets.Checkbox(value=True, description='EM pickup'),
    'surface_drift': widgets.Checkbox(value=True, description='Surface drift'),
    'detector_counts': widgets.Checkbox(value=False, description='Detector counts'),
}

preset = widgets.Dropdown(options=list(PRESETS.keys()), value='Default (lab nominal)', description='Preset')
run_button = widgets.Button(description='Run background simulation', button_style='primary', icon='play')
collision_button = widgets.Button(description='Collision analysis locked', disabled=True, button_style='danger', icon='lock')
certificate_button = widgets.Button(description='Export Guardian certificate', button_style='success', icon='download')
guardian_banner = widgets.HTML()
metrics_html = widgets.HTML()
systematics_html = widgets.HTML()
certificate_status = widgets.HTML()
output_area = widgets.Output()

In [None]:
def update_from_preset(change):
    cfg = PRESETS[change['new']]
    temperature.value = cfg['T_kelvin']
    mains.value = cfg['mains_hz']
    rf_pickup.value = cfg['rf_pickup_rms']
    patch_rms.value = cfg['patch_potential_rms_mV']
    patch_corr.value = cfg['patch_corr_length_um']
    photon_rate.value = cfg['photon_rate_bg_cps']
    readout_ms.value = cfg['readout_integration_ms']
    render_dashboard()


def render_dashboard(*_):
    cfg = BackgroundConfig(
        T_kelvin=temperature.value,
        mains_hz=mains.value,
        rf_pickup_rms=rf_pickup.value,
        em_coupling_coeff=1e-3,
        patch_potential_rms_mV=patch_rms.value,
        patch_corr_length_um=patch_corr.value,
        photon_rate_bg_cps=photon_rate.value,
        readout_integration_ms=readout_ms.value,
    )
    data = simulate_background_timeseries(
        n_samples=int(samples.value),
        dt_s=float(dt.value),
        cfg=cfg,
        seed=int(seed.value),
    )
    report = guardian_check_backgrounds(data)
    snr_val = float(estimate_snr(data))
    LATEST_STATE['config'] = cfg
    LATEST_STATE['data'] = data
    LATEST_STATE['report'] = report
    LATEST_STATE['snr'] = snr_val

    time_axis = np.arange(data['metadata']['n_samples']) * data['metadata']['dt_s']
    selected = [name for name, widget in channel_checks.items() if widget.value]

    with output_area:
        output_area.clear_output()
        fig, axes = plt.subplots(3, 1, figsize=(10, 12))
        ax_ts, ax_psd, ax_allan = axes

        if selected:
            for name in selected:
                label = name.replace('_', ' ').title()
                series = np.asarray(data[name])
                ax_ts.plot(time_axis, series, label=label)
            ax_ts.set_xlabel('Time [s]')
            ax_ts.set_ylabel('Amplitude')
            ax_ts.legend(loc='upper right')
        else:
            ax_ts.text(0.5, 0.5, 'Select at least one channel to display', ha='center', va='center')
            ax_ts.set_axis_off()

        freqs, psd = welch(np.asarray(data['em_pickup']), fs=1.0 / data['metadata']['dt_s'])
        ax_psd.semilogy(freqs, psd + 1e-18, color='tab:orange')
        ax_psd.set_xlabel('Frequency [Hz]')
        ax_psd.set_ylabel('PSD [arb/Hz]')
        ax_psd.set_title('EM pickup PSD (Welch)')

        taus, allan = allan_like(np.asarray(data['surface_drift']), data['metadata']['dt_s'])
        if taus.size:
            ax_allan.loglog(taus, np.abs(allan) + 1e-18, marker='o', color='tab:green')
            ax_allan.set_xlabel('Averaging time τ [s]')
            ax_allan.set_ylabel('Allan-like deviation [mV]')
        else:
            ax_allan.text(0.5, 0.5, 'Not enough samples for Allan deviation', ha='center', va='center')
            ax_allan.set_axis_off()

        plt.tight_layout()
        display(fig)
        plt.close(fig)

        contributions = report.get('contributions', {})
        metrics_html.value = (
            "<h4>Guardian metrics</h4>"
            f"<ul>"
            f"<li>Inventory complete: {'✅' if report['inventory_ok'] else '⚠️'}</li>"
            f"<li>Null hypothesis (95%): {'✅' if report['null_95_ok'] else '⚠️'}</li>"
            f"<li>SNR ≥ 10: {'✅' if report['snr_10_ok'] else '⚠️'} (estimate={snr_val:.2f})</li>"
            f"</ul>"
            f"<p><strong>Variance contributions</strong>: {contributions}</p>"
        )
        mitigation_lines = ''.join(
            f"<li><strong>{k}</strong>: {v}</li>" for k, v in map_systematics_to_mitigations().items()
        )
        systematics_html.value = "<h4>Mitigation playbook</h4><ul>" + mitigation_lines + "</ul>"

    if report['guardian_pass']:
        guardian_banner.value = "<div style='padding:10px;background-color:#14532d;color:white;font-weight:600;'>Guardian PASS ✅ Background gate cleared.</div>"
        collision_button.description = 'Collision analysis unlocked'
        collision_button.disabled = False
        collision_button.button_style = 'success'
        collision_button.icon = 'unlock'
    else:
        guardian_banner.value = "<div style='padding:10px;background-color:#7f1d1d;color:white;font-weight:600;'>Guardian FAIL 🚫 Background gate not yet satisfied.</div>"
        collision_button.description = 'Collision analysis locked'
        collision_button.disabled = True
        collision_button.button_style = 'danger'
        collision_button.icon = 'lock'


def export_certificate(_):
    if not LATEST_STATE['report']:
        certificate_status.value = '<em>Run the simulation before exporting.</em>'
        return
    artifact_dir = Path('artifacts/notebook')
    artifact_dir.mkdir(parents=True, exist_ok=True)
    timestamp = datetime.utcnow().strftime('%Y%m%dT%H%M%SZ')
    path = artifact_dir / f'guardian_certificate_{timestamp}.json'
    payload = {
        'timestamp_utc': timestamp,
        'config': asdict(LATEST_STATE['config']),
        'report': LATEST_STATE['report'],
        'snr_estimate': float(LATEST_STATE['snr']),
        'metadata': {
            'n_samples': int(LATEST_STATE['data']['metadata']['n_samples']),
            'dt_s': float(LATEST_STATE['data']['metadata']['dt_s']),
            'seed': int(LATEST_STATE['data']['metadata']['seed']),
        },
    }
    path.write_text(json.dumps(payload, indent=2))
    certificate_status.value = f"<span style='color:#14532d;'>Certificate saved to {path}</span>"

In [None]:
preset.observe(update_from_preset, names='value')
run_button.on_click(render_dashboard)
certificate_button.on_click(export_certificate)

def _channel_change(change):
    render_dashboard()

for widget in channel_checks.values():
    widget.observe(_channel_change, names='value')

controls = widgets.VBox([
    preset,
    temperature,
    mains,
    rf_pickup,
    patch_rms,
    patch_corr,
    photon_rate,
    readout_ms,
    samples,
    dt,
    seed,
    widgets.HBox(list(channel_checks.values())),
    run_button,
])

right_column = widgets.VBox([
    guardian_banner,
    output_area,
    metrics_html,
    systematics_html,
    widgets.HBox([certificate_button, collision_button]),
    certificate_status,
], layout=widgets.Layout(width='100%'))

app = widgets.HBox([
    controls,
    right_column,
], layout=widgets.Layout(align_items='flex-start', width='100%', gap='20px'))

display(app)
render_dashboard()