In [43]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import os
import xmltodict
import base64
import numpy as np
import array

from tqdm import tqdm

from sklearn.metrics import accuracy_score,precision_score, recall_score, roc_auc_score
from IPython.display import clear_output

# Data Preprocessings

In [2]:
def get_lead(path):
    with open(path, 'rb') as xml:
        ECG = xmltodict.parse(xml.read().decode('utf8'))
    
    augmentLeads = True
    if path.split('/')[-1][0] == '5':
        waveforms = ECG['RestingECG']['Waveform'][1]
    elif path.split('/')[-1][0] == '6':
        waveforms = ECG['RestingECG']['Waveform']
        augmentLeads = False
    else:
        waveforms = ECG['RestingECG']['Waveform']
    
    leads = {}
    
    for lead in waveforms['LeadData']:
        lead_data = lead['WaveFormData']
        lead_b64  = base64.b64decode(lead_data)
        lead_vals = np.array(array.array('h', lead_b64))
        leads[ lead['LeadID'] ] = lead_vals
    
    if augmentLeads:
        leads['III'] = np.subtract(leads['II'], leads['I'])
        leads['aVR'] = np.add(leads['I'], leads['II'])*(-0.5)
        leads['aVL'] = np.subtract(leads['I'], 0.5*leads['II'])
        leads['aVF'] = np.subtract(leads['II'], 0.5*leads['I'])
    
    return leads

In [3]:
error_train = ['6_2_003469_ecg.xml', '6_2_003618_ecg.xml', '6_2_005055_ecg.xml', '8_2_001879_ecg.xml', '8_2_002164_ecg.xml']
error_valid = ['8_2_007281_ecg.xml', '8_2_008783_ecg.xml', '8_2_007226_ecg.xml']


train_data = []
train_labels = []
valid_data = []
valid_labels = []


train_pathes = ['data/train/arrhythmia/', 'data/train/normal/']
valid_pathes = ['data/validation/arrhythmia/', 'data/validation/normal/']

error_decode = []   # 디코딩에 실패한 데이터들..
# error_len = [] # 5000, 4999개를 맞추지 못한 데이터들.. 혹은 12개의 lead가 아닌것들..

# train data
for path in train_pathes:
    for file in os.listdir(path):
        
        if file in error_train:
            print(file)
            continue
        
        try:
            data = get_lead(path + file)
        except Exception as e:
            error_decode.append(path + file)
        
        listed_data = []
        keys = sorted(data.keys())
        for key in keys:
            listed_data.append(data[key])
        
        flag = False
        for idx, i in enumerate(listed_data):
            if len(i) == 5000:
                continue
            elif len(i) == 4999:
                listed_data[idx] = np.append(i, 0)
            else:
                flag = True
        if flag:
            continue
        
        train_data.append(listed_data)
        if 'arrhythmia' in path:
            train_labels.append(1)
        else:
            train_labels.append(0)
            
# valid data
for path in valid_pathes:
    for file in os.listdir(path):
        
        if file in error_valid:
            print(file)
            continue
            
        try:
            data = get_lead(path + file)
        except Exception as e:
            error_decode.append(path + file)
        
        listed_data = []
        keys = sorted(data.keys())
        for key in keys:
            listed_data.append(data[key])
        
        valid_data.append(listed_data)
        if 'arrhythmia' in path:
            valid_labels.append(1)
        else:
            valid_labels.append(0)

print(len(error_decode))


6_2_003469_ecg.xml
8_2_001879_ecg.xml
6_2_003618_ecg.xml
8_2_002164_ecg.xml
6_2_005055_ecg.xml
8_2_007281_ecg.xml
8_2_008783_ecg.xml
8_2_007226_ecg.xml
8


In [4]:
error_lead_len = []
for idx, i in enumerate(train_data):
    if len(i) != 12:
        error_lead_len.append(idx)
for i in error_lead_len:
    del train_data[i]
    del train_labels[i]
    

# 데이터 길이 및 lead 개수 분석

In [5]:
# 데이터의 길이 분포 확인: valid는 모두 5000인것을 확인
# train은 60912개가 4999, 36개가 1249, 1개가 4988
# 위의 테스크는 먼저 4999개만 0의 패딩을 붙이고 나머지는 제외하는식으로 전처리함

