In [1]:
import pickle as pkl
import numpy as np
import wfdb
import os
import pandas as pd
from sklearn.model_selection import train_test_split
from collections import defaultdict

In [40]:
class Segment:
    def __init__(self, signal=1, label=1, normal_ratio=0):
        self.signal = signal
        self.label = label
        self.normal_ratio = normal_ratio


def create_index_df(desired_segment_len=3600, basic_arr_path="data/mit-bih-arrhythmia-database-1.0.0"):
    desired_segment_len = 3600
    basic_arr_path = "data/mit-bih-arrhythmia-database-1.0.0"
    arr_db = wfdb.get_record_list('mitdb')
    num_samples_in_record = 30 * 60 * 360
    segment_dict = defaultdict(list)
    segment_list = list()
    
    # for selection and sampling 
    segment_dict_ann = {}
    record_count = 0
    
    
    for _, record_id in enumerate(arr_db):
        record_path = os.path.join(basic_arr_path, str(record_id))

        ann = wfdb.rdann(record_path, 'atr', sampto=num_samples_in_record, return_label_elements=['description',
                                                                                                  'symbol', 'label_store'])
        df = pd.DataFrame({'description': ann.description, 'sample': ann.sample, 'symbol': ann.symbol,
                           'label_store': ann.label_store})
        counter = 0
        reset_flag = True
        allowed_labels = ['Normal beat']
        allowed_symbols = ['N']

        normal_counter = 0
        for i in range(1, df.shape[0] - 1):
            curr_label, curr_sample, curr_symbol = df.loc[i, ['description', 'sample', 'symbol']]
#             if curr_label == 'Normal beat':
#                 normal_counter += 1
            if curr_symbol == 'N':
                normal_counter += 1
            if reset_flag:
                start_sample = curr_sample
                ann_num_start = i
                normal_counter = 0
                allowed_labels = ['Normal beat']
                allowed_symbols = ['N']
            next_label, next_sample, next_symbol = df.loc[i + 1, ['description', 'sample', 'symbol']]
#             if curr_label == next_label or next_label in allowed_labels or len(allowed_labels) < 2:
            if curr_label == next_symbol or next_symbol in allowed_symbols or len(allowed_symbols) < 2:

#                 if next_label not in allowed_labels:
#                     allowed_labels.append(next_label)
#                     allowed_symbols.append(next_symbol)
                if next_symbol not in allowed_symbols:
                    allowed_labels.append(next_label)
                    allowed_symbols.append(next_symbol)
                ann_num_end = i+1
                counter += next_sample - curr_sample
                reset_flag = False
                if counter > desired_segment_len:
                    segment_dict['record_id'].extend([record_id])
                    segment_dict['start_sample'].extend([start_sample])
                    segment_dict['label'].extend([curr_label])
                    counter = 0
                    reset_flag = True
                    signal = wfdb.rdsamp(record_path, sampfrom=start_sample, sampto=start_sample + 3600)[0][:, 0]
                    normal_ratio = normal_counter / (ann_num_end - ann_num_start)
                    segment_list.append(Segment(signal, allowed_labels[-1], normal_ratio))
                    segment_dict_ann[record_count] = [record_id,  allowed_labels[-1], signal, normal_ratio,
                                                       allowed_symbols[-1]]
                    record_count = record_count + 1
            else:
                counter = 0
                normal_counter = 0
                reset_flag = True
                allowed_labels = ['Normal beat']
                allowed_symbols = ['N']


    #     return pd.DataFrame(segment_dict)
#     return segment_list
    return segment_dict_ann

In [41]:
d_dict = create_index_df()
# desired_segment_len = 3600
# basic_arr_path = "data/mit-bih-arrhythmia-database-1.0.0"
# num_samples_in_record = 30 * 60 * 360


# arr_db = wfdb.get_record_list('mitdb')
# record_id = arr_db[0]
# record_path = os.path.join(basic_arr_path, str(record_id))

