# 03_inversion (Voila)

This notebook prepares inversion inputs and runs local inversion with clean separation:
- persistent/portable inputs in `InversionInput/`
- runtime outputs/logs in `InversionRun/`

Run all cells, then launch with Voila (`--strip_sources=True`).

In [None]:
from pathlib import Path
import json
import os
import re
import shutil
import signal
import sys
import threading
import time
import traceback
import subprocess

ROOT = Path('/Users/wiktorweibull/Rockem_projects/EM_inversion_workshop').resolve()
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

try:
    import ipywidgets as ipw
    import numpy as np
    import plotly.graph_objects as go
except Exception as exc:
    raise RuntimeError(
        'Missing GUI dependencies. Install with: python3 -m pip install voila ipywidgets plotly numpy'
    ) from exc

from scripts.modules.inversion import prepare_inversion_inputs, stage_run_directory
from scripts.modules.fd_visualization import load_rss_traces

FDMODEL_DIR = ROOT / 'FDmodel'
SCRIPTS_DIR = ROOT / 'scripts'
TEMPLATES_DIR = SCRIPTS_DIR / 'templates'
INV_TEMPLATE = TEMPLATES_DIR / 'inv.cfg'
RUNINV_TEMPLATE = TEMPLATES_DIR / 'runinv.sh'
MAX_CPUS = max(2, os.cpu_count() or 2)

INV_INPUT_DIR = ROOT / 'InversionInput'
INVERSION_SETUP_METADATA = INV_INPUT_DIR / 'inversion_setup_metadata.json'
VOILA_PID_FILE = ROOT / '.voila_inversion_server.pid'
RUN_DIR_PATTERN = re.compile(r'^InversionRun(\d+)$')
PCT_RE = re.compile(r'(\d+(?:\.\d+)?)%')

state = {
    'process': None,
    'refresh_loop': False,
    'monitor_thread': None,
    'last_messages': [],
    'active_run_dir': None,
    'active_run_log': None,
    'last_sg_ls_mtime': None,
    'model_plot_key': None,
    'updating_run_selector': False,
    'refresh_in_progress': False,
    'last_plot_run_dir': None,
    'last_progress_event': None,
    'initial_load': False,
}

def push_message(msg):
    state['last_messages'].append(msg)
    if len(state['last_messages']) > 12:
        state['last_messages'] = state['last_messages'][-12:]
    status_out.value = '\n'.join(state['last_messages'])


def bind_button_with_feedback(button, handler, action_label, success_message=None):
    def _wrapped(_):
        before = len(state.get('last_messages', []))
        push_message(f'{action_label}...')
        try:
            handler(_)
            after = len(state.get('last_messages', []))
            if after <= before + 1:
                push_message(success_message or f'{action_label} completed successfully.')
        except Exception as exc:
            push_message(f'{action_label} failed: {exc}')
            raise

    button.on_click(_wrapped)


def read_tail(path, n_lines=50):
    if path is None:
        return 'No log file yet.'
    p = Path(path)
    if not p.exists():
        return 'No log file yet.'
    lines = p.read_text(errors='replace').splitlines()
    return '\n'.join(lines[-n_lines:]) if lines else 'Log file is empty.'


def read_nonempty_last_line(path):
    p = Path(path)
    if not p.exists():
        return ''
    lines = p.read_text(errors='replace').splitlines()
    for line in reversed(lines):
        if line.strip():
            return line.strip()
    return ''


def list_run_dirs(root_dir):
    out = []
    for child in Path(root_dir).iterdir():
        if not child.is_dir():
            continue
        m = RUN_DIR_PATTERN.match(child.name)
        if m:
            out.append((int(m.group(1)), child))
    out.sort(key=lambda x: x[0])
    return out


def find_next_run_dir(root_dir):
    used = {idx for idx, _ in list_run_dirs(root_dir)}
    idx = 0
    while idx in used:
        idx += 1
    return Path(root_dir) / f'InversionRun{idx}'


def find_latest_run_dir(root_dir):
    dirs = list_run_dirs(root_dir)
    if not dirs:
        return None
    return dirs[-1][1]


