## Interactive Preprocessing Visualization Notebook

This notebook provides an interactive visualization of the pre-processing steps for a single recording.

Import the necessary libraries. If any of these libraries are not installed, you can install them using pip install grnsuite.

In [16]:
import grnsuite.preprocessing as preprocessing
import grnsuite.spike_detection as spikes
import grnsuite.utils as utils
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import os
import glob
from matplotlib.widgets import Button, RadioButtons
from IPython.display import display, clear_output

import matplotlib
matplotlib.use("TkAgg")  # Ensures the correct interactive backend is used

Load the parameters from the parameters.yaml file and load the file paths to your data files.

In [17]:
# Load parameters
params = utils.load_parameters('parameters.yaml')

# Get list of files to process
if params['process_mode'] == 'all': 
    # Get all .txt files in the data directory
    files = glob.glob(os.path.join("data", "*.txt"))
    filenames = [os.path.basename(f) for f in files]
else:
    # Get specific files from parameters if they exist
    filenames = []
    if 'recordings' in params:
        filenames = [f"{name}.txt" for name in params['recordings'].keys()]
    else:
        raise ValueError("No recordings specified in parameters.yaml and process_mode is not 'all'")

print(f"Found {len(filenames)} files to process:")
for i, f in enumerate(filenames):
    print(f"  {i-1+1}. {f}")

Found 9 files to process:
  0. 20231103-M05-sucr-100-Gal-A2-01.txt
  1. 20231106-M01-sucr-100-Gal-A1-02.txt
  2. 20231103-M04-sucr-100-Gal-A2-02.txt
  3. 20231106-F02-sucr-100-Gal-A2-02.txt
  4. 20231103-M05-sucr-100-Gal-A1-02.txt
  5. 20231103-M05-sucr-100-Gal-A3-02.txt
  6. 20231106-M01-sucr-100-Gal-A2-02.txt
  7. 20231103-M04-sucr-100-Gal-A1-02.txt
  8. 20231106-F02-sucr-100-Gal-A3-02.txt


Select a file to process.

In [18]:
# Let user select a file to process - should be a number between 0 and the number of files.
file_index = int(input(f"Select a file to process (1-{len(filenames)}): ")) - 1
filename = filenames[file_index]
print(f"\nProcessing {filename}...")


Processing 20231103-M05-sucr-100-Gal-A1-02.txt...


Set up paths to the input and output files.

In [19]:
# Set up paths
filepath = os.path.join("data", filename)
sample_name = filename.replace('.txt', '')
output_dir = os.path.join("results", sample_name)
os.makedirs(output_dir, exist_ok=True)

Get the metadata from the file name - note, this relies on the parameters.yaml file that will dicate the order of the metadata in the file name.

In [5]:
# Get metadata from filename
metadata = utils.parse_filename_metadata(sample_name, params)
print(f"Metadata: {metadata}")

Metadata: {'date': '20231106', 'animal_id': 'F02', 'stimulus': 'sucr', 'concentration': '100', 'location': 'Gal', 'sensillum_id': 'A2', 'replicate': '02'}


Make sure the metadata is correct. If it isn't, you can edit the parameters.yaml file to change the order of the metadata in the file name.


## Step 1: Load the raw data

In [20]:
# === STEP 1: Load Raw Data ===
print("\n=== STEP 1: Loading Raw Data ===")
raw_data = preprocessing._load_ephys_data(filepath)

# %matplotlib inline

# Visualize raw data
plt.figure(figsize=(12, 6))
plt.plot(raw_data, 'k-', alpha=0.7)
plt.title("Step 1: Raw Electrophysiology Data")
plt.xlabel("Sample Index")
plt.ylabel("Voltage")
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()



=== STEP 1: Loading Raw Data ===


## Step 2: Select the contact artifact

In [23]:
# === STEP 2: Contact Artifact Selection ===
print("\n=== STEP 2: Contact Artifact Selection ===")
print("Use the slider to adjust the starting point of the recording.")
print("The contact artifact should be at the beginning of the red section.")

matplotlib.use("TkAgg")  # Use TkAgg for interactive selection
# The function now handles backend switching internally
selected_signal = preprocessing.interactive_contact_selection(raw_data)




=== STEP 2: Contact Artifact Selection ===
Use the slider to adjust the starting point of the recording.
The contact artifact should be at the beginning of the red section.
Selected segment from index 69970 to 159970


