# CUMIN (CUrated MINion): A Fluorescence Analysis Pipeline for Curated ROIs

This notebook provides an interactive interface to the fluorescence analysis pipeline, allowing you to explore parameter effects on ROI detection, trace extraction, and event analysis.

## Overview
- Load and select fluorescence imaging data
- Explore parameter effects with interactive visualizations:
  - Gaussian denoising for image preprocessing
  - ROI processing with PNR refinement
  - Background subtraction for improved signal
  - Event detection and analysis
- Save optimized configurations
- Run the full pipeline with optimized parameters

Each section includes in-depth explanations of the algorithms and parameters, helping you understand how different settings affect your analysis results.

In [1]:
# @title Import Libraries and Setup {display-mode: "form"}
# Import standard libraries
import os
import sys
import yaml
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import h5py
import tifffile
import cv2
from pathlib import Path
from datetime import datetime
import json
import copy
import pickle
import time
import warnings
from scipy.ndimage import binary_dilation, median_filter
from IPython.display import display, clear_output

# Import interactive libraries
import ipywidgets as widgets
from ipywidgets import interact, fixed, interact_manual, interactive

# Check for HoloViews and Panel
try:
    import holoviews as hv
    import panel as pn
    import param
    HAS_PANEL = True
    hv.extension('bokeh')
    pn.extension()
except ImportError:
    HAS_PANEL = False
    print("HoloViews and Panel are not installed. Some visualizations will be limited.")
    print("To install: pip install holoviews panel param bokeh")

# Add parent directory to path if notebook is in notebooks/
if '..' not in sys.path:
    sys.path.append('..')

# Add current directory to path to import local modules
if '.' not in sys.path:
    sys.path.append('.')

# Import pipeline modules
try:
    from modules.file_matcher import match_tif_and_roi_files, extract_metadata_from_filename
    from modules.preprocessing import (
        correct_photobleaching, remove_background, denoise, stripe_correction
    )
    from modules.roi_processing import (
        extract_roi_fluorescence, subtract_background, subtract_global_background,
        extract_rois_from_zip, save_masks_for_cnmf, extract_roi_fluorescence_with_cnmf,
        refine_rois_with_cnmfe, refine_rois_with_pnr, split_signal_noise, visualize_pnr_results,
        save_trace_data
    )
    from modules.analysis import (
        analyze_fluorescence, perform_qc_checks, extract_peak_parameters,
        extract_spontaneous_activity, calculate_baseline_excluding_peaks
    )
    from modules.visualization import generate_visualizations
    from modules.utils import setup_logging, save_slice_data, save_mouse_summary
    #from modules.advanced_analysis import run_advanced_analysis
    from modules.visualization_helpers import (
        run_denoising_visualization, run_background_removal_visualization,
        run_motion_correction_visualization, run_event_detection_visualization,
        normalize_for_display
    )
    
    # Check for motion correction module
    try:
        from modules.motion_correction import apply_normcorre_correction, estimate_motion
        HAS_MOTION_CORRECTION = True
    except ImportError:
        HAS_MOTION_CORRECTION = False
        print("Motion correction module not available. This feature will be disabled.")
    
    #Try importing advanced analysis module if available
    try:
        from modules.advanced_analysis import run_advanced_analysis
        ADVANCED_ANALYSIS_AVAILABLE = True
    except ImportError:
        ADVANCED_ANALYSIS_AVAILABLE = False
        print("Advanced analysis module not available. Some features will be disabled.")
    
    print("Successfully imported all modules")
except ImportError as e:
    print(f"Error importing modules: {e}")
    print("Make sure the required modules are in the 'modules' directory or adjust the import path.")

# Suppress warnings for cleaner output
warnings.filterwarnings("ignore")



Successfully imported all modules


## Setting Up Pipeline Parameters

Before starting the analysis, we need to configure the pipeline parameters including:

- **Input Directory**: Location of your .tif (image) and .zip (ROI) files
- **Output Directory**: Where analysis results will be saved
- **Configuration File**: YAML file with pipeline parameters
- **Pipeline Mode**: "all" runs the complete pipeline; other options are "preprocess", "extract", and "analyze"
- **Max Workers**: Number of parallel processes for multi-core processing

You can adjust these parameters below. Click "Update Parameters" after making changes.

In [2]:
# @title Pipeline Parameters {display-mode: "form"}
class Args:
    """Class to simulate command line arguments"""
    def __init__(self):
        self.input_dir = r"F:\Recovered\Research\BoninLab\PainModelingProject\ForCumin"  # CHANGE THIS
        self.output_dir = r"F:\Recovered\Research\BoninLab\PainModelingProject\ForCumin\New folder"  # CHANGE THIS
        self.config = "config.yaml"  # Path to your config file
        self.mode = "all"  # Options: "all", "preprocess", "extract", "analyze"
        self.max_workers = 4  # Adjust based on your CPU cores
        self.disable_advanced = False

# Create args object
args = Args()

# Create interactive widgets for adjusting parameters
input_dir_widget = widgets.Text(
    value=args.input_dir,
    description='Input Directory:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='80%')
)

output_dir_widget = widgets.Text(
    value=args.output_dir,
    description='Output Directory:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='80%')
)

config_widget = widgets.Text(
    value=args.config,
    description='Config File:',
    style={'description_width': 'initial'}
)

mode_widget = widgets.Dropdown(
    options=['all', 'preprocess', 'extract', 'analyze'],
    value=args.mode,
    description='Pipeline Mode:',
    style={'description_width': 'initial'}
)

workers_widget = widgets.IntSlider(
    value=args.max_workers,
    min=1,
    max=12,
    step=1,
    description='Max Workers:',
    style={'description_width': 'initial'}
)

disable_advanced_widget = widgets.Checkbox(
    value=args.disable_advanced,
    description='Disable Advanced Analysis',
    style={'description_width': 'initial'}
)

# Function to update args object based on widget values
def update_args():
    args.input_dir = input_dir_widget.value
    args.output_dir = output_dir_widget.value
    args.config = config_widget.value
    args.mode = mode_widget.value
    args.max_workers = workers_widget.value
    args.disable_advanced = disable_advanced_widget.value
    print("Parameters updated:")
    print(f"Input Directory: {args.input_dir}")
    print(f"Output Directory: {args.output_dir}")
    print(f"Config File: {args.config}")
    print(f"Pipeline Mode: {args.mode}")
    print(f"Max Workers: {args.max_workers}")
    print(f"Disable Advanced Analysis: {args.disable_advanced}")

# Create update button
update_button = widgets.Button(
    description='Update Parameters',
    button_style='info',
    tooltip='Click to update parameters'
)

update_button.on_click(lambda b: update_args())

# Display widgets
display(input_dir_widget)
display(output_dir_widget)
display(config_widget)
display(mode_widget)
display(workers_widget)
display(disable_advanced_widget)
display(update_button)

