# Single-Animal Baseline Signal Feature Analysis

This notebook loads per-mouse mean/SEM aligned photometry data generated by `SANDBOX_3_mean_sem_grandav.py` and extracts peak amplitude, time-to-peak, and half-width decay metrics for `z_470_Baseline` and `z_560_Baseline` within a configurable post-alignment window.


In [None]:
from pathlib import Path
from typing import Dict, Tuple

import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import numpy as np
import pandas as pd

POST_ALIGNMENT_WINDOW: Tuple[float, float] = (0.0, 2.0)
EVENT_LABEL = "Apply halt_2s"

# Point this to the experiment-day folder that contains the per-mouse *_processedData directories
EXPERIMENT_DAY_DIR = Path(
    "/Volumes/sambashare/data/ONIX/20241125_Cohort1_rotation/Visual_mismatch_day4"
)
OUTPUT_DIR = EXPERIMENT_DAY_DIR / "aligned_data_feature_summary"

SIGNAL_COLUMNS = [
    "z_470_Baseline",
    "z_560_Baseline",
]

assert EXPERIMENT_DAY_DIR.exists(), f"Experiment day directory not found: {EXPERIMENT_DAY_DIR}"


In [None]:
def find_mouse_csvs(
    experiment_day_dir: Path,
    event_label: str,
    suffix: str = "mean_sem_averaged.csv",
) -> Dict[str, Path]:
    """Locate per-mouse CSVs matching the requested event label."""
    csv_paths: Dict[str, Path] = {}
    pattern = f"*{event_label}_{suffix}"

    for processed_dir in experiment_day_dir.glob("*_processedData"):
        aligned_dir = processed_dir / "aligned_data"
        if not aligned_dir.exists():
            continue

        for csv_path in aligned_dir.glob(pattern):
            stem = csv_path.stem
            suffix_token = f"_{event_label}_{suffix.replace('.csv', '')}"
            if stem.endswith(suffix_token):
                mouse_id = stem[: -len(suffix_token)]
            else:
                mouse_id = stem

            if mouse_id in csv_paths:
                raise ValueError(
                    f"Multiple CSVs found for mouse '{mouse_id}'.\n"
                    f"Existing: {csv_paths[mouse_id]}\nNew: {csv_path}"
                )
            csv_paths[mouse_id] = csv_path

    if not csv_paths:
        raise FileNotFoundError(
            f"No CSV files with pattern '*{event_label}_{suffix}' found in {experiment_day_dir}"
        )

    return csv_paths


def _compute_half_width_decay(
    times: np.ndarray,
    values: np.ndarray,
    peak_index: int,
) -> float:
    """Estimate the time (relative to peak) at which the signal decays to half the peak value."""
    peak_value = values[peak_index]
    half_threshold = peak_value * 0.5

    for i in range(peak_index + 1, len(values)):
        if values[i] <= half_threshold:
            t1, t2 = times[i - 1], times[i]
            v1, v2 = values[i - 1], values[i]
            if v2 == v1:
                crossing_time = t2
            else:
                fraction = (half_threshold - v1) / (v2 - v1)
                crossing_time = t1 + fraction * (t2 - t1)
            return crossing_time - times[peak_index]

    return np.nan


def compute_signal_metrics(
    df: pd.DataFrame,
    signal_column: str,
    window: Tuple[float, float],
    time_column: str = "Time (s)",
) -> Dict[str, float]:
    """Compute peak, time-to-peak, and half-width decay for a signal within a time window."""
    if signal_column not in df.columns:
        return {"peak": np.nan, "time_to_peak": np.nan, "half_width_decay": np.nan}

    window_mask = (df[time_column] >= window[0]) & (df[time_column] <= window[1])
    window_df = df.loc[window_mask, [time_column, signal_column]].dropna()

    if window_df.empty:
        return {"peak": np.nan, "time_to_peak": np.nan, "half_width_decay": np.nan}

    # Identify peak within the window
    peak_idx_in_window = window_df[signal_column].idxmax()
    peak_value = df.loc[peak_idx_in_window, signal_column]
    peak_time = df.loc[peak_idx_in_window, time_column]

    times = window_df[time_column].to_numpy()
    values = window_df[signal_column].to_numpy()
    peak_index = window_df.index.get_loc(peak_idx_in_window)

    half_width_decay = _compute_half_width_decay(times, values, peak_index)

    return {
        "peak": float(peak_value),
        "time_to_peak": float(peak_time),
        "half_width_decay": float(half_width_decay),
    }


def sem(series: pd.Series) -> float:
    """Compute the standard error of the mean for a pandas Series."""
    clean = series.dropna()
    n = len(clean)
    if n <= 1:
        return np.nan
    return float(clean.std(ddof=1) / np.sqrt(n))


def _convert_cm_to_inches(cm: float) -> float:
    """Convert centimetres to inches for consistent figure sizing."""
    return float(cm) * 0.3937007874


def build_mouse_color_map(mouse_labels):
    """Assign consistent colors to each mouse using the gnuplot2 palette."""
    unique_mice = sorted(set(mouse_labels))
    if not unique_mice:
        return {}
    colors = plt.cm.gnuplot2(np.linspace(0, 0.95, len(unique_mice)))
    return {mouse: colors[idx] for idx, mouse in enumerate(unique_mice)}