In [24]:
# Display the result
plt.figure(figsize=(12, 6))
plt.plot(selected_signal, 'r-')
plt.title("Step 2: Selected Signal After Contact Artifact Selection")
plt.xlabel("Sample Index")
plt.ylabel("Voltage")
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()


print(f"Selected signal length: {len(selected_signal)} samples")

Selected signal length: 90000 samples


## Step 3: Filter the signal

In [28]:
# === STEP 3: Filtering ===
print("\n=== STEP 3: Filtering the Signal ===")
print(f"Applying bandpass filter ({params['filter_low']}-{params['filter_high']} Hz) and noise removal")

filtered_signal = preprocessing.filter_signal(selected_signal, params)

# Visualize filtering effects
fig, axes = plt.subplots(3, 1, figsize=(12, 10), sharex=True)

# Original signal
axes[0].plot(selected_signal, 'k-', alpha=0.7)
axes[0].set_title("Before Filtering (Raw Selected Signal)")
axes[0].set_ylabel("Voltage")
axes[0].grid(True, alpha=0.3)

# After bandpass filter
bandpass_only = preprocessing._apply_filter(
    selected_signal, 
    [params['filter_low'], params['filter_high']], 
    'bandpass', 
    fs=params['sampling_rate']
)
axes[1].plot(bandpass_only, 'b-', alpha=0.7)
axes[1].set_title(f"After Bandpass Filter ({params['filter_low']}-{params['filter_high']} Hz)")
axes[1].set_ylabel("Voltage")
axes[1].grid(True, alpha=0.3)

# After noise removal
axes[2].plot(filtered_signal, 'g-', alpha=0.7)
axes[2].set_title("After Noise Removal (Final Filtered Signal)")
axes[2].set_xlabel("Sample Index")
axes[2].set_ylabel("Voltage")
axes[2].grid(True, alpha=0.3)

plt.tight_layout()

filename = os.path.join(output_dir, f"filtering_step.png") 
# Save the figure
plt.savefig(filename, dpi=300, bbox_inches='tight')


plt.show()


print(f"Filtered signal properties:")
print(f"  - Mean: {np.mean(filtered_signal):.4f}")
print(f"  - Std Dev: {np.std(filtered_signal):.4f}")


=== STEP 3: Filtering the Signal ===
Applying bandpass filter (100-1000 Hz) and noise removal
Filtered signal properties:
  - Mean: -0.0914
  - Std Dev: 49.6228


## Step 4: Zoom to region of interest

In [29]:
# === STEP 4: Zoom to Region of Interest ===
print("\n=== STEP 4: Zooming to Region of Interest ===")
print(f"Selecting portion from {params['offset_time']}s to {params['offset_time'] + params['analysis_length']}s")

# Zoom and automatically save the figure
zoomed_signal, current_time = preprocessing.zoom_to_region(
    filtered_signal, 
    params,
    output_dir=output_dir
)

# Visualize zooming effect
fig, axes = plt.subplots(2, 1, figsize=(12, 8))

# Full filtered signal with highlighted region
time_full = np.arange(len(filtered_signal)) / params['sampling_rate']
axes[0].plot(time_full, filtered_signal, 'b-', alpha=0.5)

# Calculate indices for the zoomed region
start_idx = int(params['offset_time'] * params['sampling_rate'])
end_idx = int((params['offset_time'] + params['analysis_length']) * params['sampling_rate'])

# Highlight the zoomed region
axes[0].axvspan(params['offset_time'], params['offset_time'] + params['analysis_length'], 
                color='yellow', alpha=0.3)
axes[0].set_title("Full Filtered Signal with Highlighted Analysis Region")
axes[0].set_xlabel("Time (s)")
axes[0].set_ylabel("Voltage")
axes[0].grid(True, alpha=0.3)

# Zoomed region
axes[1].plot(current_time, zoomed_signal, 'g-')
axes[1].set_title("Zoomed Region for Analysis")
axes[1].set_xlabel("Time (s)")
axes[1].set_ylabel("Voltage")
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"Zoomed signal length: {len(zoomed_signal)} samples")
print(f"Time range: {current_time[0]:.2f}s to {current_time[-1]:.2f}s")


=== STEP 4: Zooming to Region of Interest ===
Selecting portion from 0.1s to 2.1s
Zoomed data figure saved to: results/20231103-M05-sucr-100-Gal-A1-02/zoomed_data.png
Zoomed signal length: 60000 samples
Time range: 0.10s to 2.10s


