In [21]:
from pathlib import Path
from typing import Dict, Tuple, Iterable, List
from itertools import product
import os
import pickle

import numpy as np
from scipy.stats import norm

import plotly.graph_objects as go
import ipywidgets as widgets

# ---------------------------------------------------------------------
# CONFIG
# ---------------------------------------------------------------------
RUNS_DIR = Path("runs_test4")

AMPAR_VALUES: List[float] = [1.0, 0.63]
RIN_VALUES:   List[float] = [1.0, 1.22]
ETA_VALUES:   List[float] = [0.0, 0.01, 0.1]   # sigma_eta
TEMP_VALUES:  List[float] = [0.0, 0.01, 0.1]   # sigma_temp

IDIR_VALUES:   List[float] = [0.8, 0.7, 0.6, 0.55, 0.53, 0.52, 0.51, 0.505, 0.5]
TRIAL_INDICES: List[int]   = list(range(20))   # trial00 … trial19

CACHE_FILE = Path("snr_grid_results_v2.pkl")   # new cache, different name


In [22]:
def run_id(
    ampar: float,
    rin: float,
    sigma_eta: float,
    sigma_temp: float,
    idir: float,
    trial: int,
) -> str:
    """
    Build folder name for a simulation run.

    NOTE: key order in code is (ampar, rin, sigma_eta, sigma_temp),
    but folder names still use sigmatemp{sigma_temp}_sigmaeta{sigma_eta}
    to match existing data on disk.
    """
    return (
        f"g{ampar:.3f}_rin{rin:.3f}_sigmatemp{sigma_temp:.3f}_"
        f"sigmaeta{sigma_eta:.3f}_idir{idir:.3f}_trial{trial:02d}"
    )


def load_state_matrix(folder: Path) -> np.ndarray:
    """
    Load a single run’s state history as neurons × generations (float32).
    """
    history_file = folder / "state_history.npy"
    if not history_file.exists():
        raise FileNotFoundError(f"Missing state history file: {history_file}")
    states = np.load(history_file)  # generations × neurons
    if states.ndim != 2:
        raise ValueError(f"Expected 2-D array in {history_file}, got {states.shape}")
    return states.T.astype(np.float32, copy=False)  # neurons × generations


In [23]:
def load_state_tensor_for_params(
    ampar: float,
    rin: float,
    sigma_eta: float,
    sigma_temp: float,
    runs_dir: Path = RUNS_DIR,
    idir_values: Iterable[float] = IDIR_VALUES,
    trials: Iterable[int] = TRIAL_INDICES,
) -> np.ndarray:
    """
    Load state history for a single (ampar, rin, sigma_eta, sigma_temp) combo.

    Returns:
        state: (neurons, generations, trials, idir)
    """
    idir_blocks = []
    for idir in idir_values:
        trial_blocks = []
        for trial in trials:
            folder = runs_dir / run_id(ampar, rin, sigma_eta, sigma_temp, idir, trial)
            trial_blocks.append(load_state_matrix(folder))
        trial_cube = np.stack(trial_blocks, axis=2)  # neurons × generations × trials
        idir_blocks.append(trial_cube)

    state = np.stack(idir_blocks, axis=3)  # neurons × generations × trials × idir
    return state


In [24]:
def smooth_state(state: np.ndarray, window_size: int = 5) -> np.ndarray:
    """
    Rolling mean over generations axis (axis=1); keeps other dims.

    state: (neurons, generations, trials, idir)
    """
    if window_size < 1:
        raise ValueError("window_size must be >= 1")
    cumsum = np.cumsum(state, axis=1)
    cumsum[:, window_size:] = cumsum[:, window_size:] - cumsum[:, :-window_size]
    return cumsum[:, window_size - 1:] / window_size


