# GRNsuite Demo Workflows

This notebook demonstrates both interactive and non-interactive workflows for processing electrophysiology data.

In [1]:
import grnsuite.preprocessing as preprocessing
import grnsuite.spike_detection as spikes
import grnsuite.utils as utils
import matplotlib.pyplot as plt
import os
import glob

# Load parameters
params = utils.load_parameters('parameters.yaml')

# Set parameters from yaml
filename = "20231103-M04-sucr-100-Gal-A1-02.txt"
fs = params['sampling_rate']
offset_time = params['offset_time']
analysis_length = params['analysis_length']

## Interactive Workflow

This workflow allows manual inspection and adjustment at each step.

In [2]:
import matplotlib
matplotlib.use("TkAgg")  # Ensures the correct interactive backend is used

# Load and process data interactively
filepath = os.path.join("data", filename)
raw_data = preprocessing.load_ephys_data(filepath)

# Interactive contact selection
selected_signal = preprocessing.interactive_contact_selection(raw_data)

# Filter and denoise
filtered_data = preprocessing.ashfilt(selected_signal, [100, 1000], 'bandpass', fs=fs)
denoised_data = preprocessing.ashfilt(filtered_data, None, 'noise', fs=fs)
data_zoomed, current_time = preprocessing.zoom_data(denoised_data, fs, offset_time, analysis_length)

# Plot results
plt.figure(figsize=(10, 5))
plt.plot(current_time, data_zoomed)
plt.title("Processed Data")
plt.xlabel("Time (s)")
plt.ylabel("Voltage")
plt.show()

In [3]:
# Interactive spike detection
spike_times, spike_values, schmidt_threshold = spikes.schmidt_trigger_auto(data_zoomed, current_time)

# Allow manual threshold adjustment starting from the same threshold
manual_threshold = spikes.adjust_threshold(data_zoomed)  # Will use schmidt_trigger_auto internally
spike_times, spike_values = spikes.detect_spikes_manual_threshold(data_zoomed, threshold=manual_threshold)

# Plot results
plt.figure(figsize=(10, 5))
plt.plot(current_time, data_zoomed, label="Signal")
plt.scatter(current_time[spike_times], spike_values, color='red', label="Spikes")
plt.title("Detected Spikes (Interactive)")
plt.legend()
plt.show()

## Interactive Workflow Looping Through Files

This workflow allows you to loop through all files in the data directory and process them interactively.

Data is saved in the results directory with the same filename as the original data, alongside the metadata for that file.

In [4]:
import grnsuite.preprocessing as preprocessing
import grnsuite.spike_detection as spikes
import grnsuite.utils as utils
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import os
import glob

import matplotlib
matplotlib.use("TkAgg")  # Ensures the correct interactive backend is used

# Load parameters
params = utils.load_parameters('parameters.yaml')

# Get list of files to process
if params['process_mode'] == 'all':
    files = glob.glob("data/*.txt")
    filenames = [os.path.basename(f) for f in files]
else:
    filenames = [f"{sample}.txt" for sample in params['selected_samples']]

# Process each file
for filename in filenames:
    print(f"\nProcessing {filename}...")
    
    # Set up paths
    filepath = os.path.join("data", filename)
    sample_name = filename.replace('.txt', '')
    output_dir = os.path.join("results", sample_name)
    
    # Get metadata from filename
    metadata = utils.parse_filename_metadata(sample_name, params)
    
    print("\n=== Interactive Workflow ===")
    # Load and process data interactively
    raw_data = preprocessing.load_ephys_data(filepath)

    # Interactive contact selection
    selected_signal = preprocessing.interactive_contact_selection(raw_data)

    # Filter and denoise
    filtered_data = preprocessing.ashfilt(selected_signal, [100, 1000], 'bandpass', fs=params['sampling_rate'])
    denoised_data = preprocessing.ashfilt(filtered_data, None, 'noise', fs=params['sampling_rate'])
    data_zoomed, current_time = preprocessing.zoom_data(denoised_data, 
                                                      params['sampling_rate'], 
                                                      params['offset_time'], 
                                                      params['analysis_length'])

    # Plot results
    plt.figure(figsize=(10, 5))
    plt.plot(current_time, data_zoomed)
    plt.title("Processed Data")
    plt.xlabel("Time (s)")
    plt.ylabel("Voltage")
    plt.show()

    # Interactive spike detection
    spike_times, spike_values, schmidt_threshold = spikes.schmidt_trigger_auto(data_zoomed, current_time)

    # Allow manual threshold adjustment
    manual_threshold = spikes.adjust_threshold(data_zoomed)
    spike_times, spike_values = spikes.detect_spikes_manual_threshold(data_zoomed, threshold=manual_threshold)

    # Plot results
    plt.figure(figsize=(10, 5))
    plt.plot(current_time, data_zoomed, label="Signal")
    plt.scatter(current_time[spike_times], spike_values, color='red', label="Spikes")
    plt.title("Detected Spikes (Interactive)")
    plt.legend()
    plt.show()

    print("\n=== Automated Workflow ===")
    # Process data automatically (with metadata)
    processed_data_path = preprocessing.load_and_process_data(
        filepath, 
        output_dir,
        param_file='parameters.yaml',
        metadata=metadata
    )

    # Detect spikes automatically
    spikes_path, waveforms_path = spikes.detect_and_save_spikes(
        processed_data_path,
        output_dir,
        param_file='parameters.yaml'
    )

    # Load and plot results
    processed_data = pd.read_csv(processed_data_path)
    detected_spikes = pd.read_csv(spikes_path)

    plt.figure(figsize=(10, 5))
    plt.plot(processed_data['time'], processed_data['voltage'], label="Signal")
    plt.scatter(detected_spikes['spike_times'], detected_spikes['spike_values'], 
               color='red', label="Spikes")
    plt.title("Detected Spikes (Automated)")
    plt.legend()
    plt.show()

    print("\n=== Compare Results ===")
    print(f"Interactive workflow detected {len(spike_times)} spikes")
    print(f"Automated workflow detected {len(detected_spikes)} spikes")

    # Plot both results overlaid
    plt.figure(figsize=(12, 6))
    plt.plot(current_time, data_zoomed, 'b-', alpha=0.5, label="Signal")

    # Add offset to interactive points for better visualization
    y_offset = 10  # Adjust this value to change the separation
    plt.scatter(current_time[spike_times], spike_values + y_offset, 
               color='red', label="Interactive", alpha=0.6)
    plt.scatter(detected_spikes['spike_times'], detected_spikes['spike_values'], 
               color='green', label="Automated", alpha=0.6)

    plt.title("Comparison of Detection Methods")
    plt.legend()
    plt.show()

    # Ask user if they want to continue to next file
    if len(filenames) > 1:
        response = input("\nProcess next file? (y/n): ")
        if response.lower() != 'y':
            break


