# CUMIN (CUrated MINion): A Fluorescence Analysis Pipeline for Curated ROIs

This notebook provides an interactive interface to the fluorescence analysis pipeline, allowing you to explore parameter effects on ROI detection, trace extraction, and event analysis.

## Overview
- Load and select fluorescence imaging data
- Explore parameter effects with interactive visualizations:
  - Gaussian denoising for image preprocessing
  - ROI processing with PNR refinement
  - Background subtraction for improved signal
  - Event detection and analysis
- Save optimized configurations
- Run the full pipeline with optimized parameters

Each section includes in-depth explanations of the algorithms and parameters, helping you understand how different settings affect your analysis results.

In [None]:
#!pip install datashader

In [None]:
# @title Import Libraries and Setup {display-mode: "form"}
# Import standard libraries
import os
import sys
import yaml
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import h5py
import tifffile
import cv2
from pathlib import Path
from datetime import datetime
import json
import copy
import pickle
import time
import warnings
from scipy.ndimage import binary_dilation, median_filter
import holoviews as hv
import numpy as np
import xarray as xr
from holoviews.operation.datashader import datashade, regrid
from holoviews.util import Dynamic
from IPython.core.display import display

# Import interactive libraries
import ipywidgets as widgets
from ipywidgets import interact, fixed, interact_manual, interactive
from IPython.display import display, clear_output

# Add parent directory to path if notebook is in notebooks/
if '..' not in sys.path:
    sys.path.append('..')

# Add current directory to path to import local modules
if '.' not in sys.path:
    sys.path.append('.')

# Import pipeline modules
try:
    from modules.file_matcher import match_tif_and_roi_files
    from modules.preprocessing import (
        correct_photobleaching,
        remove_background,
        denoise,
        stripe_correction
    )
    from modules.roi_processing import (
        extract_roi_fluorescence, 
        subtract_background,
        subtract_global_background,
        extract_rois_from_zip, 
        save_masks_for_cnmf, 
        extract_roi_fluorescence_with_cnmf,
        refine_rois_with_cnmfe,
        refine_rois_with_pnr,
        split_signal_noise,
        visualize_pnr_results,
        save_trace_data
    )
    from modules.analysis import (
        analyze_fluorescence, 
        perform_qc_checks,
        extract_peak_parameters,
        extract_spontaneous_activity,
        calculate_baseline_excluding_peaks
    )
    from modules.visualization import generate_visualizations
    from modules.utils import setup_logging, save_slice_data, save_mouse_summary
    from modules.visualization_helpers import (
        create_denoising_visualization,
        create_pnr_refinement_visualization,
        create_background_subtraction_visualization,
        create_event_detection_visualization
    )
    
    # Try importing advanced analysis module if available
    try:
        from modules.advanced_analysis import run_advanced_analysis
        ADVANCED_ANALYSIS_AVAILABLE = True
    except ImportError:
        ADVANCED_ANALYSIS_AVAILABLE = False
        print("Advanced analysis module not available. Some features will be disabled.")
    
    print("Successfully imported all modules")
except ImportError as e:
    print(f"Error importing modules: {e}")
    print("Make sure the required modules are in the 'modules' directory or adjust the import path.")
    
# Suppress warnings for cleaner output
warnings.filterwarnings("ignore")

## Setting Up Pipeline Parameters

Before starting the analysis, we need to configure the pipeline parameters including:

- **Input Directory**: Location of your .tif (image) and .zip (ROI) files
- **Output Directory**: Where analysis results will be saved
- **Configuration File**: YAML file with pipeline parameters
- **Pipeline Mode**: "all" runs the complete pipeline; other options are "preprocess", "extract", and "analyze"
- **Max Workers**: Number of parallel processes for multi-core processing

You can adjust these parameters below. Click "Update Parameters" after making changes.

In [None]:
# @title Pipeline Parameters {display-mode: "form"}
class Args:
    """Class to simulate command line arguments"""
    def __init__(self):
        self.input_dir = r"F:\Recovered\Research\BoninLab\PainModelingProject\calcium_imaging_data\CAAR Testing\CAAR part2 data\paclitaxel"  # CHANGE THIS
        self.output_dir = r"F:\Recovered\Research\BoninLab\PainModelingProject\calcium_imaging_data\CAAR Testing\CUMIN output\CUMIN_51_optimized_15"   # CHANGE THIS
        self.config = "../config.yaml"  # Path to your config file
        self.mode = "all"  # Options: "all", "preprocess", "extract", "analyze"
        self.max_workers = 4  # Adjust based on your CPU cores
        self.disable_advanced = False

# Create args object
args = Args()

# Create interactive widgets for adjusting parameters
input_dir_widget = widgets.Text(
    value=args.input_dir,
    description='Input Directory:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='80%')
)

output_dir_widget = widgets.Text(
    value=args.output_dir,
    description='Output Directory:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='80%')
)

config_widget = widgets.Text(
    value=args.config,
    description='Config File:',
    style={'description_width': 'initial'}
)

mode_widget = widgets.Dropdown(
    options=['all', 'preprocess', 'extract', 'analyze'],
    value=args.mode,
    description='Pipeline Mode:',
    style={'description_width': 'initial'}
)

workers_widget = widgets.IntSlider(
    value=args.max_workers,
    min=1,
    max=12,
    step=1,
    description='Max Workers:',
    style={'description_width': 'initial'}
)

disable_advanced_widget = widgets.Checkbox(
    value=args.disable_advanced,
    description='Disable Advanced Analysis',
    style={'description_width': 'initial'}
)

# Function to update args object based on widget values
def update_args():
    args.input_dir = input_dir_widget.value
    args.output_dir = output_dir_widget.value
    args.config = config_widget.value
    args.mode = mode_widget.value
    args.max_workers = workers_widget.value
    args.disable_advanced = disable_advanced_widget.value
    print("Parameters updated:")
    print(f"Input Directory: {args.input_dir}")
    print(f"Output Directory: {args.output_dir}")
    print(f"Config File: {args.config}")
    print(f"Pipeline Mode: {args.mode}")
    print(f"Max Workers: {args.max_workers}")
    print(f"Disable Advanced Analysis: {args.disable_advanced}")

# Create update button
update_button = widgets.Button(
    description='Update Parameters',
    button_style='info',
    tooltip='Click to update parameters'
)

update_button.on_click(lambda b: update_args())

# Display widgets
display(input_dir_widget)
display(output_dir_widget)
display(config_widget)
display(mode_widget)
display(workers_widget)
display(disable_advanced_widget)
display(update_button)

## Loading Configuration

The pipeline configuration is stored in a YAML file that defines parameters for each processing step. 
The configuration includes settings for:

- **Preprocessing**: Methods and parameters for photo-bleaching correction, denoising, etc.
- **ROI Processing**: ROI extraction, refinement, and background subtraction
- **Analysis**: Event detection parameters, condition-specific settings, etc.
- **Visualization**: Plot types, colormaps, and other visualization settings

Click "Load Configuration" to load the configuration file specified in the parameters section.

In [None]:
# @title Load Configuration {display-mode: "form"}
def load_config(config_path, args=None):
    """Load configuration from YAML file and apply command line overrides."""
    try:
        with open(config_path, 'r') as f:
            config = yaml.safe_load(f)
        print(f"Loaded configuration from {config_path}")
        
        # Apply command line overrides if provided
        if args and args.disable_advanced:
            # Disable advanced analysis if requested via command line
            if "advanced_analysis" in config:
                config["advanced_analysis"]["enabled"] = False
                print("Advanced analysis disabled via command line argument")
        
        return config
    except Exception as e:
        print(f"Failed to load configuration: {str(e)}")
        return None

# Setup logging
logger = setup_logging()

# Create a load button
load_config_button = widgets.Button(
    description='Load Configuration',
    button_style='success',
    tooltip='Click to load the configuration file'
)

def on_load_config_click(b):
    global config, config_original
    # Load configuration
    config = load_config(args.config, args)

    if config:
        print("Configuration loaded successfully.")
        
        # Create a backup of original config for reference
        config_original = copy.deepcopy(config)
        
        # Print some key configuration settings
        print("\nKey Configuration Settings:")
        print(f"Photobleaching correction method: {config['preprocessing'].get('correction_method', 'Not specified')}")
        if 'denoise' in config['preprocessing']:
            print(f"Denoising enabled: {config['preprocessing']['denoise'].get('enabled', False)}")
        print(f"Background subtraction method: {config['roi_processing'].get('background', {}).get('method', 'Not specified')}")
    else:
        print("Failed to load configuration. Please check the config file path.")

load_config_button.on_click(on_load_config_click)
display(load_config_button)

## Data Selection and Loading

This section allows you to select a file pair (TIF image stack and ZIP ROI file) for analysis. The pipeline searches for matching file pairs in the input directory.

### File Pairs
A file pair consists of:
- A **.tif file** containing the fluorescence imaging data (time series of frames)
- A **.zip file** containing ImageJ/FIJI ROI definitions