In [25]:
def compute_snr_per_delta(
    state: np.ndarray,
    idir_values: np.ndarray,
    eps: float = 1e-12,
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Compute SNR for all ΔIdir pairs.

    state: (neurons, generations, trials, idir)
    idir_values: array-like of len = idir

    Returns:
        deltas: (N_measurements,) absolute ΔIdir for each trial/pair
        snrs:   (N_measurements,) δ^T C^{-1} δ for each trial/pair

    More efficient than the original: for each trial and idir_i, we compute
    the covariance & eigendecomp ONCE, then reuse it for all pairs (idir_i, idir_j).
    """
    neurons, gens, trials, idirs = state.shape
    idir_values = np.asarray(idir_values, dtype=np.float32)
    if idir_values.shape[0] != idirs:
        raise ValueError("len(idir_values) must match state.shape[3]")

    # mean over generations → (neurons, trials, idir)
    means = state.mean(axis=1)

    deltas: List[float] = []
    snrs: List[float] = []

    for trial in range(trials):
        # For this trial, we will compute eigendecomp per idir_i once
        for idir_i in range(idirs):
            # mean activity for this trial & idir_i
            mu_i = means[:, trial, idir_i]

            # covariance over neurons (from neurons × generations slice)
            cov_i = np.cov(state[:, :, trial, idir_i], ddof=1)
            evals, evecs = np.linalg.eigh(cov_i)
            evals_safe = evals + eps

            # Now compare idir_i to all higher-index idirs to avoid duplicates
            for idir_j in range(idirs):
                if idir_j <= idir_i:
                    continue

                delta_idir = float(abs(idir_values[idir_i] - idir_values[idir_j]))
                delta_mu = mu_i - means[:, trial, idir_j]  # (neurons,)

                proj = evecs.T @ delta_mu  # (neurons,)
                snr_trial = float(np.sum((proj ** 2) / evals_safe))

                deltas.append(delta_idir)
                snrs.append(snr_trial)

    return np.asarray(deltas, dtype=np.float32), np.asarray(snrs, dtype=np.float32)


In [26]:
def snr_delta_for_params(
    ampar: float,
    rin: float,
    sigma_eta: float,
    sigma_temp: float,
    window_size: int = 5,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Load state, smooth, compute SNR vs ΔIdir, then average per unique ΔIdir.

    Returns:
        unique_deltas: (M,)
        mean_snrs:     (M,)
        err_snrs:      (M,)  SEM across measurements
    """
    state = load_state_tensor_for_params(ampar, rin, sigma_eta, sigma_temp)
    state_smoothed = smooth_state(state, window_size=window_size)

    deltas, snrs = compute_snr_per_delta(state_smoothed, np.array(IDIR_VALUES))

    unique_deltas = np.unique(deltas)
    mean_snrs = np.empty_like(unique_deltas, dtype=np.float32)
    err_snrs = np.empty_like(unique_deltas, dtype=np.float32)

    for idx, delta in enumerate(unique_deltas):
        mask = (deltas == delta)
        vals = snrs[mask]
        mean_snrs[idx] = vals.mean()
        # standard error of the mean
        err_snrs[idx] = vals.std(ddof=1) / np.sqrt(len(vals))

    return unique_deltas, mean_snrs, err_snrs


In [None]:
def load_or_compute_snr_grid(
    cache_file: Path = CACHE_FILE,
):
    """
    Compute SNR vs ΔIdir for all (ampar, rin, sigma_eta, sigma_temp) combos,
    store in a dict, and cache to disk. Reuse cache if it exists.
    """
    if cache_file.exists():
        print(f"Loading cached results from {cache_file}...")
        with open(cache_file, "rb") as f:
            data = pickle.load(f)
        return (
            data["AMPAR_VALUES"],
            data["RIN_VALUES"],
            data["ETA_VALUES"],
            data["TEMP_VALUES"],
            data["results"],
        )

    print("Computing SNR grid from scratch...")
    results: Dict[Tuple[float, float, float, float], Dict[str, np.ndarray]] = {}

    for ampar, rin, sigma_eta, sigma_temp in product(
        AMPAR_VALUES, RIN_VALUES, ETA_VALUES, TEMP_VALUES
    ):
        key = (ampar, rin, sigma_eta, sigma_temp)
        print("  Computing", key, "...")
        idir_delta, snr, err = snr_delta_for_params(
            ampar=ampar,
            rin=rin,
            sigma_eta=sigma_eta,
            sigma_temp=sigma_temp,
        )
        results[key] = dict(
            idir_delta=idir_delta,
            snr=snr,
            err=err,
        )

    data = dict(
        AMPAR_VALUES=AMPAR_VALUES,
        RIN_VALUES=RIN_VALUES,
        ETA_VALUES=ETA_VALUES,
        TEMP_VALUES=TEMP_VALUES,
        results=results,
    )
    with open(cache_file, "wb") as f:
        pickle.dump(data, f)

    print(f"Saved SNR grid to {cache_file}")
    return AMPAR_VALUES, RIN_VALUES, ETA_VALUES, TEMP_VALUES, results


AMPAR_VALUES, RIN_VALUES, ETA_VALUES, TEMP_VALUES, results = load_or_compute_snr_grid()


Computing SNR grid from scratch...
  Computing (1.0, 1.0, 0.0, 0.0) ...
  Computing (1.0, 1.0, 0.0, 0.01) ...
  Computing (1.0, 1.0, 0.0, 0.1) ...
  Computing (1.0, 1.0, 0.01, 0.0) ...
  Computing (1.0, 1.0, 0.01, 0.01) ...
  Computing (1.0, 1.0, 0.01, 0.1) ...
  Computing (1.0, 1.0, 0.1, 0.0) ...
  Computing (1.0, 1.0, 0.1, 0.01) ...
  Computing (1.0, 1.0, 0.1, 0.1) ...
  Computing (1.0, 1.22, 0.0, 0.0) ...
  Computing (1.0, 1.22, 0.0, 0.01) ...
  Computing (1.0, 1.22, 0.0, 0.1) ...


In [14]:
import pickle

with open("snr_grid_results.pkl", "rb") as f:
    data = pickle.load(f)

AMPAR_VALUES = data["AMPAR_VALUES"]
RIN_VALUES   = data["RIN_VALUES"]
ETA_VALUES   = data["ETA_VALUES"]
TEMP_VALUES  = data["TEMP_VALUES"]
results      = data["results"]

# Example: compare control vs FR at some (sigma_eta, sigma_temp)
def get_curve(ampar, rin, sigma_eta, sigma_temp, metric="snr"):
    rec = results[(ampar, rin, sigma_eta, sigma_temp)]
    x = rec["idir_delta"]
    snr = rec["snr"]
    err = rec["err"]
    return x, snr, err

x_c, snr_c, err_c = get_curve(1.0, 1.0, 0.1, 0.1)
x_f, snr_f, err_f = get_curve(0.63, 1.22, 0.1, 0.1)

print("Control mean SNR:", snr_c.mean())
print("FR      mean SNR:", snr_f.mean())

Control mean SNR: 18.42491728006356
FR      mean SNR: 24.182609275064017


In [15]:
CACHE_FILE = "snr_grid_results.pkl"

if os.path.exists(CACHE_FILE):
    print("Loading cached results...")
    with open(CACHE_FILE, "rb") as f:
        data = pickle.load(f)
    AMPAR_VALUES = data["AMPAR_VALUES"]
    RIN_VALUES   = data["RIN_VALUES"]
    ETA_VALUES   = data["ETA_VALUES"]
    TEMP_VALUES  = data["TEMP_VALUES"]
    results      = data["results"]


Loading cached results...


In [16]:
def snr_to_perf_and_err(snr: np.ndarray, snr_err: np.ndarray):
    """
    Convert SNR and its std error into performance and propagated error.
    SNR error is std of SNR, not variance.
    """
    snr = snr.astype(np.float64, copy=False)
    snr_err = snr_err.astype(np.float64, copy=False)

    dprime = np.sqrt(snr)
    z = dprime / np.sqrt(2.0)
    perf = norm.cdf(z)

    pdf = norm.pdf(z)
    dP_dSNR = pdf * (1.0 / (2.0 * np.sqrt(2.0 * snr)))
    perf_err = np.abs(dP_dSNR) * snr_err

    return perf.astype(np.float32), perf_err.astype(np.float32)


def get_curve(
    ampar: float,
    rin: float,
    sigma_eta: float,
    sigma_temp: float,
    metric: str = "perf",
):
    """
    Fetch x, y, yerr, y_label for a given parameter combo.

    metric: 'perf' (performance) or 'snr'
    """
    rec = results[(ampar, rin, sigma_eta, sigma_temp)]
    x = rec["idir_delta"]
    snr = rec["snr"]
    snr_err = rec["err"]

    if metric == "perf":
        y, yerr = snr_to_perf_and_err(snr, snr_err)
        y_label = "Performance"
    elif metric == "snr":
        y, yerr = snr, snr_err
        y_label = "SNR"
    else:
        raise ValueError(f"Unknown metric {metric!r}")

    return x, y, yerr, y_label


In [None]:
import plotly.graph_objects as go
import ipywidgets as widgets
from IPython.display import display, clear_output

# -------- defaults for each trace --------
trace1_defaults = dict(ampar=1.0,  rin=1.0,  sigma_eta=0.1, sigma_temp=0.1)
trace2_defaults = dict(ampar=0.63, rin=1.22, sigma_eta=0.1, sigma_temp=0.1)

DEFAULT_METRIC = "perf"  # 'perf' or 'snr'

# -------- widgets for Trace 1 --------
t1_ampar = widgets.ToggleButtons(
    options=[(str(v), v) for v in AMPAR_VALUES],
    description="ampar 1:",
    value=trace1_defaults["ampar"],
)
t1_rin = widgets.ToggleButtons(
    options=[(str(v), v) for v in RIN_VALUES],
    description="rin 1:",
    value=trace1_defaults["rin"],
)
t1_eta = widgets.ToggleButtons(
    options=[(str(v), v) for v in ETA_VALUES],
    description="σ_eta 1:",
    value=trace1_defaults["sigma_eta"],
)
t1_temp = widgets.ToggleButtons(
    options=[(str(v), v) for v in TEMP_VALUES],
    description="σ_temp 1:",
    value=trace1_defaults["sigma_temp"],
)

# -------- widgets for Trace 2 --------
t2_ampar = widgets.ToggleButtons(
    options=[(str(v), v) for v in AMPAR_VALUES],
    description="ampar 2:",
    value=trace2_defaults["ampar"],
)
t2_rin = widgets.ToggleButtons(
    options=[(str(v), v) for v in RIN_VALUES],
    description="rin 2:",
    value=trace2_defaults["rin"],
)
t2_eta = widgets.ToggleButtons(
    options=[(str(v), v) for v in ETA_VALUES],
    description="σ_eta 2:",
    value=trace2_defaults["sigma_eta"],
)
t2_temp = widgets.ToggleButtons(
    options=[(str(v), v) for v in TEMP_VALUES],
    description="σ_temp 2:",
    value=trace2_defaults["sigma_temp"],
)

# -------- global metric toggle --------
metric_widget = widgets.ToggleButtons(
    options=[("Performance", "perf"), ("SNR", "snr")],
    description="y-axis:",
    value=DEFAULT_METRIC,
)

# -------- output area for the plot --------
out = widgets.Output()

# -------- callback to update figure when anything changes --------
def update_plot(change=None):
    # read current widget values
    p1 = dict(
        ampar=t1_ampar.value,
        rin=t1_rin.value,
        sigma_eta=t1_eta.value,
        sigma_temp=t1_temp.value,
    )
    p2 = dict(
        ampar=t2_ampar.value,
        rin=t2_rin.value,
        sigma_eta=t2_eta.value,
        sigma_temp=t2_temp.value,
    )
    metric = metric_widget.value

    x1, y1, yerr1, y_label = get_curve(metric=metric, **p1)
    x2, y2, yerr2, _       = get_curve(metric=metric, **p2)

    fig = go.Figure()

    fig.add_trace(
        go.Scatter(
            x=x1,
            y=y1,
            mode="markers+lines",
            error_y=dict(
                type="data",
                array=yerr1,
                visible=True,
                thickness=1.5,
                width=4,
            ),
            name=(
                f"Trace 1 (ampar={p1['ampar']}, rin={p1['rin']}, "
                f"σ_eta={p1['sigma_eta']}, σ_temp={p1['sigma_temp']})"
            ),
        )
    )

    fig.add_trace(
        go.Scatter(
            x=x2,
            y=y2,
            mode="markers+lines",
            error_y=dict(
                type="data",
                array=yerr2,
                visible=True,
                thickness=1.5,
                width=4,
            ),
            name=(
                f"Trace 2 (ampar={p2['ampar']}, rin={p2['rin']}, "
                f"σ_eta={p2['sigma_eta']}, σ_temp={p2['sigma_temp']})"
            ),
        )
    )

    fig.update_layout(
        template="plotly_white",
        xaxis_title="idir_delta",
        yaxis_title=y_label,
        title=(
            f"{y_label} vs idir_delta — "
            f"Trace 1: ampar={p1['ampar']}, rin={p1['rin']}, "
            f"σ_eta={p1['sigma_eta']}, σ_temp={p1['sigma_temp']} | "
            f"Trace 2: ampar={p2['ampar']}, rin={p2['rin']}, "
            f"σ_eta={p2['sigma_eta']}, σ_temp={p2['sigma_temp']}"
        ),
        legend=dict(
            orientation="h",
            yanchor="bottom",
            y=1.02,
            xanchor="right",
            x=1.0,
        ),
    )

    with out:
        clear_output(wait=True)
        fig.show()

# attach callback
for w in [
    t1_ampar, t1_rin, t1_eta, t1_temp,
    t2_ampar, t2_rin, t2_eta, t2_temp,
    metric_widget,
]:
    w.observe(update_plot, names="value")

# layout widgets
trace1_box = widgets.VBox([widgets.HTML("<b>Trace 1</b>"),
                           t1_ampar, t1_rin, t1_eta, t1_temp])
trace2_box = widgets.VBox([widgets.HTML("<b>Trace 2</b>"),
                           t2_ampar, t2_rin, t2_eta, t2_temp])
metric_box = widgets.VBox([widgets.HTML("<b>Metric</b>"),
                           metric_widget])

controls = widgets.HBox([trace1_box, trace2_box, metric_box])

display(controls, out)

# draw initial figure once
update_plot()


HBox(children=(VBox(children=(HTML(value='<b>Trace 1</b>'), ToggleButtons(description='ampar 1:', options=(('1…

Output()