def format_metric_label(metric: str) -> str:
    mapping = {
        "peak": "Peak Amplitude",
        "time_to_peak": "Time to Peak (s)",
        "half_width_decay": "Decay to Half-Peak (s)",
    }
    return mapping.get(metric, metric.replace("_", " ").title())


In [28]:
mouse_csvs = find_mouse_csvs(EXPERIMENT_DAY_DIR, EVENT_LABEL)
print(f"Found {len(mouse_csvs)} animals:")
for mouse_id, csv_path in mouse_csvs.items():
    print(f"  {mouse_id}: {csv_path}")


Found 4 animals:
  B6J2718: /Volumes/sambashare/data/ONIX/20241125_Cohort1_rotation/Visual_mismatch_day4/B6J2718-2024-12-11T13-49-13_processedData/aligned_data/B6J2718_Apply halt_2s_mean_sem_averaged.csv
  B6J2721: /Volumes/sambashare/data/ONIX/20241125_Cohort1_rotation/Visual_mismatch_day4/B6J2721-2024-12-11T15-05-01_processedData/aligned_data/B6J2721_Apply halt_2s_mean_sem_averaged.csv
  B6J2722: /Volumes/sambashare/data/ONIX/20241125_Cohort1_rotation/Visual_mismatch_day4/B6J2722-2024-12-11T15-42-27_processedData/aligned_data/B6J2722_Apply halt_2s_mean_sem_averaged.csv
  B6J2719: /Volumes/sambashare/data/ONIX/20241125_Cohort1_rotation/Visual_mismatch_day4/B6J2719-2024-12-11T14-26-30_processedData/aligned_data/B6J2719_Apply halt_2s_mean_sem_averaged.csv


In [29]:
mouse_data = {
    mouse_id: pd.read_csv(csv_path)
    for mouse_id, csv_path in mouse_csvs.items()
}

# Quick sanity check on columns
first_mouse = next(iter(mouse_data))
print(f"Columns for {first_mouse}: {mouse_data[first_mouse].columns.tolist()[:10]} ...")


Columns for B6J2718: ['Time (s)', 'Velocity_0X_Baseline', 'Motor_Velocity_Baseline', 'z_470_Baseline', 'z_560_Baseline', 'Velocity_0X', 'Motor_Velocity', 'z_470', 'z_560', 'Velocity_0X_Baseline_SEM'] ...


In [30]:
metrics_records = []
for mouse_id, df in mouse_data.items():
    for signal in SIGNAL_COLUMNS:
        metrics = compute_signal_metrics(df, signal, POST_ALIGNMENT_WINDOW)
        metrics_records.append({"mouse": mouse_id, "signal": signal, **metrics})

metrics_df = pd.DataFrame(metrics_records)
metrics_df


Unnamed: 0,mouse,signal,peak,time_to_peak,half_width_decay
0,B6J2718,z_470_Baseline,-0.000866,0.228,-0.002358
1,B6J2718,z_560_Baseline,1.500513,0.411,0.910165
2,B6J2721,z_470_Baseline,0.006479,0.457,0.005832
3,B6J2721,z_560_Baseline,1.453401,0.472,0.527004
4,B6J2722,z_470_Baseline,0.058165,0.204,0.012986
5,B6J2722,z_560_Baseline,1.933666,0.387,0.707341
6,B6J2719,z_470_Baseline,-0.095165,0.091,-0.525911
7,B6J2719,z_560_Baseline,1.539122,0.324,0.601277


In [31]:
summary_records = []
for signal, group in metrics_df.groupby("signal"):
    for metric in ["peak", "time_to_peak", "half_width_decay"]:
        values = group[metric].dropna()
        summary_records.append(
            {
                "signal": signal,
                "metric": metric,
                "n": len(values),
                "mean": float(values.mean()) if len(values) else np.nan,
                "sem": sem(values) if len(values) else np.nan,
            }
        )

summary_df = pd.DataFrame(summary_records).sort_values(["signal", "metric"])
summary_df


Unnamed: 0,signal,metric,n,mean,sem
2,z_470_Baseline,half_width_decay,4,-0.127363,0.132886
0,z_470_Baseline,peak,4,-0.007847,0.031932
1,z_470_Baseline,time_to_peak,4,0.245,0.076719
5,z_560_Baseline,half_width_decay,4,0.686447,0.083248
3,z_560_Baseline,peak,4,1.606675,0.110397
4,z_560_Baseline,time_to_peak,4,0.3985,0.030606


