# 01_fw_setup (Voila)

This notebook provides a code-hidden GUI for generating modelling inputs:
- `sg.rss`, `ep.rss`
- `wav2d.rss`
- `Survey.rss`
- selected `mod.cfg` fields (`dtrec`, `apertx`, file paths)

Run all cells, then launch with Voila (`--strip_sources=True`) to hide code.


In [None]:
from pathlib import Path
import json
import os
import signal
import shutil
import tempfile
import traceback
import sys
import math

ROOT = Path('/Users/wiktorweibull/Rockem_projects/EM_inversion_workshop').resolve()
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

import numpy as np

try:
    import ipywidgets as ipw
    import plotly.graph_objects as go
except Exception as exc:
    raise RuntimeError(
        'Missing GUI dependencies. Install with: python3 -m pip install voila ipywidgets plotly'
    ) from exc

try:
    from scripts.modules.segy import read_resistivity_from_segy, write_sg_ep_rss, save_resistivity_npz
    from scripts.modules.source import create_wavelet_rss
    from scripts.modules.survey import update_cfg_values as update_survey_cfg, generate_survey_rss
    from scripts.modules.fd import DesignInputs, recommend_design, update_modcfg_for_workshop, interpolate_rss_python, enforce_rss_min_value
except Exception as exc:
    raise RuntimeError(
        'Failed to import workshop modules. Ensure segyio is installed and run Voila from the workshop root.'
    ) from exc

FDMODEL_DIR = ROOT / 'FDmodel'
SCRIPTS_DIR = ROOT / 'scripts'
TEMPLATES_DIR = SCRIPTS_DIR / 'templates'
TMP_DIR = Path(tempfile.mkdtemp(prefix='em_workshop_gui_'))

DEFAULT_SG = FDMODEL_DIR / 'sg.rss'
DEFAULT_EP = FDMODEL_DIR / 'ep.rss'
DEFAULT_WAV = FDMODEL_DIR / 'wav2d.rss'
DEFAULT_SURVEY_RSS = FDMODEL_DIR / 'Survey.rss'
DEFAULT_MODCFG = FDMODEL_DIR / 'mod.cfg'
DEFAULT_SURVEY_CFG_TEMPLATE = TEMPLATES_DIR / 'survey.cfg'
DEFAULT_MODCFG_TEMPLATE = TEMPLATES_DIR / 'mod.cfg'
SETUP_METADATA_PATH = FDMODEL_DIR / 'setup_metadata.json'
VOILA_PID_FILE = ROOT / '.voila_server.pid'

state = {
    'model': None,
    'model_npz': TMP_DIR / 'resistivity_model.npz',
    'sg_raw_path': TMP_DIR / 'sg_raw.rss',
    'ep_raw_path': TMP_DIR / 'ep_raw.rss',
    'wav_raw_path': TMP_DIR / 'wav2d_raw.rss',
    'survey_cfg_path': TMP_DIR / 'survey.cfg',
    'survey_rss_path': TMP_DIR / 'Survey.rss',
    'survey_metadata': None,
    'wavelet_params': None,
    'fd_design': None,
    'last_messages': [],
}

if not DEFAULT_SURVEY_CFG_TEMPLATE.exists():
    raise RuntimeError(f'Missing survey template: {DEFAULT_SURVEY_CFG_TEMPLATE}')
shutil.copyfile(DEFAULT_SURVEY_CFG_TEMPLATE, state['survey_cfg_path'])

if (not DEFAULT_MODCFG.exists()) and DEFAULT_MODCFG_TEMPLATE.exists():
    DEFAULT_MODCFG.parent.mkdir(parents=True, exist_ok=True)
    shutil.copyfile(DEFAULT_MODCFG_TEMPLATE, DEFAULT_MODCFG)


def push_message(msg):
    state['last_messages'].append(msg)
    if len(state['last_messages']) > 10:
        state['last_messages'] = state['last_messages'][-10:]
    status_area.value = '\n'.join(state['last_messages'])


def bind_button_with_feedback(button, handler, action_label, success_message=None):
    def _wrapped(_):
        before = len(state.get('last_messages', []))
        push_message(f'{action_label}...')
        try:
            handler(_)
            after = len(state.get('last_messages', []))
            if after <= before + 1:
                push_message(success_message or f'{action_label} completed successfully.')
        except Exception as exc:
            push_message(f'{action_label} failed: {exc}')
            raise

    button.on_click(_wrapped)


def ensure_parent(path):
    Path(path).parent.mkdir(parents=True, exist_ok=True)


def parse_flist(text):
    parts = [p.strip() for p in text.split(',') if p.strip()]
    if not parts:
        raise ValueError('flist cannot be empty')
    return [float(p) for p in parts]


def get_model_bounds():
    model = state.get('model')
    if model is None:
        raise RuntimeError('Load a SEGY model first.')
    x = np.asarray(model['x'], dtype=float)
    z = np.asarray(model['z'], dtype=float)
    return {
        'x_min': float(np.nanmin(x)),
        'x_max': float(np.nanmax(x)),
        'z_min': float(np.nanmin(z)),
        'z_max': float(np.nanmax(z)),
    }


