# 04_results

Plot inversion results: updated models (with true-model comparison), resistivity vs depth/x, and compare synthetics to real data. Optionally generate synthetics from a chosen model via FD forward.

Run with Voila (`--strip_sources=True`).

In [None]:
from pathlib import Path
import json
import os
import re
import shutil
import signal
import subprocess
import sys
import traceback

ROOT = Path('/Users/wiktorweibull/Rockem_projects/EM_inversion_workshop').resolve()
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

import numpy as np
try:
    import ipywidgets as ipw
    import plotly.graph_objects as go
except Exception as exc:
    raise RuntimeError('Install: pip install voila ipywidgets plotly numpy') from exc

from scripts.modules.fd_visualization import (
    load_rss_traces,
    compute_amp_phase_for_fd_outputs,
    build_trace_index,
)
from scripts.modules.segy import write_resistivity_to_segy_from_template

FDMODEL_DIR = ROOT / 'FDmodel'
DATA_DIR = FDMODEL_DIR / 'Data'
INV_INPUT_DIR = ROOT / 'InversionInput'
SETUP_META = FDMODEL_DIR / 'setup_metadata.json'
RUN_DIR_PATTERN = re.compile(r'^InversionRun(\d+)$')
MAX_CPUS = max(2, os.cpu_count() or 2)
MPI_EMMOD_BIN = os.path.expanduser('~/software/rockem-suite/bin/mpiEmmodADITE2d')
SG_TRUE_PATH = FDMODEL_DIR / 'sg.rss'
VOILA_PID_FILES = [
    ROOT / '.voila_server.pid',
    ROOT / '.voila_results_server.pid',
]

state = {
    'last_messages': [],
    'initial_load': True,
    'fd_process': None,
    'real_result': None,
    'syn_result': None,
    'current_run_dir': None,
    'current_model_path': None,
    'current_x': None,
    'current_z': None,
}

def push_message(msg):
    state['last_messages'].append(msg)
    if len(state['last_messages']) > 12:
        state['last_messages'] = state['last_messages'][-12:]
    if 'status_out' in dir() and status_out is not None:
        status_out.value = '\n'.join(state['last_messages'])

def list_run_dirs(root_dir):
    out = []
    for child in Path(root_dir).iterdir():
        if not child.is_dir():
            continue
        m = RUN_DIR_PATTERN.match(child.name)
        if m:
            out.append((int(m.group(1)), child))
    out.sort(key=lambda x: x[0])
    return out

def latest_sg_up_file(run_dir):
    run_dir = Path(run_dir)
    candidates = list(run_dir.glob('Results/sg_up.rss-*')) + list(run_dir.glob('sg_up.rss-*'))
    if not candidates:
        return None
    def _suffix_num(p):
        m = re.search(r'sg_up\\.rss-(\\d+)$', p.name)
        return int(m.group(1)) if m else -1
    return sorted(candidates, key=lambda p: (_suffix_num(p), p.name))[-1]

def build_model_list(run_dir):
    if run_dir is None:
        return []
    run_dir = Path(run_dir)
    options = []
    if (run_dir / 'sg0.rss').exists():
        options.append(('sg0', run_dir / 'sg0.rss'))
    if (run_dir / 'sg_ls.rss').exists():
        options.append(('sg_ls', run_dir / 'sg_ls.rss'))
    def _num(p):
        m = re.search(r'(\\d+)$', p.name)
        return int(m.group(1)) if m else 0
    for p in sorted(run_dir.glob('Results/sg_up.rss-*'), key=_num):
        options.append((p.name, p))
    if SG_TRUE_PATH.exists():
        options.append(('True (FDmodel/sg.rss)', SG_TRUE_PATH))
    return options

def get_real_data_paths(run_dir):
    run_dir = Path(run_dir) if run_dir else None
    if run_dir:
        hx = run_dir / 'Hx_data.rss'
        hz = run_dir / 'Hz_data.rss'
        if hx.exists() and hz.exists():
            return hx, hz
    hx = INV_INPUT_DIR / 'Hx_data.rss'
    hz = INV_INPUT_DIR / 'Hz_data.rss'
    return hx, hz

def _read_rss_model(path):
    from third_party.rockseis.io.rsfile import rsfile
    f = rsfile()
    f.read(str(path))
    data = np.asarray(f.data, dtype=float)
    data = np.squeeze(data)
    if data.ndim != 2:
        raise ValueError(f'Expected 2D model for {path}, got shape {data.shape}')
    nx, nz = int(data.shape[0]), int(data.shape[1])
    grid = np.asarray(data.T, dtype=float)
    dx = float(f.geomD[0]) if f.geomD[0] else 1.0
    ox = float(f.geomO[0])
    iz = 2 if (len(f.geomN) > 2 and int(f.geomN[2]) > 0) else 1
    dz = float(f.geomD[iz]) if f.geomD[iz] else 1.0
    oz = float(f.geomO[iz])
    x = ox + dx * np.arange(nx)
    z = oz + dz * np.arange(nz)
    return x, z, grid

def _extract_positions(run_dir):
    run_dir = Path(run_dir)
    hx_path = run_dir / 'Hx_data.rss'
    if not hx_path.exists():
        hx_path = INV_INPUT_DIR / 'Hx_data.rss'
    if not hx_path.exists():
        return np.array([]), np.array([]), np.array([]), np.array([])
    try:
        meta = load_rss_traces(hx_path)
        src = np.column_stack((meta['src_x'], meta['src_z']))
        rec = np.column_stack((meta['rx_x'], meta['rx_z']))
        src_u = np.unique(np.round(src, 6), axis=0)
        rec_u = np.unique(np.round(rec, 6), axis=0)
        return src_u[:, 0], src_u[:, 1], rec_u[:, 0], rec_u[:, 1]
    except Exception:
        return np.array([]), np.array([]), np.array([]), np.array([])

def _align_axes_to_survey(x, z, tx_x, tx_z, rx_x, rx_z):
    x, z = np.asarray(x, dtype=float), np.asarray(z, dtype=float)
    pts_x = np.concatenate([tx_x, rx_x]) if (tx_x.size + rx_x.size) > 0 else np.array([])
    pts_z = np.concatenate([tx_z, rx_z]) if (tx_z.size + rx_z.size) > 0 else np.array([])
    if pts_x.size == 0 or pts_z.size == 0:
        return x, z
    x_span = float(np.nanmax(x) - np.nanmin(x)) if x.size else 0.0
    z_span = float(np.nanmax(z) - np.nanmin(z)) if z.size else 0.0
    x_center = float(0.5 * (np.nanmin(x) + np.nanmax(x))) if x.size else 0.0
    z_center = float(0.5 * (np.nanmin(z) + np.nanmax(z))) if z.size else 0.0
    sx_center = float(0.5 * (np.nanmin(pts_x) + np.nanmax(pts_x)))
    sz_center = float(0.5 * (np.nanmin(pts_z) + np.nanmax(pts_z)))
    if x_span > 0.0 and abs(sx_center - x_center) > 0.5 * x_span:
        x = x + (sx_center - x_center)
    if z_span > 0.0 and abs(sz_center - z_center) > 0.5 * z_span:
        z = z + (sz_center - z_center)
    return x, z

