# MEA analyzer pipeline

## Part 1: Setup & Initialization

### Project root setup, core libraries and project modules - Please read carefully the text instructions

This part should be unchanged by users except maybe paths. Its job is to guarantee:

- Correct project root
- Correct imports
- Reproducibility
- No silent path bugs

In [None]:
# =========================
# PROJECT ROOT SETUP
# =========================
from pathlib import Path

# Current notebook location
HERE = Path.cwd().resolve()

# Find project root by locating "src" folder
PROJECT_ROOT = next(
    p for p in [HERE, *HERE.parents]
    if (p / "src").exists()
)

print("Notebook directory:", HERE)
print("Project root:", PROJECT_ROOT)
print("src exists?", (PROJECT_ROOT / "src").exists())

In [None]:
# =========================
# CORE LIBRARIES
# =========================
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns  # optional, used later for aesthetics

# Global plotting defaults
plt.rcParams["figure.dpi"] = 120
plt.rcParams["savefig.dpi"] = 300
plt.rcParams["axes.spines.top"] = False
plt.rcParams["axes.spines.right"] = False

In [None]:
# =========================
# PROJECT MODULES
# =========================

# Configuration
from config_handler import ConfigHandler

# Data ingestion & organization
from data_loader import load_mea_csv_well_averages
from data_organizer import DataOrganizer

# QC & preprocessing
from qc.outliers import (
    OutlierSpec,
    flag_outliers,
    get_outliers_table,
    apply_outlier_filter
)

# Normalization
from analysis.normalization import baseline_normalize

# Visualization
from visualization.plot_plate_layout import plot_plate_layout
from visualization.timecourse import plot_metric_timecourse

# Export
from io.table_export import export_metric_tables_wide

print("All project modules imported successfully.")

In [None]:
# =========================
# PROJECT PATHS
# =========================

CONFIG_DIR = PROJECT_ROOT / "config"
DATA_DIR = PROJECT_ROOT / "data"
OUTPUTS_DIR = PROJECT_ROOT / "outputs"
FIGURES_DIR = PROJECT_ROOT / "figures"

# Optional processed data directory
PROCESSED_DIR = DATA_DIR / "processed"

# Create directories if missing
for d in [OUTPUTS_DIR, FIGURES_DIR, PROCESSED_DIR]:
    d.mkdir(parents=True, exist_ok=True)

print("Config dir:", CONFIG_DIR)
print("Data dir:", DATA_DIR)
print("Outputs dir:", OUTPUTS_DIR)
print("Figures dir:", FIGURES_DIR)
print("Processed dir:", PROCESSED_DIR)

### The next two cells should be edited by users to choose plate configuration and global analysis switches

In [None]:
# =========================
# LOAD CONFIGURATIONS
# =========================

# Metrics configuration
config_handler = ConfigHandler(PROJECT_ROOT)
metrics_config = config_handler.load_metrics_config()

# Experiment configuration (USER EDITS THIS)
CONFIG_PATH = CONFIG_DIR / "MY_EXPERIMENT.yaml" # <-- CHANGE THIS TO YOUR PLATE CONFIGURATION FILE

if not CONFIG_PATH.exists():
    raise FileNotFoundError(f"Experiment config not found: {CONFIG_PATH}")

print("Using experiment config:", CONFIG_PATH)

In [None]:
# =========================
# USER ANALYSIS OPTIONS
# =========================

# Export tables
EXPORT_MODE = "both"          # <-- CHANGE THIS TO YOUR CHOICE "raw", "normalized", or "both"

# Outlier handling
OUTLIER_METHOD = "zscore"     # <-- CHANGE THIS TO YOUR CHOICE "zscore", "iqr", "mad"

OUTLIER_Z = 3.0               # <-- CHANGE THIS TO YOUR CHOICE (Z score value to detect outliers)

REMOVE_OUTLIERS = True        # <-- CHANGE THIS TO YOUR CHOICE "True" = Remove flagged outliers || "False" = Keep flagged outliers