### Loading Process
When you load a file pair, the following happens:
1. The TIF stack is loaded and preprocessed
2. ROI masks are extracted from the ZIP file
3. Intermediate data is prepared for interactive visualization
4. Visualization data is saved for future use

Select a file pair from the dropdown menu and click "Load Selected File Pair" to start.

In [None]:
# @title File Selection and Loading {display-mode: "form"}
# Find matching TIF and ROI file pairs
# Define extract_metadata_from_filename function
def extract_metadata_from_filename(filename):
    """Extract metadata from custom filename pattern 'CFA1_7.23.20_ipsi1_0um'."""
    import re
    
    # Initialize metadata dictionary
    metadata = {
        "mouse_id": "unknown",
        "date": "unknown",
        "pain_model": "unknown",
        "slice_type": "unknown",
        "slice_number": "1",
        "condition": "unknown"
    }
    
    # Split filename by underscore
    parts = filename.split('_')
    
    if len(parts) < 3:
        return metadata
    
    # First part typically contains pain model + mouse number (e.g., "CFA1")
    if parts[0]:
        # Extract pain model (letters) and mouse number (digits)
        model_match = re.match(r'([A-Za-z]+)([0-9]*)', parts[0])
        if model_match:
            metadata["pain_model"] = model_match.group(1)
            mouse_number = model_match.group(2) or "1"
            metadata["mouse_id"] = f"{metadata['pain_model']}{mouse_number}"
        else:
            metadata["mouse_id"] = parts[0]
    
    # Second part is usually the date
    if len(parts) > 1:
        metadata["date"] = parts[1]
    
    # Third part usually contains slice type and number
    if len(parts) > 2:
        # Look for ipsi/contra with optional number
        slice_match = re.match(r'(ipsi|contra)([0-9]*)', parts[2].lower())
        if slice_match:
            metadata["slice_type"] = slice_match.group(1).capitalize()  # Capitalize first letter
            metadata["slice_number"] = slice_match.group(2) or "1"
    
    # Last part usually has the condition
    for part in parts:
        if any(cond in part.lower() for cond in ["0um", "10um", "25um"]):
            metadata["condition"] = part
            break
    
    return metadata

def find_file_pairs():
    print(f"Finding file pairs in {args.input_dir}...")
    file_pairs = match_tif_and_roi_files(args.input_dir, logger)
    print(f"Found {len(file_pairs)} matched file pairs")
    return file_pairs

# Find file pairs button
find_pairs_button = widgets.Button(
    description='Find File Pairs',
    button_style='info',
    tooltip='Click to find matching TIF and ROI files'
)

def on_find_pairs_click(b):
    print("Find button clicked!")
    global file_pairs
    
    # Find file pairs
    file_pairs = find_file_pairs()
    
    if len(file_pairs) > 0:
        print("Found file pairs:")
        for i, (tif_path, roi_path) in enumerate(file_pairs):
            print(f"{i+1}: {Path(tif_path).stem}")
        
        # Always create dropdown, even for one pair
        create_file_selection_dropdown()
    else:
        print("No file pairs found. Please check the input directory.")

def create_file_selection_dropdown():
    """Create the dropdown for selecting file pairs when multiple are found"""
    global file_pair_dropdown
    
    # Create dropdown for file pair selection
    pair_names = [f"{i+1}: {Path(tif).stem}" for i, (tif, roi) in enumerate(file_pairs)]
    file_pair_dropdown = widgets.Dropdown(
        options=list(zip(pair_names, range(len(file_pairs)))),
        description='Select File Pair:',
        style={'description_width': 'initial'},
        layout=widgets.Layout(width='80%')
    )
    
    def on_select_change(change):
        index = change['new']
        tif_path, roi_path = file_pairs[index]
        print(f"Selected file pair {index+1}:")
        print(f"TIF: {tif_path}")
        print(f"ROI: {roi_path}")

    file_pair_dropdown.observe(on_select_change, names='value')
    
    # Display the dropdown
    display(file_pair_dropdown)
    
    # Create a button to load the selected file pair
    load_button = widgets.Button(
        description='Load Selected File Pair',
        button_style='success',
        tooltip='Click to load the selected file pair for analysis'
    )
    
    load_button.on_click(lambda b: load_selected_pair(file_pair_dropdown.value))
    display(load_button)

def load_selected_pair(index):
    """Load the selected file pair"""
    global selected_tif_path, selected_roi_path, image_data, image_shape
    global roi_masks, roi_centers, slice_output_dir, corrected_data, roi_data, metadata
    
    selected_tif_path, selected_roi_path = file_pairs[index]
    
    print(f"Loading file pair {index+1}:")
    print(f"TIF: {selected_tif_path}")
    print(f"ROI: {selected_roi_path}")
    
    # Load image data
    try:
        print("Loading image data...")
        with tifffile.TiffFile(selected_tif_path) as tif:
            image_data = tif.asarray()
            
            # Ensure data is in (frames, height, width) format
            if len(image_data.shape) == 3:
                if image_data.shape[0] < image_data.shape[1] and image_data.shape[0] < image_data.shape[2]:
                    # Already in (frames, height, width) format
                    pass
                else:
                    # Try to rearrange to (frames, height, width)
                    if image_data.shape[2] < image_data.shape[0] and image_data.shape[2] < image_data.shape[1]:
                        image_data = np.moveaxis(image_data, 2, 0)
                    elif image_data.shape[1] < image_data.shape[0] and image_data.shape[1] < image_data.shape[2]:
                        image_data = np.moveaxis(image_data, 1, 0)
        
        n_frames, height, width = image_data.shape
        image_shape = (height, width)
        
        print(f"Image loaded successfully with shape: {image_data.shape}")
        print(f"Number of frames: {n_frames}")
        print(f"Frame dimensions: {height}x{width}")
        
        # Convert to float32 if needed
        if image_data.dtype != np.float32:
            image_data = image_data.astype(np.float32)
            print("Converted data to float32")
        
        # Extract metadata from filename
        slice_name = Path(selected_tif_path).stem
        metadata = extract_metadata_from_filename(slice_name)
        print(f"Extracted metadata: {metadata}")
        
        # Extract ROI masks for visualization
        roi_masks, roi_centers = extract_rois_from_zip(selected_roi_path, image_shape, logger)
        print(f"Extracted {len(roi_masks)} ROI masks")
        
        # Create output directory for this slice
        slice_output_dir = os.path.join(args.output_dir, slice_name)
        os.makedirs(slice_output_dir, exist_ok=True)
        print(f"Created output directory: {slice_output_dir}")
        
        # Run initial preprocessing to set up visualization data
        print("Performing initial preprocessing for visualization...")
        corrected_data, _ = correct_photobleaching(
            image_data,
            None,  # No output file needed for visualization
            config["preprocessing"],
            logger,
            save_output=False
        )
        print("Initial preprocessing complete")
        
        # Extract ROI fluorescence for visualization
        _, roi_data = extract_roi_fluorescence(
            selected_roi_path,
            corrected_data,
            image_shape,
            slice_output_dir,
            config["roi_processing"],
            logger
        )
        print(f"Extracted fluorescence traces for {len(roi_masks)} ROIs")
        
        # Save visualization data for later use
        vis_data = {
            'image_data': image_data,
            'corrected_data': corrected_data,
            'roi_masks': roi_masks,
            'roi_centers': roi_centers,
            'roi_data': roi_data,
            'metadata': metadata,
            'selected_tif_path': selected_tif_path,
            'selected_roi_path': selected_roi_path,
            'image_shape': image_shape
        }
        
        # Save visualization data
        vis_data_file = os.path.join(slice_output_dir, "visualization_data.pkl")
        with open(vis_data_file, 'wb') as f:
            pickle.dump(vis_data, f)
            
        print(f"Saved visualization data to {vis_data_file}")
        print("Data loaded successfully and ready for visualization!")
        
    except Exception as e:
        print(f"Error loading files: {str(e)}")
        import traceback
        traceback.print_exc()

# Set up button click handler
find_pairs_button.on_click(on_find_pairs_click)
print("Button ready. Click to find file pairs.")
display(find_pairs_button)

In [None]:
# In a new cell
create_file_selection_dropdown()

## Load Previously Saved Visualization Data

If you've previously run this notebook and saved visualization data, you can load it here instead of reprocessing the data. This saves time when you want to continue exploring the same dataset.

The visualization data includes:
- Original image data
- Preprocessed (corrected) data
- ROI masks and centers
- ROI fluorescence traces
- Metadata extracted from the filename

Enter the path to the saved visualization data file (`visualization_data.pkl`) and click "Load Saved Visualization Data" to continue.

In [None]:
# @title Load Saved Visualization Data {display-mode: "form"}
def load_visualization_data(data_file):
    """Load saved visualization data from pickle file"""
    try:
        with open(data_file, 'rb') as f:
            data = pickle.load(f)
        
        # Assign to global variables for use in visualizations
        global image_data, corrected_data, roi_masks, roi_centers, roi_data, metadata, selected_tif_path, selected_roi_path, image_shape
        image_data = data['image_data']
        corrected_data = data['corrected_data']
        roi_masks = data['roi_masks']
        roi_centers = data['roi_centers']
        roi_data = data['roi_data']
        metadata = data['metadata']
        selected_tif_path = data['selected_tif_path']
        selected_roi_path = data['selected_roi_path']
        image_shape = data['image_shape']
        
        print("Visualization data loaded successfully!")
        print(f"File: {Path(selected_tif_path).stem}")
        print(f"Image shape: {image_data.shape}")
        print(f"Number of ROIs: {len(roi_masks)}")
        print(f"Condition: {metadata.get('condition', 'unknown')}")
        return True
    except Exception as e:
        print(f"Error loading visualization data: {str(e)}")
        return False