def compute_tx0_feasible_range(bounds):
    """
    Compute feasible tx0 range such that all tx/rx x-coordinates stay in model bounds,
    using current tx/rx widget settings for 2D line geometry.
    """
    n_tx = int(nsx.value)
    n_rx = int(ngx.value)
    if n_tx <= 0 or n_rx <= 0:
        raise ValueError('ntx and nrx must be positive integers.')

    ds = float(ds_x.value)
    rx0 = float(gx0.value)
    drx = float(dgx.value)

    tx_rel_min = min(0.0, (n_tx - 1) * ds)
    tx_rel_max = max(0.0, (n_tx - 1) * ds)

    rx_rel_min = min(rx0, rx0 + (n_rx - 1) * drx)
    rx_rel_max = max(rx0, rx0 + (n_rx - 1) * drx)

    lower_bound = max(
        bounds['x_min'] - tx_rel_min,
        bounds['x_min'] - tx_rel_min - rx_rel_min,
    )
    upper_bound = min(
        bounds['x_max'] - tx_rel_max,
        bounds['x_max'] - tx_rel_max - rx_rel_max,
    )
    return float(lower_bound), float(upper_bound)


def _check_points_within_bounds(name, xvals, zvals, bounds):
    xvals = np.asarray(xvals, dtype=float)
    zvals = np.asarray(zvals, dtype=float)
    if xvals.size == 0 or zvals.size == 0:
        return []

    x_low = xvals < bounds['x_min']
    x_high = xvals > bounds['x_max']
    z_low = zvals < bounds['z_min']
    z_high = zvals > bounds['z_max']
    bad = x_low | x_high | z_low | z_high

    if not np.any(bad):
        return []

    return [
        f"{name}: {int(np.count_nonzero(bad))} points outside model bounds "
        f"(x:[{bounds['x_min']:.3f},{bounds['x_max']:.3f}], "
        f"z:[{bounds['z_min']:.3f},{bounds['z_max']:.3f}])."
    ]


def text_input(value, description, width='460px'):
    return ipw.Text(value=str(value), description=description, layout=ipw.Layout(width=width), style={'description_width': '150px'})


def float_input(value, description, step=0.1, width='320px'):
    return ipw.FloatText(value=float(value), description=description, step=step, layout=ipw.Layout(width=width), style={'description_width': '150px'})


def int_input(value, description, width='320px'):
    return ipw.IntText(value=int(value), description=description, layout=ipw.Layout(width=width), style={'description_width': '150px'})


def round_down_to_0p5_m(value_m):
    return math.floor(float(value_m) / 0.5) * 0.5


def round_down_half_order(value):
    value = float(value)
    if value <= 0:
        raise ValueError('Sampling value must be positive.')
    exp10 = math.floor(math.log10(value))
    step = 0.5 * (10 ** exp10)
    return math.floor(value / step) * step


def interpolate_rss(in_path, out_path, d1f=None, d3f=None, method='bspline'):
    interpolate_rss_python(in_path, out_path, d1f=d1f, d3f=d3f, method=method, antialias=True)


def plot_resistivity_interactive(model, survey_metadata=None):
    z = np.asarray(model['z'])
    x = np.asarray(model['x'])
    fig = go.Figure(
        data=go.Heatmap(
            z=np.asarray(model['resistivity']),
            x=x,
            y=z,
            colorscale='Jet',
            colorbar={'title': 'Resistivity (Ohm-m)', 'x': 1.02, 'len': 0.9},
        )
    )

    if survey_metadata is not None:
        src_x = np.asarray(survey_metadata.get('source_x', []), dtype=float)
        src_z = np.asarray(survey_metadata.get('source_z', []), dtype=float)
        rec_x = np.asarray(survey_metadata.get('receiver_x', []), dtype=float)
        rec_z = np.asarray(survey_metadata.get('receiver_z', []), dtype=float)

        if src_x.size > 0 and src_z.size > 0:
            # source_x/source_z are repeated per trace; show unique source locations.
            src_xy = np.column_stack((src_x, src_z))
            src_unique = np.unique(src_xy, axis=0)
            fig.add_trace(
                go.Scatter(
                    x=src_unique[:, 0],
                    y=src_unique[:, 1],
                    mode='markers',
                    name='Sources',
                    marker={'color': 'red', 'size': 10, 'symbol': 'x'},
                )
            )

        if rec_x.size > 0 and rec_z.size > 0:
            rec_step = max(1, int(np.ceil(rec_x.size / 5000)))
            fig.add_trace(
                go.Scatter(
                    x=rec_x[::rec_step],
                    y=rec_z[::rec_step],
                    mode='markers',
                    name='Receivers',
                    marker={'color': 'white', 'size': 4, 'opacity': 0.7},
                )
            )

    fig.update_layout(
        title='Resistivity model + survey geometry',
        xaxis_title='Distance (m)',
        yaxis_title='Depth (m)',
        yaxis={'autorange': 'reversed'},
        legend={
            'orientation': 'h',
            'x': 0.0,
            'y': -0.2,
            'xanchor': 'left',
            'yanchor': 'top',
            'bgcolor': 'rgba(255,255,255,0.6)',
        },
        margin={'t': 70, 'r': 90, 'b': 80},
        height=420,
    )
    return fig