# Normalization
NORMALIZE_TO_BASELINE = True  # <-- CHANGE THIS TO YOUR CHOICE "True" = Plot after outlier removal is normalized to baseline || "False" = Plot after outlier removal uses raw values

BASELINE_PREFIX = "0_"        # baseline files start with "0_" (DO NOT CHANGE)

print("EXPORT_MODE:", EXPORT_MODE)
print("OUTLIER_METHOD:", OUTLIER_METHOD)
print("REMOVE_OUTLIERS:", REMOVE_OUTLIERS)
print("NORMALIZE_TO_BASELINE:", NORMALIZE_TO_BASELINE)

## Part 2: Validation of experiment details & Configuration of plate and metrics

In [None]:
# =========================
# LOAD EXPERIMENT CONFIG
# =========================
import yaml

with open(CONFIG_PATH, "r") as f:
    experiment_config = yaml.safe_load(f)

experiment_info = experiment_config.get("experiment", {})
conditions = experiment_config.get("conditions", {})
time_points = experiment_config.get("time_points", [])

print("Plate ID:", experiment_info.get("plate_id", "UNKNOWN"))
print("Data directory:", experiment_info.get("data_dir", "UNKNOWN"))
print("Number of conditions:", len(conditions))
print("Conditions:")
for name, info in conditions.items():
    print(f"  - {name}: {len(info['wells'])} wells")

In [None]:
# =========================
# VALIDATE WELL ASSIGNMENTS
# =========================

VALID_ROWS = {"A", "B", "C", "D"}
VALID_COLS = {"1", "2", "3", "4", "5", "6"}

def is_valid_well(well):
    return (
        len(well) in (2, 3)
        and well[0] in VALID_ROWS
        and well[1:] in VALID_COLS
    )

all_wells = []
errors = []

for cond, info in conditions.items():
    for w in info["wells"]:
        if not is_valid_well(w):
            errors.append(f"Invalid well name: {w} (condition: {cond})")
        all_wells.append(w)

# Check duplicates
from collections import Counter
counts = Counter(all_wells)
duplicates = [w for w, c in counts.items() if c > 1]

if errors:
    print("❌ WELL FORMAT ERRORS:")
    for e in errors:
        print("  ", e)

if duplicates:
    print("❌ DUPLICATE WELL ASSIGNMENTS:")
    for w in duplicates:
        print("  ", w)

if not errors and not duplicates:
    print("✓ Well assignments validated successfully.")

In [None]:
# =========================
# PLATE LAYOUT VISUALIZATION
# =========================

fig = plot_plate_layout(CONFIG_PATH)
plt.show()

In [None]:
# =========================
# TIME POINT LABELS
# =========================

if not time_points:
    print("No time points defined in config.")
    timepoint_labels = {}
else:
    timepoint_labels = {tp["index"]: tp["label"] for tp in time_points}
    print("Time point labels:")
    for k, v in timepoint_labels.items():
        print(f"  {k}: {v}")

In [None]:
# =========================
# USER CONFIRMATION
# =========================

print("\nQC CHECKLIST:")
print("✓ Plate layout visually confirmed")
print("✓ Conditions and wells verified")
print("✓ Time points verified")

print("\nIf anything above is wrong:")
print("  → Stop here")
print("  → Fix the experiment config")
print("  → Restart the notebook\n")

## Part 3: Data loading & Master table

In [None]:
# =========================
# LOCATE RAW DATA FILES
# =========================

DATA_PATH = PROJECT_ROOT / experiment_info["data_dir"]

if not DATA_PATH.exists():
    raise FileNotFoundError(f"Data directory not found: {DATA_PATH}")

# Collect CSV files
csv_files = sorted(
    DATA_PATH.glob("*.csv"),
    key=lambda p: int(p.name.split("_")[0])  # assumes 0_, 1_, 2_, ...
)

print(f"Found {len(csv_files)} CSV files:")
for f in csv_files:
    print(" -", f.name)