c4999 = 0
c5000 = 0
cx = 0

for i in train_data:
    for j in i:
        if len(j) == 4999:
            c4999 +=1
        elif len(j) == 5000:
            c5000 +=1
        else:
            cx +=1

for i in valid_data:
    for j in i:
        if len(j) == 4999:
            c4999 +=1
        elif len(j) == 5000:
            c5000 +=1
        else:
            cx +=1

print(c4999)
print(c5000)
print(cx)

0
524436
0


In [6]:
# 딱 한개의 9 lead의 데이터가 존재한다..
v12 = 0
v9 = 0
vx = 0

for i in train_data:
    if len(i) == 12:
        v12 +=1
    elif len(i) == 9:
        v9 += 1
    else:
        vx +=1
        
for i in valid_data:
    if len(i) == 12:
        v12 +=1
    elif len(i) == 9:
        v9 += 1
    else:
        vx +=1
print(v12)
print(v9)
print(vx)

43703
0
0


# Dataset 생성

In [7]:
train_dataset = torch.utils.data.TensorDataset(torch.tensor(train_data).float(), torch.tensor(train_labels))
valid_dataset = torch.utils.data.TensorDataset(torch.tensor(valid_data).float(), torch.tensor(valid_labels))

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
val_dataloader = torch.utils.data.DataLoader(valid_dataset, batch_size=32, shuffle=False)

In [94]:
# residual과 dropout 추가 필요
class Classifier(nn.Module):
    def __init__(self, drop_out=0.0):
        
        super(Classifier, self).__init__()
        self.cnn1 = nn.Conv1d(in_channels=12, out_channels=32, kernel_size=7, padding=3)
        self.cnn2 = nn.Conv1d(in_channels=32, out_channels=64, kernel_size=9, padding=4)
        self.cnn3 = nn.Conv1d(in_channels=64, out_channels=128, kernel_size=9, padding=4)
    
        self.pool1 = nn.MaxPool1d(4)
        self.pool2 = nn.MaxPool1d(5)
        self.pool3 = nn.MaxPool1d(5)
        
    
        self.fc1 = nn.Linear(128 * 50, 1024)
        self.fc2 = nn.Linear(1024, 256)
        self.fc3 = nn.Linear(256, 64)
        self.fc4 = nn.Linear(64, 1)
        
        self.relu = nn.ReLU()
        
        self.drop_out = nn.Dropout(p=drop_out)
        
        
    def forward(self, x):
        x = self.relu(self.cnn1(x)) # output (batch, 32, 5000)
        x = self.pool1(x)           # (batch, 32, 1250)
        x = self.relu(self.cnn2(x)) # (batch, 64, 1250)
        x = self.pool2(x)           # (batch, 64, 250)
        x = self.relu(self.cnn3(x)) # (batch, 128, 250)
        x = self.pool3(x)           # (batch, 128, 50)

        
        
        x = x.view(-1, 128*50)
        
        x = self.relu(self.fc1(x))
        x = self.drop_out(x)
        x = self.relu(self.fc2(x))
        x = self.drop_out(x)
        x = self.relu(self.fc3(x))
        x = self.relu(self.fc4(x))
        
        x = torch.sigmoid(x)
        
        return x.view(-1)

