In [2]:
import pickle
import pandas as pd
import numpy as np
import seaborn as sns
import mne
import matplotlib.pyplot as plt
import pyvista
import ipywidgets
import ipyevents
import pyvistaqt
import yasa
import os
import random

import scipy.signal as signal
from scipy.signal import hilbert
from scipy.signal import stft
from scipy.interpolate import interpn

import pywt
import cv2

In [3]:
%matplotlib qt

## Importing data

In [4]:
# file paths
large_participants_file = r"C:\eeg\combined_sets\large_participants_raw.fif"

# load raw files
large_participants_raw = mne.io.read_raw_fif(large_participants_file, preload=True)

Opening raw data file C:\eeg\combined_sets\large_participants_raw.fif...
Isotrak not found
    Range : 855000 ... 406561351 =   1710.000 ... 813122.702 secs
Ready.
Reading 0 ... 405706351  =      0.000 ... 811412.702 secs...


## Spindle detection

In [5]:
def detect_spindles_times(eeg_raw, do_filter=True, do_downsample=False, downsample_rate=100):
    
    # 1. Filter between 12 and 16 Hz
    
    data = eeg_raw.copy().pick_channels(['Fz'])

    if do_filter:
        data.filter(l_freq=12, h_freq=16)
    
    # 2. Downsample at 100 Hz (100 samples per second)

    if do_downsample:
        data.resample(downsample_rate)
    
    sfreq = data.info['sfreq']  
    channel_data = data.get_data()[0]
    # extract the filtered data
    
    
    # 3: Calculate amplitude by applying Hilbert transformation

    hilbert_signal = hilbert(channel_data)
    # apply hilbert transformation to bandpassed data
    # gives analytic signal with amplitude and phase information
    envelope = np.abs(hilbert_signal)
    # take the absolute part of the hilbert signal
    # also the instantaneous power of the signal
    # gives the envelope: amplitude modulation
    # how strength of oscillations change over time
    # size of sliding window
    
    # 4: Perform smoothing with a sliding window of 0.2 seconds
    # this removes high-frequency noise
    
    sliding_window = int(0.2 * sfreq)
    smoothed_envelope = np.convolve(envelope, np.ones(sliding_window) / sliding_window, mode='same')
    # convolving envelope with a uniform filter over the sliding window
    # convolution takes rolling average of 20 samples at a time
    # smooth the signal with the average of values in the window
    # in the smoothed envelope, can detect regions with higher amplitude 
    # which is when a spindle event occurs
    # np.ones: creates a filter kernel
    # have a filter where the sum of all elements equals 1
    # this filter is replaced by the average of the 20 surrounding samples
    # convolution between envelope and averaging filter
    # mode = 'same': so that output of convolution has same length as original envelope

    # 5. Define spindle detection threshold

    threshold = np.percentile(smoothed_envelope, 75)
    spindle_threshold = smoothed_envelope > threshold
    #threshold = np.mean(smoothed_envelope) + 1.5 * np.std(smoothed_envelope)
    #spindle_threshold = smoothed_envelope > threshold
    # threshold is 75th percentile of the smoothed envelope
    # will look at the duration later
    
    # 6. Detect spindles and define peaks and troughs for visualisation
    
    spindles = []
    # initialize list with spindles
    above_threshold = np.where(spindle_threshold)[0]
    # returns indices where signal above the threshold
    stacked_spindles = []
    # initialize list for stacking the spindles for the visualisation
    # contains aligned spindles at peak
    
    if len(above_threshold) > 0:
        # checking it's not empty
        start_idx = above_threshold[0]
        # would be the start of a potential spindle
        for i in range(1, len(above_threshold)):
            if above_threshold[i] > above_threshold[i - 1] + 1:  
                # if above threshold[1] > above_threshold[0] + 1
                # because all indices should be separated by 1
                # so here detects gaps
                # so starting from the second index
                # and comparing each index to the one before
                end_idx = above_threshold[i - 1]
                # so if above condition is true, this is the end of the spindle
                duration = (end_idx - start_idx) / sfreq
                if 0.5 <= duration <= 3:
                    # only keep spindles lasting 0.5 to 3 seconds
                    segment = channel_data[start_idx:end_idx]
                    # extract EEG segment corresponding to detected spindle
                    peak_idx = start_idx + np.argmax(segment) 
                    # extract the peak of the spindle
                    # this will be useful for later
                    spindles.append((start_idx / sfreq, end_idx / sfreq))
                    # all the spindles are stored in spindles
                    
                    # Aligning spindles at peak for visualization
                    before_peak_idx = max(0, peak_idx - int(1.5 * sfreq))
                    # still in the for loop, so this is the peak index of individual peak
                    after_peak_idx = min(len(channel_data), peak_idx + int(1.5 * sfreq))
                    # extracting 1.5 seconds before and after peak
                    # max and min are used for out of bounds situations at the start and end of EEG data
                    aligned_segment = channel_data[before_peak_idx:after_peak_idx]
                    stacked_spindles.append(aligned_segment)
                    # the aligned segment is saved in stacked spindles
                
                start_idx = above_threshold[i]
                # update the start index for the for loop

        # then need to process the final spindle
        end_idx = above_threshold[-1]
        duration = (end_idx - start_idx) / sfreq
        if 0.5 <= duration <= 3:
            segment = channel_data[start_idx:end_idx]
            peak_idx = start_idx + np.argmax(segment)
            spindles.append((start_idx / sfreq, end_idx / sfreq))

            before_peak_idx = max(0, peak_idx - int(1.5 * sfreq))
            after_peak_idx = min(len(channel_data), peak_idx + int(1.5 * sfreq))
            aligned_segment = channel_data[before_peak_idx:after_peak_idx]
            stacked_spindles.append(aligned_segment)
    
    return spindles
    