def plot_wavelet_and_spectrum(params):
    t = np.asarray(params['time_axis_s'])
    w = np.asarray(params['waveform'])
    dt = float(params['dt'])

    if w.size < 2:
        raise RuntimeError('Wavelet has too few samples to compute spectrum.')

    freqs = np.fft.rfftfreq(w.size, d=dt)
    amp = np.abs(np.fft.rfft(w))

    fig_time = go.Figure()
    fig_time.add_trace(go.Scatter(x=t, y=w, mode='lines', name='wavelet'))
    fig_time.update_layout(
        title='Wavelet (time domain)',
        xaxis_title='Time (s)',
        yaxis_title='Amplitude',
        height=320,
    )

    fig_spec = go.Figure()
    fig_spec.add_trace(go.Scatter(x=freqs, y=amp, mode='lines', name='|FFT|'))
    fig_spec.update_layout(
        title='Wavelet amplitude spectrum',
        xaxis_title='Frequency (Hz)',
        yaxis_title='Amplitude',
        height=320,
    )
    return fig_time, fig_spec


def make_file_browser(path_widget, title, start_dir, select_mode='file', filename_for_dir=None):
    """
    Inline file browser for Voila/ipywidgets.
    select_mode: 'file' or 'dir'
    If select_mode='dir' and filename_for_dir is set, the path widget is filled with dir/filename.
    """
    current_dir = ipw.Text(
        value=str(Path(start_dir).expanduser()),
        description='Folder',
        layout=ipw.Layout(width='760px'),
        style={'description_width': '80px'},
    )
    entries = ipw.Select(options=[], rows=10, layout=ipw.Layout(width='980px'))
    selected_info = ipw.HTML(value='')

    up_btn = ipw.Button(description='Up')
    open_btn = ipw.Button(description='Open folder')
    choose_btn = ipw.Button(description='Use selected', button_style='primary')
    refresh_btn = ipw.Button(description='Refresh')
    home_btn = ipw.Button(description='Home')
    root_btn = ipw.Button(description='Workshop root')

    def refresh_entries(_=None):
        try:
            d = Path(current_dir.value).expanduser().resolve()
            if not d.is_dir():
                raise RuntimeError(f'Not a folder: {d}')
            opts = []
            for item in sorted(d.iterdir(), key=lambda p: (not p.is_dir(), p.name.lower())):
                label = f"[DIR] {item.name}" if item.is_dir() else f"[FILE] {item.name}"
                opts.append((label, str(item)))
            entries.options = opts
            selected_info.value = f'Browsing: <code>{d}</code>'
        except Exception as exc:
            entries.options = []
            selected_info.value = f'<span style="color:red;">Browser error: {exc}</span>'

    def on_up(_):
        try:
            d = Path(current_dir.value).expanduser().resolve()
            current_dir.value = str(d.parent)
            refresh_entries()
        except Exception:
            refresh_entries()

    def on_open(_):
        if entries.value:
            p = Path(entries.value)
            if p.is_dir():
                current_dir.value = str(p)
                refresh_entries()

    def on_choose(_):
        if not entries.value:
            selected_info.value = '<span style="color:red;">Select an item first.</span>'
            return
        p = Path(entries.value)
        if select_mode == 'file' and p.is_dir():
            selected_info.value = '<span style="color:red;">Please select a file, not a folder.</span>'
            return
        if select_mode == 'dir' and not p.is_dir():
            selected_info.value = '<span style="color:red;">Please select a folder, not a file.</span>'
            return

        if select_mode == 'dir' and filename_for_dir:
            chosen = p / filename_for_dir
        else:
            chosen = p

        path_widget.value = str(chosen)
        selected_info.value = f'Selected: <code>{chosen}</code>'

    def on_home(_):
        current_dir.value = str(Path.home())
        refresh_entries()

    def on_root(_):
        current_dir.value = str(ROOT)
        refresh_entries()

    up_btn.on_click(on_up)
    open_btn.on_click(on_open)
    choose_btn.on_click(on_choose)
    refresh_btn.on_click(refresh_entries)
    home_btn.on_click(on_home)
    root_btn.on_click(on_root)

    current_dir.observe(lambda _: refresh_entries(), names='value')
    refresh_entries()

    box = ipw.VBox([
        current_dir,
        ipw.HBox([up_btn, open_btn, choose_btn, refresh_btn, home_btn, root_btn]),
        entries,
        selected_info,
    ])

    acc = ipw.Accordion(children=[box], selected_index=None)
    acc.set_title(0, title)
    return acc


# --- Section 1: SEGY input ---
segy_path = text_input(ROOT / 'input.segy', 'SEGY file')
ep_value = float_input(7.0, 'ep constant', step=0.1)

load_model_btn = ipw.Button(description='Load model', button_style='primary')
plot_model_btn = ipw.Button(description='Plot model')
write_sg_ep_btn = ipw.Button(description='Build intermediate sg/ep', button_style='success')


