# 02_visualization

GUI for running FD modelling and visualizing `Hxshot.rss`/`Hzshot.rss` outputs.


In [None]:
from pathlib import Path
import json
import os
import re
import signal
import subprocess
import threading
import time
import traceback
import sys

ROOT = Path('/Users/wiktorweibull/Rockem_projects/EM_inversion_workshop').resolve()
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

import numpy as np

try:
    import ipywidgets as ipw
    import plotly.graph_objects as go
except Exception as exc:
    raise RuntimeError(
        'Missing GUI dependencies. Install with: python3 -m pip install voila ipywidgets plotly'
    ) from exc

from scripts.modules.fd_visualization import compute_amp_phase_for_fd_outputs, save_amp_phase_npz

FDMODEL_DIR = ROOT / 'FDmodel'
DATA_DIR = FDMODEL_DIR / 'Data'
HX_PATH = DATA_DIR / 'Hxshot.rss'
HZ_PATH = DATA_DIR / 'Hzshot.rss'
MPIQUEUE_LOG = FDMODEL_DIR / 'mpiqueue.log'
RUN_SCRIPT = FDMODEL_DIR / 'runmod.sh'
SETUP_META = FDMODEL_DIR / 'setup_metadata.json'
OUT_DIR = DATA_DIR / 'processed'
OUT_NPZ = OUT_DIR / 'amp_phase_results.npz'
VOILA_PID_FILE = ROOT / '.voila_visualize_server.pid'
MAX_CPUS = max(2, os.cpu_count() or 2)
MPI_EMMOD_BIN = os.path.expanduser('~/software/rockem-suite/bin/mpiEmmodADITE2d')

state = {
    'process': None,
    'result': None,
    'last_messages': [],
    'refresh_loop': False,
    'monitor_thread': None,
    'run_started': False,
}


def push_message(msg):
    state['last_messages'].append(msg)
    if len(state['last_messages']) > 12:
        state['last_messages'] = state['last_messages'][-12:]
    status_area.value = '\n'.join(state['last_messages'])


def bind_button_with_feedback(button, handler, action_label, success_message=None):
    def _wrapped(_):
        before = len(state.get('last_messages', []))
        push_message(f'{action_label}...')
        try:
            handler(_)
            after = len(state.get('last_messages', []))
            if after <= before + 1:
                push_message(success_message or f'{action_label} completed successfully.')
        except Exception as exc:
            push_message(f'{action_label} failed: {exc}')
            raise

    button.on_click(_wrapped)


def read_tail(path, n_lines=30):
    path = Path(path)
    if not path.exists():
        return f'{path} not found.'
    lines = path.read_text(errors='replace').splitlines()
    tail = lines[-int(n_lines):] if lines else []
    return '\n'.join(tail) if tail else '(empty)'


def discover_job_logs():
    logs = []
    for p in sorted(FDMODEL_DIR.glob('*.log')):
        if p.name != 'mpiqueue.log':
            logs.append(p)
    for p in sorted(FDMODEL_DIR.glob('log.txt-*')):
        logs.append(p)
    return logs


def _last_nonempty_line(path):
    path = Path(path)
    if not path.exists():
        return ''
    lines = path.read_text(errors='replace').splitlines()
    for line in reversed(lines):
        if line.strip():
            return line.strip()
    return ''


def _extract_percent_from_line(line):
    if not line:
        return np.nan
    m = re.search(r'(\d+(?:\.\d+)?)\s*%', line)
    if m:
        try:
            return float(m.group(1))
        except Exception:
            return np.nan
    if 'complete' in line.lower() or 'finished' in line.lower():
        return 100.0
    return np.nan


def summarize_job_progress(logs):
    rows = []
    for lp in logs:
        last = _last_nonempty_line(lp)
        pct = _extract_percent_from_line(last)
        rows.append({'name': lp.name, 'percent': pct, 'last_line': last})
    return rows


def all_jobs_complete(rows):
    if not rows:
        return False
    for row in rows:
        pct = row['percent']
        if np.isnan(pct) or pct < 100.0:
            return False
    return True


def default_frequencies():
    if SETUP_META.exists():
        try:
            meta = json.loads(SETUP_META.read_text())
            vals = meta.get('flist_hz', [])
            vals = [float(v) for v in vals]
            if vals:
                return vals
        except Exception:
            pass
    return [2e3, 4e3, 6e3]


def parse_freqs(text):
    parts = [p.strip() for p in str(text).split(',') if p.strip()]
    if not parts:
        raise ValueError('Frequency list cannot be empty.')
    vals = np.asarray([float(p) for p in parts], dtype=float)
    if np.any(vals <= 0):
        raise ValueError('Frequencies must be positive.')
    return vals