def update_info_panel():
    active = state.get('active_run_dir')
    active_str = str(active) if active else 'Not started yet'
    info.value = (
        f'<b>Template:</b> {INV_TEMPLATE}<br>'
        f'<b>Inputs folder:</b> {INV_INPUT_DIR}<br>'
        f'<b>Active run folder:</b> {active_str}<br>'
        f'<b>Run script template:</b> {RUNINV_TEMPLATE}'
    )


def parse_progress_summary(run_dir):
    if run_dir is None:
        return 'No active run selected.'
    run_dir = Path(run_dir)
    progress_path = run_dir / 'progress.log'
    if not progress_path.exists():
        return 'progress.log not found yet.'

    lines = progress_path.read_text(errors='replace').splitlines()
    max_iter = None
    max_ls = None
    line_rows = []
    last_note = ''

    for line in lines:
        if 'Maximum number of iterations:' in line:
            try:
                max_iter = int(line.split(':', 1)[1].strip())
            except Exception:
                pass
        if 'Maximum number of linesearches:' in line:
            try:
                max_ls = int(line.split(':', 1)[1].strip())
            except Exception:
                pass
        if line.startswith('Linesearch') or line.startswith('Iteration'):
            line_rows.append(line)
        if 'Maximum number of iterations is performed' in line:
            last_note = line.strip()

    last_iteration = None
    linesearch_in_current_iteration = 0
    for row in line_rows:
        if row.startswith('Iteration'):
            m = re.match(r'^Iteration\s+(\d+)', row)
            if m:
                last_iteration = int(m.group(1))
            linesearch_in_current_iteration = 0
        elif row.startswith('Linesearch'):
            linesearch_in_current_iteration += 1

    total_linesearches = sum(1 for r in line_rows if r.startswith('Linesearch'))
    sg_up_files = sorted(run_dir.glob('Results/sg_up.rss-*'))
    accepted_iters = len(sg_up_files)

    current_iteration_display = last_iteration if last_iteration is not None else 0
    if (last_iteration is None) and total_linesearches > 0:
        current_iteration_display = 1

    if last_note:
        linesearch_in_current_iteration = 0

    last_row = line_rows[-1] if line_rows else 'No iteration/linesearch row yet.'

    return (
        f'Progress file: {progress_path.name}\n'
        f'Max iterations: {max_iter if max_iter is not None else "N/A"}\n'
        f'Max linesearches/iteration: {max_ls if max_ls is not None else "N/A"}\n'
        f'Current reached iteration: {current_iteration_display}\n'
        f'Linesearches attempted in current iteration: {linesearch_in_current_iteration}\n'
        f'Total linesearch evaluations: {total_linesearches}\n'
        f'Accepted iteration models in Results/: {accepted_iters}\n'
        f'Latest progress row:\n{last_row}\n'
        f'{last_note}'
    )


def parse_mpiqueue_summary(run_dir):
    if run_dir is None:
        return ('No active run selected.', 'No active run selected.')
    qpath = Path(run_dir) / 'mpiqueue.log'
    if not qpath.exists():
        return ('mpiqueue.log not found yet.', 'mpiqueue.log not found yet.')

    text = qpath.read_text(errors='replace')
    lines = text.splitlines()
    jobs_total = None
    jobs_remaining = None
    for line in lines:
        m = re.search(r'Jobs:\s*(\d+),\s*#Remaining jobs:\s*(\d+)', line)
        if m:
            jobs_total = int(m.group(1))
            jobs_remaining = int(m.group(2))

    summary = [
        f'Queue file: {qpath.name}',
        f'Jobs total: {jobs_total if jobs_total is not None else "N/A"}',
        f'Jobs remaining: {jobs_remaining if jobs_remaining is not None else "N/A"}',
    ]
    return ('\n'.join(summary), read_tail(qpath, n_lines=30))


def parse_worker_logs_summary(run_dir):
    if run_dir is None:
        return 'No active run selected.'
    logs = sorted(Path(run_dir).glob('log.txt-*'), key=lambda p: p.name)
    if not logs:
        return 'No worker logs found yet.'

    out = []
    for lp in logs:
        last = read_nonempty_last_line(lp)
        pct_match = PCT_RE.search(last)
        pct = pct_match.group(1) + '%' if pct_match else 'N/A'
        status = 'done' if 'completed' in last.lower() or '100%' in last else 'running'
        out.append(f'{lp.name}: {pct} ({status}) :: {last}')
    return '\n'.join(out)