# Widget to select a visualization data file
vis_data_path_widget = widgets.Text(
    placeholder='Enter path to visualization_data.pkl file',
    description='Data File:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='80%')
)

load_vis_data_button = widgets.Button(
    description='Load Saved Visualization Data',
    button_style='info',
    tooltip='Load previously saved visualization data'
)

def on_load_vis_data_click(b):
    path = vis_data_path_widget.value
    if path:
        success = load_visualization_data(path)
        if success:
            print("Ready for visualization!")
    else:
        print("Please enter a valid path to the visualization_data.pkl file")

load_vis_data_button.on_click(on_load_vis_data_click)

display(vis_data_path_widget)
display(load_vis_data_button)

## Gaussian Denoising

Gaussian denoising helps reduce noise in fluorescence images while preserving important features. This is particularly important for accurately identifying ROIs and detecting events in your calcium imaging data.

### How It Works
The Gaussian blur operation applies a weighted average to each pixel, where nearby pixels have more influence than distant ones. The weighting follows a Gaussian distribution centered at the pixel being processed.

### Key Parameters
- **Kernel Size**: Controls the size of the filter window (must be odd). Larger values remove more noise but may blur important details.
- **Sigma**: Controls the width of the Gaussian distribution. Higher values produce more blurring and stronger noise reduction.

### Finding Optimal Values
The best parameters balance noise reduction and feature preservation:
- For noisy recordings, try larger kernel sizes (7-11) and higher sigma values (2-4)
- For cleaner recordings, use smaller kernels (3-5) and lower sigma values (0.5-1.5)

Use the interactive tool below to experiment with different settings and observe their effects.

In [None]:
# Import necessary visualization libraries
import holoviews as hv
import panel as pn
import numpy as np
from IPython.core.display import display

# Initialize holoviews with bokeh backend
hv.extension('bokeh')

# @title Gaussian Denoising Visualization

def run_gaussian_denoising():
    # Check that image data is loaded
    if 'image_data' not in globals() or image_data is None:
        print("Please load image data first")
        return
    
    # Create panel widgets for parameters
    frame_slider = pn.widgets.IntSlider(name='Frame', start=0, 
                                     end=min(image_data.shape[0]-1, 100), step=1, value=10)
    ksize_slider = pn.widgets.IntSlider(name='Kernel Size', start=1, end=21, step=2, value=5)
    sigma_slider = pn.widgets.FloatSlider(name='Sigma X', start=0.1, end=10.0, step=0.1, value=1.5)
    
    @pn.depends(frame_slider, ksize_slider, sigma_slider)
    def apply_gaussian_blur(frame_idx, ksize, sigma):
        # Ensure ksize is odd
        if ksize % 2 == 0:
            ksize += 1
            
        # Get the selected frame
        sample_frame = image_data[frame_idx].copy()
        
        # Apply Gaussian denoising
        denoised_frame = cv2.GaussianBlur(sample_frame, (ksize, ksize), sigma)
        
        # Calculate difference
        diff = np.abs(sample_frame - denoised_frame)
        
        # Normalize for display
        norm_orig = (sample_frame - sample_frame.min()) / (sample_frame.max() - sample_frame.min() + 1e-10)
        norm_denoised = (denoised_frame - denoised_frame.min()) / (denoised_frame.max() - denoised_frame.min() + 1e-10)
        norm_diff = diff / (diff.max() + 1e-10)
        
        # Convert to HoloViews Image objects
        orig_img = hv.Image(norm_orig).opts(
            title='Original Frame',
            cmap='gray', 
            width=300, 
            height=300
        )
        
        denoised_img = hv.Image(norm_denoised).opts(
            title=f'Gaussian Denoised (k={ksize}, σ={sigma:.1f})',
            cmap='gray', 
            width=300, 
            height=300
        )
        
        diff_img = hv.Image(norm_diff).opts(
            title='Difference (Red=More Change)',
            cmap='hot', 
            width=300, 
            height=300
        )
        
        # Update config with current values
        if 'denoise' not in config['preprocessing']:
            config['preprocessing']['denoise'] = {}
        
        config['preprocessing']['denoise']['enabled'] = True
        config['preprocessing']['denoise']['method'] = 'gaussian'
        config['preprocessing']['denoise']['params'] = {
            'ksize': [ksize, ksize],
            'sigmaX': sigma
        }
        
        # Return layout with all three images
        return (orig_img + denoised_img + diff_img).cols(3)
    
    # Create a Panel app with the widgets and visualization
    app = pn.Column(
        "## Gaussian Denoising Tool",
        pn.Row(
            pn.Column(frame_slider, ksize_slider, sigma_slider, width=250),
            apply_gaussian_blur
        )
    )
    
    return app

# Allow user to run the visualization
try:
    import panel as pn
    print("Panel is installed. Run the following to start the visualization:")
    print("app = run_gaussian_denoising()")
    print("display(app)")
except ImportError:
    print("You need to install panel first:")
    print("!pip install panel")
    print("Then restart the kernel and try again.")

In [None]:
app = run_gaussian_denoising()
display(app)

## ROI Processing with PNR Refinement

Peak-to-Noise Ratio (PNR) refinement helps identify and select ROIs with strong neuronal signals while filtering out ROIs with poor signal quality. This is essential for accurate calcium imaging analysis.

### How PNR Refinement Works
1. **Frequency Separation**: The fluorescence trace is split into signal (low-frequency) and noise (high-frequency) components
2. **Signal Smoothing**: Optional smoothing can be applied to the signal component to reduce fluctuations
3. **PNR Calculation**: The ratio between peak signal value and noise standard deviation is calculated
4. **Thresholding**: ROIs with PNR values below the threshold are excluded from analysis

### Key Parameters
- **Noise Frequency Cutoff**: Determines the boundary between signal and noise components (0.01-0.2 Hz)
- **Percentile Threshold**: Percentile used to determine peak signal value (90-99.9%)
- **Trace Smoothing**: Window size for signal smoothing (0 = no smoothing)
- **Min PNR**: Minimum PNR threshold for accepting an ROI (typically 5-10)

### Finding Optimal Values
- Higher PNR thresholds produce more reliable results but may exclude valid ROIs
- The ideal noise frequency cutoff depends on the temporal characteristics of your signal
- Smoothing can help stabilize PNR values but may mask transient events

The tool below allows you to visualize and adjust these parameters to optimize ROI selection.

In [None]:
# @title PNR Refinement Interactive Visualization {display-mode: "form"}

def split_signal_noise(traces, cutoff_freq, logger=None):
    """Split traces into signal and noise components using frequency filtering."""
    from scipy import signal
    import numpy as np
    
    n_rois, n_frames = traces.shape
    
    # Verify cutoff frequency is in valid range
    if cutoff_freq <= 0 or cutoff_freq >= 0.5:
        cutoff_freq = 0.03
    
    # Design Butterworth low-pass filter
    b_low, a_low = signal.butter(2, cutoff_freq, 'low')
    
    # Design Butterworth high-pass filter (same cutoff)
    b_high, a_high = signal.butter(2, cutoff_freq, 'high')
    
    # Initialize output arrays
    signal_traces = np.zeros_like(traces)
    noise_traces = np.zeros_like(traces)
    
    # Apply filters to each ROI
    for i in range(n_rois):
        # Apply low-pass filter for signal
        signal_traces[i] = signal.filtfilt(b_low, a_low, traces[i])
        
        # Apply high-pass filter for noise
        noise_traces[i] = signal.filtfilt(b_high, a_high, traces[i])
    
    return signal_traces, noise_traces

def smooth_trace(trace, window_size):
    """Apply moving average smoothing to a trace."""
    from scipy import signal
    import numpy as np
    
    if window_size <= 0:
        return trace
        
    # Create window coefficients (simple moving average)
    window = np.ones(window_size) / window_size
    
    # Apply convolution for smoothing
    smoothed = signal.convolve(trace, window, mode='same')
    
    # Handle edge effects by copying original values at edges
    half_window = window_size // 2
    if half_window > 0:
        smoothed[:half_window] = trace[:half_window]
        smoothed[-half_window:] = trace[-half_window:]
    
    return smoothed

