In [None]:
# import required modules
import numpy as np
import matplotlib.pyplot as plt
import scipy
import math
import mne

In [None]:
#Takes a period as list of samples, desired segment length in seconds and the sampling frequency, returns a list of segments of the period {{},{},...{}}
def splitperiod(period, segment_length, sampling_freq):
    segments = []
    current_segment = []
    sample_seglen = math(segment_length * sampling_freq)
    for i, sample in enumerate(period):
        current_segment.append(sample)
        if i%sample_seglen == 0 and i!=0:
            segments.append(current_segment)
            current_segment = []
    return segments

def find_cluster(x, xval):
    """
    Find clusters of data in an ndarray that satisfy a certain condition.


    :param x: The array containing the data for the cluster search.
    :type x: ndarray

    :param xval: The value of x that has to be satisfied for clustering.
    :type xval: integer, float


    :returns: 2-tuple

        * i0:
            The index of each cluster starting point.

        * clustersize:
            The corresponding lengths of each cluster.

    :rtype: (list, list)


    Example
    -------
        >>> x = np.int32(np.round(np.random.rand(20)+0.1))
        >>> i0, clustersize = find_cluster(x, 1)

    """
    # Cluster information list
    a = []
    # Initial (place holder) values for cluster start and end points
    kstart = -1
    kend = -1
    # Going through each value of x
    for i, xi in enumerate(x):
        if xi == xval:
            # Assigning cluster starting point
            if kstart == -1:
                kstart = i
            # Assigning cluster end point for particular case
            # when there is an xval in the last position of x
            if i == len(x)-1:
                kend = i
        else:
            # Assigning cluster end point
            if kstart != -1 and kend == -1:
                kend = i-1
        # Updating cluster information list
        # and resetting kstart and kend
        if kstart != -1 and kend != -1:
            a.append(kstart)
            a.append(kend)
            kstart = -1
            kend = -1
    # Assigning indeces of cluster starting points
    # (Every other list element starting from position 0)
    i0 = a[0:-1:2]
    # Assigning cluster sizes
    # (Every other list element starting from position 1)
    clustersize = list(np.array(a[1::2]) - np.array(i0) + 1)
    # Case where cluster size is ZERO
    if len(i0) == 0:
        i0 = []
        clustersize = []
    return i0, clustersize


In [None]:
#loads data
raw_data = mne.io.read_raw_edf('Data/chb01_04.edf').get_data()
plt.plot(raw_data[1])

In [None]:
#filters the data
sos = scipy.signal.butter(20, 10, 'lp', fs=256, output='sos')
filtered = scipy.signal.sosfilt(sos, raw_data[1])
plt.plot(filtered)

In [None]:

#stores the index where the signal starts (i0) going over a certain treshold and the length of that part of the signal (clustersize)
i0, clustersize = find_cluster(filtered<0.00001, 0)

#picks out the peaks that meet the time criteria 
i0_spikes = []
spike_cluster = []
for i,size in enumerate(clustersize):
    if size > np.floor(256*0.02) and size < np.floor(256*0.07):
        spike_cluster.append(size)
        i0_spikes.append(i0[i])

# extracting a 0.3 seconds around the spike so as to have the full spike and wave form
time_to_extract = 0.3*256
extracted = []
for n,val in enumerate(clustersize):
    extracted_spike = filtered[math.floor((i0[n]+clustersize[n]*0.5)-time_to_extract*0.5):math.floor((i0[n]+clustersize[n]*0.5)+time_to_extract*0.5)]
    extracted.append(extracted_spike)
    
# scaled signal averager, takes n signals and averages them to 1 signal that is a average of all of them.
#
Sum = [[0]*int(time_to_extract)]
Sum = Sum[0]
signal = []
for n in extracted:
    signal.append(n)
print(len(signal))
#type(signal)
N = len(signal)

truth = N-1

i = 0
while i < 69:  
    n = 0
    while n < truth:
        Sum[i] = Sum[i] + signal[n][i]
        n +=1
    Sum[i] = Sum[i]/N
    i += 1
plt.plot(Sum)


In [None]:

#save pattern
import csv
a_list = Sum

a_file = open("Pattern_chb01_04.csv", "w")
a_writer = csv.writer(a_file)
a_writer.writerow(a_list)
a_file.close()