def conductivity_to_resistivity(grid, min_sigma=1e-12):
    sigma = np.clip(np.asarray(grid, dtype=float), min_sigma, 1e12)
    return 1.0 / sigma

def resistivity_slice_at_x(x_arr, z_arr, grid, x_pick):
    x_arr = np.asarray(x_arr)
    idx = int(np.argmin(np.abs(x_arr - x_pick)))
    rho = conductivity_to_resistivity(grid[:, idx])
    return np.asarray(z_arr), np.asarray(rho)

def resistivity_slice_at_z(x_arr, z_arr, grid, z_pick):
    z_arr = np.asarray(z_arr)
    idx = int(np.argmin(np.abs(z_arr - z_pick)))
    rho = conductivity_to_resistivity(grid[idx, :])
    return np.asarray(x_arr), np.asarray(rho)

def default_frequencies():
    if SETUP_META.exists():
        try:
            meta = json.loads(SETUP_META.read_text())
            vals = meta.get('flist_hz', [])
            vals = [float(v) for v in vals]
            if vals:
                return vals
        except Exception:
            pass
    return [2e3, 4e3, 6e3]

def parse_freqs(text):
    parts = [p.strip() for p in str(text).split(',') if p.strip()]
    if not parts:
        raise ValueError('Frequency list cannot be empty.')
    vals = np.asarray([float(p) for p in parts], dtype=float)
    if np.any(vals <= 0):
        raise ValueError('Frequencies must be positive.')
    return vals

def infer_iteration_from_model_path(model_path, run_dir):
    model_path = Path(model_path)
    name = model_path.name
    m = re.search(r'sg_up\.rss-(\d+)$', name)
    if m:
        return int(m.group(1))
    if name == 'sg0.rss':
        return 0
    if name == 'sg_ls.rss':
        progress = Path(run_dir) / 'progress.log'
        if progress.exists():
            lines = progress.read_text(errors='replace').splitlines()
            iters = []
            for ln in lines:
                mm = re.match(r'^Iteration\s+(\d+)', ln.strip())
                if mm:
                    iters.append(int(mm.group(1)))
            if iters:
                return max(iters)
        return 0
    return 0


def available_synthetic_pairs(run_dir):
    run_dir = Path(run_dir)
    pairs = []
    seen = set()

    def _add_pair(idx, hx, hz):
        key = (int(idx), str(hx), str(hz))
        if key in seen:
            return
        seen.add(key)
        pairs.append((int(idx), Path(hx), Path(hz)))

    # Preferred indexed names written by this GUI.
    for hx in run_dir.glob('data_Hx_mod.rss-*'):
        m = re.search(r'data_Hx_mod\.rss-(\d+)$', hx.name)
        if not m:
            continue
        idx = int(m.group(1))
        hz = run_dir / f'data_Hz_mod.rss-{idx}'
        if hz.exists():
            _add_pair(idx, hx, hz)

    # Alternate indexed names sometimes produced by inversion binaries.
    for hx in run_dir.glob('data_mod_HX.rss-*'):
        m = re.search(r'data_mod_HX\.rss-(\d+)$', hx.name)
        if not m:
            continue
        idx = int(m.group(1))
        hz = run_dir / f'data_mod_HZ.rss-{idx}'
        if hz.exists():
            _add_pair(idx, hx, hz)

    # Plain (non-indexed) name pairs.
    plain_pairs = [
        (-2, run_dir / 'data_Hx_mod.rss', run_dir / 'data_Hz_mod.rss'),
        (-1, run_dir / 'data_mod_HX.rss', run_dir / 'data_mod_HZ.rss'),
    ]
    for idx, hx, hz in plain_pairs:
        if hx.exists() and hz.exists():
            _add_pair(idx, hx, hz)

    pairs.sort(key=lambda t: t[0])
    return pairs


def find_synthetic_in_run(run_dir, iteration=None):
    pairs = available_synthetic_pairs(run_dir)
    if not pairs:
        return None, None, None

    # If a specific model iteration is requested, do NOT fall back
    # to a different iteration (misleading in GUI).
    if iteration is not None:
        for idx, hx, hz in pairs:
            if idx == int(iteration):
                return idx, hx, hz
        return None, None, None

    idx, hx, hz = pairs[-1]
    return idx, hx, hz


push_message('Results GUI loaded. Select a run and model.')

In [None]:
from IPython.display import display, clear_output
from plotly.subplots import make_subplots

runs = list_run_dirs(ROOT)
run_options = [(p.name, str(p)) for _, p in runs] if runs else [('No runs', None)]

run_selector = ipw.Dropdown(options=run_options, value=run_options[0][1] if run_options and run_options[0][1] else None, description='Run:', layout=ipw.Layout(width='400px'))
model_selector = ipw.Dropdown(options=[], value=None, description='Model:', layout=ipw.Layout(width='400px'))
refresh_models_btn = ipw.Button(description='Refresh model list')
export_segy_btn = ipw.Button(description='Export selected model to SEGY')
export_status = ipw.HTML(value='')

model_plot_out = ipw.Output(layout=ipw.Layout(width='100%', height='500px'))
x_slice_input = ipw.FloatText(value=0.0, description='x (m):', layout=ipw.Layout(width='140px'))
z_slice_input = ipw.FloatText(value=0.0, description='z (m):', layout=ipw.Layout(width='140px'))
rho_z_out = ipw.Output(layout=ipw.Layout(width='100%', height='320px'))
rho_x_out = ipw.Output(layout=ipw.Layout(width='100%', height='320px'))

freqs_input = ipw.Text(value=','.join([f'{v:g}' for v in default_frequencies()]), description='Freqs (Hz):', layout=ipw.Layout(width='320px'))
load_real_btn = ipw.Button(description='Load real data')
load_syn_from_run_btn = ipw.Button(description='Load synthetics from run')
generate_syn_btn = ipw.Button(description='Generate synthetics from model', button_style='primary')
refresh_fd_btn = ipw.Button(description='Refresh FD status / load synthetics')
quit_btn = ipw.Button(description='Quit GUI server', button_style='danger', layout=ipw.Layout(width='170px'))
nproc_input = ipw.BoundedIntText(value=min(4, MAX_CPUS), min=2, max=MAX_CPUS, description='nproc', layout=ipw.Layout(width='180px'))
fd_status = ipw.HTML(value='')