# Baseline check
if not csv_files or not csv_files[0].name.startswith(BASELINE_PREFIX):
    raise ValueError(
        f"First file must be baseline and start with '{BASELINE_PREFIX}'. "
        f"Found: {csv_files[0].name if csv_files else 'NONE'}"
    )

In [None]:
# =========================
# INITIALIZE DATA ORGANIZER
# =========================

organizer = DataOrganizer(
    experiment_config_path=CONFIG_PATH,
    config_handler=config_handler
)

print("DataOrganizer initialized successfully.")

In [None]:
# =========================
# BUILD MASTER DATAFRAME
# =========================

master_df = organizer.create_master_dataframe(
    data_loader_func=load_mea_csv_well_averages,
    verbose=True
)

print("\nMaster dataframe created.")
print("Shape:", master_df.shape)
print("Columns:", list(master_df.columns))

# Expected columns: plate_id | time_point | well | condition | condition_color | metric | metric_type | value

In [None]:
# =========================
# SELECT METRICS FOR PIPELINE (AUTOMATIC)
# =========================

AVAILABLE_METRICS = sorted(
    set(master_df["metric"])
    & set(config_handler.get_all_metrics())
)

print(f"Metrics available for analysis ({len(AVAILABLE_METRICS)}):")
for m in AVAILABLE_METRICS:
    print(" -", m)

# Optional: restrict metrics (advanced users only)
METRICS_TO_USE = [
    # "Weighted Mean Firing Rate (Hz)",
]

METRICS = METRICS_TO_USE if METRICS_TO_USE else AVAILABLE_METRICS
print(f"\nMetrics that will be plotted/exported ({len(METRICS)}).")


In [None]:
# =========================
# MASTER DF SANITY CHECKS
# =========================

print("\nValue summary:")
display(master_df["value"].describe())

print("\nMetric counts:")
display(master_df["metric"].value_counts())

print("\nCondition counts:")
display(master_df["condition"].value_counts(dropna=False))

In [None]:
# =========================
# PREVIEW DATA
# =========================

display(master_df.head(10)) # <-- Change the number to the desired number of rows to show, if needed.

## Part 4: QC & Outlier detection

In [None]:
# =========================
# OUTLIER DETECTION SETUP
# =========================

outlier_spec = OutlierSpec(
    method=OUTLIER_METHOD,
    threshold=OUTLIER_Z,
    group_cols=("plate_id", "metric", "time_point", "condition"),
    value_col="value"
)

print("Outlier detection spec:")
print(outlier_spec)

In [None]:
# =========================
# FLAG OUTLIERS (NON-DESTRUCTIVE)
# =========================

df_flagged = flag_outliers(
    master_df,
    spec=outlier_spec
)

print("Outlier flagging completed.")
print("Total flagged outliers:", df_flagged["is_outlier"].sum())

In [None]:
# =========================
# OUTLIERS TABLE
# =========================

outliers_table = get_outliers_table(df_flagged)

print("Number of outliers detected:", len(outliers_table))
display(outliers_table.head(20))

In [None]:
# =========================
# RAW DATA QC PLOT (WITH OUTLIERS)
# =========================

for metric in METRICS:
    print(f"QC plot (raw data): {metric}")

    fig = plot_metric_timecourse(
        df_flagged,
        metric=metric
    )

    plt.show()

In [None]:
# =========================
# APPLY OUTLIER FILTER (OPTIONAL)
# =========================

if REMOVE_OUTLIERS:
    df_filtered = apply_outlier_filter(df_flagged)
    print("Outliers removed (values set to NaN).")
    print("NaNs after filtering:", df_filtered["value"].isna().sum())
else:
    df_filtered = df_flagged.copy()
    print("Outliers retained.")

In [None]:
# =========================
# SAVE QC TABLES
# =========================

qc_flagged_path = OUTPUTS_DIR / "master_flagged_long.csv"
qc_filtered_path = OUTPUTS_DIR / "master_filtered_long.csv"

df_flagged.to_csv(qc_flagged_path, index=False)
df_filtered.to_csv(qc_filtered_path, index=False)

