# GRNsuite Demo Workflows

This notebook demonstrates both interactive and non-interactive workflows for processing electrophysiology data.

In [1]:
import grnsuite.preprocessing as preprocessing
import grnsuite.spike_detection as spikes
import grnsuite.utils as utils
import matplotlib.pyplot as plt
import os
import glob

# Load parameters
params = utils.load_parameters('parameters.yaml')

# Set parameters from yaml
filename = "20231103-M04-sucr-100-Gal-A1-02.txt"
fs = params['sampling_rate']
offset_time = params['offset_time']
analysis_length = params['analysis_length']

## Interactive Workflow

This workflow allows manual inspection and adjustment at each step.

In [7]:
import pandas as pd
import matplotlib
matplotlib.use("TkAgg")  # Ensures the correct interactive backend is used

# Load and process data interactively
filepath = os.path.join("data", filename)
raw_data = preprocessing.load_ephys_data(filepath)

# Interactive contact selection
selected_signal = preprocessing.interactive_contact_selection(raw_data)

# Filter and denoise
filtered_data = preprocessing.ashfilt(selected_signal, [100, 1000], 'bandpass', fs=fs)
denoised_data = preprocessing.ashfilt(filtered_data, None, 'noise', fs=fs)

# zoom into a section of the recording
data_zoomed, current_time = preprocessing.zoom_data(denoised_data, fs, offset_time, analysis_length)

# calculate the median absolute deviation of the signal
mad_signal = preprocessing.MAD_signal(data_zoomed)

# Plot results
plt.figure(figsize=(10, 5))
plt.plot(current_time, mad_signal)
plt.title("Processed Data")
plt.xlabel("Time (s)")
plt.ylabel("Voltage")
plt.show()

## Interactive Workflow Looping Through Files

This workflow allows you to loop through all files in the data directory and process them interactively.

Data is saved in the results directory with the same filename as the original data, alongside the metadata for that file.

In [1]:
import grnsuite.preprocessing as preprocessing
import grnsuite.spike_detection as spikes
import grnsuite.utils as utils
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import os
import glob
from matplotlib.widgets import Button

import matplotlib
matplotlib.use("TkAgg")  # Ensures the correct interactive backend is used

# Load parameters
params = utils.load_parameters('parameters.yaml')

# Get list of files to process
if params['process_mode'] == 'all':
    # Get all .txt files in the data directory
    files = glob.glob(os.path.join("data", "*.txt"))
    filenames = [os.path.basename(f) for f in files]
else:
    # Get specific files from parameters if they exist
    filenames = []
    if 'recordings' in params:
        filenames = [f"{name}.txt" for name in params['recordings'].keys()]
    else:
        raise ValueError("No recordings specified in parameters.yaml and process_mode is not 'all'")

print(f"Found {len(filenames)} files to process:")
for f in filenames:
    print(f"  - {f}")

# Add global control variable
should_end = False