In [137]:
# residual과 dropout 추가 필요
class RC_Classifier(nn.Module):
    def __init__(self, drop_out=0.0):
        super(RC_Classifier, self).__init__()
        
        n_layers = 2
        channel_size = 12
        hidden_size = 128
        
        self.cnn1 = nn.Conv1d(in_channels=12, out_channels=32, kernel_size=7, padding=3)
        self.cnn2 = nn.Conv1d(in_channels=32, out_channels=64, kernel_size=7, padding=3)
        self.cnn3 = nn.Conv1d(in_channels=64, out_channels=128, kernel_size=9, padding=4)
    
        self.pool1 = nn.MaxPool1d(4)
        self.pool2 = nn.MaxPool1d(5)
        self.pool3 = nn.MaxPool1d(5)
        
        self.rnn = nn.GRU(hidden_size, hidden_size//2, n_layers, batch_first=True, bidirectional=True, dropout=drop_out)
    
        self.fc1 = nn.Linear(128,32)
        self.fc2 = nn.Linear(32, 1)
        
        self.relu = nn.ReLU()
        
        self.drop_out = nn.Dropout(p=drop_out)
        
        
    def forward(self, x):
        x = self.relu(self.cnn1(x)) # output (batch, 32, 5000)
        x = self.pool1(x)           # (batch, 32, 1250)
        x = self.relu(self.cnn2(x)) # (batch, 64, 1250)
        x = self.pool2(x)           # (batch, 64, 250)
        x = self.relu(self.cnn3(x)) # (batch, 128, 250)
        x = self.pool3(x)           # (batch, 128, 50)

        x = x.permute(0,2,1)
        
        output, hidden = self.rnn(x)
        hidden = torch.cat((hidden[-2], hidden[-1]), -1) #(32,256)
        
        x = self.relu(self.fc1(hidden))
        x = self.relu(self.fc2(x))
        x = torch.sigmoid(x)
        
        return x.view(-1)

In [73]:
# residual과 dropout 추가 필요
class RNN_Classifier(nn.Module):
    def __init__(self, drop_out=0.0):
        super(RNN_Classifier,self).__init__()
        n_layers = 2
        channel_size = 12
        hidden_size = 128
        self.rnn = nn.GRU(channel_size, hidden_size, n_layers, batch_first=True, bidirectional=True, dropout=drop_out)
    
        self.fc1 = nn.Linear(256,32)
        self.fc2 = nn.Linear(32, 1)
        
        self.relu = nn.ReLU()
        
    def forward(self, x):
        
        x = x.permute(0,2,1)
        output, hidden = self.rnn(x)
        # hidden ([4, 32, 128])

        hidden = torch.cat((hidden[-2], hidden[-1]), -1) #(32,256)

        x = self.relu(self.fc1(hidden))
        x = self.relu(self.fc2(x))
        x = torch.sigmoid(x)
        
        return x.view(-1)

In [141]:
LR = 0.00001
PATIENCE = 3
FACTOR = 0.95
DROP_OUT = 0.3

In [142]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = 'cpu'
# model = Classifier(drop_out=DROP_OUT).to(device)
model = RC_Classifier(drop_out=DROP_OUT).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay = 0)
criterion = nn.BCELoss()

use_scheduler = True
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,patience=PATIENCE,factor=FACTOR)

In [143]:
epoches = 100

best_val_acc = 0
best_epoch = -1

best_acc_pred = []
train_acc_list = []
train_loss_list = []
val_acc_list = []
val_loss_list = []

for i in tqdm(range(epoches)):
    # Train
    loss_sum = 0
    true_labels = []
    pred_labels = []
    model.train()
    for e_num, (x,y) in enumerate(train_dataloader):
        x, y = x.type(torch.FloatTensor).to(device), y.type(torch.FloatTensor).to(device)
        model.zero_grad()
        pred_y = model(x)
        
        loss=criterion(pred_y,y)
        loss_sum+=loss.detach()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        true_labels.extend(y.cpu().numpy())
        pred_labels.extend(np.around(pred_y.cpu().detach().numpy()))
        
    
    acc=roc_auc_score(true_labels,pred_labels)
    print(f'epoch: {i}')
#     print(f'train \t\t loss mean {loss_sum/e_num} accuracy: {acc}')
    
    # Valid
    loss_sum=0
    true_labels=[]
    pred_labels=[]
    model.eval()
    for e_num, (x,y) in enumerate(val_dataloader):
        x, y = x.type(torch.FloatTensor).to(device), y.type(torch.FloatTensor).to(device)

        pred_y = model(x)
        loss=criterion(pred_y,y)
        
        loss_sum+=loss.detach()
        
        true_labels.extend(y.cpu().numpy())
        pred_labels.extend(np.around(pred_y.cpu().detach().numpy()))
    
    acc=roc_auc_score(true_labels,pred_labels)
    
    # update scheduler -EY
    if use_scheduler:
        scheduler.step(acc)
    
    if best_val_acc < acc:
        print("NEW RECODE!")
        best_acc_pred = pred_labels
        best_val_acc = acc
        best_epoch = i
        torch.save(model.state_dict(), f'cnn_model_{best_val_acc}.h5')
        