def on_load_model(_):
    try:
        model = read_resistivity_from_segy(segy_path.value)
        state['model'] = model
        ensure_parent(state['model_npz'])
        save_resistivity_npz(state['model_npz'], model['resistivity'], model['oz'], model['dz'], model['dx'], model['x'], model['z'])
        model_info.value = (
            f"Loaded model: shape={model['resistivity'].shape}, "
            f"ox={model['ox']:.3f}, oz={model['oz']:.3f}, dx={model['dx']:.3f}, dz={model['dz']:.3f}"
        )

        resistivity = np.asarray(model['resistivity'], dtype=float)
        positive = resistivity[resistivity > 0.0]
        if positive.size > 0:
            rho_min.value = float(np.nanmin(positive))
            rho_max.value = float(np.nanmax(positive))
            push_message(
                f"Initialized FD resistivity bounds from model: rho_min={rho_min.value:.3g}, "
                f"rho_max={rho_max.value:.3g} Ohm-m"
            )

        bounds = get_model_bounds()
        tx0_min, tx0_max = compute_tx0_feasible_range(bounds)
        if tx0_min > tx0_max:
            raise ValueError(
                'Current tx/rx geometry cannot fit in model x-bounds. '
                'Adjust ntx, nrx, dtx, rx0, or drx.'
            )

        sx0.value = float(np.clip(sx0.value, tx0_min, tx0_max))
        sz0.value = float(np.clip(sz0.value, bounds['z_min'], bounds['z_max']))
        gz0.value = float(np.clip(gz0.value, bounds['z_min'], bounds['z_max']))
        push_message(
            f"Adjusted tx0 to feasible range [{tx0_min:.3f}, {tx0_max:.3f}] and clamped tz0/rz0 to z-bounds. "
            f"Current tx0={sx0.value:.3f}, tz0={sz0.value:.3f}, rz0={gz0.value:.3f}"
        )

        plot_loaded_model()
        push_message(f"Model loaded. Intermediate NPZ: {state['model_npz']}")
    except Exception as exc:
        push_message(f"SEGY load failed: {exc}")
        push_message(traceback.format_exc())


def plot_loaded_model():
    model = state['model']
    if model is None:
        raise RuntimeError('Load a SEGY model first.')

    survey_metadata = state.get('survey_metadata')
    with model_plot_out:
        model_plot_out.clear_output(wait=True)
        fig = plot_resistivity_interactive(model, survey_metadata=survey_metadata)
        fig.show()


def on_plot_model(_):
    try:
        plot_loaded_model()
        push_message('Plotted resistivity model.')
    except Exception as exc:
        push_message(f"Plot failed: {exc}")
        push_message(traceback.format_exc())


def on_write_sg_ep(_):
    try:
        model = state['model']
        if model is None:
            raise RuntimeError('Load a SEGY model first.')
        ensure_parent(state['sg_raw_path'])
        ensure_parent(state['ep_raw_path'])
        write_sg_ep_rss(
            model['resistivity'],
            model['dx'],
            model['dz'],
            model['ox'],
            model['oz'],
            state['sg_raw_path'],
            state['ep_raw_path'],
            ep_value=ep_value.value,
        )
        push_message(f"Built intermediate sg: {state['sg_raw_path']}")
        push_message(f"Built intermediate ep: {state['ep_raw_path']}")
    except Exception as exc:
        push_message(f"sg/ep write failed: {exc}")
        push_message(traceback.format_exc())


bind_button_with_feedback(load_model_btn, on_load_model, 'Loading SEG-Y model')
bind_button_with_feedback(plot_model_btn, on_plot_model, 'Plotting loaded model')
bind_button_with_feedback(write_sg_ep_btn, on_write_sg_ep, 'Building intermediate sg/ep')
model_info = ipw.HTML(value='No model loaded yet.')
model_plot_out = ipw.Output(layout=ipw.Layout(width='980px', border='1px solid #ddd'))
segy_browser = make_file_browser(segy_path, 'Browse and select SEGY file', ROOT, select_mode='file')


# --- Section 2: Wavelet ---
flist_input = text_input('2000,4000,6000', 'flist (Hz)')
wavelet_dt = float_input(1e-6, 'dt (s)', step=1e-7)
n_periods = int_input(3, 'n_periods')
alpha = float_input(0.5, 'alpha', step=0.1)

create_wav_btn = ipw.Button(description='Build intermediate wav2d', button_style='success')
plot_wav_btn = ipw.Button(description='Plot wavelet + spectrum')
wavelet_info = ipw.HTML(value='Wavelet not generated yet.')
wavelet_plot_out = ipw.Output(layout=ipw.Layout(width='980px', border='1px solid #ddd'))


def plot_wavelet_from_state():
    params = state['wavelet_params']
    if params is None:
        raise RuntimeError('Create wavelet first.')
    fig_time, fig_spec = plot_wavelet_and_spectrum(params)
    with wavelet_plot_out:
        wavelet_plot_out.clear_output(wait=True)
        fig_time.show()
        fig_spec.show()