def latest_sg_up_file(run_dir):
    run_dir = Path(run_dir)
    candidates = list(run_dir.glob('Results/sg_up.rss-*')) + list(run_dir.glob('sg_up.rss-*'))
    if not candidates:
        return None

    def _suffix_num(path):
        m = re.search(r'sg_up\.rss-(\d+)$', path.name)
        return int(m.group(1)) if m else -1

    return sorted(candidates, key=lambda p: (_suffix_num(p), p.name))[-1]


def summarize_linesearch_model(run_dir):
    if run_dir is None:
        return 'No active run selected.'
    run_dir = Path(run_dir)
    sg_ls = run_dir / 'sg_ls.rss'
    latest = latest_sg_up_file(run_dir)

    msg = []
    if sg_ls.exists():
        st = sg_ls.stat()
        msg.append(f'Current linesearch model: {sg_ls.name} ({st.st_size} bytes)')
    else:
        msg.append('Current linesearch model: not available yet (sg_ls.rss missing).')

    if latest is not None:
        st = latest.stat()
        msg.append(f'Latest accepted model: {latest.name} ({st.st_size} bytes)')
    else:
        msg.append('Latest accepted model: not available yet (Results/sg_up.rss-* missing).')

    return '\n'.join(msg)


def _read_rss_model(path):
    from third_party.rockseis.io.rsfile import rsfile

    f = rsfile()
    f.read(str(path))

    # Model RSS files are stored as x,y,z. For 2D models y is singleton.
    # Squeeze y and transpose so plotting uses z(rows) x x(columns).
    data = np.asarray(f.data, dtype=float)
    data = np.squeeze(data)
    if data.ndim != 2:
        raise ValueError(f'Expected 2D model RSS after squeeze for {path}, got shape {data.shape}')

    nx = int(data.shape[0])
    nz = int(data.shape[1])
    grid = np.asarray(data.T, dtype=float)

    dx = float(f.geomD[0]) if f.geomD[0] else 1.0
    ox = float(f.geomO[0])

    # z is dimension 2 for x,y,z; for already-2D files fall back to dim 1.
    iz = 2 if (len(f.geomN) > 2 and int(f.geomN[2]) > 0) else 1
    dz = float(f.geomD[iz]) if f.geomD[iz] else 1.0
    oz = float(f.geomO[iz])

    x = ox + dx * np.arange(nx)
    z = oz + dz * np.arange(nz)
    return x, z, grid


def _extract_positions(run_dir):
    run_dir = Path(run_dir)
    hx_path = run_dir / 'Hx_data.rss'
    if not hx_path.exists():
        return np.array([]), np.array([]), np.array([]), np.array([])
    try:
        meta = load_rss_traces(hx_path)
        src = np.column_stack((np.asarray(meta['src_x'], dtype=float), np.asarray(meta['src_z'], dtype=float)))
        rec = np.column_stack((np.asarray(meta['rx_x'], dtype=float), np.asarray(meta['rx_z'], dtype=float)))
        src_u = np.unique(np.round(src, 6), axis=0)
        rec_u = np.unique(np.round(rec, 6), axis=0)
        return src_u[:, 0], src_u[:, 1], rec_u[:, 0], rec_u[:, 1]
    except Exception:
        return np.array([]), np.array([]), np.array([]), np.array([])


def _align_axes_to_survey(x, z, tx_x, tx_z, rx_x, rx_z):
    x = np.asarray(x, dtype=float)
    z = np.asarray(z, dtype=float)
    pts_x = np.concatenate([tx_x, rx_x]) if (tx_x.size + rx_x.size) > 0 else np.array([])
    pts_z = np.concatenate([tx_z, rx_z]) if (tx_z.size + rx_z.size) > 0 else np.array([])
    if pts_x.size == 0 or pts_z.size == 0:
        return x, z

    x_span = float(np.nanmax(x) - np.nanmin(x)) if x.size else 0.0
    z_span = float(np.nanmax(z) - np.nanmin(z)) if z.size else 0.0
    x_center = float(0.5 * (np.nanmin(x) + np.nanmax(x))) if x.size else 0.0
    z_center = float(0.5 * (np.nanmin(z) + np.nanmax(z))) if z.size else 0.0
    sx_center = float(0.5 * (np.nanmin(pts_x) + np.nanmax(pts_x)))
    sz_center = float(0.5 * (np.nanmin(pts_z) + np.nanmax(pts_z)))

    # If survey and model centers are far apart, shift model axes for plotting.
    if x_span > 0.0 and abs(sx_center - x_center) > 0.5 * x_span:
        x = x + (sx_center - x_center)
    if z_span > 0.0 and abs(sz_center - z_center) > 0.5 * z_span:
        z = z + (sz_center - z_center)
    return x, z


