In [None]:
import scipy.io
import os
import re
import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import savgol_filter
from sklearn.decomposition import FastICA

def load_data(file):
    mat = scipy.io.loadmat(file)
    return mat['val'][0]

def show_single_plot(y_beat):
    fig_subset_view, ax = plt.subplots()
    ax.grid()
    show_multiplot(ax, y_beat)

def show_multiplot(ax, y):
    x_linspace = np.linspace(0, len(y), len(y))
    ax.set_xlim([0, len(y)])
    ax.plot(x_linspace, y)


def apply_filter(y_transformed, strength):
    filtered_y = np.array([])
    for row in y_transformed.T:
        
        filtered_row = savgol_filter(row, strength, 1)
        filtered_row = filtered_row.T
        filtered_y = np.append(filtered_y, filtered_row)
        filtered_y = np.reshape(filtered_y, (-1, 1500))
    return filtered_y.T

def filter_noise(filename):
    mat = scipy.io.loadmat('heartbeat_data/'+filename)
    y=mat['val'][0]
    y_transposed = y.T # columns
    transformer = FastICA(whiten='unit-variance')
    y_transformed = transformer.fit_transform(y_transposed)
    y_beats = []

    for row in y_transformed.T:
        if min(row)< -2 or max(row)>2:
            # flip sign
            if abs(np.min(row)) > abs(np.max(row)):
                row = row * -1
            y_beats.append(row)

    y_beats = np.asarray(y_beats)
    # fig_subset_view, ax = plt.subplots()
    # ax.grid()
    # show_multiplot(ax, y_beats.T)
    return y_beats


def filter_fetus(y_beats):
    sum_first_beats = sum(scipy.signal.find_peaks(y_beats[0], height=2)[0])
    sum_second_beats = sum(scipy.signal.find_peaks(y_beats[1], height=2)[0])
    return  y_beats[0] if sum_first_beats > sum_second_beats else y_beats[1]

def main():
    start = 25
    stop = 26  #set the number of files you want to go through
    for i, filename in enumerate(os.listdir('heartbeat_data')):
        out = re.findall(r'\d+', filename)
        if int(out[0]) < start:
            continue
        if i == stop:
            break
        y_beats = filter_noise(filename)
        y_beat = filter_fetus(y_beats)


        peaks = scipy.signal.find_peaks(y_beat, height=2)
        spike_vals= peaks[0]
        spike_times = peaks[1]['peak_heights']
        show_single_plot(y_beat)
        print(spike_vals)
        print(spike_times)
main()