In [4]:
import scipy
import matplotlib.pyplot as plt
import numpy as np


In [None]:
def plot_sample(eeg_array, sample, ax, title=None):
    num_channels = eeg_array.shape[1]

    # Plot data on the specified subplot
    ax.plot(eeg_array[sample, :, :].T)
    ax.set_title(f"{title} - Sample {sample}")

In [2]:
def detect_discontinuities(eeg_array, interval=20):
    num_samples = eeg_array.shape[0]
    num_channels = eeg_array.shape[1]
    num_signals = eeg_array.shape[2]
    
    discontinuities = []

    for sample in range(num_samples):
        found_discontinuity = False
        
        for channel in range(num_channels):
            for signal in range(num_signals - interval):
                interval_values = eeg_array[sample, channel, signal:signal+interval+1]

                if np.all(interval_values == np.sort(interval_values)) or np.all(interval_values == np.sort(interval_values)[::-1]):
                    if abs(interval_values[0] - interval_values[-1]) < 50:
                        # print(interval_values)
                        # print("Indices: ", signal, signal+interval+1)
                        # plot_signal(eeg_array, sample, channel)
                        discontinuities.append((sample, channel, signal))
                        found_discontinuity = True
                        break

            if found_discontinuity:
                break

    return discontinuities

In [3]:
def get_discontinuous_files(file_list):
    discontinuous_files = []

    for i, file in enumerate(file_list):
        x = scipy.io.loadmat(file)
        x = x['data']

        print(x.shape)
        
        discontinuities = detect_discontinuities(x, 50)

        if(len(discontinuities) > 0):
            print(discontinuities)
            discontinuous_files.append((file, discontinuities))
    
    total_found = 0
    for file in discontinuous_files:
        total_found += len(file[1])

    return discontinuous_files, total_found


In [3]:
def trim_discontinuous_files(file_list, discontinuous_files, total_found, print_out=False):
    X = []
    Y = []

    if print_out:
        num_graphs = total_found
        num_rows = (num_graphs + 3) // 4  # Compute the number of subplot rows

        # Create a figure with subplots
        fig, axes = plt.subplots(nrows=num_rows, ncols=4, figsize=(12, 3*num_rows))
        fig.suptitle('Removed samples:')
        if num_rows == 1:
            fig.subplots_adjust(top=0.8)
        else:
            fig.subplots_adjust(hspace=0.3)

        subplot_index = 0

    for file in file_list:
        x = scipy.io.loadmat(file)
        x = x['data']

        rmv_count = 0

        for discontinuous_file in discontinuous_files:
            if(file == discontinuous_file[0]):
                for discontinuity in discontinuous_file[1]:
                    if print_out:
                        if num_rows == 1:
                            plot_sample(x, discontinuity[0] - rmv_count, axes[subplot_index], title=file.split('/')[-1])
                        else:
                            plot_sample(x, discontinuity[0] - rmv_count, axes[subplot_index // 4, subplot_index % 4], title=file.split('/')[-1])

                        subplot_index = (subplot_index + 1) % (4 * num_rows)
                    index = discontinuity[0]
                    x = np.delete(x, index-rmv_count, axis=0)
                    rmv_count += 1
        
        # removethe last 250 elements from each trial
        x = x[:,:,:500]

        first_letter = file.split('/')[-1][0]

        # create output vector. If file begins with 'P', then y = [1,0], else y = [0,1], having the same dimension as x
        if first_letter == 'P':
            y = np.zeros((x.shape[0],2))
            y[:,0] = 1
        else:
            y = np.zeros((x.shape[0],2))
            y[:,1] = 1

        X.append(x)
        Y.append(y)

    return X, Y

In [6]:
def trim_manually(kick_out, X, Y, file_dict, print_out=False):
    
    if print_out:
        sum2 = 0
        for kick in kick_out:
            sum2 += len(kick[1])

        num_graphs = sum2
        num_rows = (num_graphs + 2) // 3  # Compute the number of subplot rows

        subplot_index = 0
        # Create a figure with subplots
        fig, axes = plt.subplots(nrows=num_rows, ncols=3, figsize=(12, 4*num_rows))
        if num_rows == 1:
            fig.subplots_adjust(top=0.8)
        else:
            fig.subplots_adjust(hspace=0.3)
        fig.suptitle("Removed samples")

    new_X = []
    new_Y = []
    for i, x in enumerate(X):
        rmv_count = 0
        y = Y[i]
        for kick in kick_out:
            if i == kick[0]:
                for index in kick[1]:
                    if print_out:
                        plot_sample(x, index - rmv_count, axes[subplot_index // 3, subplot_index % 3], title=file_dict[i])
                        subplot_index = (subplot_index + 1) % (3 * num_rows)
                    rmv_count += 1
                    x = np.delete(x, index-rmv_count, axis=0)
                    y = np.delete(y, index-rmv_count, axis=0)

        new_X.append(x)
        new_Y.append(y)

    if print_out:
        # Adjust the vertical spacing between subplots
        plt.subplots_adjust(hspace=0.5)

    return new_X, new_Y

In [7]:
def print_eliminations(file_list, new_X, new_Y, print_shape=False):
    total_elim = 0
    total_samples = 0
    for i, file in enumerate(file_list):
        new_x = new_X[i]
        new_y = new_Y[i]
        old_x = scipy.io.loadmat(file)['data']

        if print_shape:
            print( old_x.shape, "->", new_x.shape, new_y.shape)
            
        total_samples += new_x.shape[0]
        total_elim += old_x.shape[0] - new_x.shape[0]

    print("Total samples removed: ", total_elim)
    print("Total samples remaining: ", total_samples)