def on_create_wavelet(_):
    try:
        ensure_parent(state['wav_raw_path'])
        flist_vals = parse_flist(flist_input.value)
        params = create_wavelet_rss(
            flist=flist_vals,
            dt=wavelet_dt.value,
            n_periods=n_periods.value,
            alpha=alpha.value,
            wavfile=state['wav_raw_path'],
            show_plot=False,
        )
        state['wavelet_params'] = params
        wavelet_info.value = (
            f"Generated: nt={params['nt']}, periods={params['n_periods']}, "
            f"rec_time={params['rec_time_actual']:.3e}s"
        )

        f_min.value = float(min(flist_vals))
        f_max.value = float(max(flist_vals))
        push_message(f"Initialized FD frequency bounds from wavelet: f_min={f_min.value:.3g}, f_max={f_max.value:.3g} Hz")

        plot_wavelet_from_state()
        push_message(f"Built intermediate wavelet: {state['wav_raw_path']}")
    except Exception as exc:
        push_message(f"Wavelet creation failed: {exc}")
        push_message(traceback.format_exc())


def on_plot_wavelet(_):
    try:
        plot_wavelet_from_state()
        push_message('Plotted wavelet and amplitude spectrum.')
    except Exception as exc:
        push_message(f"Wavelet plot failed: {exc}")
        push_message(traceback.format_exc())


bind_button_with_feedback(create_wav_btn, on_create_wavelet, 'Building intermediate wav2d')
bind_button_with_feedback(plot_wav_btn, on_plot_wavelet, 'Plotting wavelet and spectrum')


# --- Section 3: Survey ---
sx0 = float_input(26.0, 'tx0 (m)')
sz0 = float_input(6050.0, 'tz0 (m)')
ds_x = float_input(4.8, 'dtx (m)')
nsx = int_input(1, 'ntx')
gx0 = float_input(-13.1, 'rx0 (m)')
gz0 = float_input(6050.0, 'rz0 (m)')
dgx = float_input(-12.2, 'drx (m)')
ngx = int_input(2, 'nrx')

gen_survey_btn = ipw.Button(description='Generate survey.cfg + Survey.rss', button_style='success')
survey_info = ipw.HTML(value='Survey files not generated yet.')


def on_generate_survey(_):
    try:
        cfg_path = Path(state['survey_cfg_path'])
        out_path = Path(state['survey_rss_path'])
        ensure_parent(out_path)

        if state.get('model') is not None:
            bounds = get_model_bounds()
            tx0_min, tx0_max = compute_tx0_feasible_range(bounds)
            if tx0_min > tx0_max:
                raise ValueError(
                    'Current tx/rx geometry cannot fit in model x-bounds. '
                    'Adjust ntx, nrx, dtx, rx0, or drx.'
                )
            sx0.value = float(np.clip(sx0.value, tx0_min, tx0_max))
            sz0.value = float(np.clip(sz0.value, bounds['z_min'], bounds['z_max']))
            gz0.value = float(np.clip(gz0.value, bounds['z_min'], bounds['z_max']))

        update_survey_cfg(
            cfg_path,
            {
                'sx0': sx0.value,
                'sz0': sz0.value,
                'dsx': ds_x.value,
                'nsx': nsx.value,
                'gx0': gx0.value,
                'gz0': gz0.value,
                'dgx': dgx.value,
                'ngx': ngx.value,
            },
        )

        metadata = generate_survey_rss(
            cfg_path.parent,
            cfg_filename=cfg_path.name,
            output_filename=out_path.name,
        )

        generated_path = Path(metadata['survey_rss'])
        if generated_path.resolve() != out_path.resolve():
            shutil.copyfile(generated_path, out_path)
            metadata['survey_rss'] = str(out_path)

        if state.get('model') is not None:
            bounds = get_model_bounds()
            violations = []
            violations += _check_points_within_bounds('TX', metadata.get('source_x', []), metadata.get('source_z', []), bounds)
            violations += _check_points_within_bounds('RX', metadata.get('receiver_x', []), metadata.get('receiver_z', []), bounds)
            if violations:
                raise ValueError('Survey coordinates outside model bounds. ' + ' '.join(violations))

        shutil.copyfile(out_path, DEFAULT_SURVEY_RSS)
        state['survey_metadata'] = metadata
        survey_info.value = (
            f"Generated survey.cfg + Survey.rss | sources={metadata['num_sources']}, "
            f"receivers/source={metadata['num_receivers_per_source']}"
        )
        if state.get('model') is not None:
            plot_loaded_model()
            push_message('Updated resistivity plot with survey overlay.')
        push_message(f"Generated survey.cfg at {cfg_path}")
        push_message(f"Generated intermediate Survey.rss at {metadata['survey_rss']}")
        push_message(f"Copied Survey.rss to final FD path: {DEFAULT_SURVEY_RSS}")
    except Exception as exc:
        push_message(f"Survey generation failed: {exc}")
        push_message(traceback.format_exc())


bind_button_with_feedback(gen_survey_btn, on_generate_survey, 'Generating survey.cfg and Survey.rss')