print("QC tables saved:")
print(" -", qc_flagged_path)
print(" -", qc_filtered_path)

## Part 5: Baseline normalization

- Values are normalized to the baseline (time point 0) on a per-well basis:

- Normalized value = Value / Baseline Value

- Wells with missing or zero baseline values are excluded from normalization, because division would be undefined or biologically meaningless.

In [None]:
# =========================
# BASELINE NORMALIZATION
# =========================

if NORMALIZE_TO_BASELINE:
    df_norm, baseline_qc = baseline_normalize(
    df_filtered,
    baseline_time_point=0,
    value_col="value",
    normalized_col="value_norm",
    method="ratio",                 # ratio is usually the cleanest (baseline=1)
    exclude_zero_baseline=True,
    keep_excluded_rows=False,
    return_qc_table=True
)

    print("Baseline normalization completed.")
    print("Wells excluded due to baseline issues:", baseline_qc.shape[0])
else:
    df_norm = df_filtered.copy()
    print("Baseline normalization skipped.")

In [None]:
# =========================
# BASELINE QC TABLE
# =========================

if NORMALIZE_TO_BASELINE:
    print("Baseline exclusion reasons:")
    display(baseline_qc)

In [None]:
# =========================
# NORMALIZED TIMECOURSE PLOT
# =========================

if NORMALIZE_TO_BASELINE:
    for metric in METRICS:
        print(f"Normalized timecourse: {metric}")

        fig = plot_metric_timecourse(
            df_norm,
            metric=metric,
            value_col="value_norm"
        )

        plt.show()

In [None]:
# =========================
# SAVE NORMALIZED MASTER TABLE
# =========================

if NORMALIZE_TO_BASELINE:
    norm_master_path = OUTPUTS_DIR / "master_normalized_long.csv"
    df_norm.to_csv(norm_master_path, index=False)
    print("Normalized master table saved to:")
    print(norm_master_path)

## Part 6: Statistics (Comparisons within timepoint)

### The next cell should be edited by users to set statistics options

In [None]:
# =========================
# STATISTICS OPTIONS
# =========================

# Which data to use for statistics
STATS_MODE = "normalized"            # <-- CHANGE THIS TO YOUR CHOICE "raw" or "normalized"

# Statistical test family
STATS_TEST_FAMILY = "nonparametric"  # <-- CHANGE THIS TO YOUR CHOICE "parametric" or "nonparametric"

# Multiple testing correction
P_ADJUST_METHOD = "fdr_bh"           # <-- CHANGE THIS TO YOUR CHOICE "bonferroni", "holm", "fdr_bh"

# Minimum number of wells per condition
MIN_N_PER_GROUP = 3                  # <-- CHANGE THIS TO YOUR CHOICE (Minimum of 3 wells recommended)

print("STATS_MODE:", STATS_MODE)
print("STATS_TEST_FAMILY:", STATS_TEST_FAMILY)
print("P_ADJUST_METHOD:", P_ADJUST_METHOD)
print("MIN_N_PER_GROUP:", MIN_N_PER_GROUP)

In [None]:
# =========================
# SELECT DATA FOR STATS
# =========================

if STATS_MODE == "normalized":
    if "value_norm" not in df_norm.columns:
        raise ValueError("Normalized data requested but df_norm is missing.")
    stats_df = df_norm
    VALUE_COL = "value_norm"
else:
    stats_df = df_filtered
    VALUE_COL = "value"

print("Using value column:", VALUE_COL)

In [None]:
# =========================
# INITIALIZE STATS SPEC
# =========================

from statistics.timepoint_tests import (
    TimepointStatsSpec,
    compare_conditions_at_timepoint
)

stats_spec = TimepointStatsSpec(
    test_family=STATS_TEST_FAMILY,
    p_adjust_method=P_ADJUST_METHOD,
    value_col=VALUE_COL
)

print(stats_spec)

### Here you can quickly run statistics on one metric at a single time point (optional)

Replace "Weighted Mean Firing Rate (Hz)" for the desired available metric