def _current_iteration_model(run_dir):
    latest = latest_sg_up_file(run_dir)
    if latest is not None:
        return latest
    fallback = Path(run_dir) / 'sg0.rss'
    return fallback if fallback.exists() else None


def update_model_plot(run_dir, force=False):
    if run_dir is None:
        return
    run_dir = Path(run_dir)

    # If user switched runs, clear previous run's plot immediately.
    if state.get('last_plot_run_dir') != str(run_dir):
        with model_plot_out:
            from IPython.display import clear_output
            clear_output(wait=True)
        state['model_plot_key'] = None

    sg_ls = run_dir / 'sg_ls.rss'
    sg_true = FDMODEL_DIR / 'sg.rss'
    sg_iter = _current_iteration_model(run_dir)
    sg_up = latest_sg_up_file(run_dir)

    if (not sg_ls.exists()) and (sg_iter is None) and (not sg_true.exists()):
        with model_plot_out:
            from IPython.display import clear_output
            clear_output(wait=True)
        state['last_plot_run_dir'] = str(run_dir)
        return

    key = (
        str(sg_ls) if sg_ls.exists() else None,
        sg_ls.stat().st_mtime if sg_ls.exists() else None,
        str(sg_iter) if sg_iter is not None else None,
        sg_iter.stat().st_mtime if sg_iter is not None else None,
        str(sg_up) if sg_up is not None else None,
        sg_up.stat().st_mtime if sg_up is not None else None,
        sg_true.stat().st_mtime if sg_true.exists() else None,
    )
    if (not force) and state.get('model_plot_key') == key:
        return

    tx_x, tx_z, rx_x, rx_z = _extract_positions(run_dir)

    models = []
    if sg_iter is not None:
        models.append(('Current iteration model', sg_iter))
    if sg_ls.exists():
        models.append(('Current linesearch model (sg_ls.rss)', sg_ls))
    if sg_true.exists():
        models.append(('True conductivity model (FDmodel/sg.rss)', sg_true))
    if sg_up is not None:
        models.append((f'Latest accepted model ({sg_up.name})', sg_up))

    grids = []
    for title, path in models:
        x, z, g = _read_rss_model(path)
        x, z = _align_axes_to_survey(x, z, tx_x, tx_z, rx_x, rx_z)
        grids.append((title, path, x, z, g))

    if not grids:
        return

    # Keep fixed limits from true model when available.
    if sg_true.exists():
        _, _, gtrue = _read_rss_model(sg_true)
        zmin = float(np.nanmin(gtrue))
        zmax = float(np.nanmax(gtrue))
    else:
        gmins = [float(np.nanmin(g)) for _, _, _, _, g in grids]
        gmaxs = [float(np.nanmax(g)) for _, _, _, _, g in grids]
        zmin = min(gmins)
        zmax = max(gmaxs)

    from plotly.subplots import make_subplots

    nplots = len(grids)
    rows = int(np.ceil(nplots / 2))
    fig = make_subplots(rows=rows, cols=2, subplot_titles=[t for t, *_ in grids] + [''] * (rows * 2 - nplots))

    for i, (title, path, x, z, grid) in enumerate(grids):
        row = i // 2 + 1
        col = i % 2 + 1
        fig.add_trace(
            go.Heatmap(
                x=x,
                y=z,
                z=grid,
                colorscale='Viridis',
                zmin=zmin,
                zmax=zmax,
                showscale=(i == 0),
                colorbar=dict(title='S/m'),
            ),
            row=row,
            col=col,
        )
        if tx_x.size:
            fig.add_trace(
                go.Scatter(
                    x=tx_x,
                    y=tx_z,
                    mode='markers',
                    marker=dict(symbol='triangle-up', size=7, color='red'),
                    name='TX',
                    showlegend=(i == 0),
                ),
                row=row,
                col=col,
            )
        if rx_x.size:
            fig.add_trace(
                go.Scatter(
                    x=rx_x,
                    y=rx_z,
                    mode='markers',
                    marker=dict(symbol='circle', size=6, color='cyan'),
                    name='RX',
                    showlegend=(i == 0),
                ),
                row=row,
                col=col,
            )

        fig.update_xaxes(title_text='x (m)', row=row, col=col)
        fig.update_yaxes(title_text='z (m)', autorange='reversed', row=row, col=col)

    fig.update_layout(
        title='Inversion conductivity models (iteration / linesearch / true / latest update)',
        height=max(420, 360 * rows),
        margin=dict(t=70, b=40, l=40, r=40),
        legend=dict(orientation='h', y=-0.05),
    )

    with model_plot_out:
        from IPython.display import display, clear_output
        clear_output(wait=True)
        display(fig)

    state['model_plot_key'] = key
    state['last_plot_run_dir'] = str(run_dir)


