# Voltage Imaging Analysis Benchmark

## For Empirical Code Evaluation

This notebook defines a benchmark for automated optimization of voltage imaging analysis pipelines, suitable for empirical code evaluation frameworks that iteratively search for high-performing code variants.

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/YOUR_USERNAME/voltage-imaging-benchmark/blob/main/voltage_imaging_benchmark.ipynb)

---
# 0. Setup and Installation


In [1]:
# Install dependencies (run this cell in Colab or if packages are missing)
import sys
IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    !pip install -q numpy scipy matplotlib scikit-image scikit-learn pandas seaborn tifffile nd2 opencv-python gdown umap-learn hdbscan

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/83.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m83.7/83.7 kB[0m [31m3.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m245.8/245.8 kB[0m [31m9.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m40.9/40.9 kB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m234.5/234.5 kB[0m [31m16.7 MB/s[0m eta [36m0:00:00[0m
[?25h

## 0.2 Download Benchmark Data

The benchmark data is hosted on Google Drive. Run the cell below to download it.

In [None]:
import os
from pathlib import Path

# === CONFIGURE YOUR GOOGLE DRIVE FILE ID HERE ===
# To get the file ID from a Google Drive sharing link:
# https://drive.google.com/file/d/FILE_ID_HERE/view?usp=sharing
#                                  ^^^^^^^^^^^^ copy this part

GDRIVE_FILE_ID = "156ASrvQfbyjrHwtAeCCt_dTuWcktS9Cu"  # Replace with your actual file ID
DATA_DIR = Path("./data")
DATA_DIR.mkdir(exist_ok=True)

def download_data(file_id, output_dir):
    """Download and extract benchmark data from Google Drive."""
    import gdown
    import zipfile

    zip_path = output_dir / "benchmark_data.zip"

    # Check if data already exists
    if (output_dir / "video.tif").exists() or (output_dir / "fish1").exists():
        print("Data already downloaded. Skipping.")
        return

    # Download from Google Drive
    url = f"https://drive.google.com/uc?id={file_id}"
    print(f"Downloading data from Google Drive...")
    gdown.download(url, str(zip_path), quiet=False)

    # Extract if it's a zip file
    if zipfile.is_zipfile(zip_path):
        print("Extracting data...")
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall(output_dir)
        zip_path.unlink()  # Remove zip after extraction
        print("Done!")
    else:
        # Might be a single TIFF file
        print("Downloaded file (not a zip archive)")

# Uncomment to download:
# download_data(GDRIVE_FILE_ID, DATA_DIR)

## 0.3 Import Libraries

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy import ndimage, signal
from skimage import measure, morphology
import tifffile
from pathlib import Path
import json
from typing import List, Dict, Tuple, Optional

# Set plotting style
plt.rcParams['figure.figsize'] = (12, 6)
plt.rcParams['figure.dpi'] = 100

print("Libraries imported successfully!")

---
# 1. Problem Definition

## 1.1 Scientific Context

**Voltage imaging** enables recording of neuronal electrical activity at high temporal resolution (~200 Hz) across many neurons simultaneously. Unlike calcium imaging, voltage indicators directly report membrane potential changes, capturing:
- **Action potentials (spikes)** - fast (~1-5ms) electrical events
- **Subthreshold activity** - slower membrane potential fluctuations

## 1.2 The Data

| Property | Value |
|----------|-------|
| **Organism** | Zebrafish (larval) |
| **Indicator** | Voltron-2 (soma-targeted) |
| **Frame rate** | ~200 Hz (5 ms/frame) |
| **File format** | ND2 (Nikon) → TIFF |
| **File size** | 2-10 GB per recording |
| **Duration** | 1,000-10,000 frames (5-50 seconds) |
| **Neurons per FOV** | 10-100 |
| **Neuron size** | ~5-10 µm (~9 pixels) |
| **Neuron appearance** | Ring-like (membrane labeling) |

## 1.3 Signal Characteristics

**Important**: Voltron-2 produces **NEGATIVE deflections** during action potentials.

| Feature | Typical Value |
|---------|---------------|
| Spike duration | 1-5 ms (1-2 frames at 200 Hz) |
| Spike amplitude | 2-10% ΔF/F |
| Spike polarity | **Negative** (downward) |
| Noise type | Broadband (camera/shot noise) |
| SNR | Variable, typically 2-10 |

---
# 2. Task Specification

## 2.1 Overall Goal

Given a raw voltage imaging video, produce:
1. **ROI masks** - binary masks identifying each neuron
2. **Spike times** - timestamps of detected action potentials for each neuron
3. **Voltage traces** - cleaned ΔF/F time series for each neuron

## 2.2 Pipeline Components (Optimization Targets)

The analysis pipeline consists of sequential processing stages. Each stage can be independently optimized:

```
┌─────────────────────────────────────────────────────────────────────┐
│                    OPTIMIZATION TARGETS                             │
├─────────────────────────────────────────────────────────────────────┤
│                                                                     │
│  STAGE 1: MOTION CORRECTION                                         │
│  ├── Method: rigid vs non-rigid                                    │
│  ├── Reference: first frame, mean, median, specific frame          │
│  ├── Algorithm: phase correlation, template matching, optical flow │
│  └── Parameters: max_shift, upsample_factor, smoothing             │
│                                                                     │
│  STAGE 2: DENOISING                                                 │
│  ├── Method: none, temporal filter, spatial filter, PCA, wavelet   │
│  ├── Temporal: lowpass cutoff, Savitzky-Golay window, median       │
│  ├── Spatial: Gaussian sigma, bilateral filter                     │
│  └── PCA: local vs global, number of components, patch size        │
│                                                                     │
│  STAGE 3: ROI SEGMENTATION                                          │
│  ├── Method: threshold, watershed, CNN (Mask R-CNN), NMF           │
│  ├── Features: std projection, correlation image, PCA components   │
│  ├── Constraints: min/max area, circularity, ring-like shape       │
│  └── Post-processing: merge overlapping, remove duplicates         │
│                                                                     │
│  STAGE 4: TRACE EXTRACTION                                          │
│  ├── Aggregation: mean, median, weighted by distance               │
│  ├── Background: none, annulus subtraction, neuropil coefficient   │
│  ├── Baseline: percentile (which?), rolling window size            │
│  └── Detrending: none, linear, polynomial, exponential             │
│                                                                     │
│  STAGE 5: SPIKE DETECTION                                           │
│  ├── Method: threshold, template matching, deconvolution, ML       │
│  ├── Threshold: fixed std, adaptive, percentile-based              │
│  ├── Constraints: min spike width, refractory period               │
│  └── Post-processing: amplitude filter, artifact rejection         │
│                                                                     │
└─────────────────────────────────────────────────────────────────────┘
```

## 2.3 Input Format

```python
# Input: 3D numpy array
video: np.ndarray  # shape: (n_frames, height, width), dtype: uint16 or float32
fps: float  # frame rate in Hz (typically 200)
```


---
# 3. Challenges and Failure Modes

## 3.1 Motion Artifacts

**Problem**: The animal moves during imaging, causing:
- Rigid translation (x/y shifts)
- Non-rigid deformation (tissue stretching)
- Focus changes (neurons going in/out of focus)
- ROI contamination (wrong pixels included after movement)

**Failure modes**:
- Correlated "spikes" across many neurons (motion artifact detected as spike)
- Lost neurons (moved out of ROI)
- False spikes from intensity changes due to focus shifts

## 3.2 Segmentation Challenges

**Problem**: Neurons have irregular shapes:
- Ring-like appearance (soma-targeted membrane labeling)
- Variable brightness
- Overlapping neurons
- Similar size to noise blobs

**Failure modes**:
- Merging adjacent neurons into one ROI
- Splitting one neuron into multiple ROIs
- Including non-neuronal structures
- Missing dim neurons

## 3.3 Spike Detection Challenges

**Problem**: Spikes are fast and small:
- Only 1-2 frames wide at 200 Hz
- Low SNR (amplitude similar to noise)
- Variable amplitude across neurons and within same neuron
- Negative polarity (opposite to calcium indicators)

**Failure modes**:
- Missing true spikes (false negatives)
- Detecting noise as spikes (false positives)
- Correlated false detections from global artifacts

## 3.4 Known Artifacts

| Artifact | Cause | Detection Method |
|----------|-------|------------------|
| Blood cell shadows | Blood flow | Moving dark spots |
| Excitation variation | Laser fluctuation | Correlated intensity changes |
| Photobleaching | Dye degradation | Slow intensity decay |
| Focus drift | Mechanical drift | Blur changes over time |

---
# 4. Evaluation Metrics

## 4.1 Ground Truth Comparison (Supervised)

When expert annotations are available:

### 4.1.1 ROI Segmentation Metrics

```python
def evaluate_segmentation(predicted_masks, ground_truth_masks, iou_threshold=0.5):
    """
    Compare predicted ROIs to ground truth.
    
    Returns:
        precision: fraction of predicted ROIs that match ground truth
        recall: fraction of ground truth ROIs that were found
        f1: harmonic mean of precision and recall
        mean_iou: mean intersection-over-union for matched ROIs
    """
    pass
```

### 4.1.2 Spike Detection Metrics

```python
def evaluate_spikes(predicted_times, ground_truth_times, tolerance_ms=5):
    """
    Compare predicted spike times to ground truth.
    
    Args:
        tolerance_ms: maximum time difference to count as match
    
    Returns:
        precision: fraction of predicted spikes that are true
        recall: fraction of true spikes that were detected
        f1: harmonic mean of precision and recall
        timing_error: mean timing error for matched spikes (ms)
    """
    pass
```

## 4.2 Automated Validation (Unsupervised)

Metrics that don't require ground truth annotations:

### 4.2.1 Shuffled Control Test

**Principle**: Spike detection on temporally shuffled data should yield far fewer spikes than real data.

```python
def shuffled_control_score(trace, spike_detector, n_shuffles=100):
    """
    Compare spike detection on real vs shuffled traces.
    
    Shuffling methods:
    - Frame shuffle: random permutation of frames
    - Circular shuffle: roll trace by random offset
    - Block shuffle: shuffle blocks of ~100ms
    
    Returns:
        spike_ratio: real_spikes / mean(shuffled_spikes)
        false_positive_rate: estimated FP rate from shuffled data
        p_value: probability that real spikes are due to chance
    """
    real_spikes = spike_detector(trace)
    shuffled_spikes = []
    
    for _ in range(n_shuffles):
        shuffled_trace = np.random.permutation(trace)
        shuffled_spikes.append(len(spike_detector(shuffled_trace)))
    
    spike_ratio = len(real_spikes) / (np.mean(shuffled_spikes) + 1)
    return spike_ratio
```

**Interpretation**:
- `spike_ratio >> 1`: Good, detecting real structure
- `spike_ratio ≈ 1`: Bad, just detecting noise

### 4.2.2 Spike Template Consistency

**Principle**: Spikes from the same neuron should have similar waveform shapes.

```python
def template_consistency_score(trace, spike_times, window_ms=10, fps=200):
    """
    Measure consistency of spike waveforms within a neuron.
    
    Returns:
        mean_correlation: mean pairwise correlation between spike waveforms
        std_correlation: variability in spike shapes
    """
    window_frames = int(window_ms * fps / 1000)
    waveforms = []
    
    for t in spike_times:
        idx = int(t * fps)
        if idx >= window_frames and idx < len(trace) - window_frames:
            waveforms.append(trace[idx-window_frames:idx+window_frames])
    
    if len(waveforms) < 2:
        return 0, 0
    
    correlations = []
    for i in range(len(waveforms)):
        for j in range(i+1, len(waveforms)):
            corr = np.corrcoef(waveforms[i], waveforms[j])[0, 1]
            correlations.append(corr)
    
    return np.mean(correlations), np.std(correlations)
```

**Interpretation**:
- `mean_correlation > 0.8`: Good, consistent spike shapes
- `mean_correlation < 0.5`: Bad, likely detecting noise

### 4.2.3 Inter-Neuron Correlation (Artifact Detection)

**Principle**: High correlation in spike timing across many neurons suggests artifacts.

```python
def artifact_correlation_score(all_spike_times, n_rois, duration, bin_size_ms=10):
    """
    Detect correlated spiking that suggests motion/illumination artifacts.
    
    Returns:
        mean_pairwise_correlation: should be low for real data
        synchrony_events: number of timepoints with >50% neurons spiking
    """
    # Bin spikes
    n_bins = int(duration * 1000 / bin_size_ms)
    spike_matrix = np.zeros((n_rois, n_bins))
    
    for i, times in enumerate(all_spike_times):
        for t in times:
            bin_idx = int(t * 1000 / bin_size_ms)
            if bin_idx < n_bins:
                spike_matrix[i, bin_idx] = 1
    
    # Compute pairwise correlations
    corr_matrix = np.corrcoef(spike_matrix)
    upper_tri = corr_matrix[np.triu_indices(n_rois, k=1)]
    
    # Count synchrony events
    fraction_active = spike_matrix.sum(axis=0) / n_rois
    synchrony_events = np.sum(fraction_active > 0.5)
    
    return np.nanmean(upper_tri), synchrony_events
```

**Interpretation**:
- `mean_correlation < 0.1`: Good, neurons fire independently
- `mean_correlation > 0.3`: Bad, likely motion artifacts
- `synchrony_events > 0`: Suspicious, check these timepoints

### 4.2.4 ROI Quality Metrics

```python
def roi_quality_score(roi_mask, std_projection):
    """
    Assess ROI quality based on morphological and intensity features.
    
    Returns:
        size_score: 1 if within expected range (5-10 µm), 0 otherwise
        shape_score: circularity / ring-likeness
        snr_score: signal-to-noise in ROI region
    """
    pass
```

### 4.2.5 Physiological Plausibility

```python
def physiological_plausibility_score(spike_times, fps=200):
    """
    Check if detected spikes are physiologically plausible.
    
    Returns:
        refractory_violations: spikes closer than 2ms (should be 0)
        firing_rate_plausible: True if rate is 0.1-50 Hz
        isi_cv: coefficient of variation of inter-spike intervals
    """
    if len(spike_times) < 2:
        return 0, True, 0
    
    isis = np.diff(spike_times) * 1000  # Convert to ms
    refractory_violations = np.sum(isis < 2)  # 2ms refractory period
    
    duration = spike_times[-1] - spike_times[0]
    firing_rate = len(spike_times) / duration if duration > 0 else 0
    firing_rate_plausible = 0.1 < firing_rate < 50
    
    isi_cv = np.std(isis) / np.mean(isis) if len(isis) > 0 else 0
    
    return refractory_violations, firing_rate_plausible, isi_cv
```

---
# 5. Composite Evaluation Score

## 5.1 Weighted Score Function

Combine all metrics into a single optimization target:

```python
def compute_benchmark_score(results, video, fps, ground_truth=None):
    """
    Compute overall benchmark score for a pipeline's output.
    
    Args:
        results: dict with roi_masks, traces, spike_times
        video: original video array
        fps: frame rate
        ground_truth: optional dict with GT masks and spikes
    
    Returns:
        score: float in [0, 1], higher is better
        details: dict with individual metric scores
    """
    scores = {}
    
    # === UNSUPERVISED METRICS (always computed) ===
    
    # 1. Shuffled control (weight: 0.2)
    shuffle_ratios = []
    for i, trace in enumerate(results['traces']):
        ratio = shuffled_control_score(trace, spike_detector)
        shuffle_ratios.append(ratio)
    scores['shuffle_ratio'] = np.median(shuffle_ratios)
    # Normalize: ratio of 10 = perfect score
    scores['shuffle_score'] = min(1.0, scores['shuffle_ratio'] / 10)
    
    # 2. Template consistency (weight: 0.15)
    consistencies = []
    for i, trace in enumerate(results['traces']):
        if len(results['spike_times'][i]) >= 3:
            corr, _ = template_consistency_score(trace, results['spike_times'][i])
            consistencies.append(corr)
    scores['template_consistency'] = np.mean(consistencies) if consistencies else 0
    
    # 3. Artifact correlation (weight: 0.15)
    corr, sync = artifact_correlation_score(
        results['spike_times'],
        len(results['roi_masks']),
        video.shape[0] / fps
    )
    # Lower correlation is better
    scores['artifact_score'] = max(0, 1 - corr * 3)  # corr > 0.33 = 0 score
    
    # 4. Physiological plausibility (weight: 0.1)
    violations = 0
    plausible_count = 0
    for times in results['spike_times']:
        v, p, _ = physiological_plausibility_score(times, fps)
        violations += v
        plausible_count += int(p)
    scores['physiology_score'] = plausible_count / len(results['spike_times'])
    scores['refractory_violations'] = violations
    
    # 5. ROI count reasonableness (weight: 0.05)
    n_rois = len(results['roi_masks'])
    # Expect 10-100 ROIs
    if 10 <= n_rois <= 100:
        scores['roi_count_score'] = 1.0
    elif 5 <= n_rois <= 150:
        scores['roi_count_score'] = 0.5
    else:
        scores['roi_count_score'] = 0.0
    
    # === SUPERVISED METRICS (if ground truth available) ===
    
    if ground_truth is not None:
        # 6. ROI segmentation F1 (weight: 0.15)
        seg_metrics = evaluate_segmentation(
            results['roi_masks'],
            ground_truth['roi_masks']
        )
        scores['segmentation_f1'] = seg_metrics['f1']
        
        # 7. Spike detection F1 (weight: 0.2)
        spike_f1s = []
        for i in range(len(ground_truth['spike_times'])):
            if i < len(results['spike_times']):
                metrics = evaluate_spikes(
                    results['spike_times'][i],
                    ground_truth['spike_times'][i]
                )
                spike_f1s.append(metrics['f1'])
        scores['spike_f1'] = np.mean(spike_f1s) if spike_f1s else 0
        
        # Compute weighted score with GT
        total_score = (
            0.15 * scores['shuffle_score'] +
            0.10 * scores['template_consistency'] +
            0.10 * scores['artifact_score'] +
            0.05 * scores['physiology_score'] +
            0.05 * scores['roi_count_score'] +
            0.25 * scores['segmentation_f1'] +
            0.30 * scores['spike_f1']
        )
    else:
        # Compute weighted score without GT
        total_score = (
            0.35 * scores['shuffle_score'] +
            0.25 * scores['template_consistency'] +
            0.25 * scores['artifact_score'] +
            0.10 * scores['physiology_score'] +
            0.05 * scores['roi_count_score']
        )
    
    return total_score, scores
```

---
# 6. Baseline Implementation

A simple baseline pipeline for comparison:

```python
def baseline_pipeline(video, fps=200):
    """
    Simple baseline voltage imaging analysis pipeline.
    
    Args:
        video: np.ndarray, shape (n_frames, height, width)
        fps: frame rate in Hz
    
    Returns:
        dict with roi_masks, traces, spike_times, spike_frames
    """
    import numpy as np
    from scipy import ndimage, signal
    from skimage import measure, morphology
    
    n_frames, height, width = video.shape
    
    # === STAGE 1: No motion correction (baseline) ===
    corrected = video
    
    # === STAGE 2: Simple denoising ===
    # Temporal Gaussian smoothing
    denoised = ndimage.gaussian_filter1d(corrected.astype(float), sigma=1, axis=0)
    
    # === STAGE 3: ROI segmentation ===
    # Use std projection to find active regions
    std_proj = np.std(denoised, axis=0)
    
    # Threshold
    threshold = np.mean(std_proj) + 1.5 * np.std(std_proj)
    binary = std_proj > threshold
    
    # Clean up
    binary = morphology.remove_small_objects(binary, min_size=30)
    binary = morphology.remove_small_holes(binary, area_threshold=30)
    
    # Label connected components
    labeled = measure.label(binary)
    regions = measure.regionprops(labeled)
    
    # Filter by size
    roi_masks = []
    for region in regions:
        if 30 < region.area < 500:  # Expected neuron size
            mask = labeled == region.label
            roi_masks.append(mask)
    
    roi_masks = np.array(roi_masks)
    
    # === STAGE 4: Trace extraction ===
    traces = []
    for mask in roi_masks:
        trace = np.mean(denoised[:, mask], axis=1)
        # Compute dF/F
        f0 = np.percentile(trace, 10)
        dff = (trace - f0) / f0
        # Detrend
        dff = signal.detrend(dff)
        traces.append(dff)
    
    traces = np.array(traces)
    
    # === STAGE 5: Spike detection ===
    spike_times = []
    spike_frames = []
    
    for trace in traces:
        # Invert (Voltron-2 has negative spikes)
        inverted = -trace
        
        # Threshold at 3 std
        threshold = np.mean(inverted) + 3 * np.std(inverted)
        
        # Find peaks
        peaks, _ = signal.find_peaks(
            inverted,
            height=threshold,
            distance=int(0.005 * fps)  # 5ms refractory
        )
        
        spike_frames.append(peaks)
        spike_times.append(peaks / fps)
    
    return {
        'roi_masks': roi_masks,
        'traces': traces,
        'spike_times': spike_times,
        'spike_frames': spike_frames
    }
```

---
# 7. Search Space for Optimization

## 7.1 Discrete Choices

```python
SEARCH_SPACE = {
    'motion_correction': {
        'method': ['none', 'rigid', 'nonrigid'],
        'reference': ['first', 'mean', 'middle'],
        'algorithm': ['phase', 'template', 'optical_flow'],
    },
    'denoising': {
        'temporal': ['none', 'gaussian', 'savgol', 'median', 'lowpass'],
        'spatial': ['none', 'gaussian', 'bilateral'],
        'pca': ['none', 'global', 'local'],
    },
    'segmentation': {
        'method': ['threshold', 'watershed', 'nmf', 'correlation'],
        'feature': ['std', 'max', 'mean', 'correlation'],
    },
    'trace_extraction': {
        'aggregation': ['mean', 'median'],
        'background': ['none', 'annulus', 'neuropil'],
        'baseline': ['percentile', 'rolling_percentile'],
        'detrend': ['none', 'linear', 'polynomial'],
    },
    'spike_detection': {
        'method': ['threshold', 'template', 'adaptive'],
        'threshold_type': ['fixed_std', 'mad', 'percentile'],
    }
}
```

## 7.2 Continuous Parameters

```python
PARAMETER_RANGES = {
    'motion_max_shift': (10, 100),  # pixels
    'denoise_sigma': (0.5, 5.0),
    'savgol_window': (3, 15),  # must be odd
    'lowpass_cutoff': (20, 100),  # Hz
    'pca_components': (5, 50),
    'segmentation_threshold_factor': (1.0, 3.0),
    'roi_min_area': (20, 100),  # pixels
    'roi_max_area': (200, 1000),  # pixels
    'baseline_percentile': (5, 20),
    'baseline_window_seconds': (1, 10),
    'spike_threshold_std': (2.0, 5.0),
    'spike_min_distance_ms': (2, 20),
    'spike_prominence_factor': (0.3, 1.0),
}
```

---
# 8. Data and Ground Truth

## 8.1 Available Data

| Dataset | Description | Frames | ROIs | Has GT? |
|---------|-------------|--------|------|--------|
| fish1_fov1 | Example recording | ~2000 | ~20 | Yes (manual) |
| fish1_fov2 | Different FOV | ~2000 | ~30 | Yes (manual) |
| ... | ... | ... | ... | ... |

## 8.2 Ground Truth Format

```python
ground_truth = {
    'roi_masks': np.ndarray,  # shape: (n_rois, height, width)
    'roi_metadata': [{
        'id': int,
        'quality': str,  # 'good', 'uncertain', 'overlapping'
        'shape': str,    # 'ring', 'filled', 'irregular'
    }],
    'spike_times': List[np.ndarray],  # seconds, per ROI
    'spike_confidence': List[List[str]],  # 'certain', 'probable', 'uncertain'
}
```

## 8.3 Loading Data

```python
def load_benchmark_data(data_path):
    """
    Load benchmark data and ground truth.
    
    Returns:
        video: np.ndarray
        fps: float
        ground_truth: dict or None
    """
    import tifffile
    import json
    from pathlib import Path
    
    data_path = Path(data_path)
    
    # Load video
    video = tifffile.imread(data_path / 'video.tif')
    
    # Load metadata
    with open(data_path / 'metadata.json') as f:
        metadata = json.load(f)
    fps = metadata['fps']
    
    # Load ground truth if available
    gt_path = data_path / 'ground_truth.npz'
    if gt_path.exists():
        gt_data = np.load(gt_path, allow_pickle=True)
        ground_truth = {
            'roi_masks': gt_data['roi_masks'],
            'spike_times': gt_data['spike_times'].tolist(),
        }
    else:
        ground_truth = None
    
    return video, fps, ground_truth
```

---
# 9. References and Prior Work

## 9.1 Existing Tools

| Tool | What it does | Strengths | Limitations |
|------|--------------|-----------|-------------|
| **VolPy** (CaImAn) | Full pipeline | Mask R-CNN segmentation, SpikePursuit | Complex setup |
| **NoRMCorre** | Motion correction | Well-validated | Part of CaImAn |
| **Suite2p** | Calcium imaging | Fast, scalable | Not optimized for voltage |

## 9.2 Key Papers

1. **VolPy** - Cai et al., 2021, PLOS Comp Bio
   - Mask R-CNN for segmentation
   - SpikePursuit for spike detection
   - F1 > 90% on benchmark data

2. **Voltage imaging pipeline** - (First paper you shared)
   - Camera noise correction
   - Local PCA denoising
   - Semi-NMF segmentation
   - LSTM spike detection

3. **Whole-brain voltage imaging** - (Second paper - Positron2)
   - NoRMCorre motion correction
   - UMAP+DBSCAN artifact removal
   - SNR > 4 filtering

## 9.3 Novel Ideas to Explore

- Shuffled control validation
- Template consistency scoring
- Artifact correlation detection
- Combined supervised + unsupervised evaluation

---
# 9.5 Implementation Ideas and Advanced Techniques

## 9.5.1 Motion Correction Approaches

### Template Matching (Recommended for Rigid Motion)
```python
import cv2

def motion_correct_template(video, max_shift=50, reference='mean'):
    """
    Rigid motion correction using OpenCV template matching.
    Most robust for small translations.
    """
    if reference == 'mean':
        ref_frame = np.mean(video, axis=0).astype(np.float32)
    elif reference == 'first':
        ref_frame = video[0].astype(np.float32)
    else:
        ref_frame = video[len(video)//2].astype(np.float32)
    
    h, w = ref_frame.shape
    margin = max_shift
    template = ref_frame[margin:h-margin, margin:w-margin]
    
    corrected = np.zeros_like(video)
    shifts = []
    
    for i, frame in enumerate(video):
        result = cv2.matchTemplate(frame.astype(np.float32), template, cv2.TM_CCORR_NORMED)
        _, _, _, max_loc = cv2.minMaxLoc(result)
        shift_x = margin - max_loc[0]
        shift_y = margin - max_loc[1]
        shifts.append((shift_x, shift_y))
        
        M = np.float32([[1, 0, shift_x], [0, 1, shift_y]])
        corrected[i] = cv2.warpAffine(frame, M, (w, h))
    
    return corrected, np.array(shifts)
```

### NoRMCorre-style Piecewise Rigid
```python
def motion_correct_piecewise(video, patch_size=128, overlap=32, max_shift=20):
    """
    Piecewise rigid correction - divides FOV into patches.
    Better for non-uniform motion (e.g., tissue deformation).
    """
    # Divide into overlapping patches
    # Compute shift per patch
    # Interpolate shifts across boundaries
    # Apply smooth deformation field
    pass
```

### Optical Flow (Non-rigid)
```python
def motion_correct_optical_flow(video, reference_frame):
    """
    Dense optical flow for non-rigid deformation.
    Use sparingly - can introduce artifacts.
    """
    flow_params = dict(
        pyr_scale=0.5, levels=3, winsize=15,
        iterations=3, poly_n=5, poly_sigma=1.2, flags=0
    )
    # cv2.calcOpticalFlowFarneback for each frame
    pass
```

## 9.5.2 Denoising Strategies

### Local PCA Denoising (from paper methods)
```python
def local_pca_denoise(video, patch_size=32, n_components=10, stride=16):
    """
    Local PCA denoising - preserves local structure while removing noise.
    
    For each spatial patch:
    1. Extract time series for all pixels in patch
    2. Perform PCA, keep top n_components
    3. Reconstruct patch from low-rank approximation
    4. Average overlapping regions
    """
    n_frames, h, w = video.shape
    denoised = np.zeros_like(video, dtype=np.float32)
    weights = np.zeros((h, w), dtype=np.float32)
    
    for y in range(0, h - patch_size + 1, stride):
        for x in range(0, w - patch_size + 1, stride):
            # Extract patch time series: (n_frames, patch_size^2)
            patch = video[:, y:y+patch_size, x:x+patch_size]
            patch_flat = patch.reshape(n_frames, -1)
            
            # Center
            mean_vals = patch_flat.mean(axis=0)
            centered = patch_flat - mean_vals
            
            # SVD (more stable than PCA for this)
            U, S, Vt = np.linalg.svd(centered, full_matrices=False)
            
            # Reconstruct with top components
            reconstructed = U[:, :n_components] @ np.diag(S[:n_components]) @ Vt[:n_components, :]
            reconstructed += mean_vals
            
            # Add to output with averaging
            denoised[:, y:y+patch_size, x:x+patch_size] += reconstructed.reshape(n_frames, patch_size, patch_size)
            weights[y:y+patch_size, x:x+patch_size] += 1
    
    # Normalize by overlap count
    denoised /= weights[np.newaxis, :, :]
    return denoised
```

### Temporal Median Filter (Spike-Preserving)
```python
def temporal_median_filter(video, window=3):
    """
    Median filter preserves sharp edges (spikes) better than Gaussian.
    Window must be small (3-5) to not blur 1-2 frame spikes.
    """
    from scipy.ndimage import median_filter
    return median_filter(video, size=(window, 1, 1))
```

### Wavelet Denoising
```python
def wavelet_denoise(trace, wavelet='db4', level=3, threshold_mode='soft'):
    """
    Wavelet denoising for individual traces.
    Good for separating spike frequencies from noise.
    """
    import pywt
    coeffs = pywt.wavedec(trace, wavelet, level=level)
    
    # Estimate noise from finest detail coefficients
    sigma = np.median(np.abs(coeffs[-1])) / 0.6745
    threshold = sigma * np.sqrt(2 * np.log(len(trace)))
    
    # Threshold detail coefficients
    denoised_coeffs = [coeffs[0]]  # Keep approximation
    for c in coeffs[1:]:
        denoised_coeffs.append(pywt.threshold(c, threshold, mode=threshold_mode))
    
    return pywt.waverec(denoised_coeffs, wavelet)[:len(trace)]
```

## 9.5.3 ROI Segmentation Methods

### Correlation Image Segmentation
```python
def compute_correlation_image(video, radius=4):
    """
    For each pixel, compute mean correlation with neighbors.
    Neurons show high local correlation; noise does not.
    """
    from scipy.ndimage import uniform_filter
    
    n_frames, h, w = video.shape
    video_flat = video.reshape(n_frames, -1)
    
    # Normalize each pixel's time series
    video_norm = (video_flat - video_flat.mean(axis=0)) / (video_flat.std(axis=0) + 1e-10)
    
    # Compute local correlation via convolution trick
    corr_img = np.zeros((h, w))
    
    for dy in range(-radius, radius+1):
        for dx in range(-radius, radius+1):
            if dy == 0 and dx == 0:
                continue
            # Shift and correlate
            shifted = np.roll(np.roll(video, dy, axis=1), dx, axis=2)
            shifted_flat = shifted.reshape(n_frames, -1)
            shifted_norm = (shifted_flat - shifted_flat.mean(axis=0)) / (shifted_flat.std(axis=0) + 1e-10)
            
            corr = (video_norm * shifted_norm).mean(axis=0).reshape(h, w)
            corr_img += corr
    
    corr_img /= (2*radius + 1)**2 - 1
    return corr_img

def segment_from_correlation(corr_img, std_proj, min_area=30, max_area=500):
    """Segment ROIs using correlation image."""
    # Combine correlation and std projection
    combined = corr_img * std_proj
    combined = (combined - combined.min()) / (combined.max() - combined.min())
    
    # Adaptive threshold
    from skimage.filters import threshold_otsu
    thresh = threshold_otsu(combined)
    binary = combined > thresh
    
    # Watershed for touching neurons
    from skimage.segmentation import watershed
    from skimage.feature import peak_local_max
    from scipy import ndimage
    
    distance = ndimage.distance_transform_edt(binary)
    local_max = peak_local_max(distance, min_distance=5, labels=binary)
    markers = np.zeros_like(binary, dtype=int)
    markers[tuple(local_max.T)] = np.arange(len(local_max)) + 1
    labels = watershed(-distance, markers, mask=binary)
    
    # Filter by size
    roi_masks = []
    for region in measure.regionprops(labels):
        if min_area < region.area < max_area:
            mask = labels == region.label
            roi_masks.append(mask)
    
    return np.array(roi_masks)
```

### Ring-Detection for Membrane-Labeled Neurons
```python
def detect_ring_rois(std_proj, inner_radius=2, outer_radius=6):
    """
    Detect ring-shaped ROIs typical of soma-targeted membrane indicators.
    Uses ring-shaped matched filter.
    """
    from skimage.draw import disk
    
    # Create ring template
    size = outer_radius * 2 + 1
    template = np.zeros((size, size))
    rr_outer, cc_outer = disk((outer_radius, outer_radius), outer_radius)
    rr_inner, cc_inner = disk((outer_radius, outer_radius), inner_radius)
    template[rr_outer, cc_outer] = 1
    template[rr_inner, cc_inner] = -1  # Hollow center
    template /= np.abs(template).sum()
    
    # Convolve
    from scipy.signal import convolve2d
    response = convolve2d(std_proj, template, mode='same')
    
    # Find peaks
    from skimage.feature import peak_local_max
    peaks = peak_local_max(response, min_distance=outer_radius, threshold_rel=0.3)
    
    # Create circular masks around peaks
    roi_masks = []
    h, w = std_proj.shape
    for y, x in peaks:
        mask = np.zeros((h, w), dtype=bool)
        rr, cc = disk((y, x), outer_radius, shape=(h, w))
        mask[rr, cc] = True
        roi_masks.append(mask)
    
    return np.array(roi_masks)
```

### Semi-NMF Segmentation
```python
def seminmf_segmentation(video, n_components=50, n_rois=30):
    """
    Semi-NMF: Spatial components are non-negative (masks),
    temporal components can be negative (voltage traces).
    """
    from sklearn.decomposition import NMF
    
    n_frames, h, w = video.shape
    video_flat = video.reshape(n_frames, h*w).T  # (pixels, frames)
    
    # Standard NMF on absolute values to initialize
    nmf = NMF(n_components=n_components, max_iter=500)
    W = nmf.fit_transform(np.abs(video_flat))  # Spatial (pixels, components)
    H = nmf.components_  # Temporal (components, frames)
    
    # Threshold spatial components to get masks
    roi_masks = []
    for i in range(n_components):
        spatial = W[:, i].reshape(h, w)
        thresh = np.percentile(spatial, 95)
        mask = spatial > thresh
        
        # Filter small regions
        mask = morphology.remove_small_objects(mask, min_size=30)
        if mask.sum() > 0:
            roi_masks.append(mask)
    
    return np.array(roi_masks[:n_rois])
```

## 9.5.4 Advanced Spike Detection

### Adaptive Threshold Detection
```python
def detect_spikes_adaptive(trace, fps=200, window_sec=1.0, n_std=3.5):
    """
    Adaptive threshold that adjusts to local noise level.
    Better for traces with non-stationary noise.
    """
    window = int(window_sec * fps)
    inverted = -trace  # Voltron-2 has negative spikes
    
    # Rolling statistics
    from scipy.ndimage import uniform_filter1d
    local_mean = uniform_filter1d(inverted, window)
    local_std = np.sqrt(uniform_filter1d((inverted - local_mean)**2, window))
    
    # Adaptive threshold
    threshold = local_mean + n_std * local_std
    
    # Find peaks above threshold
    peaks, properties = signal.find_peaks(
        inverted,
        height=threshold,
        distance=int(0.003 * fps),  # 3ms refractory
        prominence=0.5 * local_std.mean()
    )
    
    return peaks / fps, properties

```

### Template Matching Spike Detection
```python
def detect_spikes_template(trace, fps=200, template_width_ms=10):
    """
    Use average spike waveform as template for matched filtering.
    Two-pass: first detect rough spikes, then refine with template.
    """
    inverted = -trace
    template_samples = int(template_width_ms * fps / 1000)
    
    # Pass 1: Simple threshold to get initial spikes
    threshold = np.mean(inverted) + 4 * np.std(inverted)
    initial_peaks, _ = signal.find_peaks(inverted, height=threshold, distance=template_samples//2)
    
    if len(initial_peaks) < 3:
        return initial_peaks / fps
    
    # Extract waveforms and compute average template
    waveforms = []
    for peak in initial_peaks:
        if peak >= template_samples and peak < len(trace) - template_samples:
            waveforms.append(inverted[peak-template_samples:peak+template_samples])
    
    template = np.mean(waveforms, axis=0)
    template = (template - template.mean()) / template.std()
    
    # Pass 2: Matched filter
    filtered = np.correlate(inverted, template, mode='same')
    
    # Detect on filtered trace
    refined_peaks, _ = signal.find_peaks(
        filtered,
        height=np.percentile(filtered, 95),
        distance=int(0.003 * fps)
    )
    
    return refined_peaks / fps
```

### MAD-based Robust Threshold
```python
def detect_spikes_mad(trace, fps=200, n_mad=5):
    """
    Use Median Absolute Deviation instead of std.
    More robust to outliers and non-Gaussian noise.
    """
    inverted = -trace
    median = np.median(inverted)
    mad = np.median(np.abs(inverted - median))
    
    # MAD to std conversion for Gaussian: std ≈ 1.4826 * MAD
    threshold = median + n_mad * 1.4826 * mad
    
    peaks, _ = signal.find_peaks(
        inverted,
        height=threshold,
        distance=int(0.003 * fps)
    )
    
    return peaks / fps
```

## 9.5.5 Artifact Detection and Removal

### UMAP + DBSCAN Artifact Clustering (from Bharioke et al.)
```python
def remove_artifact_spikes_umap(traces, spike_times_list, fps=200, window_ms=15):
    """
    Cluster spike waveforms using UMAP + DBSCAN.
    Remove clusters that appear across many neurons (likely artifacts).
    """
    import umap
    from sklearn.cluster import DBSCAN
    
    window = int(window_ms * fps / 1000)
    all_waveforms = []
    waveform_info = []  # (roi_idx, spike_idx)
    
    # Collect all waveforms
    for roi_idx, (trace, spike_times) in enumerate(zip(traces, spike_times_list)):
        for spike_idx, t in enumerate(spike_times):
            frame = int(t * fps)
            if window <= frame < len(trace) - window:
                wf = trace[frame-window:frame+window]
                wf = (wf - wf.mean()) / (wf.std() + 1e-10)
                all_waveforms.append(wf)
                waveform_info.append((roi_idx, spike_idx))
    
    if len(all_waveforms) < 10:
        return spike_times_list
    
    all_waveforms = np.array(all_waveforms)
    
    # UMAP embedding
    reducer = umap.UMAP(n_components=2, n_neighbors=15, min_dist=0.1)
    embedding = reducer.fit_transform(all_waveforms)
    
    # DBSCAN clustering
    clusterer = DBSCAN(eps=0.5, min_samples=5)
    labels = clusterer.fit_predict(embedding)
    
    # Find artifact clusters (present in >50% of ROIs)
    n_rois = len(traces)
    artifact_clusters = set()
    
    for cluster_id in set(labels):
        if cluster_id == -1:  # Noise
            continue
        cluster_mask = labels == cluster_id
        rois_in_cluster = set(waveform_info[i][0] for i in np.where(cluster_mask)[0])
        if len(rois_in_cluster) > 0.5 * n_rois:
            artifact_clusters.add(cluster_id)
    
    # Remove artifact spikes
    cleaned_spike_times = [list(st) for st in spike_times_list]
    for i, (roi_idx, spike_idx) in enumerate(waveform_info):
        if labels[i] in artifact_clusters:
            cleaned_spike_times[roi_idx][spike_idx] = None
    
    # Filter out None values
    cleaned_spike_times = [
        np.array([t for t in times if t is not None])
        for times in cleaned_spike_times
    ]
    
    return cleaned_spike_times
```

### Global Signal Regression
```python
def remove_global_signal(traces):
    """
    Remove global signal (first PC) that likely represents motion/illumination artifacts.
    """
    from sklearn.decomposition import PCA
    
    # Compute global signal (first PC)
    pca = PCA(n_components=1)
    global_signal = pca.fit_transform(traces.T).flatten()
    
    # Regress out from each trace
    cleaned = np.zeros_like(traces)
    for i, trace in enumerate(traces):
        # Linear regression
        coef = np.dot(trace, global_signal) / np.dot(global_signal, global_signal)
        cleaned[i] = trace - coef * global_signal
    
    return cleaned
```

## 9.5.6 Quality Control Metrics

### SNR Estimation per ROI
```python
def estimate_snr(trace, spike_frames, fps=200, noise_window_ms=50):
    """
    Estimate SNR as spike amplitude / baseline noise.
    """
    inverted = -trace
    noise_window = int(noise_window_ms * fps / 1000)
    
    # Get spike amplitudes
    amplitudes = inverted[spike_frames] if len(spike_frames) > 0 else []
    
    # Estimate noise from spike-free regions
    spike_free = np.ones(len(trace), dtype=bool)
    for f in spike_frames:
        spike_free[max(0,f-noise_window):min(len(trace),f+noise_window)] = False
    
    if spike_free.sum() > 100:
        noise_std = np.std(inverted[spike_free])
    else:
        noise_std = np.std(inverted)
    
    if len(amplitudes) > 0 and noise_std > 0:
        snr = np.median(amplitudes) / noise_std
    else:
        snr = 0
    
    return snr

def filter_rois_by_snr(traces, spike_times_list, min_snr=4):
    """Filter ROIs with SNR below threshold."""
    keep_idx = []
    for i, (trace, spikes) in enumerate(zip(traces, spike_times_list)):
        spike_frames = (np.array(spikes) * 200).astype(int)
        snr = estimate_snr(trace, spike_frames)
        if snr >= min_snr:
            keep_idx.append(i)
    return keep_idx
```

## 9.5.7 Complete Pipeline Examples from Literature

### Positron2 Pipeline (Wang et al. - Whole-brain Zebrafish Voltage Imaging)

This pipeline was used for imaging Positron2 across entire larval zebrafish brains.

#### Manual ROI Annotation Criteria
```python
# ROI selection criteria from the paper:
ROI_CRITERIA = {
    'size_pixels': 9,           # ~6.6 µm diameter
    'shape': 'ring-like',       # Membrane-targeted indicator
    'spike_width_max_ms': 20,   # Narrow spikes
    'spike_amplitude_std': 2.0, # >2× std of trace
    'spike_polarity': 'positive',  # Note: Positron2 has POSITIVE spikes (unlike Voltron-2)
}
```

#### Overlap Removal (Spatial Correlation)
```python
def remove_overlapping_rois(roi_masks, traces, correlation_threshold=0.9):
    """
    Remove ROIs that are likely duplicates based on spatial correlation.
    From Wang et al.: ROIs with >0.9 correlation are rejected.
    """
    n_rois = len(roi_masks)
    keep_mask = np.ones(n_rois, dtype=bool)
    
    # Flatten masks for correlation
    masks_flat = roi_masks.reshape(n_rois, -1).astype(float)
    
    # Compute pairwise spatial correlations
    for i in range(n_rois):
        if not keep_mask[i]:
            continue
        for j in range(i + 1, n_rois):
            if not keep_mask[j]:
                continue
            
            # Spatial correlation between masks
            corr = np.corrcoef(masks_flat[i], masks_flat[j])[0, 1]
            
            if corr > correlation_threshold:
                # Keep the one with higher SNR
                snr_i = np.std(traces[i])  # Simplified SNR proxy
                snr_j = np.std(traces[j])
                if snr_i >= snr_j:
                    keep_mask[j] = False
                else:
                    keep_mask[i] = False
    
    return np.where(keep_mask)[0]
```

#### Iterative UMAP-DBSCAN Artifact Removal
```python
def iterative_umap_dbscan_artifact_removal(spike_rasters, fps=200,
                                            hanning_window_ms=250,
                                            max_iterations=10):
    """
    Iterative UMAP-DBSCAN clustering to remove artifact-contaminated ROIs.
    From Wang et al.: Iterate until 2D UMAP representation is near-Gaussian.
    
    Args:
        spike_rasters: list of binary spike trains per ROI
        fps: frame rate
        hanning_window_ms: smoothing window for firing rate estimation
        max_iterations: maximum clustering iterations
    
    Returns:
        clean_roi_indices: indices of ROIs that passed artifact removal
        artifact_roi_indices: indices of ROIs flagged as artifacts
        cluster_labels: final cluster assignments for clean ROIs
    """
    import umap
    from sklearn.cluster import DBSCAN
    from scipy.signal import convolve
    from scipy.stats import shapiro
    
    n_rois = len(spike_rasters)
    n_samples = len(spike_rasters[0])
    
    # Create smoothed firing rate matrix
    window_samples = int(hanning_window_ms * fps / 1000)
    hanning = np.hanning(window_samples)
    hanning /= hanning.sum()
    
    firing_rates = np.zeros((n_rois, n_samples))
    for i, raster in enumerate(spike_rasters):
        firing_rates[i] = convolve(raster.astype(float), hanning, mode='same')
    
    # Track which ROIs are still candidates
    candidate_mask = np.ones(n_rois, dtype=bool)
    artifact_indices = []
    
    for iteration in range(max_iterations):
        current_indices = np.where(candidate_mask)[0]
        if len(current_indices) < 10:
            break
        
        current_rates = firing_rates[current_indices]
        
        # UMAP embedding
        reducer = umap.UMAP(n_components=2, n_neighbors=min(15, len(current_indices)-1),
                           min_dist=0.1, random_state=42)
        embedding = reducer.fit_transform(current_rates)
        
        # Check if distribution is approximately Gaussian (stopping criterion)
        # Use Shapiro-Wilk test on each dimension
        _, p_x = shapiro(embedding[:, 0]) if len(embedding) >= 3 else (0, 1)
        _, p_y = shapiro(embedding[:, 1]) if len(embedding) >= 3 else (0, 1)
        
        if p_x > 0.05 and p_y > 0.05:
            # Distribution is approximately Gaussian - stop iterating
            print(f"Iteration {iteration}: UMAP distribution is Gaussian, stopping.")
            break
        
        # DBSCAN clustering
        clusterer = DBSCAN(eps=0.5, min_samples=3)
        labels = clusterer.fit_predict(embedding)
        
        # Find outlier cluster (well-separated from main cluster)
        unique_labels = set(labels) - {-1}  # Exclude noise
        if len(unique_labels) <= 1:
            break
        
        # Main cluster = largest cluster
        cluster_sizes = {l: np.sum(labels == l) for l in unique_labels}
        main_cluster = max(cluster_sizes, key=cluster_sizes.get)
        
        # Find clusters that are well-separated (potential artifacts)
        main_centroid = embedding[labels == main_cluster].mean(axis=0)
        
        for cluster_id in unique_labels:
            if cluster_id == main_cluster:
                continue
            
            cluster_centroid = embedding[labels == cluster_id].mean(axis=0)
            distance = np.linalg.norm(cluster_centroid - main_centroid)
            
            # If cluster is far from main cluster, flag as artifact
            main_std = embedding[labels == main_cluster].std()
            if distance > 3 * main_std:
                # Mark these ROIs as artifacts
                artifact_mask = labels == cluster_id
                artifact_roi_indices = current_indices[artifact_mask]
                artifact_indices.extend(artifact_roi_indices)
                candidate_mask[artifact_roi_indices] = False
                print(f"Iteration {iteration}: Removed {len(artifact_roi_indices)} artifact ROIs")
    
    clean_indices = np.where(candidate_mask)[0]
    return clean_indices, np.array(artifact_indices)
```

### SNR Calculation (VolPy-style)
```python
def compute_snr_volpy(trace, spike_frames, fps=200):
    """
    Compute SNR as defined in VolPy.
    SNR = median(spike_amplitudes) / noise_std
    
    Noise is estimated from spike-free regions.
    """
    if len(spike_frames) == 0:
        return 0
    
    # For Voltron-2: invert trace (negative spikes)
    # For Positron2: use trace directly (positive spikes)
    
    # Get spike amplitudes (peak - local baseline)
    window = int(10 * fps / 1000)  # 10ms window
    amplitudes = []
    
    for frame in spike_frames:
        if frame < window or frame >= len(trace) - window:
            continue
        
        # Local baseline (before spike)
        baseline = np.median(trace[frame-window:frame-2])
        # Peak value
        peak = trace[frame]
        amplitudes.append(np.abs(peak - baseline))
    
    if len(amplitudes) == 0:
        return 0
    
    # Estimate noise from spike-free regions
    spike_free_mask = np.ones(len(trace), dtype=bool)
    for frame in spike_frames:
        start = max(0, frame - window)
        end = min(len(trace), frame + window)
        spike_free_mask[start:end] = False
    
    if spike_free_mask.sum() < 100:
        noise_std = np.std(trace)
    else:
        noise_std = np.std(trace[spike_free_mask])
    
    snr = np.median(amplitudes) / (noise_std + 1e-10)
    return snr

def filter_by_snr(traces, spike_times_list, fps=200, min_snr=4):
    """
    Filter ROIs by SNR threshold.
    Wang et al. used SNR > 4 as cutoff.
    """
    keep_indices = []
    snr_values = []
    
    for i, (trace, spike_times) in enumerate(zip(traces, spike_times_list)):
        spike_frames = (np.array(spike_times) * fps).astype(int)
        snr = compute_snr_volpy(trace, spike_frames, fps)
        snr_values.append(snr)
        
        if snr >= min_snr:
            keep_indices.append(i)
    
    return keep_indices, snr_values
```

### Second Paper Pipeline (Camera Noise Correction, Local PCA, Semi-NMF, LSTM)

If you have the second paper's methods available, please share them and I'll add detailed implementations for:
- Camera noise correction
- dipy-based motion correction  
- Local PCA denoising specifics
- Semi-NMF segmentation details
- LSTM spike detection
- Subthreshold activity extraction

In the meantime, here are general implementations of the key techniques mentioned:

#### Camera Noise Correction
```python
def correct_camera_noise(video, dark_frame=None, flat_field=None):
    """
    Correct for camera-specific noise patterns.
    
    Args:
        video: raw video (n_frames, height, width)
        dark_frame: average of frames with no illumination (fixed pattern noise)
        flat_field: normalized response to uniform illumination
    
    Returns:
        corrected video
    """
    corrected = video.astype(np.float32)
    
    # Subtract dark frame (fixed pattern noise)
    if dark_frame is not None:
        corrected -= dark_frame
    
    # Divide by flat field (pixel sensitivity variation)
    if flat_field is not None:
        # Normalize flat field
        flat_norm = flat_field / np.mean(flat_field)
        flat_norm = np.clip(flat_norm, 0.1, 10)  # Avoid extreme corrections
        corrected /= flat_norm
    
    # Remove hot pixels (outliers in spatial domain)
    from scipy.ndimage import median_filter
    for i in range(len(corrected)):
        frame = corrected[i]
        median_frame = median_filter(frame, size=3)
        # Replace pixels that deviate significantly from local median
        outlier_mask = np.abs(frame - median_frame) > 5 * np.std(frame)
        corrected[i][outlier_mask] = median_frame[outlier_mask]
    
    return corrected

def estimate_dark_frame(dark_video):
    """Estimate dark frame from a recording with no illumination."""
    return np.median(dark_video, axis=0)

def estimate_flat_field(flat_video):
    """Estimate flat field from uniform illumination recording."""
    mean_frame = np.mean(flat_video, axis=0)
    # Normalize to mean of 1
    return mean_frame / np.mean(mean_frame)
```

#### Background Subtraction Methods
```python
def subtract_background_annulus(video, roi_mask, inner_gap=2, outer_width=5):
    """
    Subtract local background using annulus around ROI.
    Common in voltage imaging to remove neuropil contamination.
    """
    from scipy.ndimage import binary_dilation, binary_erosion
    
    # Create annulus mask
    dilated = binary_dilation(roi_mask, iterations=inner_gap + outer_width)
    inner = binary_dilation(roi_mask, iterations=inner_gap)
    annulus = dilated & ~inner
    
    # Extract traces
    roi_trace = np.mean(video[:, roi_mask], axis=1)
    bg_trace = np.mean(video[:, annulus], axis=1) if annulus.sum() > 0 else 0
    
    # Subtract with optional coefficient
    return roi_trace - 0.7 * bg_trace  # 0.7 is typical neuropil coefficient

def subtract_background_percentile(trace, percentile=8, window_sec=1.0, fps=200):
    """
    Rolling percentile baseline subtraction.
    """
    from scipy.ndimage import percentile_filter
    window = int(window_sec * fps)
    baseline = percentile_filter(trace, percentile, size=window)
    return trace - baseline
```

#### Baseline Correction and dF/F Computation
```python
def compute_dff(trace, method='percentile', percentile=10, window_sec=None, fps=200):
    """
    Compute ΔF/F with various baseline methods.
    
    Methods:
        'percentile': global percentile as F0
        'rolling_percentile': sliding window percentile
        'exponential': fit exponential decay (for photobleaching)
    """
    if method == 'percentile':
        f0 = np.percentile(trace, percentile)
        dff = (trace - f0) / (f0 + 1e-10)
        
    elif method == 'rolling_percentile':
        from scipy.ndimage import percentile_filter
        window = int((window_sec or 2.0) * fps)
        f0 = percentile_filter(trace, percentile, size=window)
        dff = (trace - f0) / (f0 + 1e-10)
        
    elif method == 'exponential':
        # Fit exponential decay for photobleaching correction
        from scipy.optimize import curve_fit
        
        def exp_decay(t, a, b, c):
            return a * np.exp(-b * t) + c
        
        t = np.arange(len(trace))
        try:
            # Use robust fitting (median of rolling windows as targets)
            window = fps  # 1 second windows
            n_windows = len(trace) // window
            t_fit = np.array([i * window + window//2 for i in range(n_windows)])
            y_fit = np.array([np.median(trace[i*window:(i+1)*window]) for i in range(n_windows)])
            
            popt, _ = curve_fit(exp_decay, t_fit, y_fit,
                               p0=[trace[0]-trace[-1], 0.001, trace[-1]],
                               maxfev=1000)
            f0 = exp_decay(t, *popt)
        except:
            f0 = np.percentile(trace, percentile)
        
        dff = (trace - f0) / (f0 + 1e-10)
    
    return dff
```

#### Subthreshold Activity Extraction
```python
def extract_subthreshold(trace, spike_frames, fps=200,
                         interpolation_window_ms=20,
                         lowpass_cutoff_hz=10):
    """
    Extract subthreshold membrane potential by removing spikes and lowpass filtering.
    
    1. Interpolate over spike regions
    2. Lowpass filter to get slow membrane potential fluctuations
    """
    from scipy.interpolate import interp1d
    from scipy.signal import butter, filtfilt
    
    # Create mask of spike regions to interpolate
    window = int(interpolation_window_ms * fps / 1000)
    spike_mask = np.zeros(len(trace), dtype=bool)
    for frame in spike_frames:
        start = max(0, frame - window//2)
        end = min(len(trace), frame + window//2)
        spike_mask[start:end] = True
    
    # Interpolate over spikes
    x = np.arange(len(trace))
    valid_x = x[~spike_mask]
    valid_y = trace[~spike_mask]
    
    if len(valid_x) < 10:
        interpolated = trace.copy()
    else:
        f = interp1d(valid_x, valid_y, kind='linear',
                     bounds_error=False, fill_value='extrapolate')
        interpolated = f(x)
    
    # Lowpass filter
    nyquist = fps / 2
    b, a = butter(4, lowpass_cutoff_hz / nyquist, btype='low')
    subthreshold = filtfilt(b, a, interpolated)
    
    return subthreshold
```

#### Complete Pipeline Function
```python
def voltage_imaging_pipeline_full(video, fps=200, config=None):
    """
    Full voltage imaging pipeline with all stages.
    
    Args:
        video: (n_frames, height, width) raw video
        fps: frame rate
        config: dict of parameters for each stage
    
    Returns:
        dict with roi_masks, traces, spike_times, subthreshold, metadata
    """
    if config is None:
        config = {}
    
    results = {'metadata': {'fps': fps, 'config': config}}
    
    # === Stage 0: Camera noise correction ===
    if config.get('camera_correction', False):
        video = correct_camera_noise(video)
    
    # === Stage 1: Motion correction ===
    mc_method = config.get('motion_correction', 'template')
    if mc_method == 'template':
        video, shifts = motion_correct_template(video, max_shift=50)
        results['metadata']['motion_shifts'] = shifts
    elif mc_method == 'none':
        pass
    
    # === Stage 2: Denoising ===
    denoise_method = config.get('denoising', 'local_pca')
    if denoise_method == 'local_pca':
        video = local_pca_denoise(video,
                                   patch_size=config.get('pca_patch_size', 32),
                                   n_components=config.get('pca_components', 10))
    elif denoise_method == 'temporal_median':
        video = temporal_median_filter(video, window=3)
    
    # === Stage 3: ROI Segmentation ===
    seg_method = config.get('segmentation', 'correlation')
    std_proj = np.std(video, axis=0)
    
    if seg_method == 'correlation':
        corr_img = compute_correlation_image(video)
        roi_masks = segment_from_correlation(corr_img, std_proj)
    elif seg_method == 'ring':
        roi_masks = detect_ring_rois(std_proj)
    elif seg_method == 'threshold':
        # Simple threshold segmentation
        thresh = np.mean(std_proj) + 1.5 * np.std(std_proj)
        binary = std_proj > thresh
        binary = morphology.remove_small_objects(binary, min_size=30)
        labeled = measure.label(binary)
        roi_masks = np.array([labeled == i for i in range(1, labeled.max()+1)])
    
    results['roi_masks'] = roi_masks
    results['metadata']['n_rois_initial'] = len(roi_masks)
    
    # === Stage 4: Trace extraction ===
    traces = []
    for mask in roi_masks:
        if config.get('background_subtraction', 'annulus') == 'annulus':
            trace = subtract_background_annulus(video, mask)
        else:
            trace = np.mean(video[:, mask], axis=1)
        
        # Compute dF/F
        dff = compute_dff(trace,
                         method=config.get('baseline_method', 'rolling_percentile'),
                         fps=fps)
        traces.append(dff)
    
    traces = np.array(traces)
    results['traces'] = traces
    
    # === Stage 5: Spike detection ===
    spike_method = config.get('spike_detection', 'adaptive')
    spike_times = []
    spike_frames = []
    
    for trace in traces:
        if spike_method == 'adaptive':
            times, _ = detect_spikes_adaptive(trace, fps)
        elif spike_method == 'mad':
            times = detect_spikes_mad(trace, fps)
        elif spike_method == 'template':
            times = detect_spikes_template(trace, fps)
        else:
            # Simple threshold
            inverted = -trace
            threshold = np.mean(inverted) + 3 * np.std(inverted)
            peaks, _ = signal.find_peaks(inverted, height=threshold)
            times = peaks / fps
        
        spike_times.append(times)
        spike_frames.append((times * fps).astype(int))
    
    results['spike_times'] = spike_times
    results['spike_frames'] = spike_frames
    
    # === Stage 6: Quality filtering ===
    # SNR filter
    if config.get('snr_filter', True):
        keep_idx, snr_values = filter_by_snr(traces, spike_times, fps,
                                              min_snr=config.get('min_snr', 4))
        results['metadata']['snr_values'] = snr_values
    else:
        keep_idx = list(range(len(roi_masks)))
    
    # Overlap removal
    if config.get('overlap_removal', True):
        keep_idx_overlap = remove_overlapping_rois(
            roi_masks[keep_idx],
            traces[keep_idx],
            correlation_threshold=0.9
        )
        keep_idx = [keep_idx[i] for i in keep_idx_overlap]
    
    results['metadata']['n_rois_final'] = len(keep_idx)
    results['keep_indices'] = keep_idx
    
    # === Stage 7: Artifact removal (optional) ===
    if config.get('artifact_removal', False) and len(keep_idx) > 10:
        # Create spike rasters
        n_frames = video.shape[0]
        spike_rasters = []
        for idx in keep_idx:
            raster = np.zeros(n_frames)
            raster[spike_frames[idx]] = 1
            spike_rasters.append(raster)
        
        clean_idx, artifact_idx = iterative_umap_dbscan_artifact_removal(spike_rasters, fps)
        keep_idx = [keep_idx[i] for i in clean_idx]
        results['metadata']['n_rois_after_artifact_removal'] = len(keep_idx)
    
    # === Stage 8: Subthreshold extraction (optional) ===
    if config.get('extract_subthreshold', False):
        subthreshold = []
        for idx in keep_idx:
            sub = extract_subthreshold(traces[idx], spike_frames[idx], fps)
            subthreshold.append(sub)
        results['subthreshold'] = np.array(subthreshold)
    
    # Filter results to kept ROIs
    results['roi_masks_filtered'] = roi_masks[keep_idx]
    results['traces_filtered'] = traces[keep_idx]
    results['spike_times_filtered'] = [spike_times[i] for i in keep_idx]
    
    return results
```

### Kawashima et al. Pipeline (Voltron in Zebrafish - Raphe Serotonin Study)

This pipeline was used for voltage imaging of Voltron in zebrafish raphe neurons. Notable for detailed camera noise correction, local PCA denoising with whiteness stopping criterion, semi-NMF segmentation, and LSTM spike detection.

#### Step 1: Camera Noise Correction (Huang/Liu method)
```python
def camera_noise_correction_kawashima(video, dark_frames, calibration_data=None):
    """
    Camera noise correction from Kawashima et al.
    
    Corrected pixel: s_i = (s_i^r - o_i) / g_i
    
    Where:
    - s_i^r: raw readout at pixel i
    - o_i: camera offset (mean of dark frames)
    - g_i: camera gain (fitted from variance vs intensity relationship)
    
    Args:
        video: raw video (n_frames, height, width)
        dark_frames: video acquired with no illumination (for offset estimation)
        calibration_data: dict with multiple illumination levels for gain estimation
            {'level_k': {'mean': D_ik, 'variance': v_ik}, ...}
    
    Returns:
        corrected video
    """
    # Estimate offset (o_i) from dark frames
    offset = np.mean(dark_frames, axis=0)  # Mean of ~60k dark images
    baseline_variance = np.var(dark_frames, axis=0)  # Variance in dark condition
    
    # Estimate gain (g_i) from calibration data at multiple illumination levels
    if calibration_data is not None:
        # Fit gain: minimize sum_k ((v_ik - v_i) - g_i * (D_ik - o_i))^2
        # This is linear regression: variance_excess = gain * intensity_excess
        
        h, w = offset.shape
        gain = np.ones((h, w), dtype=np.float32)
        
        levels = sorted(calibration_data.keys())
        for y in range(h):
            for x in range(w):
                # Collect (intensity, variance) pairs across illumination levels
                intensities = []
                variances = []
                for level in levels:
                    D_ik = calibration_data[level]['mean'][y, x]
                    v_ik = calibration_data[level]['variance'][y, x]
                    
                    intensity_excess = D_ik - offset[y, x]
                    variance_excess = v_ik - baseline_variance[y, x]
                    
                    if intensity_excess > 0:
                        intensities.append(intensity_excess)
                        variances.append(variance_excess)
                
                if len(intensities) >= 2:
                    # Linear regression through origin: variance = gain * intensity
                    intensities = np.array(intensities)
                    variances = np.array(variances)
                    gain[y, x] = np.sum(intensities * variances) / np.sum(intensities ** 2)
                    gain[y, x] = np.clip(gain[y, x], 0.1, 10)  # Reasonable bounds
    else:
        # Without calibration data, assume uniform gain
        gain = np.ones_like(offset)
    
    # Apply correction: s_i = (s_i^r - o_i) / g_i
    corrected = (video.astype(np.float32) - offset) / gain
    
    return corrected, {'offset': offset, 'gain': gain, 'baseline_variance': baseline_variance}

def estimate_camera_calibration(calibration_videos):
    """
    Estimate camera calibration from videos at multiple illumination levels.
    
    Args:
        calibration_videos: dict mapping illumination level to video array
            e.g., {0: dark_video, 5: video_5mW, 10: video_10mW, 18: video_18mW}
    
    Returns:
        calibration_data for camera_noise_correction_kawashima
    """
    calibration_data = {}
    for level, video in calibration_videos.items():
        calibration_data[level] = {
            'mean': np.mean(video, axis=0),
            'variance': np.var(video, axis=0)
        }
    return calibration_data
```

#### Step 2: Motion Correction (dipy-based)
```python
def motion_correct_dipy(video, reference='mean', subpixel=True):
    """
    2D rigid registration using dipy package.
    Subpixel-level correction for sequential images.
    
    Args:
        video: motion-corrected video (n_frames, height, width)
        reference: 'mean', 'first', or frame index
        subpixel: if True, use subpixel registration
    
    Returns:
        corrected video, shifts array
    """
    try:
        from dipy.align.imaffine import AffineRegistration, MutualInformationMetric
        from dipy.align.transforms import TranslationTransform2D
        USE_DIPY = True
    except ImportError:
        USE_DIPY = False
        print("dipy not installed, falling back to skimage")
    
    n_frames, h, w = video.shape
    
    # Create reference frame
    if reference == 'mean':
        ref_frame = np.mean(video, axis=0)
    elif reference == 'first':
        ref_frame = video[0]
    elif isinstance(reference, int):
        ref_frame = video[reference]
    else:
        ref_frame = np.mean(video, axis=0)
    
    corrected = np.zeros_like(video, dtype=np.float32)
    shifts = np.zeros((n_frames, 2))
    
    if USE_DIPY:
        # dipy-based registration
        metric = MutualInformationMetric(nbins=32)
        affreg = AffineRegistration(metric=metric)
        transform = TranslationTransform2D()
        
        for i in range(n_frames):
            frame = video[i].astype(np.float64)
            ref = ref_frame.astype(np.float64)
            
            # Register
            affine = affreg.optimize(ref, frame, transform, params0=None)
            
            # Extract translation
            shifts[i] = affine.affine[:2, 2]
            
            # Apply transformation
            from scipy.ndimage import affine_transform
            corrected[i] = affine_transform(frame, affine.affine[:2, :2],
                                            offset=affine.affine[:2, 2])
    else:
        # Fallback to skimage phase_cross_correlation
        from skimage.registration import phase_cross_correlation
        from scipy.ndimage import shift as ndi_shift
        
        for i in range(n_frames):
            frame = video[i].astype(np.float64)
            
            shift, error, diffphase = phase_cross_correlation(
                ref_frame, frame, upsample_factor=10 if subpixel else 1
            )
            shifts[i] = shift
            corrected[i] = ndi_shift(frame, shift, mode='constant', cval=0)
    
    return corrected, shifts
```

#### Step 3: Local PCA Denoising (with whiteness criterion)
```python
def local_pca_denoise_kawashima(video, patch_size=64, overlap=32, max_components=50):
    """
    Local PCA denoising with automatic rank selection.
    
    Stop adding components when residual is statistically white
    (within 99% confidence interval).
    
    From Kawashima et al.:
    - Find low-rank representation Y_bar = sum_k u_k * v_k
    - Stop when residual R_k = Y - Y_bar is white noise
    
    Args:
        video: (n_frames, height, width) motion-corrected video
        patch_size: size of spatial patches
        overlap: overlap between patches
        max_components: maximum number of PCA components
    
    Returns:
        denoised video
    """
    from scipy.stats import chi2
    
    n_frames, h, w = video.shape
    stride = patch_size - overlap
    
    # Output arrays
    denoised = np.zeros_like(video, dtype=np.float64)
    weights = np.zeros((h, w), dtype=np.float64)
    
    def is_white_noise(residual, confidence=0.99):
        """
        Test if residual is statistically white using Ljung-Box test.
        For simplicity, we check if autocorrelation at lag 1 is near zero.
        """
        # Flatten spatial dimensions
        n_frames, n_pixels = residual.shape
        
        # Sample some pixels for efficiency
        sample_idx = np.random.choice(n_pixels, min(100, n_pixels), replace=False)
        
        autocorrs = []
        for idx in sample_idx:
            trace = residual[:, idx]
            if np.std(trace) > 1e-10:
                # Lag-1 autocorrelation
                autocorr = np.corrcoef(trace[:-1], trace[1:])[0, 1]
                autocorrs.append(autocorr)
        
        if len(autocorrs) == 0:
            return True
        
        # For white noise, autocorrelation should be ~N(0, 1/n_frames)
        # Check if mean absolute autocorrelation is small
        mean_abs_autocorr = np.mean(np.abs(autocorrs))
        threshold = 2.58 / np.sqrt(n_frames)  # 99% CI for autocorr of white noise
        
        return mean_abs_autocorr < threshold
    
    # Process each patch
    for y_start in range(0, h - patch_size + 1, stride):
        for x_start in range(0, w - patch_size + 1, stride):
            # Extract patch: (n_frames, patch_size, patch_size)
            patch = video[:, y_start:y_start+patch_size, x_start:x_start+patch_size]
            patch_flat = patch.reshape(n_frames, -1)  # (n_frames, n_pixels)
            
            # Center the data
            mean_vals = patch_flat.mean(axis=0)
            centered = patch_flat - mean_vals
            
            # Iterative PCA with whiteness stopping criterion
            U, S, Vt = np.linalg.svd(centered, full_matrices=False)
            
            # Find optimal number of components
            best_k = 1
            for k in range(1, min(max_components, len(S))):
                # Reconstruct with k components
                reconstructed = U[:, :k] @ np.diag(S[:k]) @ Vt[:k, :]
                residual = centered - reconstructed
                
                if is_white_noise(residual):
                    best_k = k
                    break
                best_k = k
            
            # Final reconstruction with best_k components
            reconstructed = U[:, :best_k] @ np.diag(S[:best_k]) @ Vt[:best_k, :]
            reconstructed += mean_vals
            reconstructed = reconstructed.reshape(n_frames, patch_size, patch_size)
            
            # Add to output with overlap weighting
            denoised[:, y_start:y_start+patch_size, x_start:x_start+patch_size] += reconstructed
            weights[y_start:y_start+patch_size, x_start:x_start+patch_size] += 1
    
    # Handle edges (patches that don't fit)
    # ... (simplified: just use available data)
    
    # Normalize by overlap count
    weights = np.maximum(weights, 1)  # Avoid division by zero
    denoised /= weights[np.newaxis, :, :]
    
    return denoised.astype(np.float32)
```

#### Step 4: Semi-NMF Cell Segmentation
```python
def semi_nmf_segmentation_kawashima(video, n_components=100, correlation_threshold=0.8,
                                     max_iter=500, min_roi_size=20):
    """
    Semi-nonnegative matrix factorization for cell segmentation.
    
    From Kawashima et al.:
    minimize ||Y - A*F - B||^2
    subject to: A >= 0, B = b * 1^T, b >= 0
    
    Where:
    - A: spatial components (non-negative ROI masks)
    - F: temporal components (can be negative - voltage traces)
    - B: temporally constant background
    
    Initialization: super-pixels with local correlation > 0.8
    
    Args:
        video: denoised video (n_frames, height, width)
        n_components: number of components to extract
        correlation_threshold: for super-pixel initialization
        max_iter: maximum iterations for optimization
        min_roi_size: minimum ROI size in pixels
    
    Returns:
        roi_masks: (n_rois, height, width) boolean masks
        temporal_components: (n_rois, n_frames) fluorescence traces
        background: (height, width) static background
    """
    n_frames, h, w = video.shape
    n_pixels = h * w
    
    # Reshape video: Y is (n_pixels, n_frames)
    Y = video.reshape(n_frames, n_pixels).T.astype(np.float64)
    
    # === Step 1: Initialize with super-pixels ===
    def compute_local_correlation_map(video, radius=2):
        """Compute local correlation for each pixel."""
        corr_map = np.zeros((h, w))
        video_norm = (video - video.mean(axis=0)) / (video.std(axis=0) + 1e-10)
        
        for dy in range(-radius, radius+1):
            for dx in range(-radius, radius+1):
                if dy == 0 and dx == 0:
                    continue
                shifted = np.roll(np.roll(video_norm, dy, axis=1), dx, axis=2)
                corr = (video_norm * shifted).mean(axis=0)
                corr_map += corr
        
        corr_map /= (2*radius + 1)**2 - 1
        return corr_map
    
    def find_super_pixels(video, corr_threshold=0.8):
        """Find super-pixels: connected regions with high local correlation."""
        corr_map = compute_local_correlation_map(video)
        binary = corr_map > corr_threshold
        
        from scipy.ndimage import label
        labeled, n_labels = label(binary)
        
        super_pixels = []
        for i in range(1, n_labels + 1):
            mask = labeled == i
            if mask.sum() >= min_roi_size:
                super_pixels.append(mask)
        
        return super_pixels
    
    super_pixels = find_super_pixels(video, correlation_threshold)
    n_init = min(len(super_pixels), n_components)
    
    if n_init == 0:
        print("Warning: No super-pixels found, using random initialization")
        # Random initialization
        n_init = n_components
        super_pixels = []
        for _ in range(n_init):
            mask = np.zeros((h, w), dtype=bool)
            cy, cx = np.random.randint(10, h-10), np.random.randint(10, w-10)
            mask[cy-3:cy+3, cx-3:cx+3] = True
            super_pixels.append(mask)
    
    # Initialize A (spatial components)
    A = np.zeros((n_pixels, n_init))
    for i, mask in enumerate(super_pixels[:n_init]):
        A[:, i] = mask.flatten().astype(np.float64)
    
    # Initialize background b (mean of low-variance pixels)
    pixel_var = Y.var(axis=1)
    low_var_mask = pixel_var < np.percentile(pixel_var, 20)
    b = Y[low_var_mask].mean(axis=1) if low_var_mask.sum() > 0 else Y.mean(axis=1)
    b = np.maximum(b, 0)  # Non-negative background
    
    # B = b * 1^T
    B = np.outer(b, np.ones(n_frames)) if len(b) == n_pixels else np.zeros((n_pixels, n_frames))
    
    # === Step 2: Alternating optimization ===
    for iteration in range(max_iter):
        # Update F (temporal components): F = (A^T A)^-1 A^T (Y - B)
        Y_bg = Y - B
        AtA = A.T @ A + 1e-6 * np.eye(A.shape[1])  # Regularization
        AtY = A.T @ Y_bg
        F = np.linalg.solve(AtA, AtY)
        
        # Update A (spatial components): A = (Y - B) F^T (F F^T)^-1, then clip to >= 0
        FFt = F @ F.T + 1e-6 * np.eye(F.shape[0])
        YFt = Y_bg @ F.T
        A = np.linalg.solve(FFt.T, YFt.T).T
        A = np.maximum(A, 0)  # Non-negative constraint
        
        # Update background b: b = mean((Y - A*F), axis=1), clipped to >= 0
        residual = Y - A @ F
        b = residual.mean(axis=1)
        b = np.maximum(b, 0)
        B = np.outer(b, np.ones(n_frames))
        
        # Check convergence
        if iteration % 50 == 0:
            reconstruction_error = np.linalg.norm(Y - A @ F - B, 'fro')
            print(f"Iteration {iteration}: reconstruction error = {reconstruction_error:.2f}")
    
    # === Step 3: Extract ROI masks ===
    roi_masks = []
    temporal_components = []
    
    for i in range(A.shape[1]):
        spatial = A[:, i].reshape(h, w)
        temporal = F[i, :]
        
        # Threshold spatial component to get mask
        thresh = np.percentile(spatial[spatial > 0], 90) if (spatial > 0).sum() > 0 else 0
        mask = spatial > thresh
        
        # Clean up mask
        mask = morphology.remove_small_objects(mask, min_size=min_roi_size)
        
        if mask.sum() >= min_roi_size:
            roi_masks.append(mask)
            temporal_components.append(temporal)
    
    return (np.array(roi_masks),
            np.array(temporal_components),
            b.reshape(h, w))
```

#### Step 5: LSTM Spike Detection
```python
def create_lstm_spike_detector(window_size=41, hidden_size=64, dropout=0.3):
    """
    Create LSTM network for spike detection.
    
    From Kawashima et al.:
    - Two LSTM layers with dropout between them
    - Input: time series window (41 frames = 136.67 ms at 300 Hz)
    - Output: probability of spike at center of window
    - Trained on simultaneous ephys + imaging data
    
    Args:
        window_size: input window size (41 frames in paper)
        hidden_size: LSTM hidden layer size
        dropout: dropout rate between LSTM layers
    
    Returns:
        PyTorch model (or TensorFlow/Keras equivalent)
    """
    try:
        import torch
        import torch.nn as nn
        
        class LSTMSpikeDetector(nn.Module):
            def __init__(self, input_size=1, hidden_size=64, dropout=0.3):
                super().__init__()
                self.lstm1 = nn.LSTM(input_size, hidden_size, batch_first=True)
                self.dropout = nn.Dropout(dropout)
                self.lstm2 = nn.LSTM(hidden_size, hidden_size, batch_first=True)
                self.fc = nn.Linear(hidden_size, 1)
                self.sigmoid = nn.Sigmoid()
            
            def forward(self, x):
                # x: (batch, window_size, 1)
                out, _ = self.lstm1(x)
                out = self.dropout(out)
                out, _ = self.lstm2(out)
                # Take output at last time step
                out = self.fc(out[:, -1, :])
                return self.sigmoid(out)
        
        return LSTMSpikeDetector(hidden_size=hidden_size, dropout=dropout)
    
    except ImportError:
        # Keras/TensorFlow fallback
        try:
            from tensorflow import keras
            from tensorflow.keras import layers
            
            model = keras.Sequential([
                layers.LSTM(hidden_size, return_sequences=True,
                           input_shape=(window_size, 1)),
                layers.Dropout(dropout),
                layers.LSTM(hidden_size),
                layers.Dense(1, activation='sigmoid')
            ])
            model.compile(optimizer='adam', loss='binary_crossentropy',
                         metrics=['accuracy'])
            return model
        
        except ImportError:
            print("Neither PyTorch nor TensorFlow available")
            return None

def detect_spikes_lstm(trace, model, fps=300, window_size=41, threshold=0.5):
    """
    Detect spikes using trained LSTM model.
    
    Args:
        trace: ΔF/F trace (n_frames,)
        model: trained LSTM spike detector
        fps: frame rate (300 Hz in Kawashima paper)
        window_size: input window size
        threshold: probability threshold for spike detection
    
    Returns:
        spike_times: array of spike times in seconds
    """
    import torch
    
    n_frames = len(trace)
    half_window = window_size // 2
    
    # Normalize trace
    trace_norm = (trace - np.mean(trace)) / (np.std(trace) + 1e-10)
    
    # Slide window across trace
    spike_probs = np.zeros(n_frames)
    
    model.eval()
    with torch.no_grad():
        for i in range(half_window, n_frames - half_window):
            window = trace_norm[i - half_window:i + half_window + 1]
            window_tensor = torch.FloatTensor(window).unsqueeze(0).unsqueeze(-1)
            
            prob = model(window_tensor).item()
            spike_probs[i] = prob
    
    # Find peaks above threshold
    from scipy.signal import find_peaks
    peaks, _ = find_peaks(spike_probs, height=threshold, distance=int(0.002 * fps))
    
    spike_times = peaks / fps
    return spike_times, spike_probs

def train_lstm_spike_detector(traces, ground_truth_spikes, fps=300,
                               window_size=41, epochs=100, batch_size=64):
    """
    Train LSTM spike detector on paired imaging + electrophysiology data.
    
    Args:
        traces: list of ΔF/F traces
        ground_truth_spikes: list of spike frame indices from electrophysiology
        fps: frame rate
        window_size: input window size
        epochs: training epochs
        batch_size: training batch size
    
    Returns:
        trained model
    """
    import torch
    from torch.utils.data import DataLoader, TensorDataset
    
    # Create training data
    X = []
    y = []
    half_window = window_size // 2
    
    for trace, spikes in zip(traces, ground_truth_spikes):
        trace_norm = (trace - np.mean(trace)) / (np.std(trace) + 1e-10)
        spike_set = set(spikes)
        
        for i in range(half_window, len(trace) - half_window):
            window = trace_norm[i - half_window:i + half_window + 1]
            label = 1.0 if i in spike_set else 0.0
            X.append(window)
            y.append(label)
    
    X = np.array(X)
    y = np.array(y)
    
    # Balance classes (spike vs no-spike)
    spike_idx = np.where(y == 1)[0]
    no_spike_idx = np.where(y == 0)[0]
    n_samples = min(len(spike_idx), len(no_spike_idx))
    
    balanced_idx = np.concatenate([
        spike_idx[:n_samples],
        np.random.choice(no_spike_idx, n_samples, replace=False)
    ])
    np.random.shuffle(balanced_idx)
    
    X = X[balanced_idx]
    y = y[balanced_idx]
    
    # Create model and train
    model = create_lstm_spike_detector(window_size=window_size)
    
    X_tensor = torch.FloatTensor(X).unsqueeze(-1)
    y_tensor = torch.FloatTensor(y).unsqueeze(-1)
    
    dataset = TensorDataset(X_tensor, y_tensor)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    criterion = torch.nn.BCELoss()
    
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for batch_X, batch_y in loader:
            optimizer.zero_grad()
            output = model(batch_X)
            loss = criterion(output, batch_y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        
        if epoch % 10 == 0:
            print(f"Epoch {epoch}: loss = {total_loss/len(loader):.4f}")
    
    return model
```

#### Step 6: Subthreshold Activity Estimation
```python
def extract_subthreshold_kawashima(trace, spike_frames, fps=300,
                                    median_window_ms=70, clip_window_frames=2):
    """
    Estimate subthreshold activity using rolling median.
    
    From Kawashima et al.:
    - Rolling median filter with 70 ms window
    - Clip frames around spikes (-1 to +1 frames) to avoid spike nonlinearity
    
    Args:
        trace: ΔF/F trace
        spike_frames: detected spike frame indices
        fps: frame rate
        median_window_ms: median filter window (70 ms in paper)
        clip_window_frames: frames to clip around spikes (1 in paper)
    
    Returns:
        subthreshold: subthreshold membrane potential estimate
    """
    from scipy.ndimage import median_filter
    
    n_frames = len(trace)
    
    # Create mask of frames to exclude (around spikes)
    valid_mask = np.ones(n_frames, dtype=bool)
    for spike_frame in spike_frames:
        start = max(0, spike_frame - clip_window_frames)
        end = min(n_frames, spike_frame + clip_window_frames + 1)
        valid_mask[start:end] = False
    
    # Interpolate over clipped regions
    trace_interpolated = trace.copy()
    if not valid_mask.all():
        valid_indices = np.where(valid_mask)[0]
        invalid_indices = np.where(~valid_mask)[0]
        
        if len(valid_indices) > 2:
            from scipy.interpolate import interp1d
            f = interp1d(valid_indices, trace[valid_indices],
                        kind='linear', bounds_error=False,
                        fill_value='extrapolate')
            trace_interpolated[invalid_indices] = f(invalid_indices)
    
    # Apply rolling median filter
    window_samples = int(median_window_ms * fps / 1000)
    window_samples = window_samples if window_samples % 2 == 1 else window_samples + 1
    
    subthreshold = median_filter(trace_interpolated, size=window_samples)
    
    return subthreshold
```

#### ΔF/F Computation (Kawashima method)
```python
def compute_dff_kawashima(trace, fps=300, baseline_percentile=20,
                           baseline_window_sec=180):
    """
    Compute ΔF/F using running percentile baseline.
    
    From Kawashima et al.:
    ΔF/F = (F - F0) / F0
    where F0 is running 20th percentile within 3-minute window
    
    For Voltron (negative indicator): use -ΔF/F for analysis
    
    Args:
        trace: raw fluorescence trace
        fps: frame rate
        baseline_percentile: percentile for baseline (20 in paper)
        baseline_window_sec: window size in seconds (180 = 3 minutes)
    
    Returns:
        dff: ΔF/F trace
        dff_inverted: -ΔF/F (for Voltron analysis)
    """
    from scipy.ndimage import percentile_filter
    
    window_samples = int(baseline_window_sec * fps)
    
    # Running percentile baseline
    F0 = percentile_filter(trace, baseline_percentile, size=window_samples)
    
    # Compute ΔF/F
    dff = (trace - F0) / (F0 + 1e-10)
    
    # Inverted for Voltron (negative indicator)
    dff_inverted = -dff
    
    return dff, dff_inverted
```

#### Behavioral Kernel Fitting (GLM for spike prediction)
```python
def fit_spike_glm_kernels(spike_times, behavior_data, fps=300,
                          history_sec=1.0, l2_reg=0.01):
    """
    Fit GLM kernels to predict spikes from behavioral variables.
    
    From Kawashima et al.:
    P(spike at time t) = Binomial(w_s^T * S_t + w_v^T * V_t - w_sp^T * SP_t)
    
    Where:
    - S_t: swim vigor history
    - V_t: visual input history  
    - SP_t: recent spike history
    - w_s, w_v, w_sp: learned kernels
    
    Args:
        spike_times: array of spike times in seconds
        behavior_data: dict with 'swim_vigor', 'visual_input' arrays
        fps: frame rate
        history_sec: history window for kernels (1 second in paper)
        l2_reg: L2 regularization strength
    
    Returns:
        kernels: dict with fitted w_s, w_v, w_sp
        model: fitted sklearn LogisticRegression
        performance: explained variance on validation set
    """
    from sklearn.linear_model import LogisticRegression
    from sklearn.model_selection import cross_val_score
    
    n_history = int(history_sec * fps)  # 300 time points for 1 sec at 300 Hz
    
    swim_vigor = behavior_data.get('swim_vigor', np.zeros(10000))
    visual_input = behavior_data.get('visual_input', np.zeros(10000))
    n_frames = len(swim_vigor)
    
    # Create spike raster
    spike_raster = np.zeros(n_frames)
    spike_frames = (np.array(spike_times) * fps).astype(int)
    spike_frames = spike_frames[(spike_frames >= 0) & (spike_frames < n_frames)]
    spike_raster[spike_frames] = 1
    
    # Build design matrix
    # Use sqrt of swim vigor (as in paper)
    swim_vigor_sqrt = np.sqrt(np.abs(swim_vigor)) * np.sign(swim_vigor)
    
    X = []
    y = []
    
    for t in range(n_history, n_frames):
        # Swim vigor history
        s_t = swim_vigor_sqrt[t-n_history:t]
        # Visual input history
        v_t = visual_input[t-n_history:t]
        # Spike history
        sp_t = spike_raster[t-n_history:t]
        
        features = np.concatenate([s_t, v_t, sp_t])
        X.append(features)
        y.append(spike_raster[t])
    
    X = np.array(X)
    y = np.array(y)
    
    # Balance classes (sample around spike events)
    spike_idx = np.where(y == 1)[0]
    no_spike_idx = np.where(y == 0)[0]
    
    # Sample to balance ~50/50
    n_samples = min(len(spike_idx) * 2, len(no_spike_idx))
    balanced_no_spike = np.random.choice(no_spike_idx, n_samples, replace=False)
    balanced_idx = np.concatenate([spike_idx, balanced_no_spike])
    np.random.shuffle(balanced_idx)
    
    X_balanced = X[balanced_idx]
    y_balanced = y[balanced_idx]
    
    # Fit logistic regression with L2 regularization
    model = LogisticRegression(penalty='l2', C=1/l2_reg, max_iter=1000)
    model.fit(X_balanced, y_balanced)
    
    # Extract kernels
    coef = model.coef_[0]
    kernels = {
        'swim_vigor': coef[:n_history],
        'visual_input': coef[n_history:2*n_history],
        'spike_history': coef[2*n_history:]
    }
    
    # Cross-validation performance
    cv_scores = cross_val_score(model, X_balanced, y_balanced, cv=5, scoring='roc_auc')
    
    return kernels, model, {'cv_auc': cv_scores.mean(), 'cv_std': cv_scores.std()}

def plot_behavioral_kernels(kernels, fps=300):
    """Plot fitted behavioral kernels."""
    fig, axes = plt.subplots(1, 3, figsize=(12, 3))
    
    time_ms = np.arange(len(kernels['swim_vigor'])) * 1000 / fps - len(kernels['swim_vigor']) * 1000 / fps
    
    axes[0].plot(time_ms, kernels['swim_vigor'])
    axes[0].set_xlabel('Time before spike (ms)')
    axes[0].set_ylabel('Weight')
    axes[0].set_title('Swim Vigor Kernel')
    axes[0].axhline(0, color='gray', linestyle='--', alpha=0.5)
    
    axes[1].plot(time_ms, kernels['visual_input'])
    axes[1].set_xlabel('Time before spike (ms)')
    axes[1].set_title('Visual Input Kernel')
    axes[1].axhline(0, color='gray', linestyle='--', alpha=0.5)
    
    axes[2].plot(time_ms, -kernels['spike_history'])  # Negative because it's subtracted
    axes[2].set_xlabel('Time before spike (ms)')
    axes[2].set_title('Spike History Kernel (refractory)')
    axes[2].axhline(0, color='gray', linestyle='--', alpha=0.5)
    
    plt.tight_layout()
    return fig
```

#### Population Coding Analysis (sparse LDA)
```python
def compute_population_coding(neural_activity, conditions, l2_gamma=0.5):
    """
    Compute population coding direction using sparse linear discriminant analysis.
    
    From Kawashima et al.:
    l = argmin_l -(l^T (r_high - r_low))^2 / (l^T Σ_r l)
    
    With regularized covariance: Σ_r = (1-γ) * cov(r) + γ * I
    
    Args:
        neural_activity: (n_neurons, n_timepoints) firing rates
        conditions: array of condition labels (e.g., 'high' or 'low')
        l2_gamma: regularization parameter [0, 1]
    
    Returns:
        coding_direction: (n_neurons,) vector
        explained_variance: fraction of variance explained
    """
    unique_conditions = np.unique(conditions)
    if len(unique_conditions) != 2:
        raise ValueError("Exactly 2 conditions required")
    
    cond1, cond2 = unique_conditions
    
    # Mean activity per condition
    r_cond1 = neural_activity[:, conditions == cond1].mean(axis=1)
    r_cond2 = neural_activity[:, conditions == cond2].mean(axis=1)
    
    # Regularized covariance
    r_centered = neural_activity - neural_activity.mean(axis=1, keepdims=True)
    cov_r = np.cov(r_centered)
    n_neurons = cov_r.shape[0]
    
    sigma_r = (1 - l2_gamma) * cov_r + l2_gamma * np.eye(n_neurons)
    
    # Solve for coding direction (Fisher LDA)
    # l = Σ_r^-1 (r_high - r_low)
    diff = r_cond2 - r_cond1
    coding_direction = np.linalg.solve(sigma_r, diff)
    
    # Normalize
    coding_direction /= np.linalg.norm(coding_direction)
    
    # Compute explained variance
    projected = coding_direction @ neural_activity
    proj_cond1 = projected[conditions == cond1]
    proj_cond2 = projected[conditions == cond2]
    
    between_var = (proj_cond1.mean() - proj_cond2.mean())**2
    within_var = proj_cond1.var() + proj_cond2.var()
    explained_variance = between_var / (between_var + within_var + 1e-10)
    
    return coding_direction, explained_variance
```

---
# 10. Running the Benchmark

```python
def run_benchmark(pipeline_fn, data_paths, verbose=True):
    """
    Run benchmark on multiple datasets.
    
    Args:
        pipeline_fn: function(video, fps) -> results dict
        data_paths: list of paths to benchmark data
    
    Returns:
        mean_score: float
        all_scores: list of (dataset, score, details) tuples
    """
    all_scores = []
    
    for path in data_paths:
        if verbose:
            print(f"\nProcessing: {path}")
        
        # Load data
        video, fps, ground_truth = load_benchmark_data(path)
        
        # Run pipeline
        results = pipeline_fn(video, fps)
        
        # Evaluate
        score, details = compute_benchmark_score(
            results, video, fps, ground_truth
        )
        
        all_scores.append((path, score, details))
        
        if verbose:
            print(f"  Score: {score:.4f}")
            for k, v in details.items():
                print(f"    {k}: {v:.4f}")
    
    mean_score = np.mean([s[1] for s in all_scores])
    
    if verbose:
        print(f"\n=== FINAL SCORE: {mean_score:.4f} ===")
    
    return mean_score, all_scores

# Example usage:
# score, details = run_benchmark(baseline_pipeline, ['data/fish1', 'data/fish2'])
```