Select one time point to analyze based on files index (1_, 2_, etc), cannot be the baseline

### List of available metrics:

Count:
- Number of Active Electrodes
- Number of Bursts
- Number of Network Bursts

Rate:
- Weighted Mean Firing Rate (Hz)
- Burst Frequency - Avg (Hz)
- Network Burst Frequency - Avg (Hz)

Duration:
- Burst Duration - Avg (sec)
- Network Burst Duration - Avg (sec)
- Network IBI Coefficient of Variation

Other:
- Synchrony Index

In [None]:
# =========================
# RUN STATS: SINGLE TIME POINT (EXPLORATORY)
# =========================

METRIC_STATS = "Weighted Mean Firing Rate (Hz)"  # user selects ONE metric
TIME_POINT = 3                                   # user selects ONE timepoint

desc_df, omnibus_df, pairwise_df = compare_conditions_at_timepoint(
    stats_df,
    metric=METRIC_STATS,
    time_point=TIME_POINT,
    spec=stats_spec,
    min_n_per_group=MIN_N_PER_GROUP
)

print(f"Statistics for metric: {METRIC_STATS} at time point {TIME_POINT}")

print("Descriptive statistics:")
display(desc_df)

print("Omnibus test:")
display(omnibus_df)

print("Pairwise comparisons:")
display(pairwise_df)

### No alterations are needed in the rest of the pipeline

In [None]:
# =========================
# RUN STATS ACROSS TIME POINTS
# =========================

all_timepoints = sorted(stats_df["time_point"].unique())

all_desc = []
all_omnibus = []
all_pairwise = []

for metric in METRICS:
    print(f"Running stats for metric: {metric}")

    if metric not in stats_df["metric"].unique():
        print(f"⚠ Metric not found in stats_df: {metric}")
        continue

    for tp in all_timepoints:
        try:
            d, o, p = compare_conditions_at_timepoint(
                stats_df,
                metric=metric,
                time_point=int(tp),
                spec=stats_spec,
                min_n_per_group=MIN_N_PER_GROUP
            )

            all_desc.append(d.assign(metric=metric, time_point=tp))
            all_omnibus.append(o.assign(metric=metric, time_point=tp))
            all_pairwise.append(p.assign(metric=metric, time_point=tp))

        except ValueError as e:
            print(f"Skipping metric={metric}, time point={tp}: {e}")

# Concatenate only if something was computed
if all_desc:
    desc_all_tp = pd.concat(all_desc, ignore_index=True)
    omnibus_all_tp = pd.concat(all_omnibus, ignore_index=True)
    pairwise_all_tp = pd.concat(all_pairwise, ignore_index=True)

    print("✓ Statistics computed across metrics and time points.")
else:
    print("⚠ No statistics computed — check metrics, groups, or sample size.")

In [None]:
# =========================
# INSPECT AND SAVE STATISTICS TABLES
# =========================

desc_path = OUTPUTS_DIR / "stats_descriptives_all_metrics.csv"
omnibus_path = OUTPUTS_DIR / "stats_omnibus_all_metrics.csv"
pairwise_path = OUTPUTS_DIR / "stats_pairwise_all_metrics.csv"

desc_all_tp.to_csv(desc_path, index=False)
omnibus_all_tp.to_csv(omnibus_path, index=False)
pairwise_all_tp.to_csv(pairwise_path, index=False)

print("Statistics tables saved:")
print(" -", desc_path)
print(" -", omnibus_path)
print(" -", pairwise_path)

## Part 7: Exports & Final outputs

In [None]:
# =========================
# EXPORT CLEAN TABLES (PRISM-FRIENDLY)
# =========================

exported_dirs = []
processed_dir = PROCESSED_DIR  # already defined earlier as PROJECT_ROOT/data/processed

if EXPORT_MODE in ("raw", "both"):
    raw_dir = export_metric_tables_wide(
        master_df,
        out_dir=processed_dir,
        config_handler=config_handler,
        mode="raw",
        timepoint_labels=timepoint_labels,
        drop_unassigned_wells=True
    )
    exported_dirs.append(raw_dir)