#     print(f'validataion \t loss mean {loss_sum/e_num} accuracy: {acc} ',end='\n\n')
    
print(f'best validation acc = {best_val_acc}, in epoch {best_epoch}')

  0%|                                                   | 0/100 [00:00<?, ?it/s]

epoch: 0


  1%|▍                                          | 1/100 [00:18<30:48, 18.67s/it]

NEW RECODE!
epoch: 1


  2%|▊                                          | 2/100 [00:37<30:21, 18.58s/it]

NEW RECODE!
epoch: 2


  3%|█▎                                         | 3/100 [00:55<30:06, 18.62s/it]

epoch: 3


  4%|█▋                                         | 4/100 [01:14<30:03, 18.79s/it]

NEW RECODE!
epoch: 4


  5%|██▏                                        | 5/100 [01:33<29:42, 18.76s/it]

NEW RECODE!
epoch: 5


  6%|██▌                                        | 6/100 [01:52<29:22, 18.75s/it]

epoch: 6


  7%|███                                        | 7/100 [02:11<29:04, 18.75s/it]

epoch: 7


  8%|███▍                                       | 8/100 [02:29<28:32, 18.62s/it]

epoch: 8


  9%|███▊                                       | 9/100 [02:47<28:07, 18.55s/it]

epoch: 9


 10%|████▏                                     | 10/100 [03:06<27:49, 18.55s/it]

NEW RECODE!
epoch: 10


 11%|████▌                                     | 11/100 [03:25<27:41, 18.67s/it]

epoch: 11


 12%|█████                                     | 12/100 [03:44<27:28, 18.73s/it]

epoch: 12


 13%|█████▍                                    | 13/100 [04:02<27:05, 18.68s/it]

epoch: 13


 14%|█████▉                                    | 14/100 [04:21<26:49, 18.72s/it]

epoch: 14


 15%|██████▎                                   | 15/100 [04:40<26:47, 18.92s/it]

epoch: 15


 16%|██████▋                                   | 16/100 [04:59<26:25, 18.88s/it]

epoch: 16


 17%|███████▏                                  | 17/100 [05:18<26:10, 18.92s/it]

epoch: 17


 18%|███████▌                                  | 18/100 [05:37<25:49, 18.89s/it]

NEW RECODE!
epoch: 18


 19%|███████▉                                  | 19/100 [05:56<25:32, 18.92s/it]

epoch: 19


 20%|████████▍                                 | 20/100 [06:15<25:16, 18.96s/it]

epoch: 20


 21%|████████▊                                 | 21/100 [06:34<24:54, 18.92s/it]

epoch: 21


 22%|█████████▏                                | 22/100 [06:53<24:33, 18.89s/it]

epoch: 22


 23%|█████████▋                                | 23/100 [07:11<24:10, 18.83s/it]

epoch: 23


 24%|██████████                                | 24/100 [07:30<23:49, 18.81s/it]

epoch: 24


 25%|██████████▌                               | 25/100 [07:50<23:46, 19.02s/it]

epoch: 25


 26%|██████████▉                               | 26/100 [08:08<23:17, 18.89s/it]

epoch: 26


 27%|███████████▎                              | 27/100 [08:27<22:52, 18.80s/it]

epoch: 27


 28%|███████████▊                              | 28/100 [08:45<22:28, 18.73s/it]

epoch: 28


 29%|████████████▏                             | 29/100 [09:04<22:09, 18.72s/it]

NEW RECODE!
epoch: 29


 30%|████████████▌                             | 30/100 [09:23<21:47, 18.68s/it]

epoch: 30


 31%|█████████████                             | 31/100 [09:41<21:28, 18.68s/it]

epoch: 31


 32%|█████████████▍                            | 32/100 [10:00<21:11, 18.70s/it]

epoch: 32


 33%|█████████████▊                            | 33/100 [10:19<20:49, 18.64s/it]

