In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import os
import xmltodict
import base64
import numpy as np
import array

from sklearn.metrics import accuracy_score,precision_score, recall_score
from IPython.display import clear_output

# Data Preprocessings

In [2]:
def get_lead(path):
    with open(path, 'rb') as xml:
        ECG = xmltodict.parse(xml.read().decode('utf8'))
    
    augmentLeads = True
    if path.split('/')[-1][0] == '5':
        waveforms = ECG['RestingECG']['Waveform'][1]
    elif path.split('/')[-1][0] == '6':
        waveforms = ECG['RestingECG']['Waveform']
        augmentLeads = False
    else:
        waveforms = ECG['RestingECG']['Waveform']
    
    leads = {}
    
    for lead in waveforms['LeadData']:
        lead_data = lead['WaveFormData']
        lead_b64  = base64.b64decode(lead_data)
        lead_vals = np.array(array.array('h', lead_b64))
        leads[ lead['LeadID'] ] = lead_vals
    
    if augmentLeads:
        leads['III'] = np.subtract(leads['II'], leads['I'])
        leads['aVR'] = np.add(leads['I'], leads['II'])*(-0.5)
        leads['aVL'] = np.subtract(leads['I'], 0.5*leads['II'])
        leads['aVF'] = np.subtract(leads['II'], 0.5*leads['I'])
    
    return leads

In [3]:
error_train = ['6_2_003469_ecg.xml', '6_2_003618_ecg.xml', '6_2_005055_ecg.xml', '8_2_001879_ecg.xml', '8_2_002164_ecg.xml']
error_valid = ['8_2_007281_ecg.xml', '8_2_008783_ecg.xml', '8_2_007226_ecg.xml']


train_data = []
train_labels = []
valid_data = []
valid_labels = []


train_pathes = ['data/train/arrhythmia/', 'data/train/normal/']
valid_pathes = ['data/validation/arrhythmia/', 'data/validation/normal/']

error_decode = []   # 디코딩에 실패한 데이터들..
# error_len = [] # 5000, 4999개를 맞추지 못한 데이터들.. 혹은 12개의 lead가 아닌것들..

# train data
for path in train_pathes:
    for file in os.listdir(path):
        
        if file in error_train:
            print(file)
            continue
        
        try:
            data = get_lead(path + file)
        except Exception as e:
            error_decode.append(path + file)
        
        listed_data = []
        keys = sorted(data.keys())
        for key in keys:
            listed_data.append(data[key])
        
        flag = False
        for idx, i in enumerate(listed_data):
            if len(i) == 5000:
                continue
            elif len(i) == 4999:
                listed_data[idx] = np.append(i, 0)
            else:
                flag = True
        if flag:
            continue
        
        train_data.append(listed_data)
        if 'arrhythmia' in path:
            train_labels.append(1)
        else:
            train_labels.append(0)
            
# valid data
for path in valid_pathes:
    for file in os.listdir(path):
        
        if file in error_valid:
            print(file)
            continue
            
        try:
            data = get_lead(path + file)
        except Exception as e:
            error_decode.append(path + file)
        
        listed_data = []
        keys = sorted(data.keys())
        for key in keys:
            listed_data.append(data[key])
        
        valid_data.append(listed_data)
        if 'arrhythmia' in path:
            valid_labels.append(1)
        else:
            valid_labels.append(0)

print(len(error_decode))


8_2_001879_ecg.xml
6_2_003618_ecg.xml
6_2_005055_ecg.xml
6_2_003469_ecg.xml
8_2_002164_ecg.xml
8_2_007281_ecg.xml
8_2_008783_ecg.xml
8_2_007226_ecg.xml
8


In [4]:
error_lead_len = []
for idx, i in enumerate(train_data):
    if len(i) != 12:
        error_lead_len.append(idx)
for i in error_lead_len:
    del train_data[i]
    del train_labels[i]
    

# 데이터 길이 및 lead 개수 분석

In [69]:
# 데이터의 길이 분포 확인: valid는 모두 5000인것을 확인
# train은 60912개가 4999, 36개가 1249, 1개가 4988
# 위의 테스크는 먼저 4999개만 0의 패딩을 붙이고 나머지는 제외하는식으로 전처리함

c4999 = 0
c5000 = 0
cx = 0

for i in train_data:
    for j in i:
        if len(j) == 4999:
            c4999 +=1
        elif len(j) == 5000:
            c5000 +=1
        else:
            cx +=1

for i in valid_data:
    for j in i:
        if len(j) == 4999:
            c4999 +=1
        elif len(j) == 5000:
            c5000 +=1
        else:
            cx +=1

print(c4999)
print(c5000)
print(cx)

60912
463928
37


In [75]:
# 딱 한개의 9 lead의 데이터가 존재한다..
v12 = 0
v9 = 0
vx = 0