def detect_spindles_peaks(eeg_raw, do_filter=True, do_downsample=False, downsample_rate=100):
    
    # 1. Filter between 12 and 16 Hz
    
    data = eeg_raw.copy().pick_channels(['Fz'])

    if do_filter:
        data.filter(l_freq=12, h_freq=16)
    
    # 2. Downsample at 100 Hz (100 samples per second)
    
    if do_downsample:
        data.resample(downsample_rate)
        
    sfreq = data.info['sfreq']  
    # update to new sampling frequency
    # because used later in the code
    channel_data = data.get_data()[0]
    # extract the filtered data
    
    # 3: Calculate amplitude by applying Hilbert transformation

    hilbert_signal = hilbert(channel_data)
    # apply hilbert transformation to bandpassed data
    # gives analytic signal with amplitude and phase information
    envelope = np.abs(hilbert_signal)
    # take the absolute part of the hilbert signal
    # also the instantaneous power of the signal
    # gives the envelope: amplitude modulation
    # how strength of oscillations change over time
    # size of sliding window
    
    # 4: Perform smoothing with a sliding window of 0.2 seconds
    # this removes high-frequency noise
    
    sliding_window = int(0.2 * sfreq)
    smoothed_envelope = np.convolve(envelope, np.ones(sliding_window) / sliding_window, mode='same')
    # convolving envelope with a uniform filter over the sliding window
    # convolution takes rolling average of 20 samples at a time
    # smooth the signal with the average of values in the window
    # in the smoothed envelope, can detect regions with higher amplitude 
    # which is when a spindle event occurs
    # np.ones: creates a filter kernel
    # have a filter where the sum of all elements equals 1
    # this filter is replaced by the average of the 20 surrounding samples
    # convolution between envelope and averaging filter
    # mode = 'same': so that output of convolution has same length as original envelope

    # 5. Define spindle detection threshold

    threshold = np.percentile(smoothed_envelope, 75)
    spindle_threshold = smoothed_envelope > threshold
    # 75th percentile as criteria

    #threshold = np.mean(smoothed_envelope) + 1.5 * np.std(smoothed_envelope)
    #spindle_threshold = smoothed_envelope > threshold
    
    # 6. Detect spindles and define peaks and troughs for visualisation
    
    spindles = []
    # initialize list with spindles
    above_threshold = np.where(spindle_threshold)[0]
    # returns indices where signal above the threshold
    stacked_spindles = []
    # initialize list for stacking the spindles for the visualisation
    # contains aligned spindles at peak
    
    if len(above_threshold) > 0:
        # checking it's not empty
        start_idx = above_threshold[0]
        # would be the start of a potential spindle
        for i in range(1, len(above_threshold)):
            if above_threshold[i] > above_threshold[i - 1] + 1:  
                # if above threshold[1] > above_threshold[0] + 1
                # because all indices should be separated by 1
                # so here detects gaps
                end_idx = above_threshold[i - 1]
                # so if above condition is true, this is the end of the spindle
                duration = (end_idx - start_idx) / sfreq
                if 0.5 <= duration <= 3:
                    # only keep spindles lasting 0.5 to 3 seconds
                    segment = channel_data[start_idx:end_idx]
                    # extract EEG segment corresponding to detected spindle
                    peak_idx = start_idx + np.argmax(segment) 
                    # extract the peak of the spindle
                    # this will be useful for later
                    #spindles.append(f"Spindle detected from {start_idx / sfreq:.2f}s to {end_idx / sfreq:.2f}s, peak at {peak_idx / sfreq:.2f}s")
                    spindles.append((peak_idx / sfreq))
                    # all the spindles are stored in spindles
                    
                    # Aligning spindles at peak for visualization
                    before_peak_idx = max(0, peak_idx - int(1.5 * sfreq))
                    # still in the for loop, so this is the peak index of individual peak
                    after_peak_idx = min(len(channel_data), peak_idx + int(1.5 * sfreq))
                    # extracting 1.5 seconds before and after peak
                    # max and min are used for out of bounds situations at the start and end of EEG data
                    aligned_segment = channel_data[before_peak_idx:after_peak_idx]
                    stacked_spindles.append(aligned_segment)
                    # the aligned segment is saved in stacked spindles
                
                start_idx = above_threshold[i]
                # update the start index for the for loop

        # then need to process the final spindle
        end_idx = above_threshold[-1]
        duration = (end_idx - start_idx) / sfreq
        if 0.5 <= duration <= 3:
            segment = channel_data[start_idx:end_idx]
            peak_idx = start_idx + np.argmax(segment)
            spindles.append((peak_idx / sfreq))

            before_peak_idx = max(0, peak_idx - int(1.5 * sfreq))
            after_peak_idx = min(len(channel_data), peak_idx + int(1.5 * sfreq))
            aligned_segment = channel_data[before_peak_idx:after_peak_idx]
            stacked_spindles.append(aligned_segment)

    
    return spindles