## Step 5: Signal Normalization

In [31]:
# === STEP 5: Signal Normalization ===
print("\n=== STEP 5: Signal Normalization (MAD) ===")
print("Normalizing signal using Median Absolute Deviation")

# Calculate median and MAD for display
median = np.median(zoomed_signal)
mad = np.median(np.abs(zoomed_signal - median))
mad_factor = 1.4826 * mad

mad_signal = preprocessing.normalize_signal(zoomed_signal)

# Visualize normalization effect
fig, axes = plt.subplots(2, 1, figsize=(12, 8), sharex=True)

# Before normalization
axes[0].plot(current_time, zoomed_signal, 'b-')
axes[0].axhline(y=median, color='r', linestyle='-', label=f"Median = {median:.4f}")
axes[0].axhline(y=median + mad_factor, color='g', linestyle='--', 
                label=f"Median + MAD = {median + mad_factor:.4f}")
axes[0].axhline(y=median - mad_factor, color='g', linestyle='--', 
                label=f"Median - MAD = {median - mad_factor:.4f}")
axes[0].set_title("Before MAD Normalization")
axes[0].set_ylabel("Voltage")
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# After normalization
axes[1].plot(current_time, mad_signal, 'g-')
axes[1].axhline(y=0, color='r', linestyle='-', label="Median = 0")
axes[1].axhline(y=1, color='g', linestyle='--', label="Median + MAD = 1")
axes[1].axhline(y=-1, color='g', linestyle='--', label="Median - MAD = -1")
axes[1].set_title("After MAD Normalization")
axes[1].set_xlabel("Time (s)")
axes[1].set_ylabel("Normalized Voltage")
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
filename = os.path.join(output_dir, f"normalized_step.png") 
# Save the figure
plt.savefig(filename, dpi=300, bbox_inches='tight')

plt.show()

print("MAD normalization statistics:")
print(f"  - Original median: {median:.4f}")
print(f"  - Original MAD: {mad_factor:.4f}")
print(f"  - Normalized median: {np.median(mad_signal):.4f}")
print(f"  - Normalized MAD: {np.median(np.abs(mad_signal - np.median(mad_signal))) * 1.4826:.4f}")

# Save the processed data
processed_data_path = os.path.join(output_dir, 'interactive_processed.csv')
preprocessing.save_results(mad_signal, current_time, processed_data_path)
print(f"\nSaved processed data to: {processed_data_path}")


=== STEP 5: Signal Normalization (MAD) ===
Normalizing signal using Median Absolute Deviation
MAD normalization statistics:
  - Original median: -5.3655
  - Original MAD: 23.4957
  - Normalized median: -0.2284
  - Normalized MAD: 1.0000
Successfully saved processed data to: results/20231103-M05-sucr-100-Gal-A1-02/interactive_processed.csv

Saved processed data to: results/20231103-M05-sucr-100-Gal-A1-02/interactive_processed.csv


## Step 6: Spike Detection

In [34]:
# === STEP 6: Interactive Spike Detection ===
print("\n=== STEP 6: Interactive Spike Detection ===")
print("Adjust the threshold slider to set the optimal spike detection threshold")

# Perform interactive threshold adjustment
manual_threshold = spikes.adjust_threshold(mad_signal)

# Detect spikes with the manual threshold
spikes_file, spike_times, spike_values = spikes.detect_spikes_manual_threshold(
    processed_data_path,
    output_dir,
    threshold=manual_threshold
)

y_offset = 0.5  # Adjust this value to change the separation

# Plot the final result with detected spikes
plt.figure(figsize=(12, 6))
plt.plot(current_time, mad_signal, 'b-', alpha=0.7, label="Normalized Signal")
plt.scatter(spike_times, spike_values + y_offset, color='r', s=50, label="Detected Spikes")
plt.axhline(y=manual_threshold, color='g', linestyle='--', label=f"Threshold = {manual_threshold:.2f}")
plt.title("Final Result: Normalized Signal with Detected Spikes")
plt.xlabel("Time (s)")
plt.ylabel("Normalized Voltage")
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()

filename = os.path.join(output_dir, f"trace_with_spikes.png") 
# Save the figure
plt.savefig(filename, dpi=300, bbox_inches='tight')

plt.show()

