In [None]:
#joshua Sweet, Lauren Scott, Sravani Ravula

#imports. Some might need to be installed before running.
#%matplotlib notebook 
from sklearn.metrics import classification_report as cr
from sklearn.metrics import confusion_matrix as cm
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
from find_peak import find_peaks # From external file "find_peak.py"
from sklearn.svm import SVC
from wfdb import processing
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sb
import wfdb
import random
import pickle

# A list of MIT datasets stored in "./MIT Datset/"
fileSet = {100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 111, 112, 113,
           114, 115, 116, 117, 118, 119, 121, 122, 123, 124, 200, 201, 202,
           203, 205, 207, 208, 209, 210, 212, 213, 214, 215, 217, 219, 220,
           221, 222, 223, 228, 230, 231, 232, 233, 234}

# A dictionary containing the mapping of letter annotation labels to their respective numbers
labels = {
    " ": 0, "N": 0, "L": 2, "R": 3, 
    "a": 4, "V": 5, "F": 6, "J": 7, 
    "A": 8, "S": 9,"E": 10, "j": 11, 
    "/": 12, "Q": 13, "~": 14, "|": 16, 
    "s": 18, "T": 19, "*": 20, "D": 21, 
    "\"": 22, "=": 23, "p": 24, "B": 25, 
    "^": 26, "t": 27, "+": 28, "u": 29, 
    "?": 30, "!": 31, "[": 32, "]": 33, 
    "e": 34, "n": 35, "@": 36, "x": 37, 
    "f": 38, "(": 39
}
    

def main():
    for file in fileSet:
        try:
            record = str ("MIT dataset/" + str (file))
            print ("Extracting " + record)
        
            # Extract the wfdb data revelent to the current MIT record being looked at
            sig, fields = wfdb.rdsamp (record_name = record, channels = [0])
            ann = wfdb.rdann (record_name = record, extension = "atr")
        
            # Gather the annotation and value data from the wfdb fields returned
            x, y = gather_data (sig, ann)
            plot_figure (file) # Make a graph of the data
          
        except:
            pass
        
    # Store all the data gathered into a file
    create_data (["features", "annotations"], [x, y])
    print ("Data Creation Done")
    
    # Train the data which was created into a training set
    trainX, testX, trainY, testY = train_data ("features", "annotations")
    print ("Training Set Done")
    
    # Apply the training set to an SVM module.
    svcType = SVC (kernel = "linear")
    svcType.fit (trainX, trainY)
    predicition = svcType.predict (testX)
    print ("Predicition Done")
    
    # Display the relevent information
    print (cr (testY, predicition))
    print ("Confusion Matrix: ", end = "")
    confusionMatrix = cm (testY, predicition)
    print (confusionMatrix)
    print ("Predicition Values: ", end = "")
    print (predicition)

    resultingLabels = list (set (testY + trainY))
    
    try:
        resultingLabels.sort ()
        plt.figure (figsize = (15, 16))
    
        # Creates a heatamap of the confusion matrix
        confusionMatrixData = pd.DataFrame (confusionMatrix, index = resultingLabels, columns = resultingLabels)
        sb.heatmap (confusionMatrixData)
        plt.show ()
    except:
        pass
    
    return 0
    

def plot_figure(file):
    # Get the wfdb record and annotation for that dataset, sampling each to be 999 datapoints
    record = wfdb.rdrecord ("MIT dataset/" + str (file), sampto = 999)
    ann = wfdb.rdann ("MIT dataset/" + str (file), "atr", sampto = 999)
    sig, fields = wfdb.rdsamp("MIT dataset/" + str (file), sampto = 999)
    
    peaks = find_peaks (sig, fields)
    
    # Create the matplotlib figure
    figure = wfdb.plot_wfdb (record = record, annotation = ann, plot_sym = True, time_units = "seconds", title = "Dataset " + str (file), return_fig = True)
    figure.set_size_inches (9, 8)
    
    # Loop across each axis the figure has
   # for j in range (0, len (figure.axes)):
    axis = figure.sca (figure.axes[0])
    data = axis.lines[0]
            
    ypoints = data.get_ydata ()
    xpoints = data.get_xdata ()
    
       # if j
            # Loop across each local max peak
    for i in range (0, len (peaks)):
            
        # Plot a point at the current point
        plot_points (xpoints.tolist ()[peaks[i]], ypoints.tolist ()[peaks[i]], 0.05, -0.075)
        if i > 0:      
            # Plot a line from the previous peak to the current peak
             plt.plot ([xpoints.tolist ()[peaks[i - 1]], xpoints.tolist ()[peaks[i]]], 
                        [ypoints.tolist ()[peaks[i - 1]], ypoints.tolist ()[peaks[i]]])
    