# Process each file
for filename in filenames:
    print(f"\nProcessing {filename}...")
    
    # Set up paths
    filepath = os.path.join("data", filename)
    sample_name = filename.replace('.txt', '')
    output_dir = os.path.join("results", sample_name)
    os.makedirs(output_dir, exist_ok=True)
    
    # Get metadata from filename
    metadata = utils.parse_filename_metadata(sample_name, params)
    
    print("\n=== Pre-processing ===")
    # Load and process data interactively
    raw_data = preprocessing.load_ephys_data(filepath)

    # Interactive contact selection
    selected_signal = preprocessing.interactive_contact_selection(raw_data)

    # Filter and denoise
    filtered_data = preprocessing.ashfilt(selected_signal, [params["filter_low"], params["filter_high"]], 'bandpass', fs=params['sampling_rate'])
    denoised_data = preprocessing.ashfilt(filtered_data, None, 'noise', fs=params['sampling_rate'])
    
    # zoom into a section of the recording
    data_zoomed, current_time = preprocessing.zoom_data(denoised_data, 
                                                      params['sampling_rate'], 
                                                      params['offset_time'], 
                                                      params['analysis_length'])

     # calculate the MAD of the signal
    mad_signal = preprocessing.MAD_signal(data_zoomed)
    
    print("\n=== Interactive Spike Detection ===")
    # Interactive spike detection
    # Save the current segment to a temporary CSV file for schmidt_trigger_auto
    temp_data = pd.DataFrame({
        'time': current_time,
        'voltage': mad_signal
    })
    temp_data_path = os.path.join(output_dir, 'temp_processed.csv')
    temp_data.to_csv(temp_data_path, index=False)

    # Run automatic detection and get values directly
    spikes_file, spike_times_auto, spike_values_auto = spikes.schmidt_trigger_auto(temp_data_path, output_dir)

    # Allow manual threshold adjustment
    manual_threshold = spikes.adjust_threshold(mad_signal)
    spike_indices, spike_values = spikes.detect_spikes_manual_threshold(mad_signal, threshold=manual_threshold)
    spike_times = current_time[spike_indices]  # Convert indices to times

    

    print("\n=== Automated Workflow ===")
     # Process data automatically (with metadata)
    processed_data_path = preprocessing.load_and_process_data(
        input_file=filepath,
        output_file=os.path.join(output_dir, 'processed_data.csv')
    )

    # Detect spikes automatically
    spikes_path, waveforms_path = spikes.detect_and_save_spikes(
        processed_data_path,
        output_dir
    )

    # Load and plot results
    processed_data = pd.read_csv(processed_data_path)
    detected_spikes = pd.read_csv(spikes_path)

    print("\n=== Compare Results ===")
    print(f"Interactive workflow detected {len(spike_indices)} spikes")
    print(f"Automated workflow detected {len(detected_spikes)} spikes")

    # Plot both results overlaid
    fig = plt.figure(figsize=(12, 6))
    plt.plot(current_time, mad_signal, 'b-', alpha=0.5, label="Signal")

    # Add offset to interactive points for better visualization
    y_offset = 0.5  # Adjust this value to change the separation
    plt.scatter(spike_times, spike_values + y_offset, 
               color='red', label="Interactive", alpha=0.6)
    plt.scatter(detected_spikes['spike_times'], detected_spikes['spike_values'], 
               color='green', label="Automated", alpha=0.6)

    plt.title("Comparison of Detection Methods")
    plt.legend()

    # Add Next and End buttons
    if filename != filenames[-1]:
        next_button = Button(plt.axes([0.7, 0.01, 0.1, 0.05]), 'Next File')
        
        def handle_next(event):
            plt.close(fig)
            print(f"\nMoving to next file...")

        next_button.on_clicked(handle_next)

    # Always show End button
    end_button = Button(plt.axes([0.85, 0.01, 0.1, 0.05]), 'End')
    
    def handle_end(event):
        plt.close(fig)
        print("\nStopping processing. Files completed:")
        for f in filenames[:filenames.index(filename) + 1]:
            print(f"  - {f}")
        if filename != filenames[-1]:
            print("\nFiles remaining:")
            for f in filenames[filenames.index(filename) + 1:]:
                print(f"  - {f}")
        global should_end  # Use global variable to control the loop
        should_end = True

    end_button.on_clicked(handle_end)
    plt.show()

    # Check if we should end processing
    if should_end:
        break

Found 9 files to process:
  - 20231103-M05-sucr-100-Gal-A2-01.txt
  - 20231106-M01-sucr-100-Gal-A1-02.txt
  - 20231103-M04-sucr-100-Gal-A2-02.txt
  - 20231106-F02-sucr-100-Gal-A2-02.txt
  - 20231103-M05-sucr-100-Gal-A1-02.txt
  - 20231103-M05-sucr-100-Gal-A3-02.txt
  - 20231106-M01-sucr-100-Gal-A2-02.txt
  - 20231103-M04-sucr-100-Gal-A1-02.txt
  - 20231106-F02-sucr-100-Gal-A3-02.txt

Processing 20231103-M05-sucr-100-Gal-A2-01.txt...

=== Pre-processing ===


2025-03-11 17:37:08.190 Python[70803:12045806] +[IMKClient subclass]: chose IMKClient_Modern
2025-03-11 17:37:08.190 Python[70803:12045806] +[IMKInputSession subclass]: chose IMKInputSession_Modern



=== Interactive Spike Detection ===
Initial threshold calculated: 2.399921237267498
Data range: min=-9.20761864106934, max=10.866049985879497
Selected threshold: 2.399921237267498
Final threshold value: 2.399921237267498
Using thresholds: upper=2.3999, lower=1.9199

=== Automated Workflow ===
Processing data/20231103-M05-sucr-100-Gal-A2-01.txt
Using parameters: fs=30000, offset=0.1, length=2.0
Loaded data with shape: (558000,)
Found contact artifact at index: 419319
Successfully saved processed data to: results/20231103-M05-sucr-100-Gal-A2-01/processed_data.csv

=== Compare Results ===
Interactive workflow detected 202 spikes
Automated workflow detected 202 spikes

Moving to next file...

Processing 20231106-M01-sucr-100-Gal-A1-02.txt...

=== Pre-processing ===

=== Interactive Spike Detection ===
Initial threshold calculated: 2.436599848986405
Data range: min=-8.81123824313137, max=13.86636872113898
Selected threshold: 2.4228993607071683
Final threshold value: 2.4228993607071683
Usin