def run_pnr_refinement_visualization():
    """Create interactive visualization for PNR refinement"""
    # Check that required data is loaded
    if 'roi_data' not in globals() or roi_data is None:
        print("Please load image data first")
        return None
    
    # Create widgets for parameters
    noise_freq_cutoff = widgets.FloatSlider(
        value=0.03,
        min=0.01,
        max=0.2,
        step=0.01,
        description='Noise Freq Cutoff:',
        style={'description_width': 'initial'}
    )
    
    percentile_threshold = widgets.FloatSlider(
        value=99,
        min=90,
        max=99.9,
        step=0.1,
        description='Percentile Threshold:',
        style={'description_width': 'initial'}
    )
    
    trace_smoothing = widgets.IntSlider(
        value=3,
        min=0,
        max=15,
        step=1,
        description='Trace Smoothing:',
        style={'description_width': 'initial'}
    )
    
    min_pnr = widgets.FloatSlider(
        value=8.0,
        min=3.0,
        max=20.0,
        step=0.5,
        description='Min PNR:',
        style={'description_width': 'initial'}
    )
    
    # Widget to select ROIs to display
    roi_options = [(f"ROI {i+1}", i) for i in range(min(10, len(roi_data)))]
    roi_indices = widgets.SelectMultiple(
        options=roi_options,
        value=[0, 1, 2],  # Default: first 3 ROIs
        description='ROIs to Display:',
        disabled=False,
        style={'description_width': 'initial'}
    )
    
    def display_pnr_refinement(noise_freq_cutoff, percentile_threshold, trace_smoothing, min_pnr, roi_indices):
        import matplotlib.pyplot as plt
        import numpy as np
        
        if not roi_indices or len(roi_indices) == 0:
            print("Please select at least one ROI to display")
            return
            
        # Split traces into signal and noise components
        sample_traces = roi_data[list(roi_indices)]
        signal_traces, noise_traces = split_signal_noise(sample_traces, noise_freq_cutoff)
        
        # Apply smoothing if enabled
        if trace_smoothing > 0:
            smoothed_signal = np.zeros_like(signal_traces)
            for i in range(len(signal_traces)):
                smoothed_signal[i] = smooth_trace(signal_traces[i], trace_smoothing)
        else:
            smoothed_signal = signal_traces.copy()
        
        # Compute PNR values
        pnr_values = np.zeros(len(roi_indices))
        for i in range(len(roi_indices)):
            # Get peak value (using percentile)
            peak_value = np.percentile(smoothed_signal[i], percentile_threshold)
            
            # Calculate noise standard deviation
            noise_std = np.std(noise_traces[i])
            
            # Avoid division by zero
            if noise_std > 0:
                pnr_values[i] = peak_value / noise_std
            else:
                pnr_values[i] = 0
        
        # Display traces and PNR values
        n_rois = len(roi_indices)
        fig, axes = plt.subplots(n_rois, 3, figsize=(15, 4*n_rois))
        
        # Handle single ROI case
        if n_rois == 1:
            axes = np.array([axes])
        
        for i, roi_idx in enumerate(roi_indices):
            # Original trace
            axes[i, 0].plot(roi_data[roi_idx], 'k-', label=f'Original')
            axes[i, 0].set_title(f'ROI {roi_idx+1} - Original Trace')
            axes[i, 0].set_xlabel('Frame')
            axes[i, 0].set_ylabel('Fluorescence')
            axes[i, 0].grid(True, alpha=0.3)
            
            # Signal component
            axes[i, 1].plot(signal_traces[i], 'g-', label='Signal')
            if trace_smoothing > 0:
                axes[i, 1].plot(smoothed_signal[i], 'r-', label='Smoothed Signal')
            axes[i, 1].set_title(f'Signal Component (cutoff={noise_freq_cutoff})')
            axes[i, 1].set_xlabel('Frame')
            axes[i, 1].set_ylabel('Fluorescence')
            axes[i, 1].grid(True, alpha=0.3)
            axes[i, 1].legend()
            
            # Noise component
            axes[i, 2].plot(noise_traces[i], 'b-', label='Noise')
            axes[i, 2].set_title(f'Noise Component (PNR={pnr_values[i]:.2f})')
            axes[i, 2].set_xlabel('Frame')
            axes[i, 2].set_ylabel('Fluorescence')
            axes[i, 2].grid(True, alpha=0.3)
            
            # Add PNR threshold line and indication if the ROI passes the threshold
            axes[i, 2].axhline(y=0, color='k', linestyle='--', alpha=0.3)
            if pnr_values[i] >= min_pnr:
                status = "PASS"
                color = 'green'
            else:
                status = "FAIL"
                color = 'red'
            
            axes[i, 2].text(0.05, 0.95, f"PNR: {pnr_values[i]:.2f} ({status})", 
                            transform=axes[i, 2].transAxes, 
                            fontsize=10, va='top', ha='left',
                            bbox=dict(facecolor=color, alpha=0.3))
        
        plt.tight_layout()
        plt.show()
        
        # Display summary
        n_pass = sum(pnr >= min_pnr for pnr in pnr_values)
        print(f"PNR Summary: {n_pass}/{len(roi_indices)} selected ROIs pass the threshold (>= {min_pnr})")
        
        # Update config with current values
        if 'pnr_refinement' not in config['roi_processing']:
            config['roi_processing']['pnr_refinement'] = {}
        
        config['roi_processing']['pnr_refinement']['noise_freq_cutoff'] = noise_freq_cutoff
        config['roi_processing']['pnr_refinement']['min_pnr'] = min_pnr
        config['roi_processing']['pnr_refinement']['percentile_threshold'] = percentile_threshold
        config['roi_processing']['pnr_refinement']['trace_smoothing'] = trace_smoothing
        
        # Set these parameters to be enabled in the config
        if 'steps' not in config['roi_processing']:
            config['roi_processing']['steps'] = {}
        config['roi_processing']['steps']['refine_with_pnr'] = True
        
        print(f"Updated config with: noise_freq_cutoff={noise_freq_cutoff}, min_pnr={min_pnr}")
        print(f"percentile_threshold={percentile_threshold}, trace_smoothing={trace_smoothing}")
        print("To apply these settings to your pipeline, update your config.yaml file.")
    
    # Create and return interactive widget
    interactive_plot = interactive(
        display_pnr_refinement,
        noise_freq_cutoff=noise_freq_cutoff,
        percentile_threshold=percentile_threshold,
        trace_smoothing=trace_smoothing,
        min_pnr=min_pnr,
        roi_indices=roi_indices
    )
    
    return interactive_plot

print("\n## Interactive ROI Processing with PNR Refinement ##")
print("Run the following commands to start the PNR refinement visualization:")
print("app = run_pnr_refinement_visualization()")
print("display(app)")

In [None]:
app = run_pnr_refinement_visualization()
display(app)

## Background Subtraction

Background subtraction removes non-specific fluorescence signals that can mask the true neuronal activity. This improves signal-to-noise ratio and helps detect true calcium events.

### Available Methods
1. **Darkest Pixels**: Uses the darkest regions of the image (likely non-cellular regions) to estimate background
2. **ROI Periphery**: Estimates background from the area surrounding each ROI
3. **Global Background**: Uses a global approach to identify and subtract background signal

### Key Parameters
- **Percentile (%)**: For darkest pixels method, determines how much of the image is considered background
- **Min Background Area**: Minimum size of the area to be considered a valid background region
- **Median Filter Size**: Size of median filter for noise reduction in background mask
- **Periphery Size**: For ROI periphery method, size of the expansion around ROIs

### Finding Optimal Values
- For darkest pixels, start with a low percentile (0.1-1%) and increase if needed
- Larger median filter sizes produce smoother background but may miss spatial variations
- For ROI periphery, larger values capture more surrounding tissue but risk including other cells

The interactive tool below allows you to visualize how different background subtraction methods and parameters affect your fluorescence traces.

In [None]:
# @title Background Subtraction Tool {display-mode: "form"}

def normalize_for_display(img):
    """Normalize image for display"""
    img_min = img.min()
    img_max = img.max()
    if img_max > img_min:
        return (img - img_min) / (img_max - img_min)
    return img