epoch: 33


 34%|██████████████▎                           | 34/100 [10:38<20:36, 18.73s/it]

epoch: 34


 35%|██████████████▋                           | 35/100 [10:57<20:22, 18.81s/it]

epoch: 35


 36%|███████████████                           | 36/100 [11:15<20:04, 18.81s/it]

epoch: 36


 37%|███████████████▌                          | 37/100 [11:34<19:46, 18.83s/it]

epoch: 37


 38%|███████████████▉                          | 38/100 [11:53<19:26, 18.82s/it]

epoch: 38


 39%|████████████████▍                         | 39/100 [12:12<19:10, 18.85s/it]

epoch: 39


 40%|████████████████▊                         | 40/100 [12:31<18:53, 18.89s/it]

epoch: 40


 41%|█████████████████▏                        | 41/100 [12:50<18:43, 19.05s/it]

NEW RECODE!
epoch: 41


 42%|█████████████████▋                        | 42/100 [13:10<18:29, 19.14s/it]

epoch: 42


 43%|██████████████████                        | 43/100 [13:28<17:59, 18.94s/it]

epoch: 43


 44%|██████████████████▍                       | 44/100 [13:47<17:34, 18.83s/it]

epoch: 44


 45%|██████████████████▉                       | 45/100 [14:05<17:12, 18.78s/it]

epoch: 45


 46%|███████████████████▎                      | 46/100 [14:24<16:51, 18.74s/it]

epoch: 46


 47%|███████████████████▋                      | 47/100 [14:43<16:30, 18.68s/it]

epoch: 47


 48%|████████████████████▏                     | 48/100 [15:01<16:10, 18.66s/it]

epoch: 48


 49%|████████████████████▌                     | 49/100 [15:20<15:50, 18.64s/it]

epoch: 49


 50%|█████████████████████                     | 50/100 [15:38<15:31, 18.63s/it]

epoch: 50


 51%|█████████████████████▍                    | 51/100 [15:57<15:13, 18.64s/it]

epoch: 51


 52%|█████████████████████▊                    | 52/100 [16:16<14:54, 18.63s/it]

epoch: 52


 53%|██████████████████████▎                   | 53/100 [16:35<14:37, 18.68s/it]

NEW RECODE!
epoch: 53


 54%|██████████████████████▋                   | 54/100 [16:53<14:20, 18.70s/it]

epoch: 54


 55%|███████████████████████                   | 55/100 [17:12<14:05, 18.79s/it]

epoch: 55


 56%|███████████████████████▌                  | 56/100 [17:31<13:42, 18.70s/it]

epoch: 56


 57%|███████████████████████▉                  | 57/100 [17:49<13:23, 18.69s/it]

epoch: 57


 58%|████████████████████████▎                 | 58/100 [18:08<13:06, 18.73s/it]

epoch: 58


 59%|████████████████████████▊                 | 59/100 [18:27<12:46, 18.71s/it]

epoch: 59


 60%|█████████████████████████▏                | 60/100 [18:46<12:27, 18.69s/it]

epoch: 60


 61%|█████████████████████████▌                | 61/100 [19:04<12:08, 18.68s/it]

epoch: 61


 62%|██████████████████████████                | 62/100 [19:23<11:50, 18.70s/it]

epoch: 62


 63%|██████████████████████████▍               | 63/100 [19:42<11:30, 18.66s/it]

NEW RECODE!
epoch: 63


 64%|██████████████████████████▉               | 64/100 [20:00<11:10, 18.62s/it]

epoch: 64


 65%|███████████████████████████▎              | 65/100 [20:19<10:54, 18.70s/it]

epoch: 65


 66%|███████████████████████████▋              | 66/100 [20:38<10:36, 18.71s/it]

epoch: 66


 67%|████████████████████████████▏             | 67/100 [20:56<10:17, 18.70s/it]

epoch: 67


 68%|████████████████████████████▌             | 68/100 [21:15<09:57, 18.68s/it]

epoch: 68


 69%|████████████████████████████▉             | 69/100 [21:34<09:39, 18.68s/it]