# Data comparison controls (match visualize_gui modes)
component_select = ipw.Dropdown(options=[('Hx', 'Hx'), ('Hz', 'Hz')], value='Hx', description='component')
metric_select = ipw.Dropdown(
    options=[
        ('Amplitude vs rx (local index)', 'amp_vs_rx'),
        ('Phase vs rx (deg, local index)', 'phase_vs_rx_deg'),
        ('Amplitude vs tx (fixed local rx)', 'amp_vs_tx'),
        ('Phase vs tx (deg, fixed local rx)', 'phase_vs_tx_deg'),
        ('Amplitude vs frequency', 'amp_vs_freq'),
        ('Phase vs frequency (deg)', 'phase_vs_freq_deg'),
    ],
    value='amp_vs_rx',
    description='plot',
)
comp_freq = ipw.Dropdown(options=[('n/a', 0)], value=0, description='frequency')
tx_select = ipw.Dropdown(options=[('n/a', 0)], value=0, description='tx')
rx_local_select = ipw.IntSlider(value=0, min=0, max=0, step=1, description='local rx', continuous_update=False, layout=ipw.Layout(width='260px'))
trace_idx = ipw.IntSlider(value=0, min=0, max=0, step=1, description='trace idx', continuous_update=False, layout=ipw.Layout(width='320px'))

data_compare_out = ipw.Output(layout=ipw.Layout(width='100%', height='460px'))
status_out = ipw.Textarea(value='', description='Status', layout=ipw.Layout(width='100%', height='130px'))

def refresh_model_options():
    run_dir = run_selector.value
    if not run_dir:
        model_selector.options = [('Select inverted/linesearch model...', None)]
        model_selector.value = None
        return

    opts = build_model_list(Path(run_dir))

    # Keep startup view empty on inverted panel until user picks one.
    selector_options = [('Select inverted/linesearch model...', None)]
    selector_options.extend([(lbl, str(p)) for lbl, p in opts])
    model_selector.options = selector_options

    prev = state.get('current_model_path')
    valid_values = {val for _, val in selector_options}
    model_selector.value = prev if prev in valid_values else None

    state['current_run_dir'] = run_dir
    state['current_model_path'] = model_selector.value

def update_2d_plot():
    if state.get('initial_load'):
        return
    path = state.get('current_model_path')
    run_dir = state.get('current_run_dir')
    if not run_dir:
        with model_plot_out:
            clear_output(wait=True)
        return

    try:
        run_dir = Path(run_dir)
        tx_x, tx_z, rx_x, rx_z = _extract_positions(run_dir)

        if not SG_TRUE_PATH.exists():
            with model_plot_out:
                clear_output(wait=True)
                display(ipw.HTML('True model not found at FDmodel/sg.rss.'))
            return

        # True model defines fixed color limits and right panel reference.
        x_true, z_true, g_true = _read_rss_model(SG_TRUE_PATH)
        x_true, z_true = _align_axes_to_survey(x_true, z_true, tx_x, tx_z, rx_x, rx_z)
        rho_true = conductivity_to_resistivity(g_true)
        zmin = float(np.nanmin(rho_true))
        zmax = float(np.nanmax(rho_true))

        fig = make_subplots(rows=1, cols=2, subplot_titles=['Selected inverted/linesearch model', 'True model'], horizontal_spacing=0.10)

        selected_ready = False
        if path:
            p = Path(path)
            if p.exists() and str(p) != str(SG_TRUE_PATH):
                x_sel, z_sel, g_sel = _read_rss_model(p)
                x_sel, z_sel = _align_axes_to_survey(x_sel, z_sel, tx_x, tx_z, rx_x, rx_z)
                rho_sel = conductivity_to_resistivity(g_sel)
                fig.add_trace(go.Heatmap(x=x_sel, y=z_sel, z=rho_sel, colorscale='Viridis', zmin=zmin, zmax=zmax, showscale=False), row=1, col=1)
                selected_ready = True
                state['current_x'] = float(np.mean(x_sel))
                state['current_z'] = float(np.mean(z_sel))

        if not selected_ready:
            fig.add_trace(go.Heatmap(x=x_true, y=z_true, z=np.full_like(rho_true, np.nan), colorscale='Viridis', zmin=zmin, zmax=zmax, showscale=False, hoverinfo='skip'), row=1, col=1)
            x_mid = float(np.mean(x_true)) if x_true.size else 0.0
            z_mid = float(np.mean(z_true)) if z_true.size else 0.0
            fig.add_annotation(
                x=x_mid,
                y=z_mid,
                text='Select an inverted/linesearch model from the list to display it here.',
                showarrow=False,
                xref='x',
                yref='y',
                font=dict(size=13),
                bgcolor='rgba(255,255,255,0.75)',
            )

        fig.add_trace(go.Heatmap(x=x_true, y=z_true, z=rho_true, colorscale='Viridis', zmin=zmin, zmax=zmax, showscale=True, colorbar=dict(title='Ohm.m', x=1.02, len=0.90)), row=1, col=2)

        for col, xx, zz in [(1, tx_x, tx_z), (2, tx_x, tx_z)]:
            if xx.size:
                fig.add_trace(go.Scatter(x=xx, y=zz, mode='markers', marker=dict(symbol='triangle-up', size=7, color='red'), name='TX', showlegend=(col == 1)), row=1, col=col)
        for col, xx, zz in [(1, rx_x, rx_z), (2, rx_x, rx_z)]:
            if xx.size:
                fig.add_trace(go.Scatter(x=xx, y=zz, mode='markers', marker=dict(symbol='circle', size=6, color='cyan'), name='RX', showlegend=(col == 1)), row=1, col=col)

        fig.update_xaxes(title_text='x (m)', row=1, col=1)
        fig.update_xaxes(title_text='x (m)', row=1, col=2, matches='x')
        fig.update_yaxes(title_text='z (m)', autorange='reversed', row=1, col=1)
        fig.update_yaxes(title_text='z (m)', autorange='reversed', row=1, col=2, matches='y')
        fig.update_layout(
            title='Resistivity models with survey overlay',
            height=520,
            margin=dict(t=80, b=60, l=40, r=110),
            legend=dict(orientation='h', y=-0.08),
            dragmode='zoom',
        )

        with model_plot_out:
            clear_output(wait=True)
            display(fig)

        if (not selected_ready) and (x_slice_input.value == 0.0 and z_slice_input.value == 0.0):
            state['current_x'] = float(np.mean(x_true)) if x_true.size else 0.0
            state['current_z'] = float(np.mean(z_true)) if z_true.size else 0.0

        if x_slice_input.value == 0.0 and z_slice_input.value == 0.0:
            x_slice_input.value = state.get('current_x', 0.0)
            z_slice_input.value = state.get('current_z', 0.0)
    except Exception as exc:
        push_message(f'2D plot error: {exc}')