Processing 20231103-M04-sucr-100-Gal-A1-02.txt...

=== Interactive Workflow ===

=== Automated Workflow ===
Processing data\20231103-M04-sucr-100-Gal-A1-02.txt
Using parameters: fs=30000, offset=0.1, length=2.0
Saved metadata to: results\20231103-M04-sucr-100-Gal-A1-02\metadata.json
Loaded data with shape: (399000,)
Found contact artifact at index: 270248
Attempting to save to: results\20231103-M04-sucr-100-Gal-A1-02\processed_data.csv
Successfully saved processed data to: results\20231103-M04-sucr-100-Gal-A1-02\processed_data.csv

=== Compare Results ===
Interactive workflow detected 208 spikes
Automated workflow detected 208 spikes

Processing 20231103-M04-sucr-100-Gal-A2-02.txt...

=== Interactive Workflow ===

=== Automated Workflow ===
Processing data\20231103-M04-sucr-100-Gal-A2-02.txt
Using parameters: fs=30000, offset=0.1, length=2.0
Saved metadata to: results\20231103-M04-sucr-100-Gal-A2-02\metadata.json
Loaded data with shape: (342000,)
Found contact artifact at index: 19929

## Non-Interactive (Automated) Workflow

This workflow runs automatically without user intervention - same as what Snakemake uses.

In [5]:
# Process data automatically
output_dir = os.path.join("results", filename.replace(".txt", ""))
processed_data_path = preprocessing.load_and_process_data(
    filepath, 
    output_dir,
    param_file='parameters.yaml'  # Only need to specify filepath and output_dir now
)

# Detect spikes automatically
spikes_path = spikes.detect_and_save_spikes(
    processed_data_path,
    output_dir,
    param_file='parameters.yaml'  # Only need to specify data_path and output_dir now
)

# Load and plot results
import pandas as pd

processed_data = pd.read_csv(processed_data_path)
detected_spikes = pd.read_csv(spikes_path)

plt.figure(figsize=(10, 5))
plt.plot(processed_data['time'], processed_data['voltage'], label="Signal")
plt.scatter(detected_spikes['spike_times'], detected_spikes['spike_values'], 
           color='red', label="Spikes")
plt.title("Detected Spikes (Automated)")
plt.legend()
plt.show()

## Compare Results

You can use this section to compare the results from both workflows and decide if the automated settings need adjustment.

In [8]:
print(f"Interactive workflow detected {len(spike_times)} spikes")
print(f"Automated workflow detected {len(detected_spikes)} spikes")

# Plot both results overlaid
plt.figure(figsize=(12, 6))
plt.plot(current_time, data_zoomed, 'b-', alpha=0.5, label="Signal")

# Add offset to interactive points for better visualization
y_offset = 10  # Adjust this value to change the separation
plt.scatter(current_time[spike_times], spike_values + y_offset, 
           color='red', label="Interactive", alpha=0.6)
plt.scatter(detected_spikes['spike_times'], detected_spikes['spike_values'], 
           color='green', label="Automated", alpha=0.6)

plt.title("Comparison of Detection Methods")
plt.legend()
plt.show()

Interactive workflow detected 208 spikes
Automated workflow detected 208 spikes