# --- Section 4: FD design ---
f_min = float_input(2e3, 'f_min (Hz)', step=100)
f_max = float_input(6e3, 'f_max (Hz)', step=100)
rho_min = float_input(1.0, 'rho_min (Ohm-m)', step=0.1)
rho_max = float_input(100.0, 'rho_max (Ohm-m)', step=1.0)
dim = ipw.Dropdown(options=[('2D', 2), ('3D', 3)], value=2, description='dimension', layout=ipw.Layout(width='320px'), style={'description_width': '150px'})
points_per_skin = int_input(5, 'points/skin')
k_cfl = float_input(10.0, 'k_cfl', step=1.0)
apertx_value = float_input(60.0, 'apertx (m)', step=1.0)

run_design_btn = ipw.Button(description='Compute FD design', button_style='primary')
fd_info = ipw.HTML(value='FD design not computed yet.')


def compute_fd_design_from_widgets():
    rho_min_val = float(rho_min.value)
    rho_max_val = float(rho_max.value)
    if rho_min_val <= 0.0 or rho_max_val <= 0.0:
        raise ValueError('rho_min and rho_max must be positive.')
    if rho_min_val > rho_max_val:
        raise ValueError('rho_min must be <= rho_max.')

    # Convert resistivity bounds to conductivity bounds for physics equations.
    sigma_min_val = 1.0 / rho_max_val
    sigma_max_val = 1.0 / rho_min_val

    outputs = recommend_design(
        DesignInputs(
            f_min_hz=f_min.value,
            f_max_hz=f_max.value,
            sigma_min_s_per_m=sigma_min_val,
            sigma_max_s_per_m=sigma_max_val,
            dim=dim.value,
            points_per_skin=points_per_skin.value,
            k_cfl=k_cfl.value,
        )
    )

    dx_raw = float(outputs.dx_m)
    dt_raw = float(outputs.dt_adi_s)
    dx_rounded = round_down_to_0p5_m(dx_raw)
    if dx_rounded <= 0.0:
        dx_rounded = 0.5
    dt_rounded = round_down_half_order(dt_raw)
    aperture = float(apertx_value.value)
    if aperture <= 0.0:
        raise ValueError('apertx (m) must be positive.')

    design = {
        'dx_raw_m': dx_raw,
        'dx_m': dx_rounded,
        'dt_raw_s': dt_raw,
        'dt_adi_s': dt_rounded,
        'delta_min_m': outputs.delta_min_m,
        'apertx_m': aperture,
    }
    return design, outputs


def on_run_design(_):
    try:
        design, outputs = compute_fd_design_from_widgets()
        state['fd_design'] = design
        fd_info.value = (
            f"dx raw={design['dx_raw_m']:.4f} m -> rounded={design['dx_m']:.4f} m, "
            f"dt raw={design['dt_raw_s']:.3e} s -> rounded={design['dt_adi_s']:.3e} s, "
            f"delta_min={outputs.delta_min_m:.4f} m, apertx={design['apertx_m']:.3f} m"
        )
        push_message('Computed FD design values from resistivity bounds.')
    except Exception as exc:
        push_message(f"FD design failed: {exc}")
        push_message(traceback.format_exc())


bind_button_with_feedback(run_design_btn, on_run_design, 'Computing FD design')


# --- Section 5: Generate final FD inputs ---
apply_outputs_btn = ipw.Button(description='Generate FD inputs (Finalize setup)', button_style='success')
apply_info = ipw.HTML(value='FD inputs not generated yet.')


def ensure_intermediate_files():
    model = state['model']
    if model is None:
        raise RuntimeError('Load SEGY model first (needed for intermediate generation).')

    if (not Path(state['sg_raw_path']).exists()) or (not Path(state['ep_raw_path']).exists()):
        ensure_parent(state['sg_raw_path'])
        ensure_parent(state['ep_raw_path'])
        write_sg_ep_rss(
            model['resistivity'],
            model['dx'],
            model['dz'],
            model['ox'],
            model['oz'],
            state['sg_raw_path'],
            state['ep_raw_path'],
            ep_value=ep_value.value,
        )
        push_message('Auto-generated missing intermediate sg/ep files.')

    if (not Path(state['wav_raw_path']).exists()) or (state['wavelet_params'] is None):
        ensure_parent(state['wav_raw_path'])
        flist_vals = parse_flist(flist_input.value)
        params = create_wavelet_rss(
            flist=flist_vals,
            dt=wavelet_dt.value,
            n_periods=n_periods.value,
            alpha=alpha.value,
            wavfile=state['wav_raw_path'],
            show_plot=False,
        )
        state['wavelet_params'] = params
        f_min.value = float(min(flist_vals))
        f_max.value = float(max(flist_vals))
        push_message('Auto-generated missing intermediate wavelet file.')

    if not Path(state['survey_rss_path']).exists():
        cfg_path = Path(state['survey_cfg_path'])
        out_path = Path(state['survey_rss_path'])
        ensure_parent(out_path)

        if state.get('model') is not None:
            bounds = get_model_bounds()
            tx0_min, tx0_max = compute_tx0_feasible_range(bounds)
            if tx0_min > tx0_max:
                raise ValueError(
                    'Current tx/rx geometry cannot fit in model x-bounds. '
                    'Adjust ntx, nrx, dtx, rx0, or drx.'
                )
            sx0.value = float(np.clip(sx0.value, tx0_min, tx0_max))
            sz0.value = float(np.clip(sz0.value, bounds['z_min'], bounds['z_max']))
            gz0.value = float(np.clip(gz0.value, bounds['z_min'], bounds['z_max']))

        update_survey_cfg(
            cfg_path,
            {
                'sx0': sx0.value,
                'sz0': sz0.value,
                'dsx': ds_x.value,
                'nsx': nsx.value,
                'gx0': gx0.value,
                'gz0': gz0.value,
                'dgx': dgx.value,
                'ngx': ngx.value,
            },
        )
        metadata = generate_survey_rss(
            cfg_path.parent,
            cfg_filename=cfg_path.name,
            output_filename=out_path.name,
        )
        if state.get('model') is not None:
            bounds = get_model_bounds()
            violations = []
            violations += _check_points_within_bounds('TX', metadata.get('source_x', []), metadata.get('source_z', []), bounds)
            violations += _check_points_within_bounds('RX', metadata.get('receiver_x', []), metadata.get('receiver_z', []), bounds)
            if violations:
                raise ValueError('Auto-generated survey has coordinates outside model bounds. ' + ' '.join(violations))
        state['survey_metadata'] = metadata
        push_message('Auto-generated missing intermediate survey files.')