def update_rho_plots():
    if state.get('initial_load'):
        return
    path = state.get('current_model_path')
    run_dir = state.get('current_run_dir')
    if not path or not run_dir:
        return
    try:
        path = Path(path)
        run_dir = Path(run_dir)
        x, z, grid = _read_rss_model(path)
        tx_x, tx_z, rx_x, rx_z = _extract_positions(run_dir)
        x, z = _align_axes_to_survey(x, z, tx_x, tx_z, rx_x, rx_z)
        x_pick = float(x_slice_input.value)
        z_pick = float(z_slice_input.value)

        z_ax, rho_z = resistivity_slice_at_x(x, z, grid, x_pick)
        x_ax, rho_x = resistivity_slice_at_z(x, z, grid, z_pick)

        z_curves = [np.asarray(rho_z, dtype=float)]
        x_curves = [np.asarray(rho_x, dtype=float)]

        fig_z = go.Figure()
        fig_z.add_trace(go.Scatter(x=rho_z, y=z_ax, mode='lines', name='Selected'))

        fig_x = go.Figure()
        fig_x.add_trace(go.Scatter(x=x_ax, y=rho_x, mode='lines', name='Selected'))

        if SG_TRUE_PATH.exists() and str(path) != str(SG_TRUE_PATH):
            xt, zt, gt = _read_rss_model(SG_TRUE_PATH)
            xt, zt = _align_axes_to_survey(xt, zt, tx_x, tx_z, rx_x, rx_z)
            zt_ax, rho_tz = resistivity_slice_at_x(xt, zt, gt, x_pick)
            xt_ax, rho_tx = resistivity_slice_at_z(xt, zt, gt, z_pick)
            fig_z.add_trace(go.Scatter(x=rho_tz, y=zt_ax, mode='lines', name='True'))
            fig_x.add_trace(go.Scatter(x=xt_ax, y=rho_tx, mode='lines', name='True'))
            z_curves.append(np.asarray(rho_tz, dtype=float))
            x_curves.append(np.asarray(rho_tx, dtype=float))

        z_all = np.concatenate([c[np.isfinite(c)] for c in z_curves if c.size])
        x_all = np.concatenate([c[np.isfinite(c)] for c in x_curves if c.size])

        fig_z.update_layout(
            title=f'Resistivity vs depth at x={x_pick:.1f} m',
            xaxis_title='Resistivity (Ohm.m)',
            yaxis_title='z (m)',
            yaxis_autorange='reversed',
            xaxis_autorange=True,
            yaxis_fixedrange=False,
            xaxis_fixedrange=False,
            height=300,
            margin=dict(t=50, b=40, l=50, r=20),
            legend=dict(orientation='h', y=-0.2),
        )
        with rho_z_out:
            clear_output(wait=True)
            display(fig_z)

        fig_x.update_layout(
            title=f'Resistivity vs x at z={z_pick:.1f} m',
            xaxis_title='x (m)',
            yaxis_title='Resistivity (Ohm.m)',
            yaxis_autorange=True,
            xaxis_autorange=True,
            xaxis_fixedrange=False,
            yaxis_fixedrange=False,
            height=300,
            margin=dict(t=50, b=40, l=50, r=20),
            legend=dict(orientation='h', y=-0.2),
        )
        with rho_x_out:
            clear_output(wait=True)
            display(fig_x)
    except Exception as exc:
        push_message(f'Slice plot error: {exc}')

def on_run_selected(change):
    if change.get('name') != 'value':
        return
    refresh_model_options()
    update_2d_plot()
    update_rho_plots()

def on_model_selected(change):
    if change.get('name') != 'value':
        return
    state['current_model_path'] = change.get('new')
    update_2d_plot()
    update_rho_plots()

def on_load_real(_):
    run_dir = run_selector.value
    if not run_dir:
        push_message('Select a run first.')
        return
    hx_path, hz_path = get_real_data_paths(Path(run_dir))
    if not hx_path.exists() or not hz_path.exists():
        push_message('Real data files not found.')
        return
    try:
        freqs = parse_freqs(freqs_input.value)
        result = compute_amp_phase_for_fd_outputs(hx_path, hz_path, freqs=freqs, n_pairs=3)
        state['real_result'] = result
        refresh_data_compare_controls()
        push_message('Real data loaded.')
        update_data_compare_plot()
    except Exception as exc:
        push_message(f'Load real failed: {exc}')

def _set_cfg_line(text, key, value):
    pat = re.compile(rf'^{re.escape(key)}\s*=.*?;$', re.MULTILINE)
    line = f'{key} = "{value}";'
    if pat.search(text):
        return pat.sub(line, text, count=1)
    return text + '\n' + line + '\n'


def _write_sg_with_ep_geometry(model_path, ep_path, output_path):
    """Write conductivity model using exact Ep geometry header fields."""
    from third_party.rockseis.io.rsfile import rsfile

    model_path = Path(model_path)
    ep_path = Path(ep_path)
    output_path = Path(output_path)

    if not model_path.exists():
        raise FileNotFoundError(f'Missing selected model: {model_path}')
    if not ep_path.exists():
        raise FileNotFoundError(f'Missing ep model for geometry reference: {ep_path}')

    m = rsfile()
    m.read(str(model_path))
    e = rsfile()
    e.read(str(ep_path))

    m_data = np.asarray(m.data)
    e_data = np.asarray(e.data)

    if m_data.size != e_data.size:
        raise ValueError(
            f'Selected model size ({m_data.size}) does not match ep geometry size ({e_data.size}).'
        )

    # Keep ep header geometry exactly, replace only conductivity values.
    m_reshaped = np.reshape(np.asarray(m_data, dtype=e_data.dtype), e_data.shape, order='F')
    e.data = np.asfortranarray(m_reshaped)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    e.write(str(output_path))


def _resolve_segy_template_path():
    if SETUP_META.exists():
        try:
            meta = json.loads(SETUP_META.read_text())
            p_raw = meta.get('segy_template_path')
            if p_raw:
                p = Path(p_raw).expanduser()
                if not p.is_absolute():
                    p = (ROOT / p).resolve()
                if p.exists():
                    return p, 'setup_metadata.json'
                push_message(f'SEG-Y template listed in setup metadata does not exist: {p}')
        except Exception as exc:
            push_message(f'Failed reading setup metadata for SEG-Y template: {exc}')

    fallback = ROOT / 'input.segy'
    if fallback.exists():
        return fallback, 'fallback ROOT/input.segy'
    return None, None


def _export_name_for_model(model_path, run_dir):
    model_path = Path(model_path)
    run_dir = Path(run_dir)
    name = model_path.name
    m = re.search(r'sg_up\.rss-(\d+)$', name)
    if m:
        return f'sg_up_iter{int(m.group(1)):03d}.segy', f'iteration {int(m.group(1))}'
    if name == 'sg0.rss':
        return 'sg0_iter000.segy', 'iteration 0 (initial model)'
    if name == 'sg_ls.rss':
        progress = run_dir / 'progress.log'
        if progress.exists():
            lines = progress.read_text(errors='replace').splitlines()
            iters = []
            for ln in lines:
                mm = re.match(r'^Iteration\s+(\d+)', ln.strip())
                if mm:
                    iters.append(int(mm.group(1)))
            if iters:
                it_last = max(iters)
                return f'sg_ls_iter{it_last:03d}.segy', f'line-search model near iteration {it_last}'
        return 'sg_ls_iterunknown.segy', 'line-search model (iteration unknown)'
    safe = model_path.stem.replace('.', '_')
    return f'{safe}.segy', f'model {model_path.name}'