def detect_spindles_peaks_average(eeg_raw, do_filter=True, do_downsample=False, downsample_rate=100):
    # same as above but returns spindles and stacked spindles
    
    # 1. Filter between 12 and 16 Hz
    
    data = eeg_raw.copy().pick_channels(['Fz'])

    if do_filter:
        data.filter(l_freq=12, h_freq=16)
    
    # 2. Downsample at 100 Hz (100 samples per second)
    
    if do_downsample:
        data.resample(downsample_rate)
        
    sfreq = data.info['sfreq']  
    # update to new sampling frequency
    # because used later in the code
    channel_data = data.get_data()[0]
    # extract the filtered data
    
    # 3: Calculate amplitude by applying Hilbert transformation

    hilbert_signal = hilbert(channel_data)
    # apply hilbert transformation to bandpassed data
    # gives analytic signal with amplitude and phase information
    envelope = np.abs(hilbert_signal)
    # take the absolute part of the hilbert signal
    # also the instantaneous power of the signal
    # gives the envelope: amplitude modulation
    # how strength of oscillations change over time
    # size of sliding window
    
    # 4: Perform smoothing with a sliding window of 0.2 seconds
    # this removes high-frequency noise
    
    sliding_window = int(0.2 * sfreq)
    smoothed_envelope = np.convolve(envelope, np.ones(sliding_window) / sliding_window, mode='same')
    # convolving envelope with a uniform filter over the sliding window
    # convolution takes rolling average of 20 samples at a time
    # smooth the signal with the average of values in the window
    # in the smoothed envelope, can detect regions with higher amplitude 
    # which is when a spindle event occurs
    # np.ones: creates a filter kernel
    # have a filter where the sum of all elements equals 1
    # this filter is replaced by the average of the 20 surrounding samples
    # convolution between envelope and averaging filter
    # mode = 'same': so that output of convolution has same length as original envelope

    # 5. Define spindle detection threshold

    threshold = np.percentile(smoothed_envelope, 75)
    spindle_threshold = smoothed_envelope > threshold
    # 75th percentile as criteria

    #threshold = np.mean(smoothed_envelope) + 1.5 * np.std(smoothed_envelope)
    #spindle_threshold = smoothed_envelope > threshold
    
    # 6. Detect spindles and define peaks and troughs for visualisation
    
    spindles = []
    # initialize list with spindles
    above_threshold = np.where(spindle_threshold)[0]
    # returns indices where signal above the threshold
    stacked_spindles = []
    # initialize list for stacking the spindles for the visualisation
    # contains aligned spindles at peak
    
    if len(above_threshold) > 0:
        # checking it's not empty
        start_idx = above_threshold[0]
        # would be the start of a potential spindle
        for i in range(1, len(above_threshold)):
            if above_threshold[i] > above_threshold[i - 1] + 1:  
                # if above threshold[1] > above_threshold[0] + 1
                # because all indices should be separated by 1
                # so here detects gaps
                end_idx = above_threshold[i - 1]
                # so if above condition is true, this is the end of the spindle
                duration = (end_idx - start_idx) / sfreq
                if 0.5 <= duration <= 3:
                    # only keep spindles lasting 0.5 to 3 seconds
                    segment = channel_data[start_idx:end_idx]
                    # extract EEG segment corresponding to detected spindle
                    peak_idx = start_idx + np.argmax(segment) 
                    # extract the peak of the spindle
                    # this will be useful for later
                    #spindles.append(f"Spindle detected from {start_idx / sfreq:.2f}s to {end_idx / sfreq:.2f}s, peak at {peak_idx / sfreq:.2f}s")
                    spindles.append((peak_idx / sfreq))
                    # all the spindles are stored in spindles
                    
                    # Aligning spindles at peak for visualization
                    before_peak_idx = max(0, peak_idx - int(1.5 * sfreq))
                    # still in the for loop, so this is the peak index of individual peak
                    after_peak_idx = min(len(channel_data), peak_idx + int(1.5 * sfreq))
                    # extracting 1.5 seconds before and after peak
                    # max and min are used for out of bounds situations at the start and end of EEG data
                    aligned_segment = channel_data[before_peak_idx:after_peak_idx]
                    stacked_spindles.append(aligned_segment)
                    # the aligned segment is saved in stacked spindles
                
                start_idx = above_threshold[i]
                # update the start index for the for loop

        # then need to process the final spindle
        end_idx = above_threshold[-1]
        duration = (end_idx - start_idx) / sfreq
        if 0.5 <= duration <= 3:
            segment = channel_data[start_idx:end_idx]
            peak_idx = start_idx + np.argmax(segment)
            spindles.append((peak_idx / sfreq))

            before_peak_idx = max(0, peak_idx - int(1.5 * sfreq))
            after_peak_idx = min(len(channel_data), peak_idx + int(1.5 * sfreq))
            aligned_segment = channel_data[before_peak_idx:after_peak_idx]
            stacked_spindles.append(aligned_segment)

    
    return spindles, stacked_spindles

In [6]:
def visualize_spindles(stacked_spindles, plot_name):
    max_len = max(len(seg) for seg in stacked_spindles)
    padded_stacked_spindles = [np.pad(seg, (0, max_len - len(seg)), constant_values=np.nan) for seg in stacked_spindles]
    avg_spindle_waveform = np.nanmean(padded_stacked_spindles, axis=0)
    time_axis = np.linspace(-1.5, 1.5, len(avg_spindle_waveform))
    # already stacking 1.5 around each side so keep this

    plt.figure(figsize=(8, 4))
    plt.plot(time_axis, avg_spindle_waveform, color="blue", label="Mean Spindle")
    plt.axvline(0, color="red", linestyle="--", label="Peak (0s)")
    plt.xlabel('Time (s)')
    plt.ylabel('Amplitude (µV)')
    plt.title(plot_name)
    plt.legend()
    plt.show()

## Slow oscillation detection

