In [26]:
from pathlib import Path
from typing import List, Optional
import numpy as np
import matplotlib.pyplot as plt

def load_config(config_path: str | Path) -> dict:
    config_path = Path(config_path)
    with config_path.open("r") as f:
        return yaml.safe_load(f)


In [27]:
import yaml

cfg = load_config(r"config_example.yaml")

dataset_id = cfg["dataset_id"]
session_id = cfg["session_id"]
protocol_list = cfg["protocols"]

raw_root = cfg["paths"]["raw_root"]
processed_root = cfg["paths"]["processed_root"]
save_dir = cfg["paths"]["save_dir"]

In [173]:
dataset_id = 'JPCM-09177'

In [174]:
FRAME_RATE = 20.0
STIM_ONSET_FRAME = 100
T_PRE = 2.0
T_POST = 5.0

In [175]:
def _stim_window_frames():
    pre = int(T_PRE * FRAME_RATE)
    post = int(T_POST * FRAME_RATE)
    return STIM_ONSET_FRAME - pre, STIM_ONSET_FRAME + post


In [176]:
def load_all_trial_traces_for_animal(
    dataset_id: str,
    processed_root: str | Path = "data/processed",
    stim_values: list[float] = (22.0, 42.0),
    trace_basename: str = "roi_trace",
    roi_mode: str = "auto",
    align_to_stimulus: bool = True,
):
    """
    Load all individual ROI traces for one animal across all sessions.

    Returns
    -------
    traces_by_stim : dict
        stim_value -> np.ndarray (n_trials, T)
    """

    processed_root = Path(processed_root)
    dataset_dir = processed_root / dataset_id

    if not dataset_dir.exists():
        raise FileNotFoundError(f"No processed data found for {dataset_id}")

    session_ids = sorted(p.name for p in dataset_dir.iterdir() if p.is_dir())
    if not session_ids:
        raise RuntimeError(f"No sessions found for {dataset_id}")

    traces_by_stim = {stim: [] for stim in stim_values}

    for sess in session_ids:
        for stim in stim_values:
            traces = _load_traces_for_condition(
                processed_root=processed_root,
                dataset_id=dataset_id,
                session_id=sess,
                trace_basename=trace_basename,
                stim_value=stim,
                roi_mode=roi_mode,
            )

            traces_by_stim[stim].extend(traces)

    # convert to arrays
    for stim, traces in traces_by_stim.items():
        if not traces:
            raise RuntimeError(f"No traces found for {dataset_id}, stim={stim}")

        traces = np.stack(traces, axis=0)  # (n_trials, T)

        if align_to_stimulus:
            f0, f1 = _stim_window_frames()
            traces = traces[:, f0:f1]

        traces_by_stim[stim] = traces

    return traces_by_stim


In [177]:
def _load_traces_for_condition(
    processed_root: str | Path,
    dataset_id: str,
    session_id: str | None,
    trace_basename: str,
    stim_value: float,
    roi_mode: str = "auto",
) -> List[np.ndarray]:

    processed_root = Path(processed_root)
    base = processed_root / dataset_id
    if session_id is not None:
        base = base / session_id

    fname = _trace_filename(trace_basename, stim_value,roi_mode)
    paths = sorted(base.glob(f"**/{fname}"))


    traces = []
    for p in paths:
        tr = np.load(p)
        tr = np.squeeze(tr)
        if tr.ndim == 1:
            traces.append(tr)

    if not traces:
        raise RuntimeError(
            f"No traces found for stim={stim_value} "
            f"(roi_mode='{roi_mode}')"
        )

    return traces

In [178]:
def _trace_filename(
    trace_basename: str,
    stim: float,
    roi_mode: str,
) -> str:
    """
    Construct exact trace filename.
    """
    if roi_mode == "auto":
        return f"{trace_basename}_stim_{stim}.npy"
    elif roi_mode == "manual":
        return f"{trace_basename}_manual_stim_{stim}.npy"
    else:
        raise ValueError("roi_mode must be 'auto' or 'manual'")

In [179]:
traces = load_all_trial_traces_for_animal(
    dataset_id=dataset_id,
    processed_root=processed_root,
    roi_mode="manual",
)

print(traces[22.0].shape)  # (n_trials, T)
print(traces[42.0].shape)


(25, 140)
(25, 140)


In [180]:
traces[22.0].shape

(25, 140)

In [181]:
base_lines = np.mean(traces[22.0][:,:40],axis=1)

peaks = np.nanmean(traces[22.0][:,40:80],axis=1)

In [182]:
from scipy.stats import ttest_rel
import numpy as np

def paired_baseline_vs_peak_test(
    baseline_vals: np.ndarray,
    peak_vals: np.ndarray,
    alternative: str = "two-sided",  # "two-sided", "greater", "less"
):
    """
    Paired statistical test comparing baseline vs peak values per trial.

    Parameters
    ----------
    baseline_vals : np.ndarray (n_trials,)
    peak_vals : np.ndarray (n_trials,)
    alternative : str
        Hypothesis direction.

    Returns
    -------
    results : dict
        {
            "t_stat": float,
            "p_value": float,
            "mean_baseline": float,
            "mean_peak": float,
            "mean_diff": float,
            "n": int,
        }
    """

    baseline_vals = np.asarray(baseline_vals)
    peak_vals = np.asarray(peak_vals)

    if baseline_vals.shape != peak_vals.shape:
        raise ValueError("baseline_vals and peak_vals must have same shape")

    t_stat, p_val = ttest_rel(
        peak_vals,
        baseline_vals,
        alternative=alternative,
    )

    return {
        "t_stat": float(t_stat),
        "p_value": float(p_val),
        "mean_baseline": float(np.mean(baseline_vals)),
        "mean_peak": float(np.mean(peak_vals)),
        "mean_diff": float(np.mean(peak_vals - baseline_vals)),
        "n": baseline_vals.size,
    }


In [183]:
base_lines = np.mean(traces[22.0][:,:40],axis=1)
peaks = np.nanmean(traces[22.0][:,40:80],axis=1)

In [184]:
result = paired_baseline_vs_peak_test(
   base_lines,
    peaks)
    

In [185]:
result

{'t_stat': 8.500695893657749,
 'p_value': 1.0599273078336895e-08,
 'mean_baseline': -2.2956457668442398e-05,
 'mean_peak': 0.009928076444925863,
 'mean_diff': 0.009951032902594307,
 'n': 25}

In [186]:
base_lines = np.mean(traces[42.0][:,:40],axis=1)
peaks = np.nanmean(traces[42.0][:,40:80],axis=1)

In [187]:
result = paired_baseline_vs_peak_test(
   base_lines,
    peaks)
    

In [188]:
result

{'t_stat': 11.685873917978057,
 'p_value': 2.160107284765482e-11,
 'mean_baseline': -0.001105807909750638,
 'mean_peak': 0.009034824374908568,
 'mean_diff': 0.010140632284659206,
 'n': 25}