def on_export_selected_segy(_):
    run_dir = run_selector.value
    model_path = model_selector.value
    if not run_dir:
        push_message('Select a run first.')
        return
    if not model_path or not Path(model_path).exists():
        push_message('Select an existing model first.')
        return

    run_dir = Path(run_dir)
    model_path = Path(model_path)
    export_status.value = '<span style="color:#333333">SEGY export in progress...</span>'
    template_path, template_src = _resolve_segy_template_path()
    if template_path is None:
        push_message('No SEG-Y template found. Run 01_fw_setup and finalize setup, or provide ROOT/input.segy.')
        export_status.value = '<span style="color:#aa0000">SEGY export failed: no template SEG-Y found.</span>'
        return

    try:
        x, z, sigma = _read_rss_model(model_path)
        rho = conductivity_to_resistivity(sigma)
        out_dir = run_dir / 'SEGY_output'
        out_name, model_desc = _export_name_for_model(model_path, run_dir)
        out_dir.mkdir(parents=True, exist_ok=True)
        out_path = out_dir / out_name
        info = write_resistivity_to_segy_from_template(
            template_segy_path=template_path,
            output_segy_path=out_path,
            resistivity_grid=rho,
            x_model=x,
            z_model=z,
            method='linear',
        )
        push_message(f'SEG-Y template: {template_path} ({template_src})')
        if info.get('interpolated'):
            push_message('Export grid differs from template; interpolated resistivity to exact SEG-Y template geometry.')
        else:
            push_message('Export grid already matches template geometry; wrote values without interpolation.')
        push_message('SEGY export SUCCESS.')
        push_message(f'Exported model: {model_desc}')
        push_message(f'Output file name: {out_path.name}')
        push_message(f'Output folder: {out_path.parent}')
        push_message('Export mode: overwrite enabled for same model/iteration filename.')
        export_status.value = (
            f'<span style="color:#006400">SEGY export succeeded. '
            f'File: {out_path.name} | Folder: {out_path.parent}</span>'
        )
    except Exception as exc:
        push_message(f'SEG-Y export failed: {exc}')
        export_status.value = f'<span style="color:#aa0000">SEGY export failed: {exc}</span>'


def on_load_syn_from_run(_):
    run_dir = run_selector.value
    if not run_dir:
        push_message('Select a run first.')
        fd_status.value = '<span style="color:#aa0000">Select a run first.</span>'
        return
    run_dir = Path(run_dir)
    desired_iter = infer_iteration_from_model_path(model_selector.value or '', run_dir)
    fd_status.value = f'Loading synthetics from {run_dir.name}...'
    push_message(f'Searching synthetics in {run_dir.name} (target iter={desired_iter}).')
    push_message('Accepted patterns: data_Hx_mod.rss-<iter>/data_Hz_mod.rss-<iter>, data_mod_HX.rss-<iter>/data_mod_HZ.rss-<iter>, and plain data_Hx_mod/data_Hz_mod or data_mod_HX/data_mod_HZ.')

    idx, hx_path, hz_path = find_synthetic_in_run(run_dir, iteration=desired_iter)
    if hx_path is None or hz_path is None:
        msg = f'No synthetic files found for selected model iteration {desired_iter}.'
        push_message(msg)
        fd_status.value = f'<span style="color:#aa0000">{msg}</span>'
        return
    try:
        freqs = parse_freqs(freqs_input.value)
        state['syn_result'] = compute_amp_phase_for_fd_outputs(hx_path, hz_path, freqs=freqs, n_pairs=3)
        refresh_data_compare_controls()
        push_message(f'Loaded synthetic pair iter={idx} from run.')
        push_message(f'Synthetic files: {hx_path.name}, {hz_path.name}')
        fd_status.value = f'Synthetic loaded: {hx_path.name}, {hz_path.name}'
        if state.get('real_result') is None:
            push_message('Synthetic loaded; load real data to compare.')
        update_data_compare_plot()
    except Exception as exc:
        push_message(f'Load synthetics from run failed: {exc}')
        fd_status.value = f'<span style="color:#aa0000">Synthetic load failed: {exc}</span>'


def on_generate_synthetics(_):
    run_dir = run_selector.value
    model_path = model_selector.value
    if not run_dir:
        push_message('Select a run first.')
        return
    if not model_path or not Path(model_path).exists():
        push_message('Select an existing model first.')
        return
    if state.get('fd_process') and state['fd_process'].poll() is None:
        push_message('Synthetic modelling already in progress.')
        return

    run_dir = Path(run_dir)
    model_path = Path(model_path)
    iter_idx = infer_iteration_from_model_path(model_path, run_dir)

    hx_out = run_dir / f'data_Hx_mod.rss-{iter_idx}'
    hz_out = run_dir / f'data_Hz_mod.rss-{iter_idx}'
    sg_model = run_dir / f'sg_mod_iter{iter_idx}.rss'
    cfg_path = run_dir / f'mod_synth_iter{iter_idx}.cfg'

    try:
        survey = run_dir / 'Survey.rss'
        if (not survey.exists()) and (FDMODEL_DIR / 'Survey.rss').exists():
            shutil.copy2(FDMODEL_DIR / 'Survey.rss', survey)

        for fname in ['ep.rss', 'wav2d.rss']:
            p = run_dir / fname
            if (not p.exists()) and (FDMODEL_DIR / fname).exists():
                shutil.copy2(FDMODEL_DIR / fname, p)

        required = [run_dir / 'ep.rss', run_dir / 'wav2d.rss', run_dir / 'Survey.rss']
        missing = [p.name for p in required if not p.exists()]
        if missing:
            raise FileNotFoundError('Missing required files in run folder: ' + ', '.join(missing))

        # Rebuild model with Ep geometry to satisfy solver requirement.
        _write_sg_with_ep_geometry(model_path, run_dir / 'ep.rss', sg_model)

        mod_template = FDMODEL_DIR / 'mod.cfg'
        if not mod_template.exists():
            raise FileNotFoundError(f'Missing FD template cfg: {mod_template}')

        cfg_txt = mod_template.read_text()
        cfg_txt = _set_cfg_line(cfg_txt, 'Sg', sg_model.name)
        cfg_txt = _set_cfg_line(cfg_txt, 'Ep', 'ep.rss')
        cfg_txt = _set_cfg_line(cfg_txt, 'Wavelet', 'wav2d.rss')
        cfg_txt = _set_cfg_line(cfg_txt, 'Survey', 'Survey.rss')
        cfg_txt = _set_cfg_line(cfg_txt, 'Hxrecordfile', hx_out.name)
        cfg_txt = _set_cfg_line(cfg_txt, 'Hzrecordfile', hz_out.name)
        cfg_path.write_text(cfg_txt)

        nproc = max(2, min(int(nproc_input.value), MAX_CPUS))
        fd_status.value = 'Synthetic modelling running...'
        state['fd_process'] = subprocess.Popen(
            ['mpirun', '-np', str(nproc), MPI_EMMOD_BIN, cfg_path.name],
            cwd=str(run_dir),
        )
        state['fd_hx_out'] = str(hx_out)
        state['fd_hz_out'] = str(hz_out)
        state['fd_iter_idx'] = int(iter_idx)
        push_message(f'Started synthetic modelling in {run_dir.name} for iter={iter_idx}.')
        push_message(f'Expected output targets in run folder: {hx_out.name}, {hz_out.name}')
        push_message('Also checking alternate solver names: data_mod_HX*.rss / data_mod_HZ*.rss when loading.')
    except Exception as exc:
        push_message(f'Generate failed: {exc}')
        fd_status.value = ''