def clear_stale_run_logs(announce=False):
    removed = []
    if MPIQUEUE_LOG.exists():
        MPIQUEUE_LOG.unlink()
        removed.append(MPIQUEUE_LOG.name)
    for lp in discover_job_logs():
        try:
            lp.unlink()
            removed.append(lp.name)
        except Exception:
            pass
    if announce and removed:
        push_message('Removed stale logs: ' + ', '.join(removed))
    return removed


def refresh_run_status(_=None):
    proc = state.get('process')
    process_active = False
    if proc is not None:
        rc = proc.poll()
        if rc is None:
            process_active = True
            run_status.value = f'Run process active (pid={proc.pid})'
        else:
            run_status.value = f'Last run process exited with code {rc}'
            state['process'] = None
    else:
        run_status.value = 'No active run process in this session.'

    if (not state.get('run_started')) and (not process_active):
        job_progress_out.value = 'Run not started in this session yet.'
        worker_logs_out.value = 'Run not started in this session yet.'
        mpiqueue_out.value = 'Run not started in this session yet.'
        return

    mpiqueue_out.value = read_tail(MPIQUEUE_LOG, n_lines=40)
    logs = discover_job_logs()
    if logs:
        rows = summarize_job_progress(logs)
        progress_lines = []
        for row in rows:
            pct = row['percent']
            pct_text = 'n/a' if np.isnan(pct) else f'{pct:.1f}%'
            progress_lines.append(f"{row['name']}: {pct_text} | {row['last_line']}")
        job_progress_out.value = '\n'.join(progress_lines)

        logs_text = []
        for lp in logs[:4]:
            logs_text.append(f'--- {lp.name} ---')
            logs_text.append(read_tail(lp, n_lines=15))
        worker_logs_out.value = '\n'.join(logs_text)

        if state.get('refresh_loop') and (not process_active) and all_jobs_complete(rows):
            state['refresh_loop'] = False
            push_message('All jobs completed. Auto-refresh stopped.')
    else:
        job_progress_out.value = 'No worker log files found yet.'
        worker_logs_out.value = 'No worker log files found yet.'


def start_auto_refresh():
    if state.get('refresh_loop'):
        return
    state['refresh_loop'] = True

    def _poll_loop():
        while state.get('refresh_loop'):
            try:
                refresh_run_status()
            except Exception:
                pass
            time.sleep(2.0)

    th = threading.Thread(target=_poll_loop, daemon=True)
    state['monitor_thread'] = th
    th.start()


def stop_auto_refresh():
    state['refresh_loop'] = False


def on_run_model(_):
    try:
        if not RUN_SCRIPT.exists():
            raise FileNotFoundError(f'Missing run script: {RUN_SCRIPT}')
        if state.get('process') is not None and state['process'].poll() is None:
            raise RuntimeError('A run process is already active in this GUI session.')

        clear_stale_run_logs(announce=True)

        nproc = max(2, min(int(getattr(nproc_input, 'value', 4)), MAX_CPUS))
        if not os.path.isfile(MPI_EMMOD_BIN):
            raise FileNotFoundError(f'MPI FD binary not found: {MPI_EMMOD_BIN}')

        state['run_started'] = True
        job_progress_out.value = 'Waiting for new job logs...'
        worker_logs_out.value = 'Waiting for new worker logs...'
        mpiqueue_out.value = 'Waiting for new mpiqueue.log...'

        proc = subprocess.Popen(
            ['mpirun', '-np', str(nproc), MPI_EMMOD_BIN, 'mod.cfg'],
            cwd=str(FDMODEL_DIR),
        )
        state['process'] = proc
        push_message(f'Started mpirun -np {nproc} in background (pid={proc.pid}).')
        refresh_run_status()
        start_auto_refresh()
    except Exception as exc:
        push_message(f'Run start failed: {exc}')
        push_message(traceback.format_exc())


def on_stop_model(_):
    try:
        proc = state.get('process')
        if proc is None or proc.poll() is not None:
            push_message('No active run process to stop in this session.')
            stop_auto_refresh()
            return
        os.kill(proc.pid, signal.SIGINT)
        push_message(f'Sent SIGINT to run process pid={proc.pid}.')
        stop_auto_refresh()
        refresh_run_status()
    except Exception as exc:
        push_message(f'Stop run failed: {exc}')
        push_message(traceback.format_exc())


