## **GRNsuite notebook** 

Package assumes that the raw electrophysiology data is contained in a .txt file contained in the data folder. 

In [1]:
import grnsuite.io as io

filename = "20231103-M04-sucr-100-Gal-A1-02.txt"

metadata = io.extract_metadata(filename)

print(metadata)


{'date': '20231103', 'animal_id': 'M04', 'stimulus': 'sucr', 'concentration': '100', 'location': 'Gal', 'sensillum': 'A1', 'replicate': '02'}


Set the sampling rate, offset time, and analysis length

In [2]:
# sampling rate
fs = 30000

# offset time
offset_time = 0.1 #set to 0.1 seconds

# analysis length
analysis_length = 2 #set to 2 seconds



Inspect the raw recording to determine the start of the recording. This will be automatically detected using the large contact artifact.

In [3]:
import grnsuite.preprocessing as preprocessing
import os 
import matplotlib
matplotlib.use("TkAgg")  # Ensures the correct interactive backend is used


## 1. Load data
# combine filename and data/ folder:
filepath = os.path.join("data", filename)
raw_data = preprocessing.load_ephys_data(filepath)
print(raw_data[:10])  # Show first 10 values

## 2. Define start of recording
selected_signal = preprocessing.interactive_contact_selection(raw_data)  # Adjust & extract
print(selected_signal[:10])  # Show first 10 values


[86.975098 86.975098 87.280273 86.364746 86.975098 85.754395 86.669922
 86.669922 86.364746 86.05957 ]
[3475.646973 3473.510742 3467.712402 3461.914063 3454.284668 3451.23291
 3440.246582 3431.70166  3423.461914 3413.391113]


Filter the raw signal

In [4]:
import matplotlib.pyplot as plt

# Apply a bandpass filter from 100-1000 Hz
filtered_data = preprocessing.ashfilt(selected_signal, [100, 1000], 'bandpass', fs=fs)

# Apply noise removal (removes 50, 100, 150 Hz harmonics)
denoised_data = preprocessing.ashfilt(filtered_data, None, 'noise', fs=fs)

# Plot results in three panels (selected_signal, filtered_data, denoised data)
plt.figure(figsize=(10, 5))
plt.subplot(3, 1, 1)
plt.plot(selected_signal, label="Raw Signal", alpha=0.5)
plt.legend()
plt.title("Raw Signal")

plt.subplot(3, 1, 2)
plt.plot(filtered_data, label="Filtered (100-1000 Hz)", alpha=0.8)
plt.legend()
plt.title("Filtered Signal")   

plt.subplot(3, 1, 3)
plt.plot(denoised_data, label="Denoised", alpha=0.8)
plt.legend()
plt.title("Denoised Signal")

plt.tight_layout()
plt.show()

Zoom in on region of interest and fix time from samples to seconds.

In [5]:
import grnsuite.preprocessing

# zoom in on the region of interest and create time array
data_zoomed, current_time = preprocessing.zoom_data(denoised_data, fs, offset_time, analysis_length)

# plot current_time and data_zoomed
plt.figure(figsize=(10, 5))
plt.plot(current_time, data_zoomed, label="Zoomed Data", alpha=0.8)
plt.legend()
plt.title("Zoomed Data")
plt.show()

Option 1: Spike detection using the Schmidt trigger function

In [6]:
import grnsuite.spike_detection as spikes
import numpy as np

# apply the schmidt trigger function 
## ** DECREASE t1 and t2 FOR INCREASED SENSITIVITY  **
spike_times, spike_values, schmidt_threshold = spikes.schmidt_trigger_auto(data_zoomed, current_time=None, t1=0.75, t2=1)

# plot the results as spike times plotted over the filtered data
plt.figure(figsize=(10, 5))
plt.plot(data_zoomed, label="Filtered Signal", alpha=0.5)
plt.scatter(spike_times, spike_values, color='red', label="Detected Spikes")
plt.legend()
plt.title("Schmidt Trigger Spike Detection")
plt.show()

print(schmidt_threshold)
print(len(spike_times))

[np.float64(44.619782650921614), np.float64(59.51225213642264)]
208


Option 2: adjust the upper threshold from schmidt_trigger and re-extract spikes

In [7]:
import grnsuite.spike_detection as spikes

# Allow user to adjust the threshold manually
manual_threshold = spikes.adjust_threshold(data_zoomed, schmidt_threshold[1])

print("User-selected threshold:", manual_threshold)

User-selected threshold: 59.51225213642264


Extract spikes again using manual threshold

In [8]:
import grnsuite.spike_detection as spikes

# re-extract spikes with the new threshold
spike_times, spike_values = spikes.detect_spikes_manual_threshold(data_zoomed, threshold=manual_threshold)

print("User-selected threshold:", manual_threshold)
print("number of spikes detected:", len(spike_times))

plt.plot(data_zoomed, label="Filtered Signal", alpha=0.5)
plt.scatter(spike_times, spike_values, color='red', label="Detected Spikes")
plt.legend()
plt.title("Manual Threshold Spike Detection")
plt.show()

User-selected threshold: 59.51225213642264
number of spikes detected: 208


Save waveforms (+/- 2ms around the peak)

In [10]:
import grnsuite.spike_detection as spikes
# using spike_times# Extract waveforms around detected spikes
waveforms = spikes.extract_waveforms(data_zoomed, spike_times, fs)

print("Extracted Waveforms Shape:", waveforms.shape)  # Should be (num_spikes, window_samples)

# Plot waveforms
spikes.plot_waveforms(waveforms)


Extracted Waveforms Shape: (208, 60)