def on_refresh_fd_status(_):
    proc = state.get('fd_process')
    if proc is None:
        return
    rc = proc.poll()
    if rc is None:
        fd_status.value = 'Synthetic modelling still running...'
        return

    state['fd_process'] = None
    fd_status.value = f'Synthetic modelling finished (exit {rc}).'

    run_dir = Path(run_selector.value) if run_selector.value else None
    hx_out = Path(state.get('fd_hx_out', '')) if state.get('fd_hx_out') else None
    hz_out = Path(state.get('fd_hz_out', '')) if state.get('fd_hz_out') else None
    iter_idx = state.get('fd_iter_idx')

    def _resolve_generated_pair():
        if hx_out is not None and hz_out is not None and hx_out.exists() and hz_out.exists():
            return hx_out, hz_out
        if run_dir is None:
            return None, None

        candidates = []
        if iter_idx is not None:
            candidates.extend([
                (run_dir / f'data_Hx_mod.rss-{iter_idx}', run_dir / f'data_Hz_mod.rss-{iter_idx}'),
                (run_dir / f'data_mod_HX.rss-{iter_idx}', run_dir / f'data_mod_HZ.rss-{iter_idx}'),
            ])
        candidates.extend([
            (run_dir / 'data_Hx_mod.rss', run_dir / 'data_Hz_mod.rss'),
            (run_dir / 'data_mod_HX.rss', run_dir / 'data_mod_HZ.rss'),
        ])

        for hx_c, hz_c in candidates:
            if hx_c.exists() and hz_c.exists():
                return hx_c, hz_c
        return None, None

    hx_found, hz_found = _resolve_generated_pair()

    if rc == 0 and hx_found is not None and hz_found is not None:
        try:
            freqs = parse_freqs(freqs_input.value)
            state['syn_result'] = compute_amp_phase_for_fd_outputs(hx_found, hz_found, freqs=freqs, n_pairs=3)
            refresh_data_compare_controls()
            push_message(f'Synthetics loaded for iter={iter_idx}: {hx_found.name}, {hz_found.name}')
            if hx_out is not None and hz_out is not None and (hx_found != hx_out or hz_found != hz_out):
                push_message(f'Note: solver wrote alternate names, loaded {hx_found.name}/{hz_found.name}.')
            update_data_compare_plot()
        except Exception as exc:
            push_message(f'Load synthetics failed: {exc}')
            fd_status.value = f'<span style="color:#aa0000">Synthetic load failed: {exc}</span>'
    elif rc == 0:
        push_message('Synthetic modelling completed but output files were not found (checked data_Hx_mod/data_Hz_mod and data_mod_HX/data_mod_HZ patterns).')
        fd_status.value = '<span style="color:#aa0000">Run finished, but no synthetic file pair was detected.</span>'

def refresh_data_compare_controls():
    base = state.get('real_result') or state.get('syn_result')
    if base is None:
        comp_freq.options = [('n/a', 0)]
        tx_select.options = [('n/a', 0)]
        rx_local_select.min = 0
        rx_local_select.max = 0
        rx_local_select.value = 0
        trace_idx.max = 0
        trace_idx.value = 0
        return

    geo = base['geometry']
    comp_key = component_select.value if component_select.value in ('Hx', 'Hz') else 'Hx'
    comp_data = base.get(comp_key, base.get('Hx', {}))
    freqs = np.asarray(comp_data.get('freqs', []), dtype=float)

    if freqs.size == 0:
        comp_freq.options = [('n/a', 0)]
        comp_freq.value = 0
        push_message('No frequencies available in loaded data.')
    else:
        comp_freq.options = [(f'{f:g} Hz', i) for i, f in enumerate(freqs)]
        comp_freq.value = int(min(comp_freq.value if comp_freq.value is not None else 0, freqs.size - 1))

    tx_vals = np.unique(np.asarray(geo.get('tx_idx_per_trace', []), dtype=int))
    if tx_vals.size:
        tx_select.options = [(f'Tx {int(v)}', int(v)) for v in tx_vals]
        tx_select.value = int(tx_vals[0])
    else:
        tx_select.options = [('n/a', 0)]
        tx_select.value = 0

    rx_local = np.asarray(geo.get('rx_local_idx_per_trace', geo.get('rx_idx_per_trace', [])), dtype=int)
    if rx_local.size:
        rx_local_select.min = int(np.nanmin(rx_local))
        rx_local_select.max = int(np.nanmax(rx_local))
        rx_local_select.value = int(rx_local_select.min)
    else:
        rx_local_select.min = 0
        rx_local_select.max = 0
        rx_local_select.value = 0

    ntr = int(np.asarray(geo.get('tx_idx_per_trace', [])).size)
    if ntr > 0:
        trace_idx.max = ntr - 1
        trace_idx.value = 0
    else:
        trace_idx.max = 0
        trace_idx.value = 0
        push_message('No traces available in loaded data.')


def _same_geometry(real, syn):
    if real is None or syn is None:
        return False
    rg = real['geometry']
    sg = syn['geometry']
    return (
        np.asarray(rg['tx_idx_per_trace']).shape == np.asarray(sg['tx_idx_per_trace']).shape
        and np.asarray(rg['rx_idx_per_trace']).shape == np.asarray(sg['rx_idx_per_trace']).shape
    )


def _phase_deg(arr):
    return np.rad2deg(np.asarray(arr, dtype=float))