def stop_auto_refresh():
    state['refresh_loop'] = False
    t = state.get('monitor_thread')
    if t is not None and t.is_alive():
        t.join(timeout=0.2)
    state['monitor_thread'] = None


def get_progress_event_signature(run_dir):
    if run_dir is None:
        return None
    p = Path(run_dir) / 'progress.log'
    if not p.exists():
        return None

    lines = p.read_text(errors='replace').splitlines()
    events = [ln.strip() for ln in lines if ln.startswith('Linesearch') or ln.startswith('Iteration')]
    if not events:
        return ('none', 0)
    return (events[-1], len(events))


def refresh_run_status(_=None, refresh_images=True):
    if state.get('refresh_in_progress'):
        return

    state['refresh_in_progress'] = True
    try:
        proc = state.get('process')
        run_dir = state.get('active_run_dir')
        if proc is None:
            if run_dir is None:
                run_status.value = 'No active inversion process in this session.'
            else:
                run_status.value = f'No active process. Last run folder: {run_dir.name}'
        else:
            rc = proc.poll()
            if rc is None:
                run_status.value = f'Inversion is running (pid={proc.pid}) in {run_dir.name}.'
            else:
                run_status.value = f'Inversion finished with exit code {rc} in {run_dir.name}.'
                state['process'] = None
                state['refresh_loop'] = False

        log_out.value = read_tail(state.get('active_run_log'), n_lines=60)
        progress_out.value = parse_progress_summary(run_dir)
        job_progress_out.value, mpiqueue_out.value = parse_mpiqueue_summary(run_dir)
        worker_logs_out.value = parse_worker_logs_summary(run_dir)
        ls_model_out.value = summarize_linesearch_model(run_dir)

        current_event = get_progress_event_signature(run_dir)
        if current_event != state.get('last_progress_event'):
            state['last_progress_event'] = current_event
            if refresh_images and not state.get('initial_load'):
                update_model_plot(run_dir, force=True)
            if current_event is not None and run_dir is not None:
                push_message(f'Progress event detected in {Path(run_dir).name}: {current_event[0]}')
    finally:
        state['refresh_in_progress'] = False


def start_auto_refresh(interval_s=5.0):
    stop_auto_refresh()
    state['refresh_loop'] = True

    def _loop():
        while state.get('refresh_loop'):
            try:
                refresh_run_status(refresh_images=False)
            except Exception:
                pass
            time.sleep(interval_s)

    t = threading.Thread(target=_loop, daemon=True)
    t.start()
    state['monitor_thread'] = t


def initialize_active_run():
    latest = find_latest_run_dir(ROOT)
    if latest is None:
        return
    state['active_run_dir'] = latest
    preferred_logs = [latest / 'inversion.log', latest / 'inversion_run.log']
    for lp in preferred_logs:
        if lp.exists():
            state['active_run_log'] = lp
            break
    if state['active_run_log'] is None:
        state['active_run_log'] = preferred_logs[0]



In [None]:
title = ipw.HTML('<h2>03_inversion</h2>')

initial_model_mode = ipw.Dropdown(
    options=[
        ('Uniform conductivity model', 'uniform_conductivity'),
        ('Uniform resistivity model (converted to conductivity)', 'uniform_resistivity'),
    ],
    value='uniform_resistivity',
    description='Initial model:',
    layout=ipw.Layout(width='700px'),
)