def on_apply_outputs(_):
    try:
        design = state['fd_design']
        if design is None:
            design, outputs = compute_fd_design_from_widgets()
            state['fd_design'] = design
            fd_info.value = (
                f"dx raw={design['dx_raw_m']:.4f} m -> rounded={design['dx_m']:.4f} m, "
                f"dt raw={design['dt_raw_s']:.3e} s -> rounded={design['dt_adi_s']:.3e} s, "
                f"delta_min={outputs.delta_min_m:.4f} m, apertx={design['apertx_m']:.3f} m"
            )
            push_message('Auto-computed FD design from current resistivity/frequency values.')

        ensure_intermediate_files()

        model = state['model']
        wav_params = state['wavelet_params']
        if wav_params is None:
            raise RuntimeError('Wavelet parameters are unavailable after auto-generation.')

        target_dx = float(design['dx_m'])
        target_dt = float(design['dt_adi_s'])

        current_dx = float(model['dx'])
        current_dz = float(model['dz'])
        current_wdt = float(wav_params['dt'])

        # modint d*f arguments are target output sampling intervals, not scale factors.
        dxf = target_dx
        dzf = target_dx
        dtf = target_dt

        if dxf <= 0 or dzf <= 0 or dtf <= 0:
            raise RuntimeError('Invalid target interpolation sampling values computed from design.')

        # Use positivity-preserving linear interpolation for physical property models.
        interpolate_rss(state['sg_raw_path'], DEFAULT_SG, d1f=dxf, d3f=dzf, method='linear')
        interpolate_rss(state['ep_raw_path'], DEFAULT_EP, d1f=dxf, d3f=dzf, method='linear')
        interpolate_rss(state['wav_raw_path'], DEFAULT_WAV, d1f=dtf, method='sinc')

        n_clamped_sg = enforce_rss_min_value(DEFAULT_SG, min_value=1e-8)
        n_clamped_ep = enforce_rss_min_value(DEFAULT_EP, min_value=1e-8)
        if n_clamped_sg > 0 or n_clamped_ep > 0:
            push_message(
                f"Clamped non-positive samples after interpolation (sg={n_clamped_sg}, ep={n_clamped_ep})."
            )

        shutil.copyfile(state['survey_rss_path'], DEFAULT_SURVEY_RSS)
        push_message(f'Prepared final FD inputs at {FDMODEL_DIR}')

        # dtrec should follow the intermediate wavelet sampling interval.
        modcfg_dtrec = current_wdt
        update_modcfg_for_workshop(
            DEFAULT_MODCFG,
            dtrec_s=modcfg_dtrec,
            apertx_m=design['apertx_m'],
            sg_path='sg.rss',
            ep_path='ep.rss',
            wavelet_path='wav2d.rss',
            survey_path='Survey.rss',
        )

        setup_meta = {
            'flist_hz': [float(v) for v in parse_flist(flist_input.value)],
            'dt_wavelet_s': float(current_wdt),
            'ntx': int(nsx.value),
            'nrx': int(ngx.value),
            'tx0_m': float(sx0.value),
            'tz0_m': float(sz0.value),
            'dtx_m': float(ds_x.value),
            'rx0_m': float(gx0.value),
            'rz0_m': float(gz0.value),
            'drx_m': float(dgx.value),
            'dx_model_target_m': float(target_dx),
            'dt_model_target_s': float(target_dt),
            'dtrec_written_s': float(modcfg_dtrec),
            # Keep original setup SEG-Y reference so results export can reuse exact template headers/grid.
            'segy_template_path': str(Path(segy_path.value).expanduser().resolve()),
            'segy_ox': float(model['ox']),
            'segy_oz': float(model['oz']),
            'segy_dx': float(model['dx']),
            'segy_dz': float(model['dz']),
            'segy_nx': int(np.asarray(model['x']).size),
            'segy_nz': int(np.asarray(model['z']).size),
        }
        SETUP_METADATA_PATH.write_text(json.dumps(setup_meta, indent=2) + '\n')

        apply_info.value = (
            f"Applied outputs with interpolation. target dx={target_dx:.3f} m, target dt={target_dt:.3e} s. "
            f"Current model dx={current_dx:.3f} m, dz={current_dz:.3f} m, wavelet dt={current_wdt:.3e} s. "
            f"mod.cfg dtrec={modcfg_dtrec:.3e} s. Updated mod.cfg at {DEFAULT_MODCFG}"
        )
        push_message('mod.cfg updated with rounded-safe dtrec/apertx and file links.')
        push_message(f'Wrote setup metadata: {SETUP_METADATA_PATH}')
    except Exception as exc:
        push_message(f"Generate FD inputs failed: {exc}")
        push_message(traceback.format_exc())