def update_data_compare_plot(*_):
    real = state.get('real_result')
    syn = state.get('syn_result')

    with data_compare_out:
        clear_output(wait=True)
        if real is None and syn is None:
            return
        if real is None:
            display(ipw.HTML('Synthetic loaded. Load real data to compare.'))
            return

        comp = component_select.value
        metric = metric_select.value
        fidx = int(comp_freq.value) if comp_freq.value is not None else 0
        tx_id = int(tx_select.value) if tx_select.value is not None else 0
        rx_local_target = int(rx_local_select.value)
        tr_pick = int(trace_idx.value)

        comp_data = real.get(comp)
        if comp_data is None:
            display(ipw.HTML(f'Missing {comp} data in real result.'))
            return

        geo = real['geometry']
        tx_arr = np.asarray(geo.get('tx_idx_per_trace', []), dtype=int)
        rx_arr = np.asarray(geo.get('rx_idx_per_trace', []), dtype=int)
        rx_local = np.asarray(geo.get('rx_local_idx_per_trace', rx_arr), dtype=int)
        freqs = np.asarray(comp_data.get('freqs', []), dtype=float)

        if tx_arr.size == 0 or rx_local.size == 0:
            display(ipw.HTML('Loaded data has no trace geometry to plot.'))
            return
        if freqs.size == 0:
            display(ipw.HTML('Loaded data has no frequencies to plot.'))
            return

        fidx = int(max(0, min(fidx, freqs.size - 1)))

        have_syn = syn is not None and _same_geometry(real, syn) and comp in syn
        syn_comp = syn.get(comp) if have_syn else None

        fig = go.Figure()
        title = ''

        if metric == 'amp_vs_rx':
            idx = np.where(tx_arr == tx_id)[0]
            if idx.size == 0:
                display(ipw.HTML('No traces for selected tx.'))
                return
            x = rx_local[idx]
            order = np.argsort(x)
            x = x[order]
            y_real = np.asarray(comp_data['amp_mean'][fidx, idx], dtype=float)[order]
            fig.add_trace(go.Scatter(x=x, y=y_real, mode='lines+markers', name='Real'))
            if syn_comp is not None:
                y_syn = np.asarray(syn_comp['amp_mean'][fidx, idx], dtype=float)[order]
                fig.add_trace(go.Scatter(x=x, y=y_syn, mode='lines+markers', name='Synthetic'))
            title = f'{comp} amplitude vs local rx (Tx {tx_id}, f={freqs[fidx]:g} Hz)'
            fig.update_xaxes(title_text='Local rx index')
            fig.update_yaxes(title_text='Amplitude')

        elif metric == 'phase_vs_rx_deg':
            idx = np.where(tx_arr == tx_id)[0]
            if idx.size == 0:
                display(ipw.HTML('No traces for selected tx.'))
                return
            x = rx_local[idx]
            order = np.argsort(x)
            x = x[order]
            y_real = _phase_deg(comp_data['phi_mean_rad'][fidx, idx])[order]
            fig.add_trace(go.Scatter(x=x, y=y_real, mode='lines+markers', name='Real'))
            if syn_comp is not None:
                y_syn = _phase_deg(syn_comp['phi_mean_rad'][fidx, idx])[order]
                fig.add_trace(go.Scatter(x=x, y=y_syn, mode='lines+markers', name='Synthetic'))
            title = f'{comp} phase vs local rx (Tx {tx_id}, f={freqs[fidx]:g} Hz)'
            fig.update_xaxes(title_text='Local rx index')
            fig.update_yaxes(title_text='Phase (deg)')

        elif metric == 'amp_vs_tx':
            tx_vals = np.unique(tx_arr)
            xs, y_real, y_syn = [], [], []
            for t in tx_vals:
                cand = np.where((tx_arr == int(t)) & (rx_local == rx_local_target))[0]
                if cand.size == 0:
                    continue
                i = int(cand[0])
                xs.append(int(t))
                y_real.append(float(comp_data['amp_mean'][fidx, i]))
                if syn_comp is not None:
                    y_syn.append(float(syn_comp['amp_mean'][fidx, i]))
            if not xs:
                display(ipw.HTML('No traces for selected local rx across tx.'))
                return
            fig.add_trace(go.Scatter(x=xs, y=y_real, mode='lines+markers', name='Real'))
            if syn_comp is not None:
                fig.add_trace(go.Scatter(x=xs, y=y_syn, mode='lines+markers', name='Synthetic'))
            title = f'{comp} amplitude vs tx (local rx {rx_local_target}, f={freqs[fidx]:g} Hz)'
            fig.update_xaxes(title_text='Tx index')
            fig.update_yaxes(title_text='Amplitude')

        elif metric == 'phase_vs_tx_deg':
            tx_vals = np.unique(tx_arr)
            xs, y_real, y_syn = [], [], []
            for t in tx_vals:
                cand = np.where((tx_arr == int(t)) & (rx_local == rx_local_target))[0]
                if cand.size == 0:
                    continue
                i = int(cand[0])
                xs.append(int(t))
                y_real.append(float(_phase_deg(comp_data['phi_mean_rad'][fidx, i])))
                if syn_comp is not None:
                    y_syn.append(float(_phase_deg(syn_comp['phi_mean_rad'][fidx, i])))
            if not xs:
                display(ipw.HTML('No traces for selected local rx across tx.'))
                return
            fig.add_trace(go.Scatter(x=xs, y=y_real, mode='lines+markers', name='Real'))
            if syn_comp is not None:
                fig.add_trace(go.Scatter(x=xs, y=y_syn, mode='lines+markers', name='Synthetic'))
            title = f'{comp} phase vs tx (local rx {rx_local_target}, f={freqs[fidx]:g} Hz)'
            fig.update_xaxes(title_text='Tx index')
            fig.update_yaxes(title_text='Phase (deg)')

        elif metric == 'amp_vs_freq':
            tr_pick = max(0, min(tr_pick, tx_arr.size - 1))
            y_real = np.asarray(comp_data['amp_mean'][:, tr_pick], dtype=float)
            fig.add_trace(go.Scatter(x=freqs, y=y_real, mode='lines+markers', name='Real'))
            if syn_comp is not None:
                y_syn = np.asarray(syn_comp['amp_mean'][:, tr_pick], dtype=float)
                fig.add_trace(go.Scatter(x=freqs, y=y_syn, mode='lines+markers', name='Synthetic'))
            title = f'{comp} amplitude vs frequency (trace {tr_pick}, tx={tx_arr[tr_pick]}, local_rx={rx_local[tr_pick]})'
            fig.update_xaxes(title_text='Frequency (Hz)', type='log')
            fig.update_yaxes(title_text='Amplitude')

        elif metric == 'phase_vs_freq_deg':
            tr_pick = max(0, min(tr_pick, tx_arr.size - 1))
            y_real = _phase_deg(comp_data['phi_mean_rad'][:, tr_pick])
            fig.add_trace(go.Scatter(x=freqs, y=y_real, mode='lines+markers', name='Real'))
            if syn_comp is not None:
                y_syn = _phase_deg(syn_comp['phi_mean_rad'][:, tr_pick])
                fig.add_trace(go.Scatter(x=freqs, y=y_syn, mode='lines+markers', name='Synthetic'))
            title = f'{comp} phase vs frequency (trace {tr_pick}, tx={tx_arr[tr_pick]}, local_rx={rx_local[tr_pick]})'
            fig.update_xaxes(title_text='Frequency (Hz)', type='log')
            fig.update_yaxes(title_text='Phase (deg)')

        else:
            display(ipw.HTML('Unknown plotting mode.'))
            return

        y_all = []
        for tr in fig.data:
            if getattr(tr, 'y', None) is not None:
                y_all.append(np.asarray(tr.y, dtype=float))
        y_vals = np.concatenate(y_all) if y_all else np.array([])
        y_vals = y_vals[np.isfinite(y_vals)] if y_vals.size else y_vals

        fig.update_layout(
            title=title,
            height=420,
            margin=dict(t=60, b=50, l=60, r=20),
            legend=dict(orientation='h', y=-0.2),
            xaxis_fixedrange=False,
            yaxis_fixedrange=False,
        )

        # Keep horizontal axis behavior unchanged (as before).
        fig.update_xaxes(autorange=True)

        # Vertical axis policy requested by user:
        # - amplitude: start at 0 with broad upper range
        # - phase: fixed to [-180, 180]
        if metric in ('phase_vs_rx_deg', 'phase_vs_tx_deg', 'phase_vs_freq_deg'):
            fig.update_yaxes(range=[-180.0, 180.0])
        else:
            if y_vals.size:
                ymax = float(np.nanmax(y_vals))
                if not np.isfinite(ymax):
                    ymax = 1.0
                upper = max(1.0, ymax * 2.0)
                fig.update_yaxes(range=[0.0, upper])
            else:
                fig.update_yaxes(range=[0.0, 1.0])

        display(fig)

        if syn is not None and syn_comp is None:
            display(ipw.HTML('<span style="color:#b26a00">Synthetic geometry or component differs from real data; only real curve shown for this mode.</span>'))