In [7]:
def detect_slow_oscillations_times(combined_raw, do_filter=True, do_downsample=False, downsample_rate=100):

    # according to methods from Klinzing et al.(2016)

    data = combined_raw.copy().pick_channels(['Fz'])

    if do_filter:
        data.filter(l_freq=0.16, h_freq=1.25)

    if do_downsample:
        data.resample(downsample_rate)
        
    sfreq = data.info['sfreq']
    channel_data = data.get_data()[0]
    
    # 3. find all positive-to-negative zero-crossings
    
    # zero_crossings = np.where( S!= 0)[0]
    # can also save this somewhere for further detection of spindles
    
    S = np.diff(np.sign(channel_data))
    # np.sign returns an array with 1 (positive), 0 (zero), -1 (negative)
    # np.diff calculates the difference between consecutive elements in an array
    # positive value: transition from negative to positive
    # negative value: transition from positive to negative
    # when it's a zero, means that value stayed the same
    zero_crossings = np.where(S < 0)[0]
    # -2 is when a positive-to-negative zero-crossing occurs
    # goes from 1 to -1 
    # -1 - 1 = -2
    # [0] extracts the actual array
    # extracts the indices of interest from current_data (not S)
    #signs = np.sign(current_data)
    #pos_to_neg = np.where((signs[:-1] > 0) & (signs[1:] < 0))[0]
    # detect +1 to -1
    #neg_to_pos = np.where((signs[:-1] <  0) & (signs[1:] > 0))[0]
    # detect -1 to +1

    # 4. Detect peak potentials in each pair
    slow_oscillations = []
    negative_peaks = []
    positive_peaks = []
    peak_to_peak_amplitudes = []
    candidate_indices = []

    # for loop for each pair
    # to collect all the negative and positive peaks
    # to further apply criteria
    count = 0
    for i in range(0, len(zero_crossings)-1, 1):
        # loop through all the zero_crossings
        # step of 1 (with step of 2, miss some zero_crossings)
        start_idx = zero_crossings[i] + 1
        # assigns index of zero-crossing (representing start of potential SO)
        # to start_idx
        end_idx = zero_crossings[i + 1] + 1
        # assigns index of next zero-crossing (representing end of potential SO)
        # to end_idx

        # find the negative to positive crossing in between
        #mid_crossings = neg_to_pos[(neg_to_pos > start_idx) & (neg_to_pos < end_idx)]

        #if len(mid_crossings) != 1:
            #continue

        #mid_idx = mid_crossings [0]

        #duration = (end_idx - start_idx) / sfreq
        #if not (0.8 <= duration <= 2.0):
  
        
        segment_length = (end_idx - start_idx) / sfreq

        # need to add +1 because of way extract segment later

        # have identified index for the pair
        
        # extract data segment between crossings
        
        # find peaks
        if 0.8 <= segment_length <= 2.0:
            count += 1
            segment = channel_data[start_idx:end_idx]
            positive_peak = np.max(segment)
            negative_peak = np.min(segment)
            peak_to_peak_amplitude = positive_peak - negative_peak

        # store values
            candidate_indices.append((start_idx, end_idx))
            positive_peaks.append(positive_peak)
            negative_peaks.append(negative_peak)
            peak_to_peak_amplitudes.append(peak_to_peak_amplitude)

    # calculate mean values for comparison
    #mean_negative_peak = np.mean(negative_peaks)
    # mean_negative_peak = np.mean(negative_peaks) if negative_peaks else 0
    #mean_peak_to_peak_amplitude = np.mean(peak_to_peak_amplitudes)
    # mean_peak_to_peak_amplitude = np.mean(peak_to_peak_amplitudes) if peak_to_peak_amplitudes else 0

    negative_peak_threshold = np.percentile(negative_peaks, 25)
    # keep lowest negative peaks (under the 25th percentile)
    peak_to_peak_amplitude_threshold = np.percentile(peak_to_peak_amplitudes, 75)
    # keep largest peak-to-peak amplitude (over 75th percentile)

    for (start_idx, end_idx), negative_peak, peak_to_peak_amplitude in zip(candidate_indices, negative_peaks, peak_to_peak_amplitudes):
        if peak_to_peak_amplitude >= peak_to_peak_amplitude_threshold and negative_peak <= negative_peak_threshold:
            slow_oscillations.append((start_idx / sfreq, end_idx / sfreq))
            
    return slow_oscillations
    # returns a list of tuples, in which each tuple represents the start and end times of
    # a detected slow oscillation