def on_load_outputs(_):
    try:
        missing = [str(p) for p in [HX_PATH, HZ_PATH] if not p.exists()]
        if missing:
            raise RuntimeError('Missing output files: ' + ', '.join(missing))
        load_info.value = f'Found outputs: Hx={HX_PATH.name}, Hz={HZ_PATH.name}'
        push_message('FD output files detected.')
    except Exception as exc:
        push_message(f'Load outputs failed: {exc}')
        push_message(traceback.format_exc())


def on_compute(_):
    try:
        freqs = parse_freqs(freqs_input.value)
        start_t = None if not start_time_input.value.strip() else float(start_time_input.value)
        end_t = None if not end_time_input.value.strip() else float(end_time_input.value)
        n_pairs = int(n_pairs_input.value)
        result = compute_amp_phase_for_fd_outputs(
            HX_PATH,
            HZ_PATH,
            freqs=freqs,
            start_t=start_t,
            end_t=end_t,
            n_pairs=n_pairs,
        )
        state['result'] = result
        nfreq = len(result['Hx']['freqs'])
        ntr = result['geometry']['ntrace']
        comp_freq.options = [(f'{f:g} Hz', i) for i, f in enumerate(result['Hx']['freqs'])]
        trace_idx.max = max(0, ntr - 1)
        tx_ids = np.unique(result['geometry']['tx_idx_per_trace'])
        tx_select.options = [(f'tx#{int(i)}', int(i)) for i in tx_ids]
        if len(tx_ids) > 0:
            tx_select.value = int(tx_ids[0])

        rx_local = np.asarray(result['geometry'].get('rx_local_idx_per_trace', []), dtype=int)
        if rx_local.size > 0:
            rx_local_select.max = int(np.max(rx_local))
            rx_local_select.value = min(rx_local_select.value, rx_local_select.max)
        else:
            rx_local_select.max = 0
            rx_local_select.value = 0

        compute_info.value = f'Computed amp/phase for {nfreq} frequencies and {ntr} traces.'
        push_message('Computed two-point amplitude/phase results.')
        update_plot()
    except Exception as exc:
        push_message(f'Compute failed: {exc}')
        push_message(traceback.format_exc())


def traces_for_tx(result, tx_id):
    tx_idx = np.asarray(result['geometry']['tx_idx_per_trace'])
    return np.where(tx_idx == int(tx_id))[0]


def update_plot(_=None):
    result = state.get('result')
    if result is None:
        return

    comp = component_select.value
    metric = metric_select.value
    tx_id = tx_select.value
    if tx_id is None:
        return

    freq_i = comp_freq.value if comp_freq.value is not None else 0
    tr_idx = traces_for_tx(result, tx_id)
    if tr_idx.size == 0:
        return

    geo = result['geometry']
    rx_local_ids = np.asarray(geo.get('rx_local_idx_per_trace', geo['rx_idx_per_trace']))
    comp_data = result[comp]
    fig = go.Figure()

    if metric == 'amp_vs_rx':
        x = rx_local_ids[tr_idx]
        order = np.argsort(x)
        x = x[order]
        y = comp_data['amp_mean'][freq_i, tr_idx][order]
        title = f'{comp} amplitude vs local rx (tx#{tx_id}, f={comp_data["freqs"][freq_i]:g} Hz)'
        ytitle = 'Amplitude'
        xtitle = 'Local receiver index'
    elif metric == 'phase_vs_rx_deg':
        x = rx_local_ids[tr_idx]
        order = np.argsort(x)
        x = x[order]
        y = np.degrees(comp_data['phi_mean_rad'][freq_i, tr_idx][order])
        title = f'{comp} phase vs local rx (tx#{tx_id}, f={comp_data["freqs"][freq_i]:g} Hz)'
        ytitle = 'Phase (deg)'
        xtitle = 'Local receiver index'
    elif metric == 'amp_vs_tx':
        fixed_rx = int(rx_local_select.value)
        tx_all = np.asarray(geo['tx_idx_per_trace'])
        mask = rx_local_ids == fixed_rx
        tx_vals = np.unique(tx_all[mask])
        x = []
        y = []
        for tx in tx_vals:
            tr = np.where((tx_all == tx) & mask)[0]
            if tr.size > 0:
                x.append(int(tx))
                y.append(float(np.nanmean(comp_data['amp_mean'][freq_i, tr])))
        x = np.asarray(x)
        y = np.asarray(y)
        title = f'{comp} amplitude vs tx (local rx={fixed_rx}, f={comp_data["freqs"][freq_i]:g} Hz)'
        ytitle = 'Amplitude'
        xtitle = 'Transmitter index'
    elif metric == 'phase_vs_tx_deg':
        fixed_rx = int(rx_local_select.value)
        tx_all = np.asarray(geo['tx_idx_per_trace'])
        mask = rx_local_ids == fixed_rx
        tx_vals = np.unique(tx_all[mask])
        x = []
        y = []
        for tx in tx_vals:
            tr = np.where((tx_all == tx) & mask)[0]
            if tr.size > 0:
                phi = comp_data['phi_mean_rad'][freq_i, tr]
                phi_mean = np.angle(np.nanmean(np.exp(1j * phi)))
                x.append(int(tx))
                y.append(float(np.degrees(phi_mean)))
        x = np.asarray(x)
        y = np.asarray(y)
        title = f'{comp} phase vs tx (local rx={fixed_rx}, f={comp_data["freqs"][freq_i]:g} Hz)'
        ytitle = 'Phase (deg)'
        xtitle = 'Transmitter index'
    elif metric == 'amp_vs_freq':
        global_trace = int(min(trace_idx.value, comp_data['ntrace'] - 1))
        y = comp_data['amp_mean'][:, global_trace]
        x = comp_data['freqs']
        title = f'{comp} amplitude vs frequency (trace#{global_trace})'
        ytitle = 'Amplitude'
        xtitle = 'Frequency (Hz)'
    else:
        global_trace = int(min(trace_idx.value, comp_data['ntrace'] - 1))
        y = np.degrees(comp_data['phi_mean_rad'][:, global_trace])
        x = comp_data['freqs']
        title = f'{comp} phase vs frequency (trace#{global_trace})'
        ytitle = 'Phase (deg)'
        xtitle = 'Frequency (Hz)'

    fig.add_trace(go.Scatter(x=x, y=y, mode='lines+markers', name=comp))
    fig.update_layout(title=title, xaxis_title=xtitle, yaxis_title=ytitle, height=420)
    with plot_out:
        plot_out.clear_output(wait=True)
        fig.show()