def on_quit_gui(_):
    push_message('Shutting down results GUI server...')
    fd_status.value = 'Shutting down GUI server...'
    try:
        proc = state.get('fd_process')
        if proc is not None and proc.poll() is None:
            proc.terminate()
            try:
                proc.wait(timeout=5)
            except Exception:
                proc.kill()
                proc.wait(timeout=5)
    except Exception as exc:
        push_message(f'Quit warning while stopping FD process: {exc}')

    try:
        signaled = False
        for pid_file in VOILA_PID_FILES:
            if not pid_file.exists():
                continue
            pid_text = pid_file.read_text().strip()
            if not pid_text:
                continue
            pid = int(pid_text)
            if pid == os.getpid():
                continue

            try:
                os.kill(pid, signal.SIGTERM)
                push_message(f'Sent SIGTERM to Voila server PID {pid} ({pid_file.name}).')
                signaled = True
            except Exception as exc:
                push_message(f'Could not signal Voila PID {pid} from {pid_file.name}: {exc}')

        if not signaled:
            push_message('No external Voila PID found to terminate; shutting down kernel only.')
    except Exception as exc:
        push_message(f'Quit warning: {exc}')

    # Always request kernel shutdown, then force-exit as final fallback.
    try:
        from IPython import get_ipython
        ip = get_ipython()
        if ip and getattr(ip, 'kernel', None):
            ip.kernel.do_shutdown(restart=False)
    except Exception as exc:
        push_message(f'Kernel shutdown warning: {exc}')

    try:
        os.kill(os.getpid(), signal.SIGTERM)
    except Exception:
        pass


def bind_button_with_feedback(button, handler, action_label, success_message=None):
    def _wrapped(_):
        before = len(state.get('last_messages', []))
        push_message(f'{action_label}...')
        try:
            handler(_)
            after = len(state.get('last_messages', []))
            if after <= before + 1:
                push_message(success_message or f'{action_label} completed successfully.')
        except Exception as exc:
            push_message(f'{action_label} failed: {exc}')
            raise

    button.on_click(_wrapped)


run_selector.observe(on_run_selected, names='value')
model_selector.observe(on_model_selected, names='value')
bind_button_with_feedback(refresh_models_btn, lambda _: (refresh_model_options(), update_2d_plot(), update_rho_plots()), 'Refreshing model list', 'Model list refreshed.')
bind_button_with_feedback(export_segy_btn, on_export_selected_segy, 'Exporting selected model to SEGY')
x_slice_input.observe(lambda _: update_rho_plots(), names='value')
z_slice_input.observe(lambda _: update_rho_plots(), names='value')
bind_button_with_feedback(load_real_btn, on_load_real, 'Loading real data')
bind_button_with_feedback(load_syn_from_run_btn, on_load_syn_from_run, 'Loading synthetics from run')
bind_button_with_feedback(generate_syn_btn, on_generate_synthetics, 'Starting synthetic modelling')
bind_button_with_feedback(refresh_fd_btn, on_refresh_fd_status, 'Refreshing FD status')
bind_button_with_feedback(quit_btn, on_quit_gui, 'Shutting down results GUI server')

def on_component_change(change):
    if change.get('name') != 'value':
        return
    refresh_data_compare_controls()
    update_data_compare_plot()


component_select.observe(on_component_change, names='value')
for w in [metric_select, comp_freq, tx_select, rx_local_select, trace_idx]:
    w.observe(update_data_compare_plot, names='value')

status_out.value = '\n'.join(state['last_messages'])

model_section = ipw.VBox([
    ipw.HTML('<h3>Model plots</h3>'),
    ipw.HBox([run_selector, model_selector, refresh_models_btn, export_segy_btn]),
    export_status,
    model_plot_out,
    ipw.HTML('<b>Resistivity slices</b> (choose x for depth plot, z for x plot)'),
    ipw.HBox([x_slice_input, z_slice_input]),
    rho_z_out,
    rho_x_out,
])

data_section = ipw.VBox([
    ipw.HTML('<h3>Synthetic vs real data</h3>'),
    ipw.HBox([freqs_input, load_real_btn, load_syn_from_run_btn, nproc_input, generate_syn_btn, refresh_fd_btn]),
    ipw.HBox([component_select, metric_select, comp_freq, tx_select]),
    ipw.HBox([rx_local_select, trace_idx]),
    fd_status,
    data_compare_out,
])

layout = ipw.VBox([
    ipw.HTML('<h2>04_results</h2>'),
    ipw.HBox([quit_btn], layout=ipw.Layout(justify_content='flex-end')),
    model_section,
    data_section,
    status_out,
])
display(layout)

state['initial_load'] = False
if run_selector.value:
    refresh_model_options()

# Keep startup lightweight to avoid Voila init hangs.
with model_plot_out:
    clear_output(wait=True)
    display(ipw.HTML('Select a model to render the model comparison plot.'))
with rho_z_out:
    clear_output(wait=True)
with rho_x_out:
    clear_output(wait=True)