def detect_slow_oscillations_peaks(combined_raw, do_filter=True, do_downsample=True, downsample_rate=100):

    # according to methods from Klinzing et al.(2016)

    data = combined_raw.copy().pick_channels(['Fz'])

    if do_filter:
        data.filter(l_freq=0.16, h_freq=1.25)

    if do_downsample:
        data.resample(downsample_rate)
        
    sfreq = data.info['sfreq']
    channel_data = data.get_data()[0]
    
    # 3. find all positive-to-negative zero-crossings
    
    # zero_crossings = np.where( S!= 0)[0]
    # can also save this somewhere for further detection of spindles
    
    S = np.diff(np.sign(channel_data))
    # np.sign returns an array with 1 (positive), 0 (zero), -1 (negative)
    # np.diff calculates the difference between consecutive elements in an array
    # positive value: transition from negative to positive
    # negative value: transition from positive to negative
    # when it's a zero, means that value stayed the same
    zero_crossings = np.where(S < 0)[0]
    # -2 is when a positive-to-negative zero-crossing occurs
    # goes from 1 to -1 
    # -1 - 1 = -2
    # [0] extracts the actual array
    # extracts the indices of interest from current_data (not S)


    # 4. Detect peak potentials in each pair
    slow_oscillations = []
    slow_oscillations_peaks = []
    negative_peaks = []
    positive_peaks = []
    peak_to_peak_amplitudes = []
    candidate_indices =  []

    # for loop for each pair
    # to collect all the negative and positive peaks
    # to further apply criteria
    count = 0
    for i in range(0, len(zero_crossings) - 1, 1):
        # loop through all the zero_crossings
        # step of 1 (with step of 2, miss some zero_crossings)
        start_idx = zero_crossings[i] + 1
        # assigns index of zero-crossing (representing start of potential SO)
        # to start_idx
        end_idx = zero_crossings[i + 1] + 1
        # assigns index of next zero-crossing (representing end of potential SO)
        # to end_idx
        segment_length = (end_idx - start_idx) / sfreq

        # need to add +1 because of way extract segment later

        # have identified index for the pair
        
        # extract data segment between crossings
        
        # find peaks
        if 0.8 <= segment_length <= 2.0:
            count += 1
            segment = channel_data[start_idx:end_idx]
            positive_peak = np.max(segment)
            negative_peak = np.min(segment)
            peak_to_peak_amplitude = positive_peak - negative_peak

        # store values
            candidate_indices.append((start_idx, end_idx))
            positive_peaks.append(positive_peak)
            negative_peaks.append(negative_peak)
            peak_to_peak_amplitudes.append(peak_to_peak_amplitude)

    # calculate mean values for comparison
    #mean_negative_peak = np.mean(negative_peaks)
    # mean_negative_peak = np.mean(negative_peaks) if negative_peaks else 0
    #mean_peak_to_peak_amplitude = np.mean(peak_to_peak_amplitudes)
    # mean_peak_to_peak_amplitude = np.mean(peak_to_peak_amplitudes) if peak_to_peak_amplitudes else 0

    negative_peak_threshold = np.percentile(negative_peaks, 25)
    peak_to_peak_amplitude_threshold = np.percentile(peak_to_peak_amplitudes, 75)

    for (start_idx, end_idx), negative_peak, peak_to_peak_amplitude in zip(candidate_indices, negative_peaks, peak_to_peak_amplitudes):
        if peak_to_peak_amplitude >= peak_to_peak_amplitude_threshold and negative_peak <= negative_peak_threshold:
            slow_oscillations.append((start_idx / sfreq, end_idx / sfreq))
            slow_oscillations_peaks.append((negative_peak, positive_peak))

            
    return slow_oscillations_peaks
    # returns a list of tuples, in which each tuple represents the start and end times of
    # a detected slow oscillation

In [None]:
# now want to visualise slow oscillations
# find peak and trough for each of them, and then stack them all together to visualise

# this function aligns detected slow oscillations at their trough
# creates an average SO waveform

def visualize_and_stack_slow_oscillations_trough(combined_raw, slow_oscillations, plot_name):

    # Apply band-pass filter between 0.3 and 1.25 Hz
    filtered_data = combined_raw.copy().filter(l_freq=0.16, h_freq=1.25)
    # downsampling to 100 Hz
    #filtered_data.resample(100)
    filtered_channel_data = filtered_data.get_data(picks="Fz")[0]
    
    sfreq = filtered_data.info['sfreq']
    
    stacked_data = []
    # loop through each slow oscillation
    for start_time, end_time in slow_oscillations:
        # to convert start and end times to sample indices
        start_idx = int(start_time * sfreq)
        end_idx = int(end_time * sfreq)

        # extract the slow oscillation segment
        segment = filtered_channel_data[start_idx:end_idx]

        global_trough_idx = np.argmin(filtered_channel_data[start_idx:end_idx]) + start_idx
         # argmin finds the index of the min value
        # min finds the min value itself
        
        # calculate indices for 1.5 seconds before and after trough
        before_trough_idx = max(0, global_trough_idx - int(1.5 * sfreq))
        # substracts 1.5 seconds from the trough index
        # max as a safety check, to make sure that before_trough_index never negative
        # to prevent accessing data points before the beginning of the segment
        after_trough_idx = min(len(filtered_channel_data), global_trough_idx + int(1.5 * sfreq))
        # adds 1.5 seconds to the trough index
        # min is another safety check
        
        # extract the segment around the trough
        aligned_segment = filtered_channel_data[before_trough_idx:after_trough_idx]

        # append the aligned segment to the stacked data
        stacked_data.append(aligned_segment)

    # Find the maximum length of the segments
    max_len = max(len(segment) for segment in stacked_data)

    # Pad shorter segments with np.nan
    padded_stacked_data = []
    for segment in stacked_data:
        pad_len = max_len - len(segment)
        # how much padding is needed
        pad_before = pad_len // 2
        pad_after = pad_len - pad_before
        padded_segment = np.pad(segment, (pad_before, pad_after), 'constant', constant_values=np.nan)
        # distribute the padding before and after the segment
        # use NaNs instead of zeros to avoid bias
        padded_stacked_data.append(padded_segment)

    # calculate the average stacked slow oscillation
    average_padded_stacked_data = np.nanmean(padded_stacked_data, axis=0)
    # compute average waveform by ignoring NaNs

    # visualize the average stacked slow oscillation
    time_axis = np.linspace(-1.5, 1.5, len(average_padded_stacked_data))
    # this is to create the time axis
    plt.figure(figsize=(8,4))
    plt.plot(time_axis, average_padded_stacked_data, color="blue", label="Mean SO")
    plt.axvline(0, color="red", linestyle="--", label="Trough (0s)")
    plt.xlabel('Time (s)')
    plt.ylabel('Amplitude (µV)')
    plt.title(plot_name)
    plt.legend()
    plt.show()   