def on_save_results(_):
    try:
        result = state.get('result')
        if result is None:
            raise RuntimeError('Compute results first.')
        save_amp_phase_npz(OUT_NPZ, result)
        push_message(f'Saved processed results to {OUT_NPZ}')
    except Exception as exc:
        push_message(f'Save failed: {exc}')
        push_message(traceback.format_exc())


def on_quit_gui(_):
    try:
        stop_auto_refresh()

        proc = state.get('process')
        if proc is not None and proc.poll() is None:
            os.kill(proc.pid, signal.SIGINT)
            push_message(f'Sent SIGINT to run process pid={proc.pid}.')

        pid = None
        if VOILA_PID_FILE.exists():
            pid_text = VOILA_PID_FILE.read_text().strip()
            if pid_text:
                pid = int(pid_text)

        if pid is not None:
            if pid != os.getpid():
                os.kill(pid, signal.SIGINT)
                push_message(f'Sent SIGINT to Voila server PID {pid}.')
            else:
                push_message('Voila PID matches current kernel process; skipping direct PID kill.')
        else:
            push_message('Voila PID file not found; proceeding with kernel shutdown.')
    except Exception as exc:
        push_message(f'Quit warning: {exc}. Proceeding with kernel shutdown.')
    finally:
        try:
            from IPython import get_ipython

            ip = get_ipython()
            if ip and getattr(ip, 'kernel', None):
                ip.kernel.do_shutdown(restart=False)
        except Exception:
            pass


nproc_input = ipw.BoundedIntText(
    value=min(4, MAX_CPUS), min=2, max=MAX_CPUS,
    description='nproc', layout=ipw.Layout(width='120px'),
)
run_button = ipw.Button(description='Run modelling (background)', button_style='success')
stop_button = ipw.Button(description='Stop run', button_style='warning')
refresh_button = ipw.Button(description='Refresh run status')
run_status = ipw.HTML(value='No active run process in this session.')
job_progress_out = ipw.Textarea(value='', description='job progress', layout=ipw.Layout(width='980px', height='180px'))
mpiqueue_out = ipw.Textarea(value='', description='mpiqueue', layout=ipw.Layout(width='980px', height='180px'))
worker_logs_out = ipw.Textarea(value='', description='worker logs', layout=ipw.Layout(width='980px', height='180px'))

load_button = ipw.Button(description='Load FD outputs')
load_info = ipw.HTML(value='Outputs not loaded yet.')