In [None]:
def plot_signal_metrics(
    metrics: pd.DataFrame,
    summary: pd.DataFrame,
    metric_order=("peak", "time_to_peak", "half_width_decay"),
    signal_order=None,
    save_dir: Path | None = None,
    file_prefix: str = "signal_metrics",
):
    """Plot per-mouse metrics with group mean ± SEM for each signal."""
    if metrics.empty:
        raise ValueError("No metrics available to plot. Did you run the computation cells?")

    signal_order = signal_order or list(metrics["signal"].unique())
    mouse_colors = build_mouse_color_map(metrics["mouse"].unique())

    plt.rcParams.update(
        {
            "font.size": 15,
            "font.family": "sans-serif",
            "font.sans-serif": ["Arial"],
            "axes.titlesize": 15,
            "axes.labelsize": 15,
            "legend.fontsize": 12,
            "xtick.labelsize": 15,
            "ytick.labelsize": 15,
        }
    )

    for signal in signal_order:
        signal_metrics = metrics[metrics["signal"] == signal]
        if signal_metrics.empty:
            continue

        summary_subset = (
            summary[summary["signal"] == signal].set_index("metric")
            if not summary.empty
            else pd.DataFrame()
        )

        fig, ax = plt.subplots(
            figsize=(
                _convert_cm_to_inches(12.0),
                _convert_cm_to_inches(8.0),
            )
        )

        x_positions = np.arange(len(metric_order), dtype=float)

        handles: list = []
        labels: list = []

        # plot individual mouse trajectories across metrics
        for mouse_id, mouse_df in signal_metrics.groupby("mouse"):
            row = mouse_df.iloc[0]
            y_values = [row.get(metric, np.nan) for metric in metric_order]
            ax.plot(
                x_positions,
                y_values,
                color=mouse_colors.get(mouse_id),
                linewidth=1.5,
                alpha=0.7,
                zorder=2,
            )
            scatter = ax.scatter(
                x_positions,
                y_values,
                color=mouse_colors.get(mouse_id),
                edgecolor="black",
                linewidth=0.6,
                s=80,
                zorder=3,
            )
            handles.append(scatter)
            labels.append(mouse_id)

        # overlay mean ± SEM
        group_mean_plotted = False
        for idx, metric in enumerate(metric_order):
            if summary_subset.empty or metric not in summary_subset.index:
                continue
            mean_val = summary_subset.at[metric, "mean"]
            sem_val = summary_subset.at[metric, "sem"]
            ax.errorbar(
                x_positions[idx],
                mean_val,
                yerr=sem_val,
                fmt="o",
                color="black",
                markersize=8,
                linewidth=1.8,
                capsize=6,
                zorder=4,
            )
            group_mean_plotted = True

        ax.set_xticks(x_positions)
        ax.set_xticklabels([format_metric_label(metric) for metric in metric_order], rotation=10)
        ax.set_title(f"{signal} metrics")
        ax.set_xlim(-0.4, len(metric_order) - 0.6)
        ax.grid(True, axis="y", linestyle="--", alpha=0.4)
        ax.set_ylabel("Value")

        if handles:
            legend_handles = handles.copy()
            legend_labels = labels.copy()
            if group_mean_plotted:
                legend_handles.append(
                    Line2D([], [], color="black", marker="o", linestyle="", label="Group mean")
                )
                legend_labels.append("Group mean")
            leg = ax.legend(legend_handles, legend_labels, loc="upper right", frameon=False)
            leg.set_title("Mouse")

        fig.tight_layout()

        if save_dir is not None:
            save_dir = Path(save_dir)
            save_dir.mkdir(parents=True, exist_ok=True)
            filename = save_dir / f"{file_prefix}_{signal.replace(' ', '_')}.png"
            fig.savefig(filename, dpi=300)
            print(f"Saved plot to {filename}")

        plt.show()



In [None]:
plot_signal_metrics(
    metrics_df,
    summary_df,
    save_dir=OUTPUT_DIR,
    file_prefix=f"{EVENT_LABEL.replace(' ', '_')}_metrics",
)


In [32]:
summary_pivot = summary_df.pivot(index="metric", columns="signal", values=["mean", "sem"])
summary_pivot


Unnamed: 0_level_0,mean,mean,sem,sem
signal,z_470_Baseline,z_560_Baseline,z_470_Baseline,z_560_Baseline
metric,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2
half_width_decay,-0.127363,0.686447,0.132886,0.083248
peak,-0.007847,1.606675,0.031932,0.110397
time_to_peak,0.245,0.3985,0.076719,0.030606


In [None]:
OUTPUT_DIR.mkdir(exist_ok=True)

per_mouse_path = OUTPUT_DIR / f"{EVENT_LABEL.replace(' ', '_')}_per_mouse_metrics.csv"
summary_path = OUTPUT_DIR / f"{EVENT_LABEL.replace(' ', '_')}_summary_metrics.csv"

metrics_df.to_csv(per_mouse_path, index=False)
summary_df.to_csv(summary_path, index=False)

print(f"Saved per-mouse metrics to: {per_mouse_path}")
print(f"Saved summary metrics to: {summary_path}")


Saved per-mouse metrics to: /Volumes/sambashare/data/ONIX/20241125_Cohort1_rotation/Visual_mismatch_day4/aligned_data_feature_summary/Apply_halt_2s_per_mouse_metrics.csv
Saved summary metrics to: /Volumes/sambashare/data/ONIX/20241125_Cohort1_rotation/Visual_mismatch_day4/aligned_data_feature_summary/Apply_halt_2s_summary_metrics.csv