epoch: 69


 70%|█████████████████████████████▍            | 70/100 [21:52<09:21, 18.71s/it]

epoch: 70


 71%|█████████████████████████████▊            | 71/100 [22:12<09:06, 18.83s/it]

epoch: 71


 72%|██████████████████████████████▏           | 72/100 [22:30<08:46, 18.82s/it]

epoch: 72


 73%|██████████████████████████████▋           | 73/100 [22:49<08:27, 18.81s/it]

epoch: 73


 74%|███████████████████████████████           | 74/100 [23:08<08:09, 18.83s/it]

epoch: 74


 75%|███████████████████████████████▌          | 75/100 [23:28<07:55, 19.04s/it]

epoch: 75


 76%|███████████████████████████████▉          | 76/100 [23:46<07:35, 18.96s/it]

epoch: 76


 77%|████████████████████████████████▎         | 77/100 [24:05<07:15, 18.93s/it]

epoch: 77


 78%|████████████████████████████████▊         | 78/100 [24:24<06:54, 18.86s/it]

epoch: 78


 79%|█████████████████████████████████▏        | 79/100 [24:43<06:35, 18.82s/it]

epoch: 79


 80%|█████████████████████████████████▌        | 80/100 [25:01<06:16, 18.81s/it]

epoch: 80


 81%|██████████████████████████████████        | 81/100 [25:20<05:56, 18.77s/it]

epoch: 81


 82%|██████████████████████████████████▍       | 82/100 [25:39<05:37, 18.74s/it]

epoch: 82


 83%|██████████████████████████████████▊       | 83/100 [25:58<05:18, 18.76s/it]

epoch: 83


 84%|███████████████████████████████████▎      | 84/100 [26:16<04:59, 18.74s/it]

epoch: 84


 85%|███████████████████████████████████▋      | 85/100 [26:35<04:40, 18.73s/it]

epoch: 85


 86%|████████████████████████████████████      | 86/100 [26:54<04:22, 18.74s/it]

epoch: 86


 87%|████████████████████████████████████▌     | 87/100 [27:13<04:04, 18.83s/it]

epoch: 87


 88%|████████████████████████████████████▉     | 88/100 [27:32<03:45, 18.82s/it]

epoch: 88


 89%|█████████████████████████████████████▍    | 89/100 [27:50<03:27, 18.82s/it]

epoch: 89


 90%|█████████████████████████████████████▊    | 90/100 [28:09<03:08, 18.83s/it]

epoch: 90


 91%|██████████████████████████████████████▏   | 91/100 [28:28<02:49, 18.86s/it]

epoch: 91


 92%|██████████████████████████████████████▋   | 92/100 [28:47<02:30, 18.79s/it]

epoch: 92


 93%|███████████████████████████████████████   | 93/100 [29:06<02:11, 18.78s/it]

epoch: 93


 94%|███████████████████████████████████████▍  | 94/100 [29:24<01:52, 18.81s/it]

epoch: 94


 95%|███████████████████████████████████████▉  | 95/100 [29:43<01:33, 18.76s/it]

epoch: 95


 96%|████████████████████████████████████████▎ | 96/100 [30:02<01:15, 18.79s/it]

epoch: 96


 97%|████████████████████████████████████████▋ | 97/100 [30:21<00:56, 18.79s/it]

epoch: 97


 98%|█████████████████████████████████████████▏| 98/100 [30:40<00:37, 18.97s/it]

epoch: 98


 99%|█████████████████████████████████████████▌| 99/100 [31:00<00:19, 19.18s/it]

epoch: 99


100%|█████████████████████████████████████████| 100/100 [31:19<00:00, 18.79s/it]

best validation acc = 0.975647232215507, in epoch 62





In [53]:
def plot_roc_curve(fper, tper):
    plt.plot(fper, tper, color='red', label='ROC')
    plt.plot([0, 1], [0, 1], color='green', linestyle='--')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic Curve')
    plt.legend()
    plt.show();

In [None]:
import  matplotlib.pyplot as plt
from sklearn.metrics import roc_curve

fper, tper, thresholds = roc_curve(true_labels,best_acc_pred)
plot_roc_curve(fper, tper)