## SO-spindle coupling detection

In [8]:
def detect_slow_oscillations_spindles_coupling_so_times(combined_raw, do_filter=True, do_downsample=True, downsample_rate=100):
    slow_oscillations_peaks = detect_slow_oscillations_peaks(combined_raw, do_filter=do_filter, do_downsample=do_downsample, downsample_rate=downsample_rate)
    slow_oscillations_times = detect_slow_oscillations_times(combined_raw, do_filter=do_filter, do_downsample=do_downsample, downsample_rate=downsample_rate)
    spindles_peaks = detect_spindles_peaks(combined_raw, do_filter=do_filter, do_downsample=do_downsample, downsample_rate=downsample_rate)

    coupling_times = []
    coupling_times_so = []

    # first detect the coupling events
    for start_time, end_time in slow_oscillations_times:
    #for start_time, end_time (negative_peak, positive_peak) in zip(slow_oscillations_times, slow_oscillations_peaks):
        for peak in spindles_peaks:
            if start_time < peak < end_time:
                # if negative_peak < peak < end_time:
                coupling_times.append(peak)
                # if the peak of the spindle is between the negative and positive trough
                # add it to list of coupling times

    # then calculate the slow oscillation length
    for start_time, end_time in slow_oscillations_times:
        current_start_time = start_time
        current_end_time = end_time
        for coupling_peak in coupling_times:
            if current_start_time < coupling_peak < current_end_time:
                coupling_times_so.append((current_start_time, current_end_time))

    return coupling_times_so

## Detection

### at 500 Hz

## Visualisations

### Spindle

In [41]:
spindles, stacked = detect_spindles_peaks_average(large_participants_raw)
visualize_spindles(stacked, "Average Spindle for 53 participants")

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
Filtering raw data in 123 contiguous segments
Setting up band-pass filter from 12 - 16 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 12.00
- Lower transition bandwidth: 3.00 Hz (-6 dB cutoff frequency: 10.50 Hz)
- Upper passband edge: 16.00 Hz
- Upper transition bandwidth: 4.00 Hz (-6 dB cutoff frequency: 18.00 Hz)
- Filter length: 551 samples (1.102 s)



### Slow oscillation

In [45]:
slow_oscillations_times = detect_slow_oscillations_times(large_participants_raw, do_filter=True, do_downsample=False)

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
Filtering raw data in 123 contiguous segments
Setting up band-pass filter from 0.16 - 1.2 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 0.16
- Lower transition bandwidth: 0.16 Hz (-6 dB cutoff frequency: 0.08 Hz)
- Upper passband edge: 1.25 Hz
- Upper transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 2.25 Hz)
- Filter length: 10313 samples (20.626 s)



In [46]:
## if do_downsample=True, need to turn on downsampling in visualisation function

visualize_and_stack_slow_oscillations_trough(
    large_participants_raw,
    slow_oscillations_times,
    plot_name="Average Slow Oscillation for 53 participants",
    do_downsample=False
)

Filtering raw data in 123 contiguous segments
Setting up band-pass filter from 0.16 - 1.2 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 0.16
- Lower transition bandwidth: 0.16 Hz (-6 dB cutoff frequency: 0.08 Hz)
- Upper passband edge: 1.25 Hz
- Upper transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 2.25 Hz)
- Filter length: 10313 samples (20.626 s)



## Coupling visualisation