def run_background_subtraction_visualization():
    """Create interactive visualization for background subtraction"""
    # Check if required data is loaded
    if 'roi_data' not in globals() or roi_data is None or 'image_data' not in globals() or image_data is None:
        print("Please load ROI data and image data first")
        return None
    
    # Create widgets for background subtraction parameters
    bg_method = widgets.Dropdown(
        options=[
            ('Darkest Pixels', 'darkest_pixels'), 
            ('ROI Periphery', 'roi_periphery'),
            ('Global Background', 'global_background')
        ],
        value='darkest_pixels',
        description='Method:',
        style={'description_width': 'initial'}
    )
    
    percentile = widgets.FloatSlider(
        value=0.2,
        min=0.1,
        max=10.0,
        step=0.1,
        description='Percentile (%):',
        style={'description_width': 'initial'}
    )
    
    min_bg_area = widgets.IntSlider(
        value=200,
        min=50,
        max=1000,
        step=50,
        description='Min Background Area:',
        style={'description_width': 'initial'}
    )
    
    median_filter_size = widgets.IntSlider(
        value=5,
        min=0,
        max=15,
        step=2,
        description='Median Filter Size:',
        style={'description_width': 'initial'}
    )
    
    periphery_size = widgets.IntSlider(
        value=2,
        min=1,
        max=10,
        step=1,
        description='Periphery Size:',
        style={'description_width': 'initial'}
    )
    
    # Widget to select ROIs to display
    roi_options = [(f"ROI {i+1}", i) for i in range(min(10, len(roi_data)))]
    roi_indices = widgets.SelectMultiple(
        options=roi_options,
        value=[0, 1, 2],  # Default: first 3 ROIs
        description='ROIs to Display:',
        disabled=False,
        style={'description_width': 'initial'}
    )
    
    def display_background_subtraction(bg_method, percentile, min_bg_area, median_filter_size, periphery_size, roi_indices):
        import matplotlib.pyplot as plt
        import numpy as np
        from scipy.ndimage import binary_dilation, median_filter
        
        if not roi_indices or len(roi_indices) == 0:
            print("Please select at least one ROI to display")
            return
        
        # Create configuration for background subtraction
        bg_config = {
            "method": bg_method,
            "percentile": percentile,
            "min_background_area": min_bg_area,
            "median_filter_size": median_filter_size,
            "periphery_size": periphery_size
        }
        
        # Get ROI data for selected ROIs
        selected_roi_data = roi_data[list(roi_indices)]
        selected_roi_masks = [roi_masks[i] for i in roi_indices]
        
        # Apply background subtraction - first create a shallow copy of functions we need
        if bg_method == 'global_background':
            # For global background, we need to use the correct function
            from modules.roi_processing import subtract_global_background
            bg_corrected_data = subtract_global_background(
                image_data, 
                selected_roi_data,
                selected_roi_masks,
                bg_config,
                logger
            )
        else:
            # For other methods, use standard background subtraction
            from modules.roi_processing import subtract_background
            bg_corrected_data = subtract_background(
                image_data, 
                selected_roi_data,
                selected_roi_masks,
                bg_config,
                logger
            )
        
        # Display original vs background-corrected traces
        n_rois = len(roi_indices)
        fig, axes = plt.subplots(n_rois, 2, figsize=(15, 4*n_rois))
        
        # Handle single ROI case
        if n_rois == 1:
            axes = np.array([axes])
        
        for i, roi_idx in enumerate(roi_indices):
            # Original trace
            axes[i, 0].plot(selected_roi_data[i], 'k-', label=f'Original')
            axes[i, 0].set_title(f'ROI {roi_idx+1} - Original Trace')
            axes[i, 0].set_xlabel('Frame')
            axes[i, 0].set_ylabel('Fluorescence')
            axes[i, 0].grid(True, alpha=0.3)
            
            # Background-corrected trace
            axes[i, 1].plot(bg_corrected_data[i], 'g-', label='Background Corrected')
            axes[i, 1].set_title(f'Background Corrected ({bg_method})')
            axes[i, 1].set_xlabel('Frame')
            axes[i, 1].set_ylabel('Fluorescence')
            axes[i, 1].grid(True, alpha=0.3)
            
            # Add zero line for reference
            axes[i, 1].axhline(y=0, color='k', linestyle='--', alpha=0.3)
        
        plt.tight_layout()
        plt.show()
        
        # Create visualizations based on method
        if bg_method == 'darkest_pixels':
            # Create darkest pixels mask
            avg_intensity = np.mean(image_data, axis=0)
            threshold = np.percentile(avg_intensity, percentile)
            darkest_pixels_mask = avg_intensity <= threshold
            
            # Apply median filter to remove noise
            if median_filter_size > 0:
                darkest_pixels_mask = median_filter(darkest_pixels_mask.astype(float), 
                                                   size=median_filter_size) > 0.5
            
            # Create a visualization of the background mask
            fig, axes = plt.subplots(1, 2, figsize=(12, 6))
            
            # Display average intensity image
            axes[0].imshow(normalize_for_display(avg_intensity), cmap='gray')
            axes[0].set_title('Average Intensity')
            axes[0].axis('off')
            
            # Display background mask
            axes[1].imshow(darkest_pixels_mask, cmap='hot')
            axes[1].set_title(f'Background Mask (percentile={percentile}%)')
            axes[1].axis('off')
            
            plt.tight_layout()
            plt.show()
            
        elif bg_method == 'roi_periphery' and n_rois > 0:
            # Create periphery mask for the first selected ROI
            first_roi_idx = roi_indices[0]
            mask = roi_masks[first_roi_idx]
            expanded_mask = binary_dilation(mask, iterations=periphery_size)
            periphery_mask = expanded_mask & ~mask
            
            # Create a visualization of the ROI periphery
            fig, axes = plt.subplots(1, 3, figsize=(15, 5))
            
            # Get first frame for background display
            first_frame = image_data[0]
            
            # Display original ROI
            axes[0].imshow(normalize_for_display(first_frame), cmap='gray')
            axes[0].imshow(mask, cmap='hot', alpha=0.5)
            axes[0].set_title(f'ROI {first_roi_idx+1} Mask')
            axes[0].axis('off')
            
            # Display expanded ROI
            axes[1].imshow(normalize_for_display(first_frame), cmap='gray')
            axes[1].imshow(expanded_mask, cmap='hot', alpha=0.5)
            axes[1].set_title(f'Expanded Mask (periphery={periphery_size})')
            axes[1].axis('off')
            
            # Display periphery only
            axes[2].imshow(normalize_for_display(first_frame), cmap='gray')
            axes[2].imshow(periphery_mask, cmap='hot', alpha=0.5)
            axes[2].set_title('Periphery Mask (for background)')
            axes[2].axis('off')
            
            plt.tight_layout()
            plt.show()
        
        # Update config with current values
        if 'background' not in config['roi_processing']:
            config['roi_processing']['background'] = {}
        
        config['roi_processing']['background']['method'] = bg_method
        config['roi_processing']['background']['percentile'] = percentile
        config['roi_processing']['background']['min_background_area'] = min_bg_area
        config['roi_processing']['background']['median_filter_size'] = median_filter_size
        config['roi_processing']['background']['periphery_size'] = periphery_size
        
        print(f"Updated config with: method={bg_method}, percentile={percentile}")
        print(f"min_background_area={min_bg_area}, median_filter_size={median_filter_size}")
        print(f"periphery_size={periphery_size}")
        print("To apply these settings to your pipeline, update your config.yaml file.")
    
    # Create and return interactive widget
    interactive_plot = interactive(
        display_background_subtraction,
        bg_method=bg_method,
        percentile=percentile,
        min_bg_area=min_bg_area,
        median_filter_size=median_filter_size,
        periphery_size=periphery_size,
        roi_indices=roi_indices
    )
    
    return interactive_plot

print("\n## Interactive Background Subtraction Tool ##")
print("Run the following commands to start the background subtraction visualization:")
print("app = run_background_subtraction_visualization()")
print("display(app)")

In [None]:
app = run_background_subtraction_visualization()
display(app)

## Event Detection and Analysis

Accurate detection of calcium events is crucial for analyzing neuronal activity. This tool allows you to adjust event detection parameters to optimize sensitivity and specificity for your data.

### Event Detection Methods
The pipeline uses the SciPy `find_peaks` function with several parameters that control sensitivity:

### Key Parameters
- **Prominence**: Minimum height difference between a peak and surrounding baseline
- **Width**: Minimum width (in frames) of a valid peak
- **Distance**: Minimum separation (in frames) between valid peaks
- **Height**: Minimum absolute height above baseline for a valid peak
- **Activity Threshold**: Threshold for declaring an ROI as 'active'

### Condition-Specific Analysis
Different experimental conditions require different analysis approaches:
- **Spontaneous (0µm)**: Analyzes spontaneous activity throughout the recording
- **Evoked (10µm/25µm)**: Focuses on activity following stimulus application (frame 100)

### Finding Optimal Values
- Higher prominence values detect stronger events but may miss subtle ones
- Width requirements help distinguish true events from noise
- The activity threshold should be set based on your experimental design and expected effect size

The interactive tool below allows you to visualize how parameter adjustments affect event detection sensitivity.

In [None]:
# @title Event Detection Tool {display-mode: "form"}