for i in train_data:
    if len(i) == 12:
        v12 +=1
    elif len(i) == 9:
        v9 += 1
    else:
        vx +=1
        
for i in valid_data:
    if len(i) == 12:
        v12 +=1
    elif len(i) == 9:
        v9 += 1
    else:
        vx +=1
print(v12)
print(v9)
print(vx)

43703
1
0


# Dataset 생성

In [5]:
train_dataset = torch.utils.data.TensorDataset(torch.tensor(train_data).float(), torch.tensor(train_labels))
valid_dataset = torch.utils.data.TensorDataset(torch.tensor(valid_data).float(), torch.tensor(valid_labels))

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
val_dataloader = torch.utils.data.DataLoader(valid_dataset, batch_size=32, shuffle=False)

  train_dataset = torch.utils.data.TensorDataset(torch.tensor(train_data).float(), torch.tensor(train_labels))


In [101]:
# residual과 dropout 추가 필요
class Classifier(nn.Module):
    def __init__(self):
        super(Classifier,self).__init__()
        self.cnn1 = nn.Conv1d(in_channels=12, out_channels=32, kernel_size=5, padding=2) 
        self.cnn2 = nn.Conv1d(in_channels=32, out_channels=64, kernel_size=5, padding=2)
        self.cnn3 = nn.Conv1d(in_channels=64, out_channels=64, kernel_size=5, padding=2)
    
        self.pool1 = nn.MaxPool1d(4)
        self.pool2 = nn.MaxPool1d(5)
        
    
        self.fc1 = nn.Linear(64 * 50, 128)
        self.fc2 = nn.Linear(128, 32)
        self.fc3 = nn.Linear(32, 1)
        
        self.relu = nn.ReLU()
        
        
    def forward(self, x):
        x = self.relu(self.cnn1(x))
        x = self.pool1(x)
        x = self.relu(self.cnn2(x))
        x = self.pool2(x)
        x = self.relu(self.cnn3(x))
        x = self.pool2(x)

        x = x.view(-1, 64*50)
        
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.relu(self.fc3(x))
        
        x = torch.sigmoid(x)
        
        return x.view(-1)
        


In [106]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = Classifier().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.00001)
criterion = nn.BCELoss()


In [None]:
epoches = 50
for i in range(epoches):
    
    # Train
    loss_sum = 0
    true_labels = []
    pred_labels = []
    model.train()
    for e_num, (x,y) in enumerate(train_dataloader):
        x, y = x.type(torch.FloatTensor).to(device), y.type(torch.FloatTensor).to(device)
        model.zero_grad()
        pred_y = model(x)
        
        loss=criterion(pred_y,y)
        loss_sum+=loss.detach()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        true_labels.extend(y.cpu().numpy())
        pred_labels.extend(np.around(pred_y.cpu().detach().numpy()))
        
        
    acc=accuracy_score(true_labels,pred_labels)
    print(f'train \t\t loss mean {loss_sum/e_num} accuracy: {acc}')
    
    # Valid
    loss_sum=0
    true_labels=[]
    pred_labels=[]
    model.eval()
    for e_num, (x,y) in enumerate(val_dataloader):
        x, y = x.type(torch.FloatTensor).to(device), y.type(torch.FloatTensor).to(device)

        pred_y = model(x)
        loss=criterion(pred_y,y)
        
        loss_sum+=loss.detach()
        
        true_labels.extend(y.cpu().numpy())
        pred_labels.extend(np.around(pred_y.cpu().detach().numpy()))
        
        
    acc=accuracy_score(true_labels,pred_labels)
    print(f'validataion \t loss mean {loss_sum/e_num} accuracy: {acc} ',end='\n\n')

train 		 loss mean 0.5464473366737366 accuracy: 0.8162908042426115
validataion 	 loss mean 0.46484997868537903 accuracy: 0.9117102284420663 

train 		 loss mean 0.4627467095851898 accuracy: 0.9137061064771909
validataion 	 loss mean 0.437894344329834 accuracy: 0.9415517596213212 

train 		 loss mean 0.4433371424674988 accuracy: 0.9329626197096077
validataion 	 loss mean 0.4295545518398285 accuracy: 0.9528709611031076 

train 		 loss mean 0.4339553415775299 accuracy: 0.9406858202038925
validataion 	 loss mean 0.4264028072357178 accuracy: 0.9454620292241201 

train 		 loss mean 0.4276195466518402 accuracy: 0.9459118525383585
validataion 	 loss mean 0.42061182856559753 accuracy: 0.9584276600123482 

train 		 loss mean 0.4218957722187042 accuracy: 0.95064874884152
validataion 	 loss mean 0.4220859408378601 accuracy: 0.956163819715991 

train 		 loss mean 0.4186888039112091 accuracy: 0.9529399649881578
validataion 	 loss mean 0.4181143045425415 accuracy: 0.9627495369417576 

train 		 loss m