In [44]:
def visualize_so_spindle_coupling(raw, so_spindle_coupling_times, plot_name, do_downsample=False):
    # here compute the STFT for spindles
    # not all the data
    # did apply per-frequency z score normalisation

    # slow oscillation 
    so_filtered_data = raw.copy().filter(l_freq=0.16, h_freq=1.25)
    if do_downsample:
        so_filtered_data.resample(100)
    so_channel_data = so_filtered_data.get_data(picks="Fz")[0]
    sfreq = so_filtered_data.info['sfreq']
    # define the sampling frequency
    # after or before downsampling

    # spindle
    spindle_filtered_data = raw.copy().filter(l_freq=12, h_freq=16)
    if do_downsample:
        spindle_filtered_data.resample(100)
    spindle_channel_data = spindle_filtered_data.get_data(picks="Fz")[0]

    # raw
    raw_channel_data = raw.get_data(picks="Fz")[0]


    # align everything to the SO trough
    aligned_so_segments = []
    aligned_spindle_segments = []
    aligned_raw_segments = []
    for start_time, end_time in so_spindle_coupling_times:
        # extract start and end times of SOs
        # do this for every coupled SO
        start_idx = int(start_time * sfreq)
        end_idx = int(end_time * sfreq)

        # trough is argmin
        so_segment = so_channel_data[start_idx:end_idx]
        trough_idx = np.argmin(so_segment) + start_idx

        # keep -1.5 and 1.5 seconds around the trough
        before_trough_idx = max(0, trough_idx - int(1.5 * sfreq))
        after_trough_idx = min(len(so_channel_data), trough_idx + int(1.5 * sfreq))
        aligned_so_segment = so_channel_data[before_trough_idx:after_trough_idx]
        aligned_so_segments.append(aligned_so_segment)

        # extract spindle data
        aligned_spindle_segment = spindle_channel_data[before_trough_idx:after_trough_idx]
        # extract the data between 12 and 16 Hz that occurs during that same time window
        aligned_spindle_segments.append(aligned_spindle_segment)


    # pad and average the SO (as in previous function)
    max_len_so = max(len(seg) for seg in aligned_so_segments)
    padded_so_segments = [np.pad(seg, (0, max_len_so - len(seg)), constant_values=np.nan) for seg in aligned_so_segments]
    avg_so_waveform = np.nanmean(padded_so_segments, axis=0)
    time_axis = np.linspace(-1.5, 1.5, len(avg_so_waveform))

    # average spectrogram for spindle and pad
    max_len_spindle = max(len(seg) for seg in aligned_spindle_segments)
    padded_spindle_segments = [np.pad(seg, (0, max_len_spindle- len(seg)), constant_values=np.nan) for seg in aligned_spindle_segments]

    # STFT for spindle
    baseline = np.mean(spindle_channel_data[:50])
    # first 50 time points of data
    Sxx_list = []
    for seg in padded_spindle_segments:
        # NaNs are replaced with 0
        if np.isnan(seg).any():
            seg = np.nan_to_num(seg) 
        # baseline correction
        seg = seg - baseline
        freqs, times, Sxx = signal.stft(seg, fs=sfreq, nperseg=int(sfreq/4), noverlap=int(sfreq/8))
        # z-score normalisation
        m = np.mean(Sxx)
        s = np.std(Sxx)
        Sxx = (Sxx - m) / s
        Sxx_list.append(np.abs(Sxx))

    # then average all the spectrograms
    avg_Sxx = np.nanmean(Sxx_list, axis=0)

    # interpolation (cubic)
    #tim_interp = np.linspace(times.min(), times.max(), 5120)
    #freq_interp = np.linspace(freqs.min(), freqs.max(), 1024)
    #interp_tf = RegularGridInterpolator(times, freqs, avg_Sxx, kind='cubic')
    #pow_interp_tf = interp_tf(tim_interp, freq_interp)

    # interpolation (cubic)
    tim_interp = np.linspace(times.min(), times.max(), 5120) 
    freq_interp = np.linspace(freqs.min(), freqs.max(), 1024)
    grid = (freqs, times) 
    points = np.array(np.meshgrid(freq_interp, tim_interp, indexing='ij')).reshape(2, -1).T
    pow_interp_tf = interpn(grid, avg_Sxx, points, method='cubic')
    pow_interp_tf = pow_interp_tf.reshape(len(freq_interp), len(tim_interp))

    # plot
    fig, ax = plt.subplots(figsize=(10, 6))

    # plot spectrogram
    #im = ax.pcolormesh(times - 1.5, freqs, avg_Sxx, shading='gouraud', cmap='viridis')
    im = ax.imshow(pow_interp_tf, aspect='auto', extent=[times.min()-1.5, times.max()-1.5, freqs.min(), freqs.max()], origin='lower', cmap='viridis', vmin=0, vmax=8)
    fig.colorbar(im, ax=ax, label='Power/Frequency (dB/Hz)', pad=0.1)
    #cbar = fig.colorbar(im, ax=ax, label='Power/Frequency (dB/Hz)', pad=0.1)
    #cbar.set_label('Z-scored Power', rotation=270, labelpad=15)
    

    # Overlay SO waveform
    ax2 = ax.twinx()
    ax2.plot(time_axis, avg_so_waveform, color='red', linewidth=2, label='Average SO')
    ax2.set_ylabel('Amplitude (µV)', color='red')
    ax2.tick_params(axis='y', labelcolor='red')
    ax2.axvline(0, color='white', linestyle='--', label="Trough (0s)")


    ax.set_xlabel('Time (s)')
    ax.set_ylabel('Frequency (Hz)')
    ax.set_ylim(0, 30)
    ax.set_title(plot_name)
    ax.legend()
    #ax.legend(loc='upper left')
    #ax2.legend(loc='upper right')
    plt.show()