Text(value='F:\\Recovered\\Research\\BoninLab\\PainModelingProject\\ForCumin', description='Input Directory:',…

Text(value='F:\\Recovered\\Research\\BoninLab\\PainModelingProject\\ForCumin\\New folder', description='Output…

Text(value='config.yaml', description='Config File:', style=TextStyle(description_width='initial'))

Dropdown(description='Pipeline Mode:', options=('all', 'preprocess', 'extract', 'analyze'), style=DescriptionS…

IntSlider(value=4, description='Max Workers:', max=12, min=1, style=SliderStyle(description_width='initial'))

Checkbox(value=False, description='Disable Advanced Analysis', style=CheckboxStyle(description_width='initial'…

Button(button_style='info', description='Update Parameters', style=ButtonStyle(), tooltip='Click to update par…

## Loading Configuration

The pipeline configuration is stored in a YAML file that defines parameters for each processing step. 
The configuration includes settings for:

- **Preprocessing**: Methods and parameters for photo-bleaching correction, denoising, etc.
- **ROI Processing**: ROI extraction, refinement, and background subtraction
- **Analysis**: Event detection parameters, condition-specific settings, etc.
- **Visualization**: Plot types, colormaps, and other visualization settings

Click "Load Configuration" to load the configuration file specified in the parameters section.

In [3]:
# @title Load Configuration {display-mode: "form"}
def load_config(config_path, args=None):
    """Load configuration from YAML file and apply command line overrides."""
    try:
        with open(config_path, 'r') as f:
            config = yaml.safe_load(f)
        print(f"Loaded configuration from {config_path}")
        
        # Apply command line overrides if provided
        if args and args.disable_advanced:
            # Disable advanced analysis if requested via command line
            if "advanced_analysis" in config:
                config["advanced_analysis"]["enabled"] = False
                print("Advanced analysis disabled via command line argument")
        
        return config
    except Exception as e:
        print(f"Failed to load configuration: {str(e)}")
        return None

# Setup logging
logger = setup_logging()

# Create a load button
load_config_button = widgets.Button(
    description='Load Configuration',
    button_style='success',
    tooltip='Click to load the configuration file'
)

def on_load_config_click(b):
    global config, config_original
    # Load configuration
    config = load_config(args.config, args)
    if config:
        print("Configuration loaded successfully.")
        # Create a backup of original config for reference
        config_original = copy.deepcopy(config)
        # Print some key configuration settings
        print("\nKey Configuration Settings:")
        print(f"Photobleaching correction method: {config['preprocessing'].get('correction_method', 'Not specified')}")
        print(f"Motion correction enabled: {config.get('motion_correction', {}).get('enabled', False)}")
        if 'denoise' in config['preprocessing']:
            print(f"Denoising enabled: {config['preprocessing']['denoise'].get('enabled', False)}")
        if 'background_removal' in config['preprocessing']:
            print(f"Background removal enabled: {config['preprocessing']['background_removal'].get('enabled', False)}")
        print(f"Background subtraction method: {config['roi_processing'].get('background', {}).get('method', 'Not specified')}")
    else:
        print("Failed to load configuration. Please check the config file path.")

load_config_button.on_click(on_load_config_click)
display(load_config_button)

Button(button_style='success', description='Load Configuration', style=ButtonStyle(), tooltip='Click to load t…

## Data Selection and Loading

This section allows you to select a file pair (TIF image stack and ZIP ROI file) for analysis. The pipeline searches for matching file pairs in the input directory.

### File Pairs
A file pair consists of:
- A **.tif file** containing the fluorescence imaging data (time series of frames)
- A **.zip file** containing ImageJ/FIJI ROI definitions

### Loading Process
When you load a file pair, the following happens:
1. The TIF stack is loaded and preprocessed
2. ROI masks are extracted from the ZIP file
3. Intermediate data is prepared for interactive visualization
4. Visualization data is saved for future use

Select a file pair from the dropdown menu and click "Load Selected File Pair" to start.

In [None]:
# @title Select and Load Data {display-mode: "form"}

def select_and_load_data():
    """Select and load a TIFF file for analysis"""
    global image_data, tif_path
    
    # Check if input directory exists
    if not os.path.exists(args.input_dir):
        print(f"Input directory {args.input_dir} does not exist!")
        return
    
    # Get list of TIFF files
    tiff_files = sorted([f for f in os.listdir(args.input_dir) if f.endswith(('.tif', '.tiff'))])
    
    if not tiff_files:
        print(f"No TIFF files found in {args.input_dir}")
        return
    
    # Create file selection widget
    file_selector = widgets.Dropdown(
        options=tiff_files,
        description='Select TIFF file:',
        style={'description_width': 'initial'},
        layout=widgets.Layout(width='80%')
    )
    
    max_frames_slider = widgets.IntSlider(
        value=100,
        min=50,
        max=500,
        step=50,
        description='Max Frames:',
        style={'description_width': 'initial'}
    )
    
    load_button = widgets.Button(
        description='Load Data',
        button_style='success',
        tooltip='Click to load the selected TIFF file'
    )
    
    info_output = widgets.Output()
    
    def on_load_click(b):
        with info_output:
            clear_output()
            print(f"Loading {file_selector.value}...")
            
            try:
                # Set global tif_path
                global tif_path
                tif_path = os.path.join(args.input_dir, file_selector.value)
                
                # Load limited frames to save memory
                with tifffile.TiffFile(tif_path) as tif:
                    n_frames = min(max_frames_slider.value, len(tif.pages))
                    global image_data
                    image_data = np.zeros((n_frames, tif.pages[0].shape[0], tif.pages[0].shape[1]), dtype=np.float32)
                    
                    for i in range(n_frames):
                        image_data[i] = tif.pages[i].asarray()
                    
                    print(f"Loaded {n_frames} frames with shape {image_data.shape[1]}x{image_data.shape[2]}")
                    print("Data statistics:")
                    print(f"  Min: {image_data.min():.2f}")
                    print(f"  Max: {image_data.max():.2f}")
                    print(f"  Mean: {image_data.mean():.2f}")
                    print(f"Data loaded successfully. You can now proceed with processing steps.")
                    
                    # Display a sample frame
                    plt.figure(figsize=(8, 8))
                    plt.imshow(image_data[0], cmap='gray')
                    plt.title(f"First Frame of {file_selector.value}")
                    plt.colorbar(label='Intensity')
                    plt.show()
            except Exception as e:
                print(f"Error loading file: {str(e)}")
    
    load_button.on_click(on_load_click)
    
    # Display widgets
    display(widgets.VBox([
        file_selector, 
        max_frames_slider, 
        load_button, 
        info_output
    ]))

# Initialize global variables
image_data = None
tif_path = None

# Run the function
select_and_load_data()

## Load Previously Saved Visualization Data

If you've previously run this notebook and saved visualization data, you can load it here instead of reprocessing the data. This saves time when you want to continue exploring the same dataset.

The visualization data includes:
- Original image data
- Preprocessed (corrected) data
- ROI masks and centers
- ROI fluorescence traces
- Metadata extracted from the filename

Enter the path to the saved visualization data file (`visualization_data.pkl`) and click "Load Saved Visualization Data" to continue.

In [None]:
# @title Load Saved Visualization Data {display-mode: "form"}
def load_visualization_data(data_file):
    """Load saved visualization data from pickle file"""
    try:
        with open(data_file, 'rb') as f:
            data = pickle.load(f)
        
        # Assign to global variables for use in visualizations
        global image_data, corrected_data, roi_masks, roi_centers, roi_data, metadata, selected_tif_path, selected_roi_path, image_shape
        image_data = data['image_data']
        corrected_data = data['corrected_data']
        roi_masks = data['roi_masks']
        roi_centers = data['roi_centers']
        roi_data = data['roi_data']
        metadata = data['metadata']
        selected_tif_path = data['selected_tif_path']
        selected_roi_path = data['selected_roi_path']
        image_shape = data['image_shape']
        
        print("Visualization data loaded successfully!")
        print(f"File: {Path(selected_tif_path).stem}")
        print(f"Image shape: {image_data.shape}")
        print(f"Number of ROIs: {len(roi_masks)}")
        print(f"Condition: {metadata.get('condition', 'unknown')}")
        return True
    except Exception as e:
        print(f"Error loading visualization data: {str(e)}")
        return False

# Widget to select a visualization data file
vis_data_path_widget = widgets.Text(
    placeholder='Enter path to visualization_data.pkl file',
    description='Data File:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='80%')
)

load_vis_data_button = widgets.Button(
    description='Load Saved Visualization Data',
    button_style='info',
    tooltip='Load previously saved visualization data'
)

def on_load_vis_data_click(b):
    path = vis_data_path_widget.value
    if path:
        success = load_visualization_data(path)
        if success:
            print("Ready for visualization!")
    else:
        print("Please enter a valid path to the visualization_data.pkl file")

load_vis_data_button.on_click(on_load_vis_data_click)
display(vis_data_path_widget)
display(load_vis_data_button)

## Motion Correction

In [None]:
# @title Motion Correction {display-mode: "form"}
def run_motion_correction():
    """Run interactive motion correction visualization"""
    global image_data, config, selected_tif_path
    
    if 'image_data' not in globals() or image_data is None:
        print("Please load image data first")
        return
    
    if 'config' not in globals() or config is None:
        print("Please load configuration first")
        return
    
    # Create the visualization
    motion_app = run_motion_correction_visualization(image_data, selected_tif_path, config, logger)
    
    if motion_app is not None:
        display(motion_app)
    else:
        print("Failed to create motion correction visualization.")
        print("Make sure HoloViews and Panel are installed:")
        print("pip install holoviews panel param bokeh")

# Run the function
run_motion_correction()

## Photobleaching Correction

In [None]:
# @title Photobleaching Correction {display-mode: "form"}
def run_photobleaching_correction():
    """Apply photobleaching correction"""
    global image_data, config, corrected_data
    
    if 'image_data' not in globals() or image_data is None:
        print("Image data not found. Please make sure you've loaded data using the 'Load and Explore Data' cell.")
        print("If you've loaded data but still see this error, try running the following code to verify data availability:")
        print("print('Available variables:', [var for var in globals() if not var.startswith('_')])")
        print("print('Image data shape:', image_data.shape if 'image_data' in globals() and image_data is not None else 'Not available')")
        return
    
    if 'config' not in globals() or config is None:
        print("Please load configuration first")
        return
    
    # Get photobleaching correction parameters
    method = widgets.Dropdown(
        options=[
            ('Polynomial Detrend', 'polynomial_detrend'),
            ('Exponential Decay', 'exponential_decay'),
            ('Bi-Exponential', 'bi_exponential'),
            ('Adaptive Percentile', 'adaptive_percentile'),
            ('Two-Stage Detrend', 'two_stage_detrend')
        ],
        value=config['preprocessing'].get('correction_method', 'polynomial_detrend'),
        description='Method:',
        style={'description_width': 'initial'}
    )
    
    polynomial_order = widgets.IntSlider(
        value=config['preprocessing'].get('polynomial_order', 3),
        min=1,
        max=5,
        step=1,
        description='Polynomial Order:',
        style={'description_width': 'initial'}
    )
    
    smoothing_sigma = widgets.FloatSlider(
        value=config['preprocessing'].get('smoothing_sigma', 2.0),
        min=0.0,
        max=5.0,
        step=0.1,
        description='Smoothing Sigma:',
        style={'description_width': 'initial'}
    )
    
    generate_plot = widgets.Checkbox(
        value=config['preprocessing'].get('generate_plot', True),
        description='Generate Plot',
        style={'description_width': 'initial'}
    )
    
    run_button = widgets.Button(
        description='Run Correction',
        button_style='warning',
        tooltip='Click to run photobleaching correction'
    )
    
    info_output = widgets.Output()
    
    def on_run_click(b):
        with info_output:
            clear_output()
            print(f"Running photobleaching correction with method: {method.value}")
            
            # Update config with current values
            config['preprocessing']['correction_method'] = method.value
            config['preprocessing']['polynomial_order'] = polynomial_order.value
            config['preprocessing']['smoothing_sigma'] = smoothing_sigma.value
            config['preprocessing']['generate_plot'] = generate_plot.value
            
            try:
                # Apply photobleaching correction
                global corrected_data
                corrected_data, _ = correct_photobleaching(
                    image_data,
                    None,  # No output path needed
                    config['preprocessing'],
                    logger,
                    save_output=False
                )
                
                print("Photobleaching correction completed successfully.")
                
                # Plot mean intensity before and after correction
                mean_original = np.mean(image_data, axis=(1, 2))
                mean_corrected = np.mean(corrected_data, axis=(1, 2))
                
                fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10), gridspec_kw={'height_ratios': [3, 1]})
                
                # Format method name for display
                method_display = method.value.replace('_', ' ').title()
                
                # Plot mean intensity over time
                ax1.plot(mean_original, 'r-', alpha=0.7, label='Original')
                ax1.plot(mean_corrected, 'g-', alpha=0.7, label=f'Corrected ({method_display})')
                ax1.set_title('Photobleaching Correction Verification', fontsize=14)
                ax1.set_xlabel('Frame')
                ax1.set_ylabel('Mean Intensity')
                ax1.legend(loc='best')
                ax1.grid(True, alpha=0.3)
                
                # Plot the correction factor (ratio between original and corrected)
                correction_factor = mean_corrected / np.maximum(mean_original, 1e-6)
                ax2.plot(correction_factor, 'b-', alpha=0.7)
                ax2.set_title('Correction Factor', fontsize=12)
                ax2.set_xlabel('Frame')
                ax2.set_ylabel('Factor')
                ax2.grid(True, alpha=0.3)
                
                plt.tight_layout()
                plt.show()
                
                # Show a before/after frame comparison
                sample_frame = 50  # Use frame 50 for comparison
                if sample_frame >= len(image_data):
                    sample_frame = len(image_data) // 2
                
                fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
                
                ax1.imshow(image_data[sample_frame], cmap='gray')
                ax1.set_title(f"Original (Frame {sample_frame})")
                ax1.axis('off')
                
                ax2.imshow(corrected_data[sample_frame], cmap='gray')
                ax2.set_title(f"Corrected (Frame {sample_frame})")
                ax2.axis('off')
                
                plt.tight_layout()
                plt.show()
                
                # Display button to proceed with corrected data
                print("\nPhotobleaching correction is complete. The corrected data is now available.")
                print("You can now proceed to ROI extraction and analysis.")
                
            except Exception as e:
                print(f"Error during photobleaching correction: {str(e)}")
                import traceback
                traceback.print_exc()
    
    run_button.on_click(on_run_click)
    
    # Create description of methods
    method_info = widgets.HTML(
        """
        <h3>Photobleaching Correction Methods:</h3>
        <ul>
            <li><strong>Polynomial Detrend</strong>: Fits a polynomial to mean intensities and normalizes. Good for simple bleaching patterns.</li>
            <li><strong>Exponential Decay</strong>: Models bleaching as an exponential decay. Good for typical photobleaching.</li>
            <li><strong>Bi-Exponential</strong>: Uses two exponential components for complex decay patterns.</li>
            <li><strong>Adaptive Percentile</strong>: Uses a sliding window percentile approach. More adaptive to complex patterns.</li>
            <li><strong>Two-Stage Detrend</strong>: Applies two polynomial fits in sequence. Good for mixed decay patterns.</li>
        </ul>
        """
    )
    
    # Display widgets
    display(method_info)
    display(widgets.VBox([
        method,
        polynomial_order,
        smoothing_sigma,
        generate_plot,
        run_button,
        info_output
    ]))

# Initialize corrected_data
corrected_data = None

# Run the function
run_photobleaching_correction()

## Gaussian Denoising

Gaussian denoising helps reduce noise in fluorescence images while preserving important features. This is particularly important for accurately identifying ROIs and detecting events in your calcium imaging data.

### How It Works
The Gaussian blur operation applies a weighted average to each pixel, where nearby pixels have more influence than distant ones. The weighting follows a Gaussian distribution centered at the pixel being processed.

### Key Parameters
- **Kernel Size**: Controls the size of the filter window (must be odd). Larger values remove more noise but may blur important details.
- **Sigma**: Controls the width of the Gaussian distribution. Higher values produce more blurring and stronger noise reduction.

### Finding Optimal Values
The best parameters balance noise reduction and feature preservation:
- For noisy recordings, try larger kernel sizes (7-11) and higher sigma values (2-4)
- For cleaner recordings, use smaller kernels (3-5) and lower sigma values (0.5-1.5)

Use the interactive tool below to experiment with different settings and observe their effects.

In [None]:
# Import necessary libraries
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import display

# Import our custom visualization module
from modules.visualization_helpers import run_denoising_visualization

# Setup configuration dictionary
if 'config' not in globals():
    config = {
        'preprocessing': {
            'denoise': {
                'enabled': False,
                'method': 'gaussian',
                'params': {'ksize': (5, 5), 'sigmaX': 1.5}
            }
        }
    }

# @title Advanced Denoising with Contour Plots {display-mode: "form"}
def run_denoising():
    """Run interactive denoising visualization with contour plots"""
    global image_data, config, corrected_data
    
    # Use corrected data if available, otherwise use original image data
    data_to_use = corrected_data if 'corrected_data' in globals() and corrected_data is not None else image_data
    
    if data_to_use is None:
        print("Please load image data first")
        return
    
    if 'config' not in globals() or config is None:
        print("Please load configuration first")
        return
    
    # Create the visualization
    denoising_app = run_denoising_visualization(data_to_use, config, logger)
    
    if denoising_app is not None:
        display(denoising_app)
    else:
        print("Failed to create denoising visualization.")
        print("Make sure HoloViews and Panel are installed:")
        print("pip install holoviews panel param bokeh")

# Run the function
run_denoising()

## Background Removal

In [None]:
# @title Background Removal {display-mode: "form"}
def run_background_removal():
    """Run interactive background removal visualization"""
    global image_data, config, corrected_data
    
    # Use corrected data if available, otherwise use original image data
    data_to_use = corrected_data if 'corrected_data' in globals() and corrected_data is not None else image_data
    
    if data_to_use is None:
        print("Please load image data first")
        return
    
    if 'config' not in globals() or config is None:
        print("Please load configuration first")
        return
    
    # Create a custom implementation of background removal visualization
    # This avoids compatibility issues with HoloViews parameter naming
    
    # Get method options
    method = widgets.Dropdown(
        options=[
            ('Uniform', 'uniform'),
            ('Tophat', 'tophat')
        ],
        value='uniform',
        description='Method:',
        style={'description_width': 'initial'}
    )
    
    # Window size parameter
    window_size = widgets.IntSlider(
        value=7,
        min=3,
        max=99,
        step=2,
        description='Window Size:',
        style={'description_width': 'initial'}
    )
    
    # Output for displaying results
    output = widgets.Output()
    
    # Run button
    run_button = widgets.Button(
        description='Apply Background Removal',
        button_style='success',
        tooltip='Click to run background removal'
    )
    
    def on_button_click(b):
        with output:
            clear_output()
            print(f"Applying {method.value} background removal with window size {window_size.value}...")
            
            try:
                # Get the selected frame
                frame_idx = 10  # Use frame 10 as example
                if frame_idx >= data_to_use.shape[0]:
                    frame_idx = data_to_use.shape[0] // 2
                
                sample_frame = data_to_use[frame_idx].copy()
                
                # For tophat, create disk element
                selem = None
                if method.value == 'tophat':
                    # Use skimage's disk function if available
                    try:
                        from skimage.morphology import disk
                        selem = disk(window_size.value)
                    except ImportError:
                        print("skimage.morphology.disk not available. Using default structuring element.")
                
                # Apply background removal
                import time
                start_time = time.time()
                
                # Import function from local module to avoid namespace issues
                from modules.visualization_helpers import remove_background_perframe
                
                bg_removed_frame = remove_background_perframe(sample_frame, method.value, 
                                                             window_size.value, selem)
                
                process_time = time.time() - start_time
                
                # Calculate difference
                diff = np.abs(sample_frame - bg_removed_frame)
                
                # Get a central row for profile
                center_row = sample_frame.shape[0] // 2
                original_profile = sample_frame[center_row, :]
                bg_removed_profile = bg_removed_frame[center_row, :]
                
                # Plot the results
                fig = plt.figure(figsize=(15, 10))
                
                # Original frame
                ax1 = fig.add_subplot(231)
                ax1.imshow(sample_frame, cmap='gray')
                ax1.set_title('Original Frame')
                ax1.axis('off')
                
                # Background removed
                ax2 = fig.add_subplot(232)
                ax2.imshow(bg_removed_frame, cmap='gray')
                ax2.set_title(f'{method.value.title()} Background Removal\nWindow: {window_size.value}, Time: {process_time:.3f}s')
                ax2.axis('off')
                
                # Difference
                ax3 = fig.add_subplot(233)
                ax3.imshow(diff, cmap='gray') #previously 'hot'
                ax3.set_title('Difference (Removed Background)')
                ax3.axis('off')
                
                # Intensity profile
                ax4 = fig.add_subplot(212)
                ax4.plot(original_profile, 'b-', alpha=0.8, label='Original')
                ax4.plot(bg_removed_profile, 'r-', alpha=0.8, label='Background Removed')
                ax4.set_title('Intensity Profile (Center Row)')
                ax4.set_xlabel('Pixel')
                ax4.set_ylabel('Intensity')
                ax4.grid(True, alpha=0.3)
                ax4.legend()
                
                plt.tight_layout()
                plt.show()
                
                # Update config with current values
                if 'background_removal' not in config['preprocessing']:
                    config['preprocessing']['background_removal'] = {}
                
                config['preprocessing']['background_removal']['enabled'] = True
                config['preprocessing']['background_removal']['method'] = method.value
                config['preprocessing']['background_removal']['window_size'] = window_size.value
                
                print(f"Background removal complete. Parameters added to configuration.")
                
            except Exception as e:
                print(f"Error applying {method.value} background removal: {str(e)}")
                import traceback
                traceback.print_exc()
    
    run_button.on_click(on_button_click)
    
    # Create method descriptions
    method_descriptions = widgets.HTML(
        """
        <h3>Background Removal Methods:</h3>
        <ul>
            <li><strong>Uniform</strong>: Removes background by subtracting a uniformly blurred version of the image. Good for removing slow variations in background intensity.</li>
            <li><strong>Tophat</strong>: Uses morphological top-hat transformation to enhance small features while removing background. Excellent for extracting small bright features from varying backgrounds.</li>
        </ul>
        <p><em>Adjust the window size parameter to control how much background is removed. Larger values remove more broadly distributed background structures.</em></p>
        """
    )
    
    # Display everything
    display(method_descriptions)
    display(widgets.VBox([method, window_size, run_button]))
    display(output)

# Run the function
run_background_removal()

## Extract ROIs

In [None]:
# @title Extract ROIs {display-mode: "form"}

def extract_rois():
    """Extract ROIs from a zip file"""
    global corrected_data, image_data, config, roi_masks, roi_data, tif_path
    
    # Use corrected data if available, otherwise use original data
    data_to_use = corrected_data if corrected_data is not None else image_data
    
    if data_to_use is None:
        print("Please load or process image data first")
        return
    
    if 'config' not in globals() or config is None:
        print("Please load configuration first")
        return
    
    # Get list of ZIP files in input directory
    zip_files = sorted([f for f in os.listdir(args.input_dir) if f.endswith('.zip')])
    
    if not zip_files:
        print(f"No ZIP files found in {args.input_dir}")
        return
    
    # Create file selection widget
    zip_selector = widgets.Dropdown(
        options=zip_files,
        description='Select ROI file:',
        style={'description_width': 'initial'},
        layout=widgets.Layout(width='80%')
    )
    
    run_button = widgets.Button(
        description='Extract ROIs',
        button_style='warning',
        tooltip='Click to extract ROIs'
    )
    
    info_output = widgets.Output()
    
    def on_run_click(b):
        with info_output:
            clear_output()
            print(f"Extracting ROIs from {zip_selector.value}...")
            
            try:
                # Set path to ROI zip file
                roi_path = os.path.join(args.input_dir, zip_selector.value)
                
                # Create temporary output directory
                temp_output_dir = os.path.join(args.output_dir, "temp_output")
                os.makedirs(temp_output_dir, exist_ok=True)
                
                # Get image shape
                image_shape = data_to_use.shape[1:]
                
                # Extract ROIs
                global roi_masks, roi_data
                roi_masks, roi_data = extract_roi_fluorescence(
                    roi_path,
                    data_to_use,
                    image_shape,
                    temp_output_dir,
                    config["roi_processing"],
                    logger
                )
                
                print(f"Extracted {len(roi_masks)} ROIs")
                
                # Display ROI visualization
                # Create a composite mask with all ROIs
                composite_mask = np.zeros(image_shape, dtype=np.uint8)
                
                # Assign different colors to each ROI
                for i, mask in enumerate(roi_masks):
                    # Add ROI to composite with unique intensity
                    composite_mask[mask] = i + 1
                
                # Create a color-coded visualization
                vis_image = np.zeros((*image_shape, 3), dtype=np.uint8)
                
                # Generate random colors for each ROI
                np.random.seed(0)  # For reproducibility
                colors = np.random.randint(50, 255, size=(len(roi_masks), 3))
                
                # Apply colors to ROIs
                for i in range(len(roi_masks)):
                    roi_indices = composite_mask == (i + 1)
                    vis_image[roi_indices] = colors[i]
                
                # Display ROI image overlaid on the first frame
                plt.figure(figsize=(10, 10))
                
                # First frame as background
                plt.imshow(data_to_use[0], cmap='gray')
                
                # Overlay ROIs with transparency
                plt.imshow(vis_image, alpha=0.5)
                
                plt.title(f"Extracted ROIs ({len(roi_masks)} total)")
                plt.axis('off')
                plt.show()
                
                # Plot some sample traces
                n_samples = min(5, len(roi_masks))
                plt.figure(figsize=(12, 2*n_samples))
                
                for i in range(n_samples):
                    plt.subplot(n_samples, 1, i+1)
                    plt.plot(roi_data[i])
                    plt.title(f"ROI {i+1} Trace")
                    plt.xlabel("Frame")
                    plt.ylabel("Fluorescence")
                    plt.grid(True, alpha=0.3)
                
                plt.tight_layout()
                plt.show()
                
                # Display next steps message
                print("\nROI extraction complete. You can now proceed to background subtraction and analysis.")
                
            except Exception as e:
                print(f"Error extracting ROIs: {str(e)}")
                import traceback
                traceback.print_exc()
    
    run_button.on_click(on_run_click)
    
    # Display widgets
    display(widgets.VBox([
        zip_selector,
        run_button,
        info_output
    ]))

# Initialize ROI variables
roi_masks = None
roi_data = None

# Run the function
extract_rois()

## Background Subtraction

Background subtraction removes non-specific fluorescence signals that can mask the true neuronal activity. This improves signal-to-noise ratio and helps detect true calcium events.

### Available Methods
1. **Darkest Pixels**: Uses the darkest regions of the image (likely non-cellular regions) to estimate background
2. **ROI Periphery**: Estimates background from the area surrounding each ROI
3. **Global Background**: Uses a global approach to identify and subtract background signal

### Key Parameters
- **Percentile (%)**: For darkest pixels method, determines how much of the image is considered background
- **Min Background Area**: Minimum size of the area to be considered a valid background region
- **Median Filter Size**: Size of median filter for noise reduction in background mask
- **Periphery Size**: For ROI periphery method, size of the expansion around ROIs

### Finding Optimal Values
- For darkest pixels, start with a low percentile (0.1-1%) and increase if needed
- Larger median filter sizes produce smoother background but may miss spatial variations
- For ROI periphery, larger values capture more surrounding tissue but risk including other cells

The interactive tool below allows you to visualize how different background subtraction methods and parameters affect your fluorescence traces.

In [None]:
# @title Background Subtraction {display-mode: "form"}
def run_background_subtraction():
    """Interactive background subtraction for ROI traces"""
    global corrected_data, roi_masks, roi_data, config
    
    if 'corrected_data' not in globals() or corrected_data is None:
        print("Please run photobleaching correction first")
        return
    
    if 'roi_masks' not in globals() or roi_masks is None or 'roi_data' not in globals() or roi_data is None:
        print("Please extract ROIs first")
        return
    
    if 'config' not in globals() or config is None:
        print("Please load configuration first")
        return
    
    # Get background subtraction parameters
    method = widgets.Dropdown(
        options=[
            ('ROI Periphery', 'roi_periphery'),
            ('Darkest Pixels', 'darkest_pixels'),
            ('Global Background', 'global_background'),
            ('Lowpass Filter', 'lowpass_filter')
        ],
        value=config['roi_processing'].get('background', {}).get('method', 'roi_periphery'),
        description='Method:',
        style={'description_width': 'initial'}
    )
    
    # Parameters for ROI periphery
    periphery_size = widgets.IntSlider(
        value=config['roi_processing'].get('background', {}).get('periphery_size', 2),
        min=1,
        max=10,
        step=1,
        description='Periphery Size:',
        style={'description_width': 'initial'}
    )
    
    # Parameters for darkest pixels
    percentile = widgets.FloatSlider(
        value=config['roi_processing'].get('background', {}).get('percentile', 0.1),
        min=0.01,
        max=1.0,
        step=0.01,
        description='Percentile:',
        style={'description_width': 'initial'}
    )
    
    median_filter_size = widgets.IntSlider(
        value=config['roi_processing'].get('background', {}).get('median_filter_size', 3),
        min=0,
        max=19,
        step=1,
        description='Median Filter:',
        style={'description_width': 'initial'}
    )
    
    dilation_size = widgets.IntSlider(
        value=config['roi_processing'].get('background', {}).get('dilation_size', 2),
        min=0,
        max=10,
        step=1,
        description='Dilation Size:',
        style={'description_width': 'initial'}
    )
    
    # Parameters for global background
    min_background_area = widgets.IntSlider(
        value=config['roi_processing'].get('background', {}).get('min_background_area', 200),
        min=50,
        max=1000,
        step=50,
        description='Min Area:',
        style={'description_width': 'initial'}
    )
    
    background_dilation = widgets.IntSlider(
        value=config['roi_processing'].get('background', {}).get('background_dilation', 2),
        min=0,
        max=10,
        step=1,
        description='BG Dilation:',
        style={'description_width': 'initial'}
    )
    
    # Parameters for lowpass filter
    cutoff_freq = widgets.FloatSlider(
        value=config['roi_processing'].get('background', {}).get('cutoff_freq', 0.001),
        min=0.0001,
        max=0.01,
        step=0.0001,
        description='Cutoff Freq:',
        style={'description_width': 'initial'}
    )
    
    filter_order = widgets.IntSlider(
        value=config['roi_processing'].get('background', {}).get('filter_order', 2),
        min=1,
        max=6,
        step=1,
        description='Filter Order:',
        style={'description_width': 'initial'}
    )
    
    # Add slope correction checkbox
    apply_slope_correction = widgets.Checkbox(
        value=True,
        description='Apply Slope Correction After Background Subtraction',
        style={'description_width': 'initial'}
    )
    
    # Create parameter containers for each method
    roi_periphery_params = widgets.VBox([periphery_size])
    darkest_pixels_params = widgets.VBox([percentile, median_filter_size, dilation_size])
    global_background_params = widgets.VBox([min_background_area, background_dilation])
    lowpass_filter_params = widgets.VBox([cutoff_freq, filter_order])
    
    # Container for method-specific parameters
    params_container = widgets.Output()
    
    # Show parameters based on selected method
    def update_params(change):
        with params_container:
            clear_output()
            if change.new == 'roi_periphery':
                display(roi_periphery_params)
            elif change.new == 'darkest_pixels':
                display(darkest_pixels_params)
            elif change.new == 'global_background':
                display(global_background_params)
            elif change.new == 'lowpass_filter':
                display(lowpass_filter_params)
    
    method.observe(update_params, names='value')
    
    # Display initial parameters
    with params_container:
        if method.value == 'roi_periphery':
            display(roi_periphery_params)
        elif method.value == 'darkest_pixels':
            display(darkest_pixels_params)
        elif method.value == 'global_background':
            display(global_background_params)
        elif method.value == 'lowpass_filter':
            display(lowpass_filter_params)
    
    # Create a temporary output directory for intermediate results
    temp_output_dir = os.path.join(args.output_dir, 'temp_bg_subtraction')
    os.makedirs(temp_output_dir, exist_ok=True)
    
    run_button = widgets.Button(
        description='Run Background Subtraction',
        button_style='success',
        tooltip='Apply background subtraction to ROI traces'
    )
    
    info_output = widgets.Output()
    
    def on_run_click(b):
        with info_output:
            clear_output()
            print(f"Running background subtraction with method: {method.value}")
            
            # Update config with current values
            if 'background' not in config['roi_processing']:
                config['roi_processing']['background'] = {}
            
            config['roi_processing']['background']['method'] = method.value
            
            if method.value == 'roi_periphery':
                config['roi_processing']['background']['periphery_size'] = periphery_size.value
            elif method.value == 'darkest_pixels':
                config['roi_processing']['background']['percentile'] = percentile.value
                config['roi_processing']['background']['median_filter_size'] = median_filter_size.value
                config['roi_processing']['background']['dilation_size'] = dilation_size.value
            elif method.value == 'global_background':
                config['roi_processing']['background']['min_background_area'] = min_background_area.value
                config['roi_processing']['background']['background_dilation'] = background_dilation.value
            elif method.value == 'lowpass_filter':
                config['roi_processing']['background']['cutoff_freq'] = cutoff_freq.value
                config['roi_processing']['background']['filter_order'] = filter_order.value
            
            try:
                # Apply background subtraction first
                global bg_corrected_data
                
                # Set save_intermediate_traces to true temporarily
                original_save_setting = config['roi_processing'].get('save_intermediate_traces', False)
                config['roi_processing']['save_intermediate_traces'] = True
                
                if method.value == 'global_background':
                    bg_subtracted_data = subtract_global_background(
                        corrected_data,
                        roi_data,
                        roi_masks,
                        config['roi_processing']['background'],
                        logger,
                        output_dir=temp_output_dir
                    )
                else:
                    bg_subtracted_data = subtract_background(
                        corrected_data,
                        roi_data,
                        roi_masks,
                        config['roi_processing']['background'],
                        logger,
                        output_dir=temp_output_dir
                    )
                
                # Store a copy of the background-subtracted data before slope correction
                bg_subtracted_only = bg_subtracted_data.copy()
                
                # Apply slope correction if enabled
                if apply_slope_correction.value:
                    print("\nApplying slope correction to background-subtracted traces...")
                    
                    # Import necessary functions
                    import numpy as np
                    from scipy.signal import find_peaks
                    
                    # Get the condition from metadata if available
                    condition = metadata.get("condition", "unknown") if 'metadata' in globals() else "unknown"
                    print(f"Using condition: {condition}")
                    
                    # Get photobleaching correction settings from config
                    pb_settings = config.get("analysis", {}).get("photobleaching_correction", {})
                    
                    # Default extended frames
                    default_extended = pb_settings.get("default_extended_frames", [0, 200])
                    extended_frames = default_extended.copy()
                    
                    # Default prominence
                    prominence = pb_settings.get("prominence", 0.05)
                    
                    # Apply condition-specific settings if available
                    if condition and "condition_specific" in pb_settings and condition in pb_settings["condition_specific"]:
                        condition_config = pb_settings["condition_specific"][condition]
                        
                        if "extended_frames" in condition_config:
                            extended_frames = condition_config["extended_frames"]
                            print(f"Using condition-specific range {extended_frames} for {condition} photobleaching correction")
                        
                        if "prominence" in condition_config:
                            prominence = condition_config["prominence"]
                            print(f"Using condition-specific prominence {prominence} for {condition} peak detection")
                    
                    # Initialize array for slope-corrected data
                    bg_corrected_data = np.zeros_like(bg_subtracted_data)
                    
                    # Process each ROI
                    for i in range(len(bg_subtracted_data)):
                        trace = bg_subtracted_data[i]
                        
                        # Extend baseline window to specified frames or use all available frames if less
                        ex_frames = [extended_frames[0], min(extended_frames[1], len(trace)-1)]
                        baseline_window = trace[ex_frames[0]:ex_frames[1]+1]
                        baseline_x = np.arange(len(baseline_window))
                        
                        # Find peaks in the baseline window to exclude them
                        peaks, _ = find_peaks(baseline_window, prominence=prominence)
                        
                        # Create mask to exclude peaks and their surrounding frames (±2 frames)
                        mask = np.ones(len(baseline_window), dtype=bool)
                        for peak in peaks:
                            start = max(0, peak - 2)
                            end = min(len(baseline_window), peak + 3)  # +3 because slicing is exclusive of end
                            mask[start:end] = False
                        
                        # If all frames would be excluded, keep at least half of them
                        if not np.any(mask) and len(baseline_window) > 0:
                            print(f"ROI {i+1}: All baseline frames would be excluded. Keeping 50% of frames.")
                            mask = np.ones(len(baseline_window), dtype=bool)
                            for peak in peaks:
                                mask[peak] = False  # Just exclude the exact peak
                        
                        # Fit a line to the non-peak frames to estimate photobleaching slope
                        if np.sum(mask) > 1:  # Need at least 2 points for linear regression
                            x_fit = baseline_x[mask]
                            y_fit = baseline_window[mask]
                            
                            # Use polyfit for linear regression: y = mx + b
                            m, b = np.polyfit(x_fit, y_fit, 1)
                            
                            # If slope is essentially flat, don't correct
                            if abs(m) < 1e-5:
                                print(f"ROI {i+1}: No significant trend detected. Slope is nearly flat.")
                                bg_corrected_data[i] = trace.copy()
                            else:
                                # Calculate the trend using the estimated slope and intercept
                                x_all = np.arange(len(trace))
                                trend = m * x_all + b
                                
                                # Calculate the mean of the baseline points used for fitting
                                baseline_mean = np.mean(baseline_window[mask])
                                
                                # Create a flat baseline at the mean level
                                flat_baseline = np.ones(len(trace)) * baseline_mean
                                
                                # Replace the trended baseline with a flat baseline
                                # Preserve the fluctuations around the trend line
                                corrected_trace = trace - trend + flat_baseline
                                
                                if m < 0:
                                    print(f"ROI {i+1}: Negative slope detected ({m:.6f}). Applied correction to flatten baseline.")
                                else:
                                    print(f"ROI {i+1}: Positive slope detected ({m:.6f}). Applied correction to flatten baseline.")
                                    
                                bg_corrected_data[i] = corrected_trace
                        else:
                            print(f"ROI {i+1}: Not enough non-peak points to estimate trend. Using background-subtracted trace without slope correction.")
                            bg_corrected_data[i] = trace.copy()
                    
                    print("\nBackground subtraction and slope correction completed successfully.")
                else:
                    # If slope correction is disabled, use the background-subtracted data directly
                    bg_corrected_data = bg_subtracted_data
                    print("\nBackground subtraction completed successfully (slope correction skipped).")
                
                # Restore original setting
                config['roi_processing']['save_intermediate_traces'] = original_save_setting
                
                # Plot comparison of traces with separate y-axes
                plt.figure(figsize=(12, 10))
                
                # Choose a few ROIs to plot
                sample_rois = min(5, len(roi_data))
                for i in range(sample_rois):
                    # Create a subplot with two y-axes
                    fig, ax1 = plt.subplots(figsize=(10, 3))
                    ax2 = ax1.twinx()
                    
                    # Plot before and after on different y-axes
                    line1 = ax1.plot(roi_data[i], 'r-', alpha=0.7, label='Original')
                    line2 = ax2.plot(bg_corrected_data[i], 'g-', alpha=0.7, label='Processed')
                    
                    # Set labels and title
                    ax1.set_title(f'ROI {i+1} - Before and After Processing')
                    ax1.set_xlabel('Frame')
                    ax1.set_ylabel('Original', color='red')
                    ax2.set_ylabel('Processed', color='green')
                    
                    # Color the y-axis tick labels
                    ax1.tick_params(axis='y', labelcolor='red')
                    ax2.tick_params(axis='y', labelcolor='green')
                    
                    # Add legend
                    lines = line1 + line2
                    labels = ['Original', 'Processed']
                    plt.legend(lines, labels, loc='best')
                    
                    plt.tight_layout()
                    plt.show()
                
                # If slope correction was applied, show before and after slope correction
                if apply_slope_correction.value:
                    for i in range(sample_rois):
                        fig, ax = plt
                # If slope correction was applied, show before and after slope correction
                if apply_slope_correction.value:
                    for i in range(sample_rois):
                        fig, ax = plt.subplots(figsize=(10, 3))
                        
                        # Plot background-subtracted before and after slope correction
                        ax.plot(bg_subtracted_only[i], 'b-', alpha=0.7, label='After BG Subtraction')
                        ax.plot(bg_corrected_data[i], 'g-', alpha=0.7, label='After Slope Correction')
                        
                        # Set labels and title
                        ax.set_title(f'ROI {i+1} - Effect of Slope Correction')
                        ax.set_xlabel('Frame')
                        ax.set_ylabel('Fluorescence')
                        ax.legend(loc='best')
                        
                        plt.tight_layout()
                        plt.show()

                # Create a side-by-side comparison of all traces
                plt.figure(figsize=(16, 8))
                
                # Before
                plt.subplot(1, 2, 1)
                for i in range(min(10, len(roi_data))):
                    plt.plot(roi_data[i], alpha=0.7)
                plt.title('Original Traces')
                plt.xlabel('Frame')
                plt.ylabel('Fluorescence')
                plt.grid(True, alpha=0.3)
                
                # After
                plt.subplot(1, 2, 2)
                for i in range(min(10, len(bg_corrected_data))):
                    plt.plot(bg_corrected_data[i], alpha=0.7)
                plt.title('Processed Traces')
                plt.xlabel('Frame')
                plt.ylabel('Fluorescence')
                plt.grid(True, alpha=0.3)
                
                plt.tight_layout()
                plt.show()
                
                print("Processing complete! The corrected traces are now available for further analysis.")
                
            except Exception as e:
                print(f"Error during processing: {str(e)}")
                import traceback
                traceback.print_exc()
    
    run_button.on_click(on_run_click)
    
    # Create description of background subtraction methods
    bg_info = widgets.HTML(
        """
        <h3>Background Subtraction Methods:</h3>
        <ul>
            <li><strong>ROI Periphery</strong>: Uses the area surrounding each ROI as local background.</li>
            <li><strong>Darkest Pixels</strong>: Uses the darkest pixels in the image as global background.</li>
            <li><strong>Global Background</strong>: Identifies a region in the image with low intensity as background.</li>
            <li><strong>Lowpass Filter</strong>: Uses a low-pass filter to separate fast signals from slow background changes.</li>
        </ul>
        """
    )
    
    # Display widgets
    display(bg_info)
    display(widgets.VBox([
        method,
        params_container,
        apply_slope_correction,
        run_button,
        info_output
    ]))

# Initialize the background-corrected data variable
bg_corrected_data = None

# Run the function
run_background_subtraction()

In [None]:
# @title Background Subtraction {display-mode: "form"}
def run_background_subtraction():
    """Interactive background subtraction for ROI traces"""
    global corrected_data, roi_masks, roi_data, config
    
    if 'corrected_data' not in globals() or corrected_data is None:
        print("Please run photobleaching correction first")
        return
    
    if 'roi_masks' not in globals() or roi_masks is None or 'roi_data' not in globals() or roi_data is None:
        print("Please extract ROIs first")
        return
    
    if 'config' not in globals() or config is None:
        print("Please load configuration first")
        return
    
    # Get background subtraction parameters
    method = widgets.Dropdown(
        options=[
            ('ROI Periphery', 'roi_periphery'),
            ('Darkest Pixels', 'darkest_pixels'),
            ('Global Background', 'global_background'),
            ('Lowpass Filter', 'lowpass_filter')
        ],
        value=config['roi_processing'].get('background', {}).get('method', 'roi_periphery'),
        description='Method:',
        style={'description_width': 'initial'}
    )
    
    # Parameters for ROI periphery
    periphery_size = widgets.IntSlider(
        value=config['roi_processing'].get('background', {}).get('periphery_size', 2),
        min=1,
        max=10,
        step=1,
        description='Periphery Size:',
        style={'description_width': 'initial'}
    )
    
    # Parameters for darkest pixels
    percentile = widgets.FloatSlider(
        value=config['roi_processing'].get('background', {}).get('percentile', 0.1),
        min=0.01,
        max=5.0,
        step=0.01,
        description='Percentile:',
        style={'description_width': 'initial'}
    )
    
    median_filter_size = widgets.IntSlider(
        value=config['roi_processing'].get('background', {}).get('median_filter_size', 3),
        min=1,
        max=19,
        step=2,
        description='Median Filter:',
        style={'description_width': 'initial'}
    )
    
    dilation_size = widgets.IntSlider(
        value=config['roi_processing'].get('background', {}).get('dilation_size', 2),
        min=0,
        max=10,
        step=1,
        description='Dilation Size:',
        style={'description_width': 'initial'}
    )
    
    # Parameters for global background
    min_background_area = widgets.IntSlider(
        value=config['roi_processing'].get('background', {}).get('min_background_area', 200),
        min=50,
        max=1000,
        step=50,
        description='Min Area:',
        style={'description_width': 'initial'}
    )
    
    background_dilation = widgets.IntSlider(
        value=config['roi_processing'].get('background', {}).get('background_dilation', 2),
        min=0,
        max=10,
        step=1,
        description='BG Dilation:',
        style={'description_width': 'initial'}
    )
    
    # Parameters for lowpass filter
    cutoff_freq = widgets.FloatSlider(
        value=config['roi_processing'].get('background', {}).get('cutoff_freq', 0.001),
        min=0.0001,
        max=0.01,
        step=0.0001,
        description='Cutoff Freq:',
        style={'description_width': 'initial'}
    )
    
    filter_order = widgets.IntSlider(
        value=config['roi_processing'].get('background', {}).get('filter_order', 2),
        min=1,
        max=6,
        step=1,
        description='Filter Order:',
        style={'description_width': 'initial'}
    )
    
    # Create parameter containers for each method
    roi_periphery_params = widgets.VBox([periphery_size])
    darkest_pixels_params = widgets.VBox([percentile, median_filter_size, dilation_size])
    global_background_params = widgets.VBox([min_background_area, background_dilation])
    lowpass_filter_params = widgets.VBox([cutoff_freq, filter_order])
    
    # Container for method-specific parameters
    params_container = widgets.Output()
    
    # Show parameters based on selected method
    def update_params(change):
        with params_container:
            clear_output()
            if change.new == 'roi_periphery':
                display(roi_periphery_params)
            elif change.new == 'darkest_pixels':
                display(darkest_pixels_params)
            elif change.new == 'global_background':
                display(global_background_params)
            elif change.new == 'lowpass_filter':
                display(lowpass_filter_params)
    
    method.observe(update_params, names='value')
    
    # Display initial parameters
    with params_container:
        if method.value == 'roi_periphery':
            display(roi_periphery_params)
        elif method.value == 'darkest_pixels':
            display(darkest_pixels_params)
        elif method.value == 'global_background':
            display(global_background_params)
        elif method.value == 'lowpass_filter':
            display(lowpass_filter_params)
    
    # Create a temporary output directory for intermediate results
    temp_output_dir = os.path.join(args.output_dir, 'temp_bg_subtraction')
    os.makedirs(temp_output_dir, exist_ok=True)
    
    run_button = widgets.Button(
        description='Run Background Subtraction',
        button_style='success',
        tooltip='Apply background subtraction to ROI traces'
    )
    
    info_output = widgets.Output()
    
    def on_run_click(b):
        with info_output:
            clear_output()
            print(f"Running background subtraction with method: {method.value}")
            
            # Update config with current values
            if 'background' not in config['roi_processing']:
                config['roi_processing']['background'] = {}
            
            config['roi_processing']['background']['method'] = method.value
            
            if method.value == 'roi_periphery':
                config['roi_processing']['background']['periphery_size'] = periphery_size.value
            elif method.value == 'darkest_pixels':
                config['roi_processing']['background']['percentile'] = percentile.value
                config['roi_processing']['background']['median_filter_size'] = median_filter_size.value
                config['roi_processing']['background']['dilation_size'] = dilation_size.value
            elif method.value == 'global_background':
                config['roi_processing']['background']['min_background_area'] = min_background_area.value
                config['roi_processing']['background']['background_dilation'] = background_dilation.value
            elif method.value == 'lowpass_filter':
                config['roi_processing']['background']['cutoff_freq'] = cutoff_freq.value
                config['roi_processing']['background']['filter_order'] = filter_order.value
            
            try:
                # Apply background subtraction
                global bg_corrected_data
                
                # Set save_intermediate_traces to true temporarily
                original_save_setting = config['roi_processing'].get('save_intermediate_traces', False)
                config['roi_processing']['save_intermediate_traces'] = True
                
                if method.value == 'global_background':
                    bg_corrected_data = subtract_global_background(
                        corrected_data,
                        roi_data,
                        roi_masks,
                        config['roi_processing']['background'],
                        logger,
                        output_dir=temp_output_dir
                    )
                else:
                    bg_corrected_data = subtract_background(
                        corrected_data,
                        roi_data,
                        roi_masks,
                        config['roi_processing']['background'],
                        logger,
                        output_dir=temp_output_dir
                    )
                
                # Restore original setting
                config['roi_processing']['save_intermediate_traces'] = original_save_setting
                
                print("Background subtraction completed successfully.")
                
                # Plot comparison of traces before and after background subtraction with separate y-axes
                plt.figure(figsize=(12, 10))
                
                # Choose a few ROIs to plot
                sample_rois = min(5, len(roi_data))
                for i in range(sample_rois):
                    # Create a subplot with two y-axes
                    fig, ax1 = plt.subplots(figsize=(10, 3))
                    ax2 = ax1.twinx()
                    
                    # Plot before and after on different y-axes
                    line1 = ax1.plot(roi_data[i], 'r-', alpha=0.7, label='Before')
                    line2 = ax2.plot(bg_corrected_data[i], 'g-', alpha=0.7, label='After')
                    
                    # Set labels and title
                    ax1.set_title(f'ROI {i+1} - Before and After Background Subtraction')
                    ax1.set_xlabel('Frame')
                    ax1.set_ylabel('Before Correction', color='red')
                    ax2.set_ylabel('After Correction', color='green')
                    
                    # Color the y-axis tick labels
                    ax1.tick_params(axis='y', labelcolor='red')
                    ax2.tick_params(axis='y', labelcolor='green')
                    
                    # Add legend
                    lines = line1 + line2
                    labels = ['Before', 'After']
                    plt.legend(lines, labels, loc='best')
                    
                    plt.tight_layout()
                    plt.show()

                # Create a side-by-side comparison of all traces
                plt.figure(figsize=(16, 8))
                
                # Before
                plt.subplot(1, 2, 1)
                for i in range(min(10, len(roi_data))):
                    plt.plot(roi_data[i], alpha=0.7)
                plt.title('Traces Before Background Subtraction')
                plt.xlabel('Frame')
                plt.ylabel('Fluorescence')
                plt.grid(True, alpha=0.3)
                
                # After
                plt.subplot(1, 2, 2)
                for i in range(min(10, len(bg_corrected_data))):
                    plt.plot(bg_corrected_data[i], alpha=0.7)
                plt.title('Traces After Background Subtraction')
                plt.xlabel('Frame')
                plt.ylabel('Fluorescence')
                plt.grid(True, alpha=0.3)
                
                plt.tight_layout()
                plt.show()
                
                print("Background subtraction complete! The corrected traces are now available for further analysis.")
                
            except Exception as e:
                print(f"Error during background subtraction: {str(e)}")
                import traceback
                traceback.print_exc()
    
    run_button.on_click(on_run_click)
    
    # Create description of background subtraction methods
    bg_info = widgets.HTML(
        """
        <h3>Background Subtraction Methods:</h3>
        <ul>
            <li><strong>ROI Periphery</strong>: Uses the area surrounding each ROI as local background.</li>
            <li><strong>Darkest Pixels</strong>: Uses the darkest pixels in the image as global background.</li>
            <li><strong>Global Background</strong>: Identifies a region in the image with low intensity as background.</li>
            <li><strong>Lowpass Filter</strong>: Uses a low-pass filter to separate fast signals from slow background changes.</li>
        </ul>
        """
    )
    
    # Display widgets
    display(bg_info)
    display(widgets.VBox([
        method,
        params_container,
        run_button,
        info_output
    ]))

# Initialize the background-corrected data variable
bg_corrected_data = None

# Run the function
run_background_subtraction()

## ROI Processing with PNR Refinement

Peak-to-Noise Ratio (PNR) refinement helps identify and select ROIs with strong neuronal signals while filtering out ROIs with poor signal quality. This is essential for accurate calcium imaging analysis.

### How PNR Refinement Works
1. **Frequency Separation**: The fluorescence trace is split into signal (low-frequency) and noise (high-frequency) components
2. **Signal Smoothing**: Optional smoothing can be applied to the signal component to reduce fluctuations
3. **PNR Calculation**: The ratio between peak signal value and noise standard deviation is calculated
4. **Thresholding**: ROIs with PNR values below the threshold are excluded from analysis

### Key Parameters
- **Noise Frequency Cutoff**: Determines the boundary between signal and noise components (0.01-0.2 Hz)
- **Percentile Threshold**: Percentile used to determine peak signal value (90-99.9%)
- **Trace Smoothing**: Window size for signal smoothing (0 = no smoothing)
- **Min PNR**: Minimum PNR threshold for accepting an ROI (typically 5-10)

### Finding Optimal Values
- Higher PNR thresholds produce more reliable results but may exclude valid ROIs
- The ideal noise frequency cutoff depends on the temporal characteristics of your signal
- Smoothing can help stabilize PNR values but may mask transient events

The tool below allows you to visualize and adjust these parameters to optimize ROI selection.

In [None]:
# @title PNR-Based ROI Refinement {display-mode: "form"}
import random

def run_pnr_refinement():
    """Interactive PNR-based refinement of ROIs"""
    global roi_data, roi_masks, bg_corrected_data, config
    
    # Use background-corrected data if available, otherwise use original ROI data
    data_to_use = bg_corrected_data if 'bg_corrected_data' in globals() and bg_corrected_data is not None else roi_data
    
    if 'roi_masks' not in globals() or roi_masks is None or data_to_use is None:
        print("Please extract ROIs and perform background subtraction first")
        return
    
    if 'config' not in globals() or config is None:
        print("Please load configuration first")
        return
    
    # Get PNR refinement parameters
    noise_freq_cutoff = widgets.FloatSlider(
        value=config['roi_processing'].get('pnr_refinement', {}).get('noise_freq_cutoff', 0.03),
        min=0.01,
        max=0.5,
        step=0.01,
        description='Noise Cutoff:',
        style={'description_width': 'initial'}
    )
    
    min_pnr = widgets.FloatSlider(
        value=config['roi_processing'].get('pnr_refinement', {}).get('min_pnr', 10),
        min=1.0,
        max=20.0,
        step=0.5,
        description='Min PNR:',
        style={'description_width': 'initial'}
    )
    
    percentile_threshold = widgets.FloatSlider(
        value=config['roi_processing'].get('pnr_refinement', {}).get('percentile_threshold', 99),
        min=90.0,
        max=99.9,
        step=0.1,
        description='Percentile:',
        style={'description_width': 'initial'}
    )
    
    trace_smoothing = widgets.IntSlider(
        value=config['roi_processing'].get('pnr_refinement', {}).get('trace_smoothing', 3),
        min=0,
        max=9,
        step=1,
        description='Smoothing:',
        style={'description_width': 'initial'}
    )
    
    auto_determine = widgets.Checkbox(
        value=config['roi_processing'].get('pnr_refinement', {}).get('auto_determine', False),
        description='Auto-Determine Cutoff',
        style={'description_width': 'initial'}
    )
    
    run_button = widgets.Button(
        description='Run PNR Refinement',
        button_style='success',
        tooltip='Apply PNR-based refinement to ROIs'
    )
    
    info_output = widgets.Output()
    
    def on_run_click(b):
        with info_output:
            clear_output()
            print("Running PNR-based ROI refinement...")
            
            # Update config with current values
            if 'pnr_refinement' not in config['roi_processing']:
                config['roi_processing']['pnr_refinement'] = {}
            
            config['roi_processing']['pnr_refinement']['noise_freq_cutoff'] = noise_freq_cutoff.value
            config['roi_processing']['pnr_refinement']['min_pnr'] = min_pnr.value
            config['roi_processing']['pnr_refinement']['percentile_threshold'] = percentile_threshold.value
            config['roi_processing']['pnr_refinement']['trace_smoothing'] = trace_smoothing.value
            config['roi_processing']['pnr_refinement']['auto_determine'] = auto_determine.value
            
            try:
                # Apply PNR refinement
                global refined_masks, refined_traces, pnr_values, diagnostic_info
                
                refined_masks, refined_traces, pnr_values, diagnostic_info = refine_rois_with_pnr(
                    data_to_use,
                    roi_masks,
                    config['roi_processing'],
                    logger
                )
                
                print(f"PNR refinement complete! Kept {len(refined_masks)}/{len(roi_masks)} ROIs.")
                
                # Create visualization
                # Split signal and noise components
                signal_traces, noise_traces = split_signal_noise(data_to_use, noise_freq_cutoff.value, logger)
                
                # Get some representative ROIs to visualize
                n_vis = min(5, len(roi_masks))
                vis_indices = []
                
                # Try to include both kept and discarded ROIs
                kept_indices = diagnostic_info['kept_indices']
                all_indices = list(range(len(roi_masks)))
                discarded_indices = [i for i in all_indices if i not in kept_indices]
                
                # Add some kept ROIs if available
                n_kept_vis = min(3, len(kept_indices))
                if n_kept_vis > 0:
                    vis_indices.extend(kept_indices[:n_kept_vis])
                
                # Add some discarded ROIs if available
                n_discarded_vis = min(2, len(discarded_indices))
                if n_discarded_vis > 0:
                    vis_indices.extend(discarded_indices[:n_discarded_vis])
                
                # If we still need more, add random ones
                if len(vis_indices) < n_vis:
                    remaining = n_vis - len(vis_indices)
                    remaining_indices = [i for i in all_indices if i not in vis_indices]
                    if remaining_indices:
                        vis_indices.extend(random.sample(remaining_indices, min(remaining, len(remaining_indices))))
                
                # Plot signal-noise decomposition for selected ROIs
                plt.figure(figsize=(12, 3 * len(vis_indices)))
                
                for i, idx in enumerate(vis_indices):
                    # Create original trace plot
                    plt.subplot(len(vis_indices), 3, i*3 + 1)
                    plt.plot(data_to_use[idx], 'k-', label=f'Original (ROI {idx+1})')
                    plt.title(f'ROI {idx+1} - Original')
                    plt.ylabel('Fluorescence')
                    plt.grid(True, alpha=0.3)
                    
                    # Create signal plot
                    plt.subplot(len(vis_indices), 3, i*3 + 2)
                    plt.plot(signal_traces[idx], 'g-', label='Signal')
                    is_kept = idx in kept_indices
                    pnr_val = pnr_values[idx] if idx < len(pnr_values) else 0
                    status = "Kept" if is_kept else "Discarded"
                    plt.title(f'Signal Component (PNR: {pnr_val:.2f}) - {status}')
                    plt.grid(True, alpha=0.3)
                    
                    # Create noise plot
                    plt.subplot(len(vis_indices), 3, i*3 + 3)
                    plt.plot(noise_traces[idx], 'r-', label='Noise')
                    plt.title('Noise Component')
                    plt.grid(True, alpha=0.3)
                
                plt.tight_layout()
                plt.show()
                
                # Generate a histogram of PNR values
                plt.figure(figsize=(10, 6))
                
                # Create bins for PNR values
                bins = np.linspace(0, max(pnr_values) * 1.1, 30)
                
                # Plot histogram
                n, bins, patches = plt.hist(pnr_values, bins=bins, alpha=0.7)
                
                # Add a vertical line at the threshold
                plt.axvline(x=min_pnr.value, color='r', linestyle='--', 
                           label=f'Threshold: {min_pnr.value}')
                
                plt.title('Distribution of Peak-to-Noise Ratio (PNR) Values')
                plt.xlabel('PNR')
                plt.ylabel('Count')
                plt.grid(True, alpha=0.3)
                plt.legend()
                
                plt.tight_layout()
                plt.show()
                
                # Plot showing before/after refinement
                if len(refined_masks) > 0:
                    # Sample frame to display ROIs
                    sample_frame = min(10, corrected_data.shape[0]-1)
                    
                    plt.figure(figsize=(12, 8))
                    
                    # Before refinement
                    plt.subplot(1, 2, 1)
                    plt.imshow(corrected_data[sample_frame], cmap='gray')
                    plt.title(f"All ROIs (Before Refinement: {len(roi_masks)})")
                    
                    # Draw all ROI outlines
                    for i, mask in enumerate(roi_masks):
                        # Find contours
                        contours = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0]
                        for contour in contours:
                            contour = np.squeeze(contour)
                            if len(contour.shape) == 1:
                                # Single point contour
                                continue
                                
                            # Create color based on whether it was kept
                            color = 'g' if i in kept_indices else 'r'
                            
                            # Use fill method to draw closed polygons instead of plot
                            polygon = plt.Polygon(contour, fill=False, edgecolor=color, linewidth=1.5, closed=True)
                            plt.gca().add_patch(polygon)
                    
                    plt.axis('off')
                    
                    # After refinement
                    plt.subplot(1, 2, 2)
                    plt.imshow(corrected_data[sample_frame], cmap='gray')
                    plt.title(f"Kept ROIs (After Refinement: {len(refined_masks)})")
                    
                    # Draw only kept ROI outlines
                    for i, mask in enumerate(refined_masks):
                        # Find contours
                        contours = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0]
                        for contour in contours:
                            contour = np.squeeze(contour)
                            if len(contour.shape) == 1:
                                # Single point contour
                                continue
                                
                            # Use fill method to draw closed polygons
                            polygon = plt.Polygon(contour, fill=False, edgecolor='g', linewidth=1.5, closed=True)
                            plt.gca().add_patch(polygon)
                        
                        # Add ROI number
                        y_indices, x_indices = np.where(mask)
                        if len(y_indices) > 0 and len(x_indices) > 0:
                            # Calculate centroid
                            center_y = int(np.mean(y_indices))
                            center_x = int(np.mean(x_indices))
                            plt.text(center_x, center_y, str(i+1), color='white', 
                                    fontsize=8, ha='center', va='center',
                                    bbox=dict(boxstyle="circle", fc="black", alpha=0.7))
                    
                    plt.axis('off')
                    
                    plt.tight_layout()
                    plt.show()
                    
                    # Add button to use refined ROIs
                    use_refined_button = widgets.Button(
                        description='Use Refined ROIs',
                        button_style='success',
                        tooltip='Replace original ROIs with refined ROIs'
                    )
                    
                    def on_use_refined_click(b):
                        global roi_masks, roi_data
                        # Update global variables with refined data
                        roi_masks = refined_masks
                        roi_data = refined_traces
                        print(f"Updated to use {len(refined_masks)} refined ROIs for further analysis")
                    
                    use_refined_button.on_click(on_use_refined_click)
                    display(use_refined_button)
                
                print("\nPNR refinement complete. Review the results above to decide if you want to use the refined ROIs.")
                
            except Exception as e:
                print(f"Error during PNR refinement: {str(e)}")
                import traceback
                traceback.print_exc()
    
    run_button.on_click(on_run_click)
    
    # Create description of PNR refinement parameters
    pnr_info = widgets.HTML(
        """
        <h3>PNR-Based ROI Refinement:</h3>
        <p>This method refines ROIs based on their peak-to-noise ratio (PNR). ROIs with PNR above the threshold are kept, others are discarded.</p>
        <ul>
            <li><strong>Noise Cutoff</strong>: Frequency cutoff for signal/noise separation. Higher values include more high-frequency components in the noise.</li>
            <li><strong>Min PNR</strong>: Minimum peak-to-noise ratio required to keep an ROI.</li>
            <li><strong>Percentile</strong>: Percentile of signal to use as peak value.</li>
            <li><strong>Smoothing</strong>: Window size for smoothing signal traces (0 to disable).</li>
            <li><strong>Auto-Determine</strong>: Automatically determine optimal frequency cutoff (overrides Noise Cutoff).</li>
        </ul>
        """
    )
    
    # Display widgets
    display(pnr_info)
    display(widgets.VBox([
        noise_freq_cutoff,
        min_pnr,
        percentile_threshold,
        trace_smoothing,
        auto_determine,
        run_button,
        info_output
    ]))

# Initialize variables for refined data
refined_masks = None
refined_traces = None
pnr_values = None
diagnostic_info = None

# Run the function
run_pnr_refinement()

## Event Detection and Analysis

Accurate detection of calcium events is crucial for analyzing neuronal activity. This tool allows you to adjust event detection parameters to optimize sensitivity and specificity for your data.

### Event Detection Methods
The pipeline uses the SciPy `find_peaks` function with several parameters that control sensitivity:

### Key Parameters
- **Prominence**: Minimum height difference between a peak and surrounding baseline
- **Width**: Minimum width (in frames) of a valid peak
- **Distance**: Minimum separation (in frames) between valid peaks
- **Height**: Minimum absolute height above baseline for a valid peak
- **Activity Threshold**: Threshold for declaring an ROI as 'active'

### Condition-Specific Analysis
Different experimental conditions require different analysis approaches:
- **Spontaneous (0µm)**: Analyzes spontaneous activity throughout the recording
- **Evoked (10µm/25µm)**: Focuses on activity following stimulus application (frame 100)

### Finding Optimal Values
- Higher prominence values detect stronger events but may miss subtle ones
- Width requirements help distinguish true events from noise
- The activity threshold should be set based on your experimental design and expected effect size

The interactive tool below allows you to visualize how parameter adjustments affect event detection sensitivity.

In [None]:
# @title Event Detection {display-mode: "form"}
import numpy as np
from scipy.signal import find_peaks
import matplotlib.pyplot as plt

def run_event_detection():
    """Interactive event detection and visualization"""
    global roi_data, roi_masks, corrected_data, bg_corrected_data, metadata, config
    
    # First check if we have the necessary data
    if not all(var in globals() for var in ['roi_masks']):
        print("Please extract ROIs first")
        return
    
    if 'config' not in globals() or config is None:
        print("Please load configuration first")
        return
    
    # Create a copy of traces for visualization
    # First determine which data to use (priority: bg_corrected > roi_data > extract from corrected_data)
    if 'bg_corrected_data' in globals() and bg_corrected_data is not None:
        traces_for_analysis = bg_corrected_data
        print("Using background-corrected data for analysis")
    elif 'roi_data' in globals() and roi_data is not None:
        traces_for_analysis = roi_data
        print("Using ROI data for analysis")
    elif 'corrected_data' in globals() and corrected_data is not None and 'roi_masks' in globals() and roi_masks:
        # Extract traces directly from corrected_data using ROI masks
        print("Extracting traces from corrected data using ROI masks")
        n_rois = len(roi_masks)
        n_frames = corrected_data.shape[0]
        traces_for_analysis = np.zeros((n_rois, n_frames), dtype=np.float32)
        for i, mask in enumerate(roi_masks):
            for t in range(n_frames):
                binary_mask = mask.astype(bool)
                traces_for_analysis[i, t] = np.mean(corrected_data[t][binary_mask])
    else:
        print("No suitable data found for event detection")
        return
    
    # Convert to dF/F using a simple baseline calculation if needed
    # This is just for visualization - the real pipeline will use more sophisticated methods
    df_f_traces = np.zeros_like(traces_for_analysis)
    for i in range(len(traces_for_analysis)):
        # Use first 100 frames or fewer for baseline calculation
        baseline_frames = min(100, traces_for_analysis.shape[1])
        baseline = np.percentile(traces_for_analysis[i, :baseline_frames], 8)
        df_f_traces[i] = (traces_for_analysis[i] - baseline) / baseline if baseline > 0 else traces_for_analysis[i]
    
    # Create widgets for event detection parameters
    prominence = widgets.FloatSlider(
        value=config['analysis'].get('peak_detection', {}).get('prominence', 0.03),
        min=0.01,
        max=0.2,
        step=0.01,
        description='Prominence:',
        style={'description_width': 'initial'}
    )
    
    width = widgets.IntSlider(
        value=config['analysis'].get('peak_detection', {}).get('width', 2),
        min=1,
        max=10,
        step=1,
        description='Width:',
        style={'description_width': 'initial'}
    )
    
    distance = widgets.IntSlider(
        value=config['analysis'].get('peak_detection', {}).get('distance', 10),
        min=5,
        max=30,
        step=1,
        description='Distance:',
        style={'description_width': 'initial'}
    )
    
    height = widgets.FloatSlider(
        value=config['analysis'].get('peak_detection', {}).get('height', 0.02),
        min=0.01,
        max=0.2,
        step=0.01,
        description='Height:',
        style={'description_width': 'initial'}
    )
    
    # Activity threshold
    active_threshold = widgets.FloatSlider(
        value=config['analysis'].get('active_threshold', 0.02),
        min=0.01,
        max=0.1,
        step=0.01,
        description='Activity Threshold:',
        style={'description_width': 'initial'}
    )
    
    # Widget for condition selection
    condition_value = 'unknown'
    if 'metadata' in globals() and metadata and 'condition' in metadata:
        condition_value = metadata['condition']
    
    condition = widgets.Dropdown(
        options=[
            ('Spontaneous (0µm)', '0um'),
            ('Evoked (10µm)', '10um'),
            ('Evoked (25µm)', '25um')
        ],
        value=condition_value if condition_value in ['0um', '10um', '25um'] else '0um',
        description='Condition:',
        style={'description_width': 'initial'}
    )
    
    # Widget to select ROIs to display
    roi_options = [(f"ROI {i+1}", i) for i in range(min(10, len(df_f_traces)))]
    roi_indices = widgets.SelectMultiple(
        options=roi_options,
        value=[0, 1, 2] if len(roi_options) >= 3 else [0],  # Default: first 3 ROIs
        description='ROIs to Display:',
        disabled=False,
        style={'description_width': 'initial'}
    )
    
    # Output widget for results
    output = widgets.Output()
    
    def display_event_detection(b=None):
        with output:
            clear_output()
            
            selected_indices = list(roi_indices.value)
            if not selected_indices:
                print("Please select at least one ROI to display")
                return
            
            # Create peak detection config
            peak_config = {
                "prominence": prominence.value,
                "width": width.value,
                "distance": distance.value,
                "height": height.value,
                "rel_height": 0.5
            }
            
            # Create the peak detection and display
            n_rois = len(selected_indices)
            fig, axes = plt.subplots(n_rois, 1, figsize=(12, 3*n_rois))
            
            # Handle single ROI case
            if n_rois == 1:
                axes = np.array([axes])
            
            # Set analysis frames based on condition
            if condition.value == '0um':
                # For spontaneous, analyze all frames
                analysis_frames = [0, df_f_traces.shape[1]-1]
                active_metric = "spont_peak_frequency"
                title_suffix = "Spontaneous Activity"
            else:
                # For evoked, focus on frames after stimulus
                analysis_frames = [100, df_f_traces.shape[1]-1]
                active_metric = "peak_amplitude"
                title_suffix = f"Evoked Activity ({condition.value})"
            
            # Calculate baseline frames - just use first 100 frames or fewer
            baseline_frames = [0, min(100, df_f_traces.shape[1]-1)]
            
            # Process and display each selected ROI
            active_rois = 0
            for i, roi_idx in enumerate(selected_indices):
                trace = df_f_traces[roi_idx]
                
                # Extract analysis window
                analysis_start, analysis_end = analysis_frames
                analysis_window = trace[analysis_start:analysis_end+1]
                
                # For evoked conditions, calculate and display stimulus time
                if condition.value != '0um':
                    stim_frame = 100  # Frame where stimulus occurs
                
                # Extract peaks
                if condition.value == '0um':
                    # For spontaneous, look at peaks during baseline period
                    baseline_trace = trace[baseline_frames[0]:baseline_frames[1]+1]
                    peaks, properties = find_peaks(
                        baseline_trace,
                        prominence=prominence.value/2,  # Use lower threshold for spontaneous
                        width=width.value,
                        distance=distance.value,
                        height=height.value
                    )
                    
                    # Calculate peak frequency (peaks per 100 frames)
                    peak_freq = len(peaks) / (len(baseline_trace) / 100) if len(baseline_trace) > 0 else 0
                    
                    # Check if ROI is active based on peak frequency
                    is_active = peak_freq > active_threshold.value
                    if is_active:
                        active_rois += 1
                    
                    # Plot trace
                    axes[i].plot(trace, 'k-')
                    
                    # Highlight baseline window
                    axes[i].axvspan(baseline_frames[0], baseline_frames[1], color='lightgray', alpha=0.2)
                    
                    # Find and highlight peaks in full trace
                    all_peaks, _ = find_peaks(
                        trace,
                        prominence=prominence.value/2,
                        width=width.value,
                        distance=distance.value,
                        height=height.value
                    )
                    
                    if len(all_peaks) > 0:
                        axes[i].plot(all_peaks, trace[all_peaks], 'ro')
                    
                    # Add title with metrics
                    axes[i].set_title(f"ROI {roi_idx+1} - {'Active' if is_active else 'Inactive'} - Peak Freq: {peak_freq:.2f}/100 frames")
                    
                else:
                    # For evoked, look at peaks after stimulus
                    peaks, properties = find_peaks(
                        analysis_window,
                        prominence=prominence.value,
                        width=width.value,
                        distance=distance.value,
                        height=height.value
                    )
                    
                    # Calculate peak amplitude (max value)
                    peak_amplitude = np.max(analysis_window) if len(analysis_window) > 0 else 0
                    
                    # Check if ROI is active based on peak amplitude
                    is_active = peak_amplitude > active_threshold.value
                    if is_active:
                        active_rois += 1
                    
                    # Plot trace
                    axes[i].plot(trace, 'k-')
                    
                    # Add a vertical line at stimulus time
                    axes[i].axvline(x=stim_frame, color='r', linestyle='--', alpha=0.7)
                    
                    # Highlight analysis window
                    axes[i].axvspan(analysis_start, analysis_end, color='lightgray', alpha=0.2)
                    
                    # Find and highlight peaks
                    if len(peaks) > 0:
                        # Adjust peak indices to match original trace
                        adjusted_peaks = peaks + analysis_start
                        axes[i].plot(adjusted_peaks, trace[adjusted_peaks], 'ro')
                    
                    # Add title with metrics
                    axes[i].set_title(f"ROI {roi_idx+1} - {'Active' if is_active else 'Inactive'} - Peak Amplitude: {peak_amplitude:.4f}")
                
                # Add zero line for reference
                axes[i].axhline(y=0, color='k', linestyle='--', alpha=0.3)
                
                # Add a threshold line
                axes[i].axhline(y=active_threshold.value, color='g', linestyle=':', alpha=0.5)
                
                axes[i].set_xlabel('Frame')
                axes[i].set_ylabel('dF/F')
                axes[i].grid(True, alpha=0.3)
            
            plt.tight_layout()
            plt.suptitle(f"Event Detection - {title_suffix} ({active_rois}/{n_rois} ROIs Active)", fontsize=16, y=1.02)
            plt.show()
            
            # Update config with current values
            # Peak detection parameters
            if 'peak_detection' not in config['analysis']:
                config['analysis']['peak_detection'] = {}
            
            config['analysis']['peak_detection']['prominence'] = prominence.value
            config['analysis']['peak_detection']['width'] = width.value
            config['analysis']['peak_detection']['distance'] = distance.value
            config['analysis']['peak_detection']['height'] = height.value
            
            # Activity threshold
            config['analysis']['active_threshold'] = active_threshold.value
            
            # Condition-specific parameters
            if 'condition_specific' not in config['analysis']:
                config['analysis']['condition_specific'] = {}
            
            if condition.value not in config['analysis']['condition_specific']:
                config['analysis']['condition_specific'][condition.value] = {}
            
            config['analysis']['condition_specific'][condition.value]['active_threshold'] = active_threshold.value
            config['analysis']['condition_specific'][condition.value]['active_metric'] = active_metric
            
            print(f"Updated config with: prominence={prominence.value}, width={width.value}, distance={distance.value}, height={height.value}")
            print(f"active_threshold={active_threshold.value}, condition={condition.value}, active_metric={active_metric}")
            print("To apply these settings to your pipeline, update your config.yaml file.")
    
    # Create display button
    display_button = widgets.Button(
        description='Display Event Detection',
        button_style='success',
        tooltip='Display event detection for selected ROIs'
    )
    display_button.on_click(display_event_detection)
    
    # Create description of event detection parameters
    event_info = widgets.HTML(
        """
        <h3>Event Detection Parameters:</h3>
        <ul>
            <li><strong>Prominence</strong>: Minimum vertical distance between a peak and its neighboring valleys. Higher values detect only more significant peaks.</li>
            <li><strong>Width</strong>: Minimum width of peaks in frames. Increase to detect broader peaks.</li>
            <li><strong>Distance</strong>: Minimum distance between peaks in frames. Increase to avoid detecting multiple peaks in one event.</li>
            <li><strong>Height</strong>: Minimum height threshold for peaks. Peaks below this value are ignored.</li>
            <li><strong>Activity Threshold</strong>: Threshold to determine if an ROI is considered active.</li>
        </ul>
        <p><strong>Condition:</strong> Select the appropriate condition to adjust analysis parameters. For spontaneous activity (0µm), all frames are analyzed. For evoked activity (10µm, 25µm), only frames after stimulus are analyzed.</p>
        <p>Select ROIs to visualize, then click "Display Event Detection".</p>
        """
    )
    
    # Create widget layout
    parameter_widgets = widgets.VBox([
        prominence,
        width, 
        distance, 
        height,
        active_threshold,
        condition
    ])
    
    control_widgets = widgets.VBox([
        roi_indices,
        display_button
    ])
    
    # Display all widgets
    display(event_info)
    display(widgets.HBox([parameter_widgets, control_widgets]))
    display(output)

# Run the function directly
run_event_detection()

## Save Optimized Configuration

After exploring different parameter settings in the interactive visualizations, you can save your optimized configuration for future use. This will create a new YAML configuration file with all the parameter adjustments you've made.

### Options
- **Save Configuration**: Write the current parameters to a YAML file
- **Print Current Configuration**: Display the current configuration in the notebook
- **Reset Configuration**: Revert to the original configuration that was loaded

The saved configuration can be used with the command-line version of the pipeline by specifying the `--config` parameter.

In [None]:
# @title Configuration Management {display-mode: "form"}
def save_config_to_file():
    """Save the updated configuration to a YAML file"""
    # Create a file selector
    output_path = widgets.Text(
        value='optimized_config.yaml',
        placeholder='Enter file path to save config',
        description='Output File:',
        style={'description_width': 'initial'},
        layout=widgets.Layout(width='80%')
    )
    
    display(output_path)
    
    # Create a save button
    save_button = widgets.Button(
        description='Save Configuration',
        button_style='success',
        tooltip='Click to save the current configuration to a file'
    )
    
    def on_save_click(b):
        try:
            path = output_path.value
            if not path:
                print("Please enter a valid file path")
                return
            
            # Save the configuration to the specified file
            with open(path, 'w') as f:
                yaml.dump(config, f, default_flow_style=False)
            
            print(f"Configuration saved to {path}")
            print("To use this configuration in your pipeline, specify it with the --config parameter")
        except Exception as e:
            print(f"Error saving configuration: {str(e)}")
    
    save_button.on_click(on_save_click)
    display(save_button)
    
    # Create a button to print the current configuration
    print_button = widgets.Button(
        description='Print Current Configuration',
        button_style='info',
        tooltip='Click to print the current configuration to the notebook'
    )
    
    def on_print_click(b):
        # Print the configuration in a readable format
        print("Current Configuration:")
        print("=" * 50)
        
        # Print as formatted YAML
        print(yaml.dump(config, default_flow_style=False))
    
    print_button.on_click(on_print_click)
    display(print_button)
    
    # Create a button to reset configuration to original
    reset_button = widgets.Button(
        description='Reset Configuration',
        button_style='danger',
        tooltip='Click to reset the configuration to the original values'
    )
    
    def on_reset_click(b):
        # Reset the configuration to the original values
        global config
        config = copy.deepcopy(config_original)
        print("Configuration reset to original values")
    
    reset_button.on_click(on_reset_click)
    display(reset_button)

# Call the save configuration function
save_config_to_file()

## Run Full Pipeline

Once you've optimized the parameters using the interactive visualizations, you can run the complete analysis pipeline on all your data. The pipeline will:

1. Process all matched TIF/ROI file pairs in the input directory
2. Apply preprocessing with your optimized parameters
3. Extract and refine ROIs
4. Perform background subtraction
5. Detect and analyze events
6. Generate visualizations and metrics

The results will be saved in the output directory specified in the parameters section. Each file pair will have its own subdirectory containing:
- Corrected data (HDF5 format)
- ROI masks and traces
- Analysis metrics (Excel and CSV)
- Visualizations (PNG format)

### Performance Considerations
- Processing multiple large files can be memory-intensive
- The pipeline supports parallel processing using multiple CPU cores
- Adjust the "Max Workers" parameter based on your computer's capabilities

Click "Run Full Pipeline" to start processing all file pairs.

In [4]:
# @title Run Full Pipeline {display-mode: "form"}
def run_full_pipeline():
    """Run the full analysis pipeline with the option to choose configuration"""
    global config
    
    if 'config' not in globals() or config is None:
        print("Please load configuration first")
        return
    
    # Create a dropdown to select configuration source
    config_source = widgets.RadioButtons(
        options=[
            ('Use current configuration in memory', 'current'),
            ('Load configuration from file', 'file')
        ],
        value='current',
        description='Config Source:',
        style={'description_width': 'initial'},
        layout=widgets.Layout(width='50%')
    )
    
    # File selector widget (hidden initially)
    file_selector_container = widgets.VBox([])
    
    def update_file_selector():
        # Clear existing widgets
        file_selector_container.children = ()
        
        if config_source.value == 'file':
            # Create file path input
            file_path = widgets.Text(
                value='',
                placeholder='Enter path to config YAML file',
                description='Config file:',
                style={'description_width': 'initial'},
                layout=widgets.Layout(width='70%')
            )
            
            # Create a button to browse files
            browse_button = widgets.Button(
                description='Browse...',
                button_style='info',
                tooltip='Browse for configuration files',
                layout=widgets.Layout(width='120px')
            )
            
            # Function to handle file browsing
            def on_browse_click(b):
                # Create file browser UI - this is just an example
                # In a real implementation, you'd use a file picker appropriate for your environment
                config_dir = args.output_dir if hasattr(args, 'output_dir') else '.'
                
                # List YAML files in the directory
                import glob
                yaml_files = glob.glob(os.path.join(config_dir, '*.yaml')) + glob.glob(os.path.join(config_dir, '*.yml'))
                
                if not yaml_files:
                    print("No YAML files found in", config_dir)
                    return
                
                # Create a dropdown for selecting a file
                file_dropdown = widgets.Dropdown(
                    options=yaml_files,
                    description='Select file:',
                    style={'description_width': 'initial'},
                    layout=widgets.Layout(width='70%')
                )
                
                # Function to update file path when selection changes
                def on_selection_change(change):
                    file_path.value = change['new']
                
                file_dropdown.observe(on_selection_change, names='value')
                
                # Display the dropdown in a new output
                browse_output.clear_output()
                with browse_output:
                    display(file_dropdown)
            
            browse_button.on_click(on_browse_click)
            
            # Container for file browser output
            browse_output = widgets.Output()
            
            # Add widgets to container
            file_selector_container.children = (
                widgets.HBox([file_path, browse_button]),
                browse_output
            )
    
    # Update file selector when config source changes
    config_source.observe(lambda change: update_file_selector(), names='value')
    
    # Create a button to run the pipeline
    run_button = widgets.Button(
        description='Run Full Pipeline',
        button_style='danger',
        tooltip='Click to run the full pipeline with selected configuration'
    )
    
    # Output for pipeline progress
    pipeline_output = widgets.Output()
    
    # Function to run the pipeline
    def on_run_click(b):
        with pipeline_output:
            clear_output()
            print("Starting pipeline execution...")
            
            try:
                # Determine which config to use
                if config_source.value == 'current':
                    # Use current config in memory
                    config_to_use = config
                    print("Using current configuration in memory")
                else:
                    # Load config from file
                    if len(file_selector_container.children) > 0:
                        file_path_widget = file_selector_container.children[0].children[0]
                        config_file_path = file_path_widget.value
                        
                        if not config_file_path:
                            print("Error: No configuration file specified")
                            return
                        
                        if not os.path.exists(config_file_path):
                            print(f"Error: Configuration file not found: {config_file_path}")
                            return
                        
                        # Load the specified config file
                        try:
                            with open(config_file_path, 'r') as f:
                                config_to_use = yaml.safe_load(f)
                            print(f"Loaded configuration from: {config_file_path}")
                        except Exception as e:
                            print(f"Error loading configuration file: {str(e)}")
                            return
                    else:
                        print("Error: File selection UI not properly initialized")
                        return
                
                print(f"Input directory: {args.input_dir}")
                print(f"Output directory: {args.output_dir}")
                print(f"Pipeline mode: {args.mode}")
                print(f"Max workers: {args.max_workers}")
                print("=" * 50)
                
                # Save the selected configuration to a temporary file
                temp_config_path = os.path.join(args.output_dir, 'temp_config.yaml')
                with open(temp_config_path, 'w') as f:
                    yaml.dump(config_to_use, f, default_flow_style=False)
                
                print(f"Saved configuration to {temp_config_path}")
                
                # Import the main pipeline function
                from pipeline import main
                
                # Update args with the temp config path
                args.config = temp_config_path
                
                # Run the pipeline
                print("\nExecuting pipeline. This may take a while...\n")
                
                # In a Jupyter notebook, we need to use a different approach than in the script
                # since we can't directly use sys.argv
                import sys
                original_argv = sys.argv
                
                # Create mock argv
                sys.argv = ['pipeline.py', 
                           f'--input_dir={args.input_dir}', 
                           f'--output_dir={args.output_dir}',
                           f'--config={temp_config_path}',
                           f'--mode={args.mode}',
                           f'--max_workers={args.max_workers}']
                
                if args.disable_advanced:
                    sys.argv.append('--disable_advanced')
                
                # Execute the pipeline
                try:
                    main()
                    print("\nPipeline completed successfully!")
                except Exception as e:
                    print(f"Error executing pipeline: {str(e)}")
                    import traceback
                    traceback.print_exc()
                
                # Restore original argv
                sys.argv = original_argv
                
            except Exception as e:
                print(f"Error setting up pipeline: {str(e)}")
                import traceback
                traceback.print_exc()
    
    run_button.on_click(on_run_click)
    
    # Create warning about running the full pipeline
    warning_text = widgets.HTML(
        """
        <div style="background-color: #ffe6e6; padding: 10px; border-radius: 5px; margin-bottom: 10px;">
            <h3 style="color: #cc0000;">⚠️ Warning</h3>
            <p>This will run the full pipeline on all files in the input directory using the selected configuration.</p>
            <p>Depending on the number of files and your settings, this could take a long time to complete.</p>
            <p>Make sure your configuration parameters are optimized before running the full pipeline.</p>
        </div>
        """
    )
    
    # Display all components
    display(warning_text)
    display(config_source)
    display(file_selector_container)
    display(run_button)
    display(pipeline_output)
    
    # Initialize the file selector based on the default value
    update_file_selector()

# Run the function
run_full_pipeline()

HTML(value='\n        <div style="background-color: #ffe6e6; padding: 10px; border-radius: 5px; margin-bottom:…

RadioButtons(description='Config Source:', layout=Layout(width='50%'), options=(('Use current configuration in…

VBox()

Button(button_style='danger', description='Run Full Pipeline', style=ButtonStyle(), tooltip='Click to run the …

Output()

## Notebook Summary

This interactive notebook provides a user-friendly interface to the fluorescence analysis pipeline, allowing you to:

1. **Load and Configure**: Set up pipeline parameters and load configuration from YAML file
2. **Select and Preprocess Data**: Load TIF/ROI file pairs and prepare them for analysis
3. **Explore Parameter Effects**: Use interactive visualizations to understand how different parameters affect:
   - Gaussian denoising image quality
   - ROI selection with PNR refinement
   - Background subtraction effectiveness 
   - Event detection sensitivity and specificity
4. **Optimize Settings**: Adjust parameters to find optimal values for your specific dataset
5. **Save Configuration**: Save your optimized configuration for future use
6. **Run Full Pipeline**: Process all file pairs with your optimized settings

### Output Data
The analysis results saved in the output directory include:
- Corrected fluorescence data (HDF5 format)
- ROI masks and fluorescence traces
- Analysis metrics for each ROI (Excel and CSV format)
- Visualizations of ROIs, traces, and detected events

### Advanced Analysis
For more sophisticated analysis of the pipeline outputs, you can:
- Use the generated CSVs and Excel files for statistical analysis
- Load the HDF5 files for custom visualization and processing
- Examine the visualization PNGs for quality control
- Compare results across different experimental conditions

For questions or issues with the pipeline, please refer to the documentation or contact the developers.