uniform_cond = ipw.FloatText(value=0.01, description='sigma (S/m):', layout=ipw.Layout(width='320px'))
uniform_rho = ipw.FloatText(value=100.0, description='rho (Ohm.m):', layout=ipw.Layout(width='320px'))
max_iterations = ipw.IntText(value=20, description='Max iter:', layout=ipw.Layout(width='220px'))
apertx_value = ipw.FloatText(value=60.0, description='apertx (m):', layout=ipw.Layout(width='220px'))
dtx_value = ipw.FloatText(value=6.0, description='dtx:', layout=ipw.Layout(width='220px'))
dtz_value = ipw.FloatText(value=6.0, description='dtz:', layout=ipw.Layout(width='220px'))
clean_run_dir = ipw.Checkbox(value=True, description='Clean selected InversionRunN before run')
run_selector = ipw.Dropdown(options=[('No runs found', None)], value=None, description='Load run:', layout=ipw.Layout(width='420px'))

nproc_inv_input = ipw.BoundedIntText(
    value=min(4, MAX_CPUS), min=2, max=MAX_CPUS,
    description='nproc', layout=ipw.Layout(width='120px'),
)
gen_inputs_btn = ipw.Button(description='Generate inversion inputs', button_style='primary')
run_inv_btn = ipw.Button(description='Run inversion locally', button_style='success')
refresh_btn = ipw.Button(description='Refresh from progress.log')
stop_btn = ipw.Button(description='Stop run', button_style='warning')
quit_btn = ipw.Button(description='Quit GUI server', button_style='danger')

run_status = ipw.HTML(value='No active inversion process in this session.')
status_out = ipw.Textarea(value='', description='Status:', layout=ipw.Layout(width='100%', height='200px'))
log_out = ipw.Textarea(value='No log file yet.', description='Run log:', layout=ipw.Layout(width='100%', height='220px'))
progress_out = ipw.Textarea(value='No progress information yet.', description='Progress:', layout=ipw.Layout(width='100%', height='160px'))
job_progress_out = ipw.Textarea(value='No queue information yet.', description='Queue summary:', layout=ipw.Layout(width='100%', height='120px'))
mpiqueue_out = ipw.Textarea(value='No mpiqueue tail yet.', description='mpiqueue tail:', layout=ipw.Layout(width='100%', height='220px'))
worker_logs_out = ipw.Textarea(value='No worker log information yet.', description='Worker logs:', layout=ipw.Layout(width='100%', height='240px'))
ls_model_out = ipw.Textarea(value='No linesearch model yet.', description='LS model:', layout=ipw.Layout(width='100%', height='140px'))
model_plot_out = ipw.Output(layout=ipw.Layout(width='100%', border='1px solid #ddd', height='1200px', overflow='auto'))


def _selected_model_kwargs():
    mode = initial_model_mode.value
    if mode == 'uniform_conductivity':
        return {
            'initial_model_mode': mode,
            'uniform_conductivity': float(uniform_cond.value),
            'uniform_resistivity': None,
        }
    return {
        'initial_model_mode': mode,
        'uniform_conductivity': None,
        'uniform_resistivity': float(uniform_rho.value),
    }


def write_setup_metadata(created_files):
    metadata = {
        'input_dir': str(INV_INPUT_DIR),
        'run_dir': str(state.get('active_run_dir')) if state.get('active_run_dir') else None,
        'fdmodel_dir': str(FDMODEL_DIR),
        'initial_model_mode': initial_model_mode.value,
        'max_iterations': int(max_iterations.value),
        'apertx': float(apertx_value.value),
        'dtx': float(dtx_value.value),
        'dtz': float(dtz_value.value),
        'files': {k: str(v) for k, v in created_files.items()},
    }
    INVERSION_SETUP_METADATA.write_text(json.dumps(metadata, indent=2))


def set_active_run(run_dir):
    if run_dir is None:
        state['active_run_dir'] = None
        state['active_run_log'] = None
        state['last_progress_event'] = None
        return
    run_dir = Path(run_dir)
    state['active_run_dir'] = run_dir
    preferred_logs = [run_dir / 'inversion.log', run_dir / 'inversion_run.log']
    chosen = None
    for lp in preferred_logs:
        if lp.exists():
            chosen = lp
            break
    state['active_run_log'] = chosen if chosen is not None else preferred_logs[0]
    state['model_plot_key'] = None
    state['last_plot_run_dir'] = None
    state['last_progress_event'] = None