def visualize_so_no_spindle_coupling(raw, slow_oscillations_times, spindles_peaks, plot_name, do_downsample=False):

    # identify non-coupled SOs
    # as those that don't respect the definition
    non_coupled_so_times = []
    for start_time, end_time in slow_oscillations_times:
        coupled = False
        for peak in spindles_peaks:
            if start_time < peak < end_time:
                coupled = True
                break
        if not coupled:
            non_coupled_so_times.append((start_time, end_time))

    # slow oscillation
    so_filtered_data = raw.copy().filter(l_freq=0.16, h_freq=1.25)
    if do_downsample:
        so_filtered_data.resample(100)
    so_channel_data = so_filtered_data.get_data(picks="Fz")[0]
    sfreq = so_filtered_data.info['sfreq']

    # spindle
    spindle_filtered_data = raw.copy().filter(l_freq=12, h_freq=16)
    if do_downsample:
        spindle_filtered_data.resample(100)
    spindle_channel_data = spindle_filtered_data.get_data(picks="Fz")[0]

    # raw
    raw_channel_data = raw.get_data(picks="Fz")[0]

    # align everything to the SO trough
    aligned_so_segments = []
    aligned_spindle_segments = []
    for start_time, end_time in non_coupled_so_times:
        start_idx = int(start_time * sfreq)
        end_idx = int(end_time * sfreq)

        # trough is argming
        so_segment = so_channel_data[start_idx:end_idx]
        trough_idx = np.argmin(so_segment) + start_idx

        # same logic as above
        before_trough_idx = max(0, trough_idx - int(1.5 * sfreq))
        after_trough_idx = min(len(so_channel_data), trough_idx + int(1.5 * sfreq))
        aligned_so_segment = so_channel_data[before_trough_idx:after_trough_idx]
        aligned_so_segments.append(aligned_so_segment)

        # spindle aligned to SO trough
        aligned_spindle_segment = spindle_channel_data[before_trough_idx:after_trough_idx]
        aligned_spindle_segments.append(aligned_spindle_segment)


    # SO average
    max_len_so = max(len(seg) for seg in aligned_so_segments)
    padded_so_segments = [np.pad(seg, (0, max_len_so - len(seg)), constant_values=np.nan) for seg in aligned_so_segments]
    avg_so_waveform = np.nanmean(padded_so_segments, axis=0)
    time_axis = np.linspace(-1.5, 1.5, len(avg_so_waveform))

    # spindle spectogram average and pad
    max_len_spindle = max(len(seg) for seg in aligned_spindle_segments)
    padded_spindle_segments = [np.pad(seg, (0, max_len_spindle- len(seg)), constant_values=np.nan) for seg in aligned_spindle_segments]

    # STFT for spindles
    Sxx_list = []
    baseline = np.mean(spindle_channel_data[:50])
    for seg in padded_spindle_segments:
        # Handle potential NaNs from padding
        if np.isnan(seg).any():
            seg = np.nan_to_num(seg) # Replace NaNs with 0 for STFT
        seg = seg - baseline
        freqs, times, Sxx = signal.stft(seg, fs=sfreq, nperseg=int(sfreq/4), noverlap=int(sfreq/8))
        # z-score normalisation
        m = np.mean(Sxx)
        s = np.std(Sxx)
        Sxx = (Sxx - m) / s
        Sxx_list.append(np.abs(Sxx))

    # spectrogram average
    avg_Sxx = np.nanmean(Sxx_list, axis=0)

    # interpolation (cubic)
    tim_interp = np.linspace(times.min(), times.max(), 5120) 
    freq_interp = np.linspace(freqs.min(), freqs.max(), 1024)
    grid = (freqs, times) 
    points = np.array(np.meshgrid(freq_interp, tim_interp, indexing='ij')).reshape(2, -1).T
    pow_interp_tf = interpn(grid, avg_Sxx, points, method='cubic')
    pow_interp_tf = pow_interp_tf.reshape(len(freq_interp), len(tim_interp))


    # plot
    fig, ax = plt.subplots(figsize=(10, 6))

    # spectogram plot
    #im = ax.pcolormesh(times - 1.5, freqs, avg_Sxx, shading='gouraud', cmap='viridis')
    im = ax.imshow(pow_interp_tf, aspect='auto', extent=[times.min()-1.5, times.max()-1.5, freqs.min(), freqs.max()], origin='lower', cmap='viridis', vmin=0, vmax=8)
    fig.colorbar(im, ax=ax, label='Power/Frequency (dB/Hz)', pad=0.1)
    #cbar = fig.colorbar(im, ax=ax, label='Power/Frequency (dB/Hz)', pad=0.1)
    #cbar.set_label('Z-scored Power', rotation=270, labelpad=15)

    # SO waveform
    ax2 = ax.twinx()
    ax2.plot(time_axis, avg_so_waveform, color='red', linewidth=2, label='Average SO')
    ax2.set_ylabel('Amplitude (µV)', color='red')
    ax2.tick_params(axis='y', labelcolor='red')
    ax2.axvline(0, color='white', linestyle='--', label="Trough (0s)")


    ax.set_xlabel('Time (s)')
    ax.set_ylabel('Frequency (Hz)')
    ax.set_ylim(0, 30)
    ax.set_title(plot_name)
    ax.legend()
    #ax.legend(loc='upper left')
    #ax2.legend(loc='upper right')
    plt.show()



In [45]:
spindles_peaks = detect_spindles_peaks(
    large_participants_raw,
    do_filter=True,
    do_downsample=False
)

visualize_so_spindle_coupling(
    large_participants_raw,
    so_spindle_coupling_times,
    "Average Coupled SO waveform overlaid on Spindle STFT representation for 53 participants",
    do_downsample=False
)

visualize_so_no_spindle_coupling(
    large_participants_raw,
    slow_oscillations_times,
    spindles_peaks,
    "Average Coupled SO waveform overlaid on Spindle STFT representation for 53 participants",
    do_downsample=False
)

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
Filtering raw data in 123 contiguous segments
Setting up band-pass filter from 12 - 16 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 12.00
- Lower transition bandwidth: 3.00 Hz (-6 dB cutoff frequency: 10.50 Hz)
- Upper passband edge: 16.00 Hz
- Upper transition bandwidth: 4.00 Hz (-6 dB cutoff frequency: 18.00 Hz)
- Filter length: 551 samples (1.102 s)

Filtering raw data in 123 contiguous segments
Setting up band-pass filter from 0.16 - 1.2 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 0.16
- Lo

  ax.legend()


Filtering raw data in 123 contiguous segments
Setting up band-pass filter from 0.16 - 1.2 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 0.16
- Lower transition bandwidth: 0.16 Hz (-6 dB cutoff frequency: 0.08 Hz)
- Upper passband edge: 1.25 Hz
- Upper transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 2.25 Hz)
- Filter length: 10313 samples (20.626 s)

Filtering raw data in 123 contiguous segments
Setting up band-pass filter from 12 - 16 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 12.00
- Lower transition bandwidth: 3.00 Hz (-6 dB cutoff frequency: 10.50 Hz)
- Upper pas

  ax.legend()
