# GRNsuite Demo Workflows

This notebook demonstrates the interactive vs non-interactive workflows for processing electrophysiology data.

## Interactive Workflow Looping Through Files

This workflow allows you to loop through all files in the data directory and process them interactively.

Data is saved in the results directory with the same filename as the original data, alongside the metadata for that file.

## Step 1: Spike Detection (loops through all files)

In [1]:
import grnsuite.preprocessing as preprocessing
import grnsuite.spike_detection as spikes
import grnsuite.utils as utils
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import os
import glob
from matplotlib.widgets import Button

import matplotlib
matplotlib.use("TkAgg")  # Ensures the correct interactive backend is used

# Load parameters
params = utils.load_parameters('parameters.yaml')

# Get list of files to process
if params['process_mode'] == 'all':
    # Get all .txt files in the data directory
    files = glob.glob(os.path.join("data", "*.txt"))
    filenames = [os.path.basename(f) for f in files]
else:
    # Get specific files from parameters if they exist
    filenames = []
    if 'recordings' in params:
        filenames = [f"{name}.txt" for name in params['recordings'].keys()]
    else:
        raise ValueError("No recordings specified in parameters.yaml and process_mode is not 'all'")

print(f"Found {len(filenames)} files to process:")
for f in filenames:
    print(f"  - {f}")

# Add global control variable
should_end = False

# Process each file
for filename in filenames:
    print(f"\nProcessing {filename}...")
    
    # Set up paths
    filepath = os.path.join("data", filename)
    sample_name = filename.replace('.txt', '')
    output_dir = os.path.join("results", sample_name)
    os.makedirs(output_dir, exist_ok=True)
    
    # Get metadata from filename
    metadata = utils.parse_filename_metadata(sample_name, params)
    
    print("\n=== Interactive Pre-processing ===")
    # Load raw data
    raw_data = preprocessing._load_ephys_data(filepath)
    
    # Step-by-step processing (interactive path)
    selected_signal = preprocessing.interactive_contact_selection(raw_data)
    filtered_signal = preprocessing.filter_signal(selected_signal, params)
    
    # Save zoomed data visualization
    zoomed_signal, current_time = preprocessing.zoom_to_region(filtered_signal, params, output_dir=output_dir)
    
    mad_signal = preprocessing.normalize_signal(zoomed_signal)
    
    # Save the interactive results
    temp_data_path = os.path.join(output_dir, 'interactive_processed.csv')
    preprocessing.save_results(mad_signal, current_time, temp_data_path)
    
    print("\n=== Interactive Spike Detection ===")
    # Run automatic detection
    spikes_file_auto, spike_times_auto, spike_values_auto = spikes.schmidt_trigger_auto(
        temp_data_path, 
        output_dir
    )

    # Allow manual threshold adjustment and detect spikes
    manual_threshold = spikes.adjust_threshold(mad_signal)
    spikes_file_manual, spike_times, spike_values = spikes.detect_spikes_manual_threshold(
        temp_data_path,
        output_dir,
        threshold=manual_threshold
    )
    
    # Extract and save waveforms using the manual threshold detection results
    print("\n=== Extracting Spike Waveforms ===")
    waveforms_file = spikes.extract_waveforms(
        data_file=temp_data_path,
        spikes_file=spikes_file_manual,  # Use the manual threshold results
        output_dir=output_dir,
        pre_peak_length=2,  # 2ms before spike
        post_peak_length=2   # 2ms after spike
    )
    
    # Load the waveforms for visualization
    waveforms_df = pd.read_csv(waveforms_file)
    waveforms = waveforms_df.values
    
    # Plot and save waveforms visualization
    avg_waveform, std_waveform = spikes.plot_waveforms(
        waveforms=waveforms,
        pre_peak_ms=2.0,
        post_peak_ms=2.0,
        show_time_axis=True,
        output_dir=output_dir
    )

    print("\n=== Automated Workflow ===")
    # Process data using the complete automated workflow
    auto_output_path = os.path.join(output_dir, 'auto_processed.csv')
    preprocessing.process_recording(
        input_file=filepath,
        output_file=auto_output_path,
        interactive=False  # Use automated contact selection
    )

    # Load and plot results
    processed_data = pd.read_csv(auto_output_path)
    detected_spikes = pd.read_csv(spikes_file_auto)  # Using the auto detection results

    print("\n=== Compare Results ===")
    print(f"Interactive workflow detected {len(spike_times)} spikes")
    print(f"Automated workflow detected {len(detected_spikes)} spikes")
    print(f"Extracted {waveforms.shape[0]} waveforms")

    # Plot both results overlaid
    fig = plt.figure(figsize=(12, 6))
    plt.plot(current_time, mad_signal, 'b-', alpha=0.5, label="Signal")

    # Add offset to interactive points for better visualization
    y_offset = 0.5  # Adjust this value to change the separation
    plt.scatter(spike_times, spike_values + y_offset, 
               color='red', label="Interactive", alpha=0.6)
    plt.scatter(detected_spikes['spike_times'], detected_spikes['spike_values'], 
               color='green', label="Automated", alpha=0.6)
    plt.legend()
    plt.title(f"Spike Detection Results - {sample_name}")
    plt.xlabel("Time (s)")
    plt.ylabel("Normalized Voltage")
    
    # Save comparison figure
    comparison_path = os.path.join(output_dir, 'spike_comparison.png')
    plt.savefig(comparison_path, dpi=300)
    print(f"Comparison figure saved to: {comparison_path}")

    # Add Next and End buttons
    if filename != filenames[-1]:
        next_button = Button(plt.axes([0.7, 0.01, 0.1, 0.05]), 'Next File')
        
        def handle_next(event):
            plt.close(fig)
            print(f"\nMoving to next file...")

        next_button.on_clicked(handle_next)

    # Always show End button
    end_button = Button(plt.axes([0.85, 0.01, 0.1, 0.05]), 'End')
    
    def handle_end(event):
        plt.close(fig)
        print("\nStopping processing. Files completed:")
        for f in filenames[:filenames.index(filename) + 1]:
            print(f"  - {f}")
        if filename != filenames[-1]:
            print("\nFiles remaining:")
            for f in filenames[filenames.index(filename) + 1:]:
                print(f"  - {f}")
        global should_end  # Use global variable to control the loop
        should_end = True

    end_button.on_clicked(handle_end)
    plt.show()

    # Check if we should end processing
    if should_end:
        break