def run_event_detection_visualization():
    """Create interactive visualization for event detection and analysis"""
    # Check if required data is loaded
    if 'roi_data' not in globals() or roi_data is None:
        print("Please load ROI data first")
        return None
    
    # Create a copy of traces for visualization
    # We'll convert ROI data to dF/F for the event detection
    if 'corrected_data' in globals() and corrected_data is not None:
        # Extract traces directly from corrected_data using ROI masks
        n_rois = len(roi_masks)
        n_frames = corrected_data.shape[0]
        traces_for_analysis = np.zeros((n_rois, n_frames), dtype=np.float32)
        for i, mask in enumerate(roi_masks):
            for t in range(n_frames):
                binary_mask = mask.astype(bool)
                traces_for_analysis[i, t] = np.mean(corrected_data[t][binary_mask])
    else:
        # If corrected_data isn't available, use roi_data directly
        traces_for_analysis = roi_data.copy()
    
    # Convert to dF/F using a simple baseline calculation
    # This is just for visualization - the real pipeline will use more sophisticated methods
    df_f_traces = np.zeros_like(traces_for_analysis)
    for i in range(len(traces_for_analysis)):
        # Use first 100 frames or fewer for baseline calculation
        baseline_frames = min(100, traces_for_analysis.shape[1])
        baseline = np.percentile(traces_for_analysis[i, :baseline_frames], 8)
        df_f_traces[i] = (traces_for_analysis[i] - baseline) / baseline if baseline > 0 else traces_for_analysis[i]
    
    # Create widgets for event detection parameters
    prominence = widgets.FloatSlider(
        value=0.03,
        min=0.01,
        max=0.2,
        step=0.01,
        description='Prominence:',
        style={'description_width': 'initial'}
    )
    
    width = widgets.IntSlider(
        value=2,
        min=1,
        max=10,
        step=1,
        description='Width:',
        style={'description_width': 'initial'}
    )
    
    distance = widgets.IntSlider(
        value=10,
        min=5,
        max=30,
        step=1,
        description='Distance:',
        style={'description_width': 'initial'}
    )
    
    height = widgets.FloatSlider(
        value=0.02,
        min=0.01,
        max=0.2,
        step=0.01,
        description='Height:',
        style={'description_width': 'initial'}
    )
    
    # Activity threshold
    active_threshold = widgets.FloatSlider(
        value=0.02,
        min=0.01,
        max=0.1,
        step=0.01,
        description='Activity Threshold:',
        style={'description_width': 'initial'}
    )
    
    # Widget for condition selection
    condition = widgets.Dropdown(
        options=[
            ('Spontaneous (0µm)', '0um'),
            ('Evoked (10µm)', '10um'),
            ('Evoked (25µm)', '25um')
        ],
        value='0um',
        description='Condition:',
        style={'description_width': 'initial'}
    )
    
    # Widget to select ROIs to display
    roi_options = [(f"ROI {i+1}", i) for i in range(min(10, len(df_f_traces)))]
    roi_indices = widgets.SelectMultiple(
        options=roi_options,
        value=[0, 1, 2],  # Default: first 3 ROIs
        description='ROIs to Display:',
        disabled=False,
        style={'description_width': 'initial'}
    )
    
    def display_event_detection(prominence, width, distance, height, active_threshold, condition, roi_indices):
        import matplotlib.pyplot as plt
        from scipy.signal import find_peaks
        import numpy as np
        
        if not roi_indices or len(roi_indices) == 0:
            print("Please select at least one ROI to display")
            return
        
        # Create peak detection config
        peak_config = {
            "prominence": prominence,
            "width": width,
            "distance": distance,
            "height": height,
            "rel_height": 0.5
        }
        
        # Create the peak detection and display
        n_rois = len(roi_indices)
        fig, axes = plt.subplots(n_rois, 1, figsize=(15, 4*n_rois))
        
        # Handle single ROI case
        if n_rois == 1:
            axes = np.array([axes])
        
        # Set analysis frames based on condition
        if condition == '0um':
            # For spontaneous, analyze all frames
            analysis_frames = [0, df_f_traces.shape[1]-1]
            active_metric = "spont_peak_frequency"
            title_suffix = "Spontaneous Activity"
        else:
            # For evoked, focus on frames after stimulus
            analysis_frames = [100, df_f_traces.shape[1]-1]
            active_metric = "peak_amplitude"
            title_suffix = f"Evoked Activity ({condition})"
        
        # Calculate baseline frames - just use first 100 frames or fewer
        baseline_frames = [0, min(100, df_f_traces.shape[1]-1)]
        
        # Process and display each selected ROI
        active_rois = 0
        for i, roi_idx in enumerate(roi_indices):
            trace = df_f_traces[roi_idx]
            
            # Extract analysis window
            analysis_start, analysis_end = analysis_frames
            analysis_window = trace[analysis_start:analysis_end+1]
            
            # For evoked conditions, calculate and display stimulus time
            if condition != '0um':
                stim_frame = 100  # Frame where stimulus occurs
            
            # Extract peaks
            if condition == '0um':
                # For spontaneous, look at peaks during baseline period
                baseline_trace = trace[baseline_frames[0]:baseline_frames[1]+1]
                peaks, properties = find_peaks(
                    baseline_trace,
                    prominence=prominence/2,  # Use lower threshold for spontaneous
                    width=width,
                    distance=distance,
                    height=active_threshold
                )
                
                # Calculate peak frequency (peaks per 100 frames)
                peak_freq = len(peaks) / (len(baseline_trace) / 100) if len(baseline_trace) > 0 else 0
                
                # Check if ROI is active based on peak frequency
                is_active = peak_freq > active_threshold
                if is_active:
                    active_rois += 1
                
                # Plot trace
                axes[i].plot(trace, 'k-', label='dF/F')
                
                # Highlight baseline window
                axes[i].axvspan(baseline_frames[0], baseline_frames[1], color='lightgray', alpha=0.2, label='Baseline Window')
                
                # Find and highlight peaks in full trace
                all_peaks, _ = find_peaks(
                    trace,
                    prominence=prominence/2,
                    width=width,
                    distance=distance,
                    height=active_threshold
                )
                
                if len(all_peaks) > 0:
                    axes[i].plot(all_peaks, trace[all_peaks], 'ro', label='Peaks')
                
                # Add title with metrics
                axes[i].set_title(f"ROI {roi_idx+1} - {'Active' if is_active else 'Inactive'} - Peak Freq: {peak_freq:.2f}/100 frames")
                
            else:
                # For evoked, look at peaks after stimulus
                peaks, properties = find_peaks(
                    analysis_window,
                    prominence=prominence,
                    width=width,
                    distance=distance,
                    height=height
                )
                
                # Calculate peak amplitude (max value)
                peak_amplitude = np.max(analysis_window) if len(analysis_window) > 0 else 0
                
                # Check if ROI is active based on peak amplitude
                is_active = peak_amplitude > active_threshold
                if is_active:
                    active_rois += 1
                
                # Plot trace
                axes[i].plot(trace, 'k-', label='dF/F')
                
                # Add a vertical line at stimulus time
                axes[i].axvline(x=stim_frame, color='r', linestyle='--', label='Stimulus', alpha=0.7)
                
                # Highlight analysis window
                axes[i].axvspan(analysis_start, analysis_end, color='lightgray', alpha=0.2, label='Analysis Window')
                
                # Find and highlight peaks
                if len(peaks) > 0:
                    # Adjust peak indices to match original trace
                    adjusted_peaks = peaks + analysis_start
                    axes[i].plot(adjusted_peaks, trace[adjusted_peaks], 'ro', label='Peaks')
                
                # Add title with metrics
                axes[i].set_title(f"ROI {roi_idx+1} - {'Active' if is_active else 'Inactive'} - Peak Amplitude: {peak_amplitude:.4f}")
            
            # Add zero line for reference
            axes[i].axhline(y=0, color='k', linestyle='--', alpha=0.3)
            
            # Add a threshold line
            axes[i].axhline(y=active_threshold, color='g', linestyle=':', 
                           label=f'Threshold ({active_threshold:.2f})', alpha=0.5)
            
            axes[i].set_xlabel('Frame')
            axes[i].set_ylabel('dF/F')
            axes[i].legend()
            axes[i].grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.suptitle(f"Event Detection - {title_suffix} ({active_rois}/{n_rois} ROIs Active)", fontsize=16, y=1.02)
        plt.show()
        
        # Update config with current values
        # Peak detection parameters
        if 'peak_detection' not in config['analysis']:
            config['analysis']['peak_detection'] = {}
        
        config['analysis']['peak_detection']['prominence'] = prominence
        config['analysis']['peak_detection']['width'] = width
        config['analysis']['peak_detection']['distance'] = distance
        config['analysis']['peak_detection']['height'] = height
        
        # Activity threshold
        config['analysis']['active_threshold'] = active_threshold
        
        # Condition-specific parameters
        if 'condition_specific' not in config['analysis']:
            config['analysis']['condition_specific'] = {}
        
        if condition not in config['analysis']['condition_specific']:
            config['analysis']['condition_specific'][condition] = {}
        
        config['analysis']['condition_specific'][condition]['active_threshold'] = active_threshold
        config['analysis']['condition_specific'][condition]['active_metric'] = active_metric
        
        print(f"Updated config with: prominence={prominence}, width={width}, distance={distance}, height={height}")
        print(f"active_threshold={active_threshold}, condition={condition}, active_metric={active_metric}")
        print("To apply these settings to your pipeline, update your config.yaml file.")
    
    # Create and return interactive widget
    interactive_plot = interactive(
        display_event_detection,
        prominence=prominence,
        width=width,
        distance=distance,
        height=height,
        active_threshold=active_threshold,
        condition=condition,
        roi_indices=roi_indices
    )
    
    return interactive_plot

print("\n## Interactive Event Detection Tool ##")
print("Run the following commands to start the event detection visualization:")
print("app = run_event_detection_visualization()")
print("display(app)")

In [None]:
app = run_event_detection_visualization()
display(app)

## Save Optimized Configuration

After exploring different parameter settings in the interactive visualizations, you can save your optimized configuration for future use. This will create a new YAML configuration file with all the parameter adjustments you've made.