bind_button_with_feedback(apply_outputs_btn, on_apply_outputs, 'Generating FD inputs (finalize setup)')

quit_btn = ipw.Button(description='Quit GUI server', button_style='danger')


def on_quit_gui(_):
    try:
        push_message('Quit requested. Stopping Voila server...')

        pid = None
        if VOILA_PID_FILE.exists():
            pid_text = VOILA_PID_FILE.read_text().strip()
            if pid_text:
                pid = int(pid_text)

        if pid is not None:
            if pid == os.getpid():
                push_message('PID file points to current kernel process; skipping server kill.')
            else:
                os.kill(pid, signal.SIGINT)
                push_message(f'Sent SIGINT to Voila server PID {pid}.')
        else:
            push_message('Voila PID file not found. Falling back to kernel shutdown only.')

    except Exception as exc:
        push_message(f'Quit warning: {exc}. Falling back to kernel shutdown only.')
    finally:
        try:
            from IPython import get_ipython
            ip = get_ipython()
            if ip and getattr(ip, 'kernel', None):
                ip.kernel.do_shutdown(restart=False)
        except Exception:
            pass
        try:
            os.kill(os.getpid(), signal.SIGTERM)
        except Exception:
            pass


bind_button_with_feedback(quit_btn, on_quit_gui, 'Shutting down setup GUI server')


status_area = ipw.Textarea(
    value='Ready.',
    description='Status',
    layout=ipw.Layout(width='980px', height='180px'),
    style={'description_width': '90px'},
)


from IPython.display import display

app_header = ipw.HTML(
    '<h2>01_fw_setup</h2>'
    '<p>Setup GUI for preparing FD modelling inputs. Only input required: a SEG-Y file. Intermediate files are written to a temporary session folder, and final FD modelling files are written to <code>FDmodel</code>.</p>'
    '<p><b>Plot controls:</b> all figures are interactive. Use mouse drag to zoom, scroll to zoom in/out, and double-click to reset.</p>'
)

segy_section = ipw.VBox([
    ipw.HTML('<h3>1) SEG-Y input and resistivity model</h3>'),
    segy_path,
    segy_browser,
    ep_value,
    ipw.HBox([load_model_btn, plot_model_btn, write_sg_ep_btn]),
    model_info,
    model_plot_out,
])

wavelet_section = ipw.VBox([
    ipw.HTML('<h3>2) Source wavelet (intermediate)</h3>'),
    flist_input,
    wavelet_dt,
    n_periods,
    alpha,
    ipw.HBox([create_wav_btn, plot_wav_btn]),
    wavelet_info,
    wavelet_plot_out,
])

survey_section = ipw.VBox([
    ipw.HTML('<h3>3) Survey geometry</h3>'),
    ipw.HBox([sx0, sz0, ds_x, nsx]),
    ipw.HBox([gx0, gz0, dgx, ngx]),
    gen_survey_btn,
    survey_info,
])

fd_section = ipw.VBox([
    ipw.HTML('<h3>4) FD design helper (resistivity-based inputs)</h3>'),
    ipw.HBox([f_min, f_max, dim]),
    ipw.HBox([rho_min, rho_max]),
    ipw.HBox([points_per_skin, k_cfl, apertx_value]),
    run_design_btn,
    fd_info,
])

apply_section = ipw.VBox([
    ipw.HTML('<h3>5) Generate FD inputs (Finalize setup)</h3>'),
    ipw.HTML(f'<code>mod.cfg</code> target: <code>{DEFAULT_MODCFG}</code><br><code>Intermediate folder</code>: <code>{TMP_DIR}</code>'),
    apply_outputs_btn,
    apply_info,
    ipw.HTML('<b>Session control</b>: use this button to stop the server started by <code>start_gui.sh</code>.'),
    quit_btn,
])

gui = ipw.VBox([
    app_header,
    segy_section,
    wavelet_section,
    survey_section,
    fd_section,
    apply_section,
    status_area,
])

display(gui)