Found 9 files to process:
  - 20231103-M05-sucr-100-Gal-A2-01.txt
  - 20231106-M01-sucr-100-Gal-A1-02.txt
  - 20231103-M04-sucr-100-Gal-A2-02.txt
  - 20231106-F02-sucr-100-Gal-A2-02.txt
  - 20231103-M05-sucr-100-Gal-A1-02.txt
  - 20231103-M05-sucr-100-Gal-A3-02.txt
  - 20231106-M01-sucr-100-Gal-A2-02.txt
  - 20231103-M04-sucr-100-Gal-A1-02.txt
  - 20231106-F02-sucr-100-Gal-A3-02.txt

Processing 20231103-M05-sucr-100-Gal-A2-01.txt...

=== Interactive Pre-processing ===


2025-03-14 14:58:28.444 Python[54717:16356064] +[IMKClient subclass]: chose IMKClient_Modern
2025-03-14 14:58:28.444 Python[54717:16356064] +[IMKInputSession subclass]: chose IMKInputSession_Modern


Selected segment from index 419319 to 509319
Zoomed data figure saved to: results/20231103-M05-sucr-100-Gal-A2-01/zoomed_data.png
Successfully saved processed data to: results/20231103-M05-sucr-100-Gal-A2-01/interactive_processed.csv

=== Interactive Spike Detection ===
Initial threshold calculated: 2.3982556828280686
Data range: min=-9.20173212259001, max=10.85913567763796
Selected threshold: 2.3982556828280686
Final threshold value: 2.3982556828280686
Using thresholds: upper=2.3983, lower=1.9186

=== Extracting Spike Waveforms ===
Waveform figure saved to: results/20231103-M05-sucr-100-Gal-A2-01/waveforms_plot.png

=== Automated Workflow ===
Auto-selected signal from index 419319 to 509319 (duration: 3.0s)
Successfully saved processed data to: results/20231103-M05-sucr-100-Gal-A2-01/auto_processed.csv
Saved metadata to: results/20231103-M05-sucr-100-Gal-A2-01/metadata.json

=== Compare Results ===
Interactive workflow detected 202 spikes
Automated workflow detected 202 spikes
Extract