def refresh_run_selector(select_latest=False):
    runs = list_run_dirs(ROOT)
    options = [(p.name, str(p)) for _, p in runs]
    if not options:
        options = [('No runs found', None)]

    current = None
    if state.get('active_run_dir') is not None:
        current = str(state['active_run_dir'])

    state['updating_run_selector'] = True
    run_selector.options = options
    if select_latest and runs:
        run_selector.value = str(runs[-1][1])
    elif current in [v for _, v in options]:
        run_selector.value = current
    else:
        run_selector.value = options[0][1]
    state['updating_run_selector'] = False


def on_select_run(change):
    if change.get('name') != 'value' or state.get('updating_run_selector'):
        return
    selected = change.get('new')
    if not selected:
        return

    proc = state.get('process')
    if proc is not None and proc.poll() is None:
        push_message('Cannot switch run while inversion is active. Stop it first.')
        state['updating_run_selector'] = True
        run_selector.value = str(state['active_run_dir']) if state.get('active_run_dir') else None
        state['updating_run_selector'] = False
        return

    set_active_run(Path(selected))
    update_info_panel()
    refresh_run_status()
    push_message(f'Loaded existing run: {Path(selected).name}')


def on_generate_inputs(_):
    try:
        if not INV_TEMPLATE.exists():
            raise FileNotFoundError(f'Missing inversion template: {INV_TEMPLATE}')
        if not FDMODEL_DIR.exists():
            raise FileNotFoundError(f'Missing FDmodel folder: {FDMODEL_DIR}')

        INV_INPUT_DIR.mkdir(parents=True, exist_ok=True)
        created = prepare_inversion_inputs(
            fdmodel_dir=FDMODEL_DIR,
            template_cfg=INV_TEMPLATE,
            output_dir=INV_INPUT_DIR,
            max_iterations=int(max_iterations.value),
            apertx=float(apertx_value.value),
            dtx=float(dtx_value.value),
            dtz=float(dtz_value.value),
            **_selected_model_kwargs(),
        )

        write_setup_metadata(created)
        push_message('Generated inversion inputs in InversionInput/.')
        push_message('Wrote setup metadata: InversionInput/inversion_setup_metadata.json')
        update_info_panel()
        refresh_run_status()
    except Exception as exc:
        push_message(f'ERROR generating inversion inputs: {exc}')
        push_message(traceback.format_exc())


def on_run_inversion(_):
    try:
        if state.get('process') is not None and state['process'].poll() is None:
            raise RuntimeError('Inversion process already running. Stop it first.')
        if not RUNINV_TEMPLATE.exists():
            raise FileNotFoundError(f'Missing run script template: {RUNINV_TEMPLATE}')
        if not (INV_INPUT_DIR / 'inv.cfg').exists():
            raise FileNotFoundError('Missing InversionInput/inv.cfg. Generate inversion inputs first.')

        run_dir = find_next_run_dir(ROOT)
        run_dir.mkdir(parents=True, exist_ok=True)

        staged = stage_run_directory(
            input_dir=INV_INPUT_DIR,
            run_dir=run_dir,
            clean=bool(clean_run_dir.value),
            include_patterns=['*'],
        )

        run_script = run_dir / 'runinv.sh'
        nproc = max(2, min(int(getattr(nproc_inv_input, 'value', 4)), MAX_CPUS))
        runinv_content = RUNINV_TEMPLATE.read_text().replace('mpirun ', f'mpirun -np {nproc} ', 1)
        run_script.write_text(runinv_content)

        run_log = run_dir / 'inversion.log'
        with open(run_log, 'w') as logf:
            proc = subprocess.Popen(
                ['sh', str(run_script)],
                cwd=str(run_dir),
                stdout=logf,
                stderr=subprocess.STDOUT,
            )

        set_active_run(run_dir)
        state['active_run_log'] = run_log
        state['process'] = proc
        state['last_sg_ls_mtime'] = None
        state['last_progress_event'] = ('__new_run__', -1)

        if INVERSION_SETUP_METADATA.exists():
            current = json.loads(INVERSION_SETUP_METADATA.read_text())
            current['run_dir'] = str(run_dir)
            current['run_log'] = str(run_log)
            current['run_script'] = str(run_script)
            INVERSION_SETUP_METADATA.write_text(json.dumps(current, indent=2))

        push_message(f'Staged {len(staged)} input files into {run_dir.name}/.')
        push_message(f'Copied run script to {run_script}.')
        push_message(f'Started inversion process (pid={proc.pid}) in {run_dir.name}.')

        # Force dropdown to current run deterministically.
        refresh_run_selector(select_latest=False)
        state['updating_run_selector'] = True
        run_selector.value = str(run_dir)
        state['updating_run_selector'] = False

        # Clear previous run figure immediately and render only new-run files.
        with model_plot_out:
            from IPython.display import clear_output
            clear_output(wait=True)

        # Force immediate plot refresh from the new run directory (sg0/sg_ls if present).
        update_model_plot(run_dir, force=True)

        update_info_panel()
        start_auto_refresh(interval_s=5.0)
        refresh_run_status()
        push_message('Auto-refreshing logs every 5s; figures refresh only when progress.log has new linesearch/iteration events.')
    except Exception as exc:
        push_message(f'ERROR starting inversion: {exc}')
        push_message(traceback.format_exc())