# ann = wfdb.rdann(record_path, 'atr', sampto=num_samples_in_record, return_label_elements=['description','symbol', 'label_store'])
# ann.symbol

In [42]:
seg_df = pd.DataFrame.from_dict(d_dict, orient='index')
seg_df.rename(columns={0: 'record_id', 1:'label', 2:'signal', 3:'normal_ratio', 4:'symbol'}, inplace= True)
# seg_df['label'].value_counts()

In [43]:
seg_df.symbol.value_counts()

N    3437
V    1252
L     494
R     489
/     370
A     224
~     158
+     124
"      47
F      40
|      32
a      29
x      25
!      11
E      10
f       8
J       3
Q       3
j       2
e       2
Name: symbol, dtype: int64

In [45]:
num_labels_dict = {
'Normal beat': 283, #N
'Left bundle branch block beat': 103, #L
'Atrial premature beat': 66, # A
'Atrial flutter': 20,
'Atrial fibrillation': 135,
'Pre-excitation (WPW)': 21,
'Premature ventricular contraction': 133, #V
'Ventricular bigeminy': 55,
'Ventricular trigeminy': 13,
'Ventricular tachycardia': 10,
'Idioventricular rhythm': 10,
'Ventricular flutter': 10,   #!
'Fusion of ventricular and normal beat': 11, #F
'Second-degree heart block': 10,
'Pacemaker rhythm': 45,
'Supraventricular tachyarrhythmia': 13,
'Right bundle branch block beat': 62, #'R'
                 }

num_labels_we_have = {
'N': 283, #'Normal beat'
'L': 103, #'Left bundle branch block beat'
'A': 66, #'Atrial premature beat':
'V': 133, #'Premature ventricular contraction'
'!': 10,   #'Ventricular flutter'
'F': 11, # 'Fusion of ventricular and normal beat'
'R': 62 # 'Right bundle branch block beat'
               }



In [46]:
import random

# x = seg_df.loc[seg_df[seg_df.label == 'Normal beat'].index]

def sample_per_sym(seg_df, sym, num):
    # filter
    label_df = seg_df.loc[seg_df[seg_df.symbol == sym].index]
    # sample
    print(f' len df: {len(label_df)}')
    sample_df = label_df.sample(n=num)
    signals = np.stack(sample_df.signal.values)
    return sample_df, signals


sample_df = []
signals = []
for sym, num in num_labels_we_have.items():
    print(f'label:{label}, num: {num}')
    sample_df_label, signals_label = sample_per_sym(seg_df, sym, num)
    sample_df.append(sample_df_label)
    signals.append(signals_label)
sample_df = pd.concat(sample_df)


label:Atrial premature beat, num: 283
 len df: 3437
label:Atrial premature beat, num: 103
 len df: 494
label:Atrial premature beat, num: 66
 len df: 224
label:Atrial premature beat, num: 133
 len df: 1252
label:Atrial premature beat, num: 10
 len df: 11
label:Atrial premature beat, num: 11
 len df: 40
label:Atrial premature beat, num: 62
 len df: 489


In [39]:
sample_df.head()

Unnamed: 0,record_id,label,signal,normal_ratio,symbol
6666,234,Normal beat,"[1.47, 1.48, 1.36, 1.105, 0.785, 0.445, 0.14, ...",0.9375,N
559,103,Normal beat,"[1.865, 1.93, 1.815, 1.445, 0.87, 0.275, -0.17...",0.916667,N
5519,220,Normal beat,"[1.255, 1.385, 1.335, 1.055, 0.5, -0.14, -0.78...",0.909091,N
2891,121,Normal beat,"[-0.84, -0.84, -0.845, -0.895, -0.975, -1.065,...",0.9,N
3227,123,Normal beat,"[1.115, 1.19, 1.085, 0.81, 0.365, -0.17, -0.69...",0.888889,N


In [38]:
sample_df.to_pickle("./sampled.pkl")