### Options
- **Save Configuration**: Write the current parameters to a YAML file
- **Print Current Configuration**: Display the current configuration in the notebook
- **Reset Configuration**: Revert to the original configuration that was loaded

The saved configuration can be used with the command-line version of the pipeline by specifying the `--config` parameter.

In [None]:
# @title Configuration Management {display-mode: "form"}
def save_config_to_file():
    """Save the updated configuration to a YAML file"""
    # Create a file selector
    output_path = widgets.Text(
        value='optimized_config.yaml',
        placeholder='Enter file path to save config',
        description='Output File:',
        style={'description_width': 'initial'},
        layout=widgets.Layout(width='80%')
    )
    
    display(output_path)
    
    # Create a save button
    save_button = widgets.Button(
        description='Save Configuration',
        button_style='success',
        tooltip='Click to save the current configuration to a file'
    )
    
    def on_save_click(b):
        try:
            path = output_path.value
            if not path:
                print("Please enter a valid file path")
                return
            
            # Save the configuration to the specified file
            with open(path, 'w') as f:
                yaml.dump(config, f, default_flow_style=False)
            
            print(f"Configuration saved to {path}")
            print("To use this configuration in your pipeline, specify it with the --config parameter")
        except Exception as e:
            print(f"Error saving configuration: {str(e)}")
    
    save_button.on_click(on_save_click)
    display(save_button)
    
    # Create a button to print the current configuration
    print_button = widgets.Button(
        description='Print Current Configuration',
        button_style='info',
        tooltip='Click to print the current configuration to the notebook'
    )
    
    def on_print_click(b):
        # Print the configuration in a readable format
        print("Current Configuration:")
        print("=" * 50)
        
        # Print as formatted YAML
        print(yaml.dump(config, default_flow_style=False))
    
    print_button.on_click(on_print_click)
    display(print_button)
    
    # Create a button to reset configuration to original
    reset_button = widgets.Button(
        description='Reset Configuration',
        button_style='danger',
        tooltip='Click to reset the configuration to the original values'
    )
    
    def on_reset_click(b):
        # Reset the configuration to the original values
        global config
        config = copy.deepcopy(config_original)
        print("Configuration reset to original values")
    
    reset_button.on_click(on_reset_click)
    display(reset_button)

# Call the save configuration function
save_config_to_file()

## Run Full Pipeline

Once you've optimized the parameters using the interactive visualizations, you can run the complete analysis pipeline on all your data. The pipeline will:

1. Process all matched TIF/ROI file pairs in the input directory
2. Apply preprocessing with your optimized parameters
3. Extract and refine ROIs
4. Perform background subtraction
5. Detect and analyze events
6. Generate visualizations and metrics

The results will be saved in the output directory specified in the parameters section. Each file pair will have its own subdirectory containing:
- Corrected data (HDF5 format)
- ROI masks and traces
- Analysis metrics (Excel and CSV)
- Visualizations (PNG format)

### Performance Considerations
- Processing multiple large files can be memory-intensive
- The pipeline supports parallel processing using multiple CPU cores
- Adjust the "Max Workers" parameter based on your computer's capabilities

Click "Run Full Pipeline" to start processing all file pairs.

In [None]:
# @title Run Full Pipeline {display-mode: "form"}
# Define the process_file_pair function at the module level (outside of any other function)
def process_file_pair(pair_info):
    """Process a single file pair.
    
    Parameters
    ----------
    pair_info : tuple
        Tuple containing (pair_idx, tif_path, roi_path, config, args)
    
    Returns
    -------
    dict
        Processing results
    """
    pair_idx, tif_path, roi_path, config, args = pair_info
    slice_name = Path(tif_path).stem
    print(f"Processing {slice_name}...")
    
    # Create output directory for this slice
    slice_output_dir = os.path.join(args.output_dir, slice_name)
    os.makedirs(slice_output_dir, exist_ok=True)
    
    # Extract metadata from filename
    metadata = extract_metadata_from_filename(slice_name)
    
    # Load and preprocess the image data
    with tifffile.TiffFile(tif_path) as tif:
        image_data = tif.asarray()
        
        # Ensure data is in (frames, height, width) format
        if len(image_data.shape) == 3:
            if image_data.shape[0] < image_data.shape[1] and image_data.shape[0] < image_data.shape[2]:
                # Already in (frames, height, width) format
                pass
            else:
                # Try to rearrange to (frames, height, width)
                if image_data.shape[2] < image_data.shape[0] and image_data.shape[2] < image_data.shape[1]:
                    image_data = np.moveaxis(image_data, 2, 0)
                elif image_data.shape[1] < image_data.shape[0] and image_data.shape[1] < image_data.shape[2]:
                    image_data = np.moveaxis(image_data, 1, 0)
    
    n_frames, height, width = image_data.shape
    image_shape = (height, width)
    
    # Setup logging
    log_file = os.path.join(slice_output_dir, f"{slice_name}_processing.log")
    slice_logger = setup_logging(log_file, process_id=pair_idx)
    
    # Apply photobleaching correction
    output_h5 = os.path.join(slice_output_dir, f"{slice_name}_corrected.h5")
    corrected_data, _ = correct_photobleaching(
        image_data,
        output_h5,
        config["preprocessing"],
        slice_logger,
        save_output=config["preprocessing"].get("save_corrected_data", True)
    )
    
    # Extract ROIs
    roi_masks, roi_data = extract_roi_fluorescence(
        roi_path,
        corrected_data,
        image_shape,
        slice_output_dir,
        config["roi_processing"],
        slice_logger
    )
    
    # Background subtraction
    if config["roi_processing"].get("steps", {}).get("subtract_background", True):
        bg_method = config["roi_processing"]["background"].get("method", "darkest_pixels")
        
        if bg_method == "global_background":
            bg_corrected_data = subtract_global_background(
                corrected_data,
                roi_data,
                roi_masks,
                config["roi_processing"]["background"],
                slice_logger,
                output_dir=slice_output_dir
            )
        else:
            bg_corrected_data = subtract_background(
                corrected_data,
                roi_data,
                roi_masks,
                config["roi_processing"]["background"],
                slice_logger,
                output_dir=slice_output_dir
            )
    else:
        bg_corrected_data = roi_data
    
    # Analyze fluorescence
    metrics_df, df_f_traces = analyze_fluorescence(
        bg_corrected_data,
        roi_masks,
        tif_path,
        config["analysis"],
        slice_logger,
        output_dir=slice_output_dir,
        metadata=metadata
    )
    
    # Save metrics to Excel
    metrics_file = os.path.join(slice_output_dir, f"{slice_name}_metrics.xlsx")
    metrics_df.to_excel(metrics_file, index=False)
    
    # Also save as CSV for easier processing
    csv_file = os.path.join(slice_output_dir, f"{slice_name}_metrics.csv")
    metrics_df.to_csv(csv_file, index=False)
    
    # Generate visualizations
    flagged_rois = perform_qc_checks(
        bg_corrected_data,
        metrics_df,
        config["analysis"].get("qc_thresholds", {}),
        slice_logger
    )
    
    generate_visualizations(
        df_f_traces,
        roi_masks,
        metrics_df,
        flagged_rois,
        tif_path,
        slice_output_dir,
        config["visualization"],
        slice_logger,
        metadata=metadata
    )
    
    return {
        "slice_name": slice_name,
        "metrics_file": metrics_file,
        "metadata": metadata
    }

def run_full_pipeline():
    """Run the full pipeline with the current configuration"""
    # Create a run button
    run_button = widgets.Button(
        description='Run Full Pipeline',
        button_style='success',
        tooltip='Click to run the full pipeline with the current configuration'
    )
    
    def on_run_click(b):
        # Run the pipeline with the current configuration
        print("Running pipeline...")
        print(f"Input Directory: {args.input_dir}")
        print(f"Output Directory: {args.output_dir}")
        print(f"Pipeline Mode: {args.mode}")
        print(f"Max Workers: {args.max_workers}")
        
        try:
            # Update args with current widget values
            update_args()
            
            # Match tif and roi files
            file_pairs = match_tif_and_roi_files(args.input_dir, logger)
            print(f"Found {len(file_pairs)} matched file pairs")
            
            if len(file_pairs) == 0:
                print("No file pairs found. Please check the input directory.")
                return
            
            # Process each file pair
            import concurrent.futures
            from tqdm.notebook import tqdm
            
            # Create a progress widget
            progress = widgets.IntProgress(
                value=0,
                min=0,
                max=len(file_pairs),
                description='Processing:',
                bar_style='info',
                orientation='horizontal'
            )
            display(progress)
            
            # Process file pairs sequentially or in parallel based on max_workers
            results = []
            
            # Create tuples of arguments for the process_file_pair function
            # This allows us to pass the config to each process
            pair_infos = [
                (i, tif_path, roi_path, config, args) 
                for i, (tif_path, roi_path) in enumerate(file_pairs)
            ]
            
            if args.max_workers > 1:
                print(f"Using {args.max_workers} parallel workers")
                with concurrent.futures.ProcessPoolExecutor(max_workers=args.max_workers) as executor:
                    futures = [executor.submit(process_file_pair, pair_info) 
                              for pair_info in pair_infos]
                    
                    for i, future in enumerate(concurrent.futures.as_completed(futures)):
                        try:
                            result = future.result()
                            results.append(result)
                        except Exception as e:
                            print(f"Error in worker process: {str(e)}")
                        finally:
                            progress.value += 1
            else:
                print("Processing files sequentially")
                for pair_info in pair_infos:
                    try:
                        result = process_file_pair(pair_info)
                        results.append(result)
                    except Exception as e:
                        print(f"Error processing file: {str(e)}")
                    finally:
                        progress.value += 1
            
            progress.bar_style = 'success'
            print(f"Processing completed for {len(results)} file pairs")
            
            # Generate summary if we have successful results
            if results:
                print("Generating summary...")
                # Group results by mouse ID
                mouse_data = {}
                for result in results:
                    if "metadata" in result:
                        mouse_id = result["metadata"].get("mouse_id", "unknown")
                        if mouse_id not in mouse_data:
                            mouse_data[mouse_id] = []
                        mouse_data[mouse_id].append(result)
                
                # Create summary for each mouse
                for mouse_id, slices in mouse_data.items():
                    summary_path = save_mouse_summary(mouse_id, slices, args.output_dir, logger)
                    print(f"Saved summary for mouse {mouse_id} to {summary_path}")
            
        except Exception as e:
            print(f"Error running pipeline: {str(e)}")
            import traceback
            traceback.print_exc()
    
    run_button.on_click(on_run_click)
    display(run_button)