if EXPORT_MODE in ("normalized", "both"):
    if "value_norm" not in df_norm.columns:
        print("⚠ Normalized export requested, but normalization was not performed.")
    else:
        norm_dir = export_metric_tables_wide(
            df_norm,
            out_dir=processed_dir,
            config_handler=config_handler,
            mode="normalized",
            timepoint_labels=timepoint_labels,
            drop_unassigned_wells=True
        )
        exported_dirs.append(norm_dir)

print("Export complete. Tables written to:")
for d in exported_dirs:
    print(" -", d)

In [None]:
# =========================
# EXPORT FIGURES FOR SELECTED METRICS
# =========================

for metric in METRICS:
    print(f"Exporting figures for: {metric}")

    safe_metric = metric.replace(" ", "_").replace("/", "-")

    # RAW plot (post outlier handling)
    fig_raw = plot_metric_timecourse(
        df_filtered,
        metric=metric,
        use_normalized=False,
        show_outliers=False,
        timepoint_labels=timepoint_labels,
        show=False          # <-- THIS is the key change
    )
    raw_fig_path = FIGURES_DIR / f"RAW_{safe_metric}.png"
    fig_raw.savefig(raw_fig_path, dpi=300, bbox_inches="tight")
    plt.close(fig_raw)
    print("Saved:", raw_fig_path)

    # NORMALIZED plot (if available)
    if "value_norm" in df_norm.columns:
        fig_norm = plot_metric_timecourse(
            df_norm,
            metric=metric,
            use_normalized=True,
            show_outliers=False,
            timepoint_labels=timepoint_labels,
            show=False      # <-- AND THIS
        )
        norm_fig_path = FIGURES_DIR / f"NORM_{safe_metric}.png"
        fig_norm.savefig(norm_fig_path, dpi=300, bbox_inches="tight")
        plt.close(fig_norm)
        print("Saved:", norm_fig_path)
    else:
        print("Normalization not performed → normalized figure not saved.")

In [None]:
# =========================
# EXPORT MASTER TABLES (LONG FORMAT)
# =========================

master_df.to_csv(OUTPUTS_DIR / "master_raw_long.csv", index=False)
df_flagged.to_csv(OUTPUTS_DIR / "master_flagged_long.csv", index=False)
df_filtered.to_csv(OUTPUTS_DIR / "master_filtered_long.csv", index=False)

if "value_norm" in df_norm.columns:
    df_norm.to_csv(OUTPUTS_DIR / "master_normalized_long.csv", index=False)

print("Master tables exported to:", OUTPUTS_DIR)

In [None]:
# =========================
# FINAL RUN SUMMARY
# =========================

plate_id = experiment_info.get("plate_id", "UNKNOWN")
n_files = len(csv_files)
n_metrics = master_df["metric"].nunique()
n_timepoints = master_df["time_point"].nunique()
n_assigned_rows = master_df["condition"].notna().sum()

print("=" * 70)
print("MEA ANALYSIS RUN SUMMARY")
print("=" * 70)
print("Plate ID:", plate_id)
print("CSV files processed:", n_files)
print("Time points:", n_timepoints)
print("Metrics detected:", n_metrics)
print("Rows (assigned wells only):", n_assigned_rows)

# Outliers
if "is_outlier" in df_flagged.columns:
    print("Outliers flagged:", int(df_flagged["is_outlier"].sum()))
else:
    print("Outliers flagged: (no outlier column found)")

# Baseline exclusions
if NORMALIZE_TO_BASELINE and "baseline_qc" in globals():
    print("Wells excluded due to baseline issues:", len(baseline_qc))
else:
    print("Baseline exclusion table: not generated")

print("\nOutputs:")
print(" - Long tables:", OUTPUTS_DIR)
print(" - Figures:", FIGURES_DIR)
print(" - Prism tables:", PROCESSED_DIR)

print("=" * 70)
print("Done.")
print("=" * 70)