print(f"Detected {len(spike_times)} spikes using threshold = {manual_threshold:.2f}")
print(f"Spike details saved to: {spikes_file}")


=== STEP 6: Interactive Spike Detection ===
Adjust the threshold slider to set the optimal spike detection threshold
Initial threshold calculated: 2.4326678633141707
Data range: min=-4.425654472774549, max=10.593441813494758
Selected threshold: 2.4326678633141707
Final threshold value: 2.4326678633141707
Using thresholds: upper=2.4327, lower=1.9461
Detected 212 spikes using threshold = 2.43
Spike details saved to: results/20231103-M05-sucr-100-Gal-A1-02/detected_spikes.csv


## Step 7: Extract waveforms

In [36]:
# === STEP 7: Waveform Extraction and Analysis ===
print("\n=== STEP 7: Waveform Extraction and Analysis ===")
print("Extracting waveforms around each detected spike")

# Define waveform extraction parameters
pre_peak_ms = 1.5  # milliseconds before spike
post_peak_ms = 1.5  # milliseconds after spike

# Use the existing extract_waveforms function
waveforms_file = spikes.extract_waveforms(
    data_file=processed_data_path,
    spikes_file=spikes_file,
    output_dir=output_dir,
    pre_peak_length=pre_peak_ms,
    post_peak_length=post_peak_ms
)

# Load the saved waveforms
waveforms_df = pd.read_csv(waveforms_file)
waveforms = waveforms_df.values

# Use the enhanced plot_waveforms function that now includes average and std
print("Plotting waveforms with average and standard deviation:")
avg_waveform, std_waveform = spikes.plot_waveforms(
    waveforms, 
    pre_peak_ms=pre_peak_ms, 
    post_peak_ms=post_peak_ms,
    show_time_axis=True,
    output_dir=output_dir
)

# Create a raster plot to visualize spike timing
plt.figure(figsize=(12, 4))
for i, spike_time in enumerate(spike_times):
    plt.plot([spike_time, spike_time], [i-0.4, i+0.4], 'k-')
    
plt.title("Spike Raster Plot")
plt.xlabel("Time (s)")
plt.ylabel("Spike Number")
plt.xlim(current_time[0], current_time[-1])
plt.tight_layout()
filename = os.path.join(output_dir, f"rasters.png") 
# Save the figure
plt.savefig(filename, dpi=300, bbox_inches='tight')
plt.show()

# Calculate time axis for metric calculations
time_ms = np.linspace(-pre_peak_ms, post_peak_ms, waveforms.shape[1])

# Calculate waveform metrics
peak_idx = np.argmax(avg_waveform)
trough_idx = np.argmin(avg_waveform)
peak_amplitude = avg_waveform[peak_idx]
trough_amplitude = avg_waveform[trough_idx]
peak_to_trough = peak_amplitude - trough_amplitude

# Calculate timing metrics
peak_time = time_ms[peak_idx]
trough_time = time_ms[trough_idx]
peak_to_trough_duration = abs(trough_time - peak_time)

print(f"\nExtracted {waveforms.shape[0]} waveforms")
print(f"Waveform duration: {pre_peak_ms + post_peak_ms:.1f} ms ({waveforms.shape[1]} samples)")
print(f"Saved waveforms to: {waveforms_file}")

print("\nWaveform metrics:")
print(f"Peak amplitude: {peak_amplitude:.3f}")
print(f"Trough amplitude: {trough_amplitude:.3f}")
print(f"Peak-to-trough amplitude: {peak_to_trough:.3f}")
print(f"Peak time: {peak_time:.2f} ms")
print(f"Trough time: {trough_time:.2f} ms")
print(f"Peak-to-trough duration: {peak_to_trough_duration:.2f} ms")


=== STEP 7: Waveform Extraction and Analysis ===
Extracting waveforms around each detected spike
Plotting waveforms with average and standard deviation:
Waveform figure saved to: results/20231103-M05-sucr-100-Gal-A1-02/waveforms_plot.png

Extracted 212 waveforms
Waveform duration: 3.0 ms (90 samples)
Saved waveforms to: results/20231103-M05-sucr-100-Gal-A1-02/waveforms.csv

Waveform metrics:
Peak amplitude: 5.947
Trough amplitude: -2.416
Peak-to-trough amplitude: 8.363
Peak time: 0.02 ms
Trough time: 1.50 ms
Peak-to-trough duration: 1.48 ms