# Call the run pipeline function
run_full_pipeline()

In [None]:
# @title Run Full Pipeline {display-mode: "form"}
def run_full_pipeline():
    """Run the full pipeline with the current configuration"""
    # Create a run button
    run_button = widgets.Button(
        description='Run Full Pipeline',
        button_style='success',
        tooltip='Click to run the full pipeline with the current configuration'
    )
    
    def on_run_click(b):
        # Run the pipeline with the current configuration
        print("Running pipeline...")
        print(f"Input Directory: {args.input_dir}")
        print(f"Output Directory: {args.output_dir}")
        print(f"Pipeline Mode: {args.mode}")
        print(f"Max Workers: {args.max_workers}")
        
        try:
            # Update args with current widget values
            update_args()
            
            # Match tif and roi files
            file_pairs = match_tif_and_roi_files(args.input_dir, logger)
            print(f"Found {len(file_pairs)} matched file pairs")
            
            if len(file_pairs) == 0:
                print("No file pairs found. Please check the input directory.")
                return
            
            # Process each file pair
            import concurrent.futures
            from tqdm.notebook import tqdm
            
            # Process a single file pair
            def process_file_pair(pair_idx):
                tif_path, roi_path = file_pairs[pair_idx]
                slice_name = Path(tif_path).stem
                print(f"Processing {slice_name}...")
                
                # Create output directory for this slice
                slice_output_dir = os.path.join(args.output_dir, slice_name)
                os.makedirs(slice_output_dir, exist_ok=True)
                
                # Extract metadata from filename
                metadata = extract_metadata_from_filename(slice_name)
                
                # Load and preprocess the image data
                with tifffile.TiffFile(tif_path) as tif:
                    image_data = tif.asarray()
                    
                    # Ensure data is in (frames, height, width) format
                    if len(image_data.shape) == 3:
                        if image_data.shape[0] < image_data.shape[1] and image_data.shape[0] < image_data.shape[2]:
                            # Already in (frames, height, width) format
                            pass
                        else:
                            # Try to rearrange to (frames, height, width)
                            if image_data.shape[2] < image_data.shape[0] and image_data.shape[2] < image_data.shape[1]:
                                image_data = np.moveaxis(image_data, 2, 0)
                            elif image_data.shape[1] < image_data.shape[0] and image_data.shape[1] < image_data.shape[2]:
                                image_data = np.moveaxis(image_data, 1, 0)
                
                n_frames, height, width = image_data.shape
                image_shape = (height, width)
                
                # Apply photobleaching correction
                output_h5 = os.path.join(slice_output_dir, f"{slice_name}_corrected.h5")
                corrected_data, _ = correct_photobleaching(
                    image_data,
                    output_h5,
                    config["preprocessing"],
                    logger,
                    save_output=config["preprocessing"].get("save_corrected_data", True)
                )
                
                # Extract ROIs
                roi_masks, roi_data = extract_roi_fluorescence(
                    roi_path,
                    corrected_data,
                    image_shape,
                    slice_output_dir,
                    config["roi_processing"],
                    logger
                )
                
                # Background subtraction
                if config["roi_processing"].get("steps", {}).get("subtract_background", True):
                    bg_method = config["roi_processing"]["background"].get("method", "darkest_pixels")
                    
                    if bg_method == "global_background":
                        bg_corrected_data = subtract_global_background(
                            corrected_data,
                            roi_data,
                            roi_masks,
                            config["roi_processing"]["background"],
                            logger,
                            output_dir=slice_output_dir
                        )
                    else:
                        bg_corrected_data = subtract_background(
                            corrected_data,
                            roi_data,
                            roi_masks,
                            config["roi_processing"]["background"],
                            logger,
                            output_dir=slice_output_dir
                        )
                else:
                    bg_corrected_data = roi_data
                
                # Analyze fluorescence
                metrics_df, df_f_traces = analyze_fluorescence(
                    bg_corrected_data,
                    roi_masks,
                    tif_path,
                    config["analysis"],
                    logger,
                    output_dir=slice_output_dir,
                    metadata=metadata
                )
                
                # Save metrics to Excel
                metrics_file = os.path.join(slice_output_dir, f"{slice_name}_metrics.xlsx")
                metrics_df.to_excel(metrics_file, index=False)
                
                # Also save as CSV for easier processing
                csv_file = os.path.join(slice_output_dir, f"{slice_name}_metrics.csv")
                metrics_df.to_csv(csv_file, index=False)
                
                # Generate visualizations
                flagged_rois = perform_qc_checks(
                    bg_corrected_data,
                    metrics_df,
                    config["analysis"].get("qc_thresholds", {}),
                    logger
                )
                
                generate_visualizations(
                    df_f_traces,
                    roi_masks,
                    metrics_df,
                    flagged_rois,
                    tif_path,
                    slice_output_dir,
                    config["visualization"],
                    logger,
                    metadata=metadata
                )
                
                return {
                    "slice_name": slice_name,
                    "metrics_file": metrics_file,
                    "metadata": metadata
                }
            
            # Create a progress widget
            progress = widgets.IntProgress(
                value=0,
                min=0,
                max=len(file_pairs),
                description='Processing:',
                bar_style='info',
                orientation='horizontal'
            )
            display(progress)
            
            # Process file pairs sequentially or in parallel based on max_workers
            results = []
            if args.max_workers > 1:
                print(f"Using {args.max_workers} parallel workers")
                with concurrent.futures.ProcessPoolExecutor(max_workers=args.max_workers) as executor:
                    futures = [executor.submit(process_file_pair, i) for i in range(len(file_pairs))]
                    for i, future in enumerate(concurrent.futures.as_completed(futures)):
                        result = future.result()
                        results.append(result)
                        progress.value += 1
            else:
                print("Processing files sequentially")
                for i in range(len(file_pairs)):
                    result = process_file_pair(i)
                    results.append(result)
                    progress.value += 1
            
            progress.bar_style = 'success'
            print(f"Processing completed for {len(results)} file pairs")
            
        except Exception as e:
            print(f"Error running pipeline: {str(e)}")
            import traceback
            traceback.print_exc()
    
    run_button.on_click(on_run_click)
    display(run_button)

# Call the run pipeline function
run_full_pipeline()

## Notebook Summary

This interactive notebook provides a user-friendly interface to the fluorescence analysis pipeline, allowing you to:

1. **Load and Configure**: Set up pipeline parameters and load configuration from YAML file
2. **Select and Preprocess Data**: Load TIF/ROI file pairs and prepare them for analysis
3. **Explore Parameter Effects**: Use interactive visualizations to understand how different parameters affect:
   - Gaussian denoising image quality
   - ROI selection with PNR refinement
   - Background subtraction effectiveness 
   - Event detection sensitivity and specificity
4. **Optimize Settings**: Adjust parameters to find optimal values for your specific dataset
5. **Save Configuration**: Save your optimized configuration for future use
6. **Run Full Pipeline**: Process all file pairs with your optimized settings

### Output Data
The analysis results saved in the output directory include:
- Corrected fluorescence data (HDF5 format)
- ROI masks and fluorescence traces
- Analysis metrics for each ROI (Excel and CSV format)
- Visualizations of ROIs, traces, and detected events

### Advanced Analysis
For more sophisticated analysis of the pipeline outputs, you can:
- Use the generated CSVs and Excel files for statistical analysis
- Load the HDF5 files for custom visualization and processing
- Examine the visualization PNGs for quality control
- Compare results across different experimental conditions

For questions or issues with the pipeline, please refer to the documentation or contact the developers.