def plot_points(x, y, shiftx, shifty):
    # Plot a point and the cords of that point at postion (x,y)
    plt.plot (x, y, ".")
    plt.text (x + shiftx, y + shifty, "(" + str (round (x, 3)) + "," + str (round (y, 3)) +  ")",fontsize=6)
        

def gather_data(signal, annotations, x = [], y = []):
    # Loop through every annotation label extracted
    for i in range (annotations.sample.size - 1):
        # For each annotation label, take the index of that current label, and gather each time sample
        # from the range of i - 180 to i + 180
        times = []
        
        # If the index of i - 180 would be negative, just take the range of 0 to 360
        if annotations.sample[i] - 180 < 0:
            for j in range (0,360):
                times.append (float (signal[j]))
        else:
            for j in range ((annotations.sample[i] - 180), (annotations.sample[i] + 180)):
                times.append (float (signal[j]))
        
        x.append (times)
        y.append(annotations.symbol[i])
    return x, y


def create_data(filename, data):
    # Create each file and store the relevent data inside of it
    for i in range (0, 2):
        with open("MIT dataset/" + filename[i], "wb+") as file:
            pickle.dump(data[i], file)
            file.close ()
    
def load_data(features, annotations):
    return pickle.load (open("MIT dataset/" + features, "rb")), pickle.load (open("MIT dataset/" + annotations, "rb"))

def train_data(features, annotations):
    xValues = []
    yValues = []
    x, y = load_data (features, annotations)
    # Create a list of every y annotation converted to its number value
    y = list (map (convert_annotations, y))
    
    # MinMax the list of times related to each annotation label, then fit and transform the set.
    scaler = MinMaxScaler (copy = True, feature_range = (0, 39))
    scaler.fit (x)
    scaler.transform (x)
    
    # Loop through every annotation value
    for i in range (len (y)):
        # If the annotation value is 0. (Which means its a normal beat)
        if not y[i]:
            # Only add it to the training set if a random number from 0.95 to 1 is generated.
            # This is done to filter out high amounts of normal beats
            if random.random() > 0.95:
                yValues.append (y[i])
                xValues.append (x[i])
        # Else it is a beat which is abnormal, add it to the training set.
        else:
            yValues.append (y[i])
            xValues.append (x[i])
    return train_test_split (xValues, yValues)

def convert_annotations(y):
    try:
        return labels[y[0]]
    except:
        return 0

if __name__ == '__main__':
    main()


Extracting MIT dataset/230
Extracting MIT dataset/231
Extracting MIT dataset/232
Extracting MIT dataset/233
Extracting MIT dataset/234
Extracting MIT dataset/200
Extracting MIT dataset/201
Extracting MIT dataset/202
Extracting MIT dataset/203
Extracting MIT dataset/205
Extracting MIT dataset/207
Extracting MIT dataset/208
Extracting MIT dataset/209
Extracting MIT dataset/210
Extracting MIT dataset/212
Extracting MIT dataset/213
Extracting MIT dataset/214
Extracting MIT dataset/215
Extracting MIT dataset/217
Extracting MIT dataset/219
Extracting MIT dataset/220
Extracting MIT dataset/221
Extracting MIT dataset/222
Extracting MIT dataset/223
Extracting MIT dataset/100
Extracting MIT dataset/101
Extracting MIT dataset/102
Extracting MIT dataset/103
Extracting MIT dataset/104
Extracting MIT dataset/105
Extracting MIT dataset/106
Extracting MIT dataset/107
Extracting MIT dataset/108