freqs_input = ipw.Text(value=','.join([f'{v:g}' for v in default_frequencies()]), description='flist (Hz)', layout=ipw.Layout(width='420px'))
start_time_input = ipw.Text(value='', description='start_t (s)', layout=ipw.Layout(width='280px'))
end_time_input = ipw.Text(value='', description='end_t (s)', layout=ipw.Layout(width='280px'))
n_pairs_input = ipw.IntText(value=3, description='n_pairs', layout=ipw.Layout(width='220px'))
compute_button = ipw.Button(description='Compute amplitude/phase', button_style='primary')
compute_info = ipw.HTML(value='Not computed yet.')

component_select = ipw.Dropdown(options=[('Hx', 'Hx'), ('Hz', 'Hz')], value='Hx', description='component')
metric_select = ipw.Dropdown(
    options=[
        ('Amplitude vs rx (local index)', 'amp_vs_rx'),
        ('Phase vs rx (deg, local index)', 'phase_vs_rx_deg'),
        ('Amplitude vs tx (fixed local rx)', 'amp_vs_tx'),
        ('Phase vs tx (deg, fixed local rx)', 'phase_vs_tx_deg'),
        ('Amplitude vs frequency', 'amp_vs_freq'),
        ('Phase vs frequency (deg)', 'phase_vs_freq_deg'),
    ],
    value='amp_vs_rx',
    description='plot',
)
comp_freq = ipw.Dropdown(options=[('n/a', 0)], value=0, description='frequency')
tx_select = ipw.Dropdown(options=[('n/a', 0)], value=0, description='tx')
rx_local_select = ipw.IntSlider(value=0, min=0, max=0, step=1, description='local rx', continuous_update=False, layout=ipw.Layout(width='320px'))
trace_idx = ipw.IntSlider(value=0, min=0, max=0, step=1, description='trace idx', continuous_update=False, layout=ipw.Layout(width='420px'))
save_button = ipw.Button(description='Save processed results', button_style='info')
quit_button = ipw.Button(description='Quit GUI server', button_style='danger')
plot_out = ipw.Output(layout=ipw.Layout(width='980px', border='1px solid #ddd'))

status_area = ipw.Textarea(value='Ready.', description='Status', layout=ipw.Layout(width='980px', height='180px'))

bind_button_with_feedback(run_button, on_run_model, 'Starting modelling run')
bind_button_with_feedback(stop_button, on_stop_model, 'Stopping modelling run')
bind_button_with_feedback(refresh_button, refresh_run_status, 'Refreshing run status')
bind_button_with_feedback(load_button, on_load_outputs, 'Loading FD outputs')
bind_button_with_feedback(compute_button, on_compute, 'Computing amplitude/phase')
bind_button_with_feedback(save_button, on_save_results, 'Saving processed results')
bind_button_with_feedback(quit_button, on_quit_gui, 'Shutting down visualization GUI server')
component_select.observe(update_plot, names='value')
metric_select.observe(update_plot, names='value')
comp_freq.observe(update_plot, names='value')
tx_select.observe(update_plot, names='value')
rx_local_select.observe(update_plot, names='value')
trace_idx.observe(update_plot, names='value')

from IPython.display import display

app_header = ipw.HTML(
    '<h2>02_visualization</h2>'
    '<p>Run FD modelling and extract two-point amplitude/phase from Hx/Hz shot records.</p>'
)

run_section = ipw.VBox([
    ipw.HTML('<h3>1) Run modelling</h3>'),
    ipw.HBox([nproc_input, run_button, stop_button, refresh_button]),
    run_status,
    ipw.HTML('<b>Per-job completion (auto-refresh every 2s while running)</b>'),
    job_progress_out,
    mpiqueue_out,
    worker_logs_out,
])

load_section = ipw.VBox([
    ipw.HTML('<h3>2) Load outputs</h3>'),
    ipw.HBox([load_button]),
    load_info,
])

compute_section = ipw.VBox([
    ipw.HTML('<h3>3) Compute two-point amplitude/phase</h3>'),
    ipw.HBox([freqs_input, start_time_input, end_time_input, n_pairs_input]),
    compute_button,
    compute_info,
])

plot_section = ipw.VBox([
    ipw.HTML('<h3>4) Visualization and export</h3>'),
    ipw.HBox([component_select, metric_select, comp_freq, tx_select, rx_local_select]),
    trace_idx,
    ipw.HBox([save_button, quit_button]),
    ipw.HTML('<b>Session control</b>: use Quit GUI server to stop Voila even if browser is closed.'),
    plot_out,
])

gui = ipw.VBox([
    app_header,
    run_section,
    load_section,
    compute_section,
    plot_section,
    status_area,
])

clear_stale_run_logs(announce=False)
refresh_run_status()
display(gui)