def on_refresh(_):
    refresh_run_status()
    # Ensure manual refresh always re-evaluates model panel for selected run.
    update_model_plot(state.get('active_run_dir'), force=True)


def on_stop(_):
    try:
        proc = state.get('process')
        if proc is None or proc.poll() is not None:
            push_message('No active inversion process to stop.')
            refresh_run_status()
            return
        proc.terminate()
        try:
            proc.wait(timeout=5)
        except subprocess.TimeoutExpired:
            proc.kill()
        push_message('Stopped inversion process.')
    except Exception as exc:
        push_message(f'ERROR stopping inversion: {exc}')
    finally:
        refresh_run_status()


def on_quit_server(_):
    try:
        stop_auto_refresh()
        proc = state.get('process')
        if proc is not None and proc.poll() is None:
            proc.terminate()
            try:
                proc.wait(timeout=5)
            except subprocess.TimeoutExpired:
                proc.kill()
        pid = None
        if VOILA_PID_FILE.exists():
            pid = int(VOILA_PID_FILE.read_text().strip())
        push_message('Shutting down inversion GUI server...')
        if pid:
            os.kill(pid, signal.SIGTERM)
    except Exception as exc:
        push_message(f'ERROR quitting GUI server: {exc}')


bind_button_with_feedback(gen_inputs_btn, on_generate_inputs, 'Generating inversion inputs')
bind_button_with_feedback(run_inv_btn, on_run_inversion, 'Starting inversion run')
bind_button_with_feedback(refresh_btn, on_refresh, 'Refreshing inversion status')
bind_button_with_feedback(stop_btn, on_stop, 'Stopping inversion run')
bind_button_with_feedback(quit_btn, on_quit_server, 'Shutting down inversion GUI server')
run_selector.observe(on_select_run, names='value')

controls1 = ipw.HBox([max_iterations, apertx_value, dtx_value, dtz_value])
controls2 = ipw.HBox([uniform_cond, uniform_rho])
run_controls = ipw.HBox([run_selector])
buttons = ipw.HBox([nproc_inv_input, gen_inputs_btn, run_inv_btn, stop_btn, refresh_btn, quit_btn])

info = ipw.HTML(value='')
update_info_panel()

layout = ipw.VBox([
    title,
    info,
    initial_model_mode,
    controls1,
    controls2,
    run_controls,
    clean_run_dir,
    buttons,
    run_status,
    progress_out,
    job_progress_out,
    ls_model_out,
    model_plot_out,
    worker_logs_out,
    mpiqueue_out,
    status_out,
    log_out,
])

state['initial_load'] = True
initialize_active_run()
refresh_run_selector(select_latest=True)
update_info_panel()
display(layout)
refresh_run_status(refresh_images=False)
state['initial_load'] = False

# Auto-refresh is NOT started on load (avoids Voila hang). It starts only when you
# click "Run inversion locally". The background thread only updates log text, never
# the model plot; use "Refresh from progress.log" to update the plot.