In [1]:
# from google.colab import drive
# drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
from __future__ import print_function
import torch
import torch.utils.data
import numpy as np
import pandas as pd

from torch import nn, optim
from torch.utils.data.dataset import Dataset
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler

is_cuda = False
num_epochs = 100
batch_size = 10
torch.manual_seed(46)
log_interval = 10
in_channels_ = 1
num_segments_in_record = 100
segment_len = 3600
num_records = 48
num_classes = 16
allow_label_leakage = True

device = torch.device("cuda:2" if is_cuda else "cpu")
# train_ids, test_ids = train_test_split(np.arange(index_set), train_size=.8, random_state=46)
# scaler = MinMaxScaler(feature_range=(0, 1), copy=False)


class CustomDatasetFromCSV(Dataset):
    def __init__(self, data_path, transforms_=None):
        self.df = pd.read_pickle(data_path)
        self.transforms = transforms_

    def __getitem__(self, index):
        row = self.df.iloc[index]
        signal = row['signal']
        target = row['target']
        if self.transforms is not None:
            signal = self.transforms(signal)
        signal = signal.reshape(1, signal.shape[0])
        return signal, target

    def __len__(self):
        return self.df.shape[0]


train_dataset = CustomDatasetFromCSV('/content/drive/MyDrive/Arrhythmia-CNN-master/data/Arrhythmia_dataset.pkl')
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=False)
test_dataset = CustomDatasetFromCSV('/content/drive/MyDrive/Arrhythmia-CNN-master/data/Arrhythmia_dataset.pkl')
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)


class Flatten(torch.nn.Module):
    def forward(self, x):
        batch_size = x.shape[0]
        return x.view(batch_size, -1)


def basic_layer(in_channels, out_channels, kernel_size, batch_norm=False, max_pool=True, conv_stride=1, padding=0
                , pool_stride=2, pool_size=2):
    layer = nn.Sequential(
        nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=conv_stride,
                  padding=padding),
        nn.ReLU())
    if batch_norm:
        layer = nn.Sequential(
            layer,
            nn.BatchNorm1d(num_features=out_channels))
    if max_pool:
        layer = nn.Sequential(
            layer,
            nn.MaxPool1d(kernel_size=pool_size, stride=pool_stride))

    return layer


class arrhythmia_classifier(nn.Module):
    def __init__(self, in_channels=in_channels_):
        super(arrhythmia_classifier, self).__init__()
        self.cnn = nn.Sequential(
            basic_layer(in_channels=in_channels, out_channels=128, kernel_size=50, batch_norm=True, max_pool=True,
                        conv_stride=3, pool_stride=3),
            basic_layer(in_channels=128, out_channels=32, kernel_size=7, batch_norm=True, max_pool=True,
                        conv_stride=1, pool_stride=2),
            basic_layer(in_channels=32, out_channels=32, kernel_size=10, batch_norm=False, max_pool=False,
                        conv_stride=1),
            basic_layer(in_channels=32, out_channels=128, kernel_size=5, batch_norm=False, max_pool=True,
                        conv_stride=2, pool_stride=2),
            basic_layer(in_channels=128, out_channels=256, kernel_size=15, batch_norm=False, max_pool=True,
                        conv_stride=1, pool_stride=2),
            basic_layer(in_channels=256, out_channels=512, kernel_size=5, batch_norm=False, max_pool=False,
                        conv_stride=1),
            basic_layer(in_channels=512, out_channels=128, kernel_size=3, batch_norm=False, max_pool=False,
                        conv_stride=1),
            Flatten(),
            nn.Linear(in_features=1152, out_features=512),
            nn.ReLU(),
            nn.Dropout(p=.1),
            nn.Linear(in_features=512, out_features=num_classes),
            nn.Softmax()
        )

    def forward(self, x, ex_features=None):
        return self.cnn(x)


def calc_next_len_conv1d(current_len=112500, kernel_size=16, stride=8, padding=0, dilation=1):
    return int(np.floor((current_len + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1))


model = arrhythmia_classifier().to(device).double()
lr = 0.0003
num_of_iteration = len(train_dataset) // batch_size

optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-6)
criterion = nn.NLLLoss()


def train(epoch):
    model.train()
    train_loss = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                       100. * batch_idx / len(train_loader),
                       loss.item() / len(data)))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
        epoch, train_loss / len(train_loader.dataset)))


def test(epoch):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(test_loader):
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = criterion(output, target)
            test_loss += loss.item()

            if batch_idx == 0:
                n = min(data.size(0), 4)

    test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.5f}'.format(test_loss))
    print(f'Learning rate: {optimizer.param_groups[0]["lr"]:.6f}')


if __name__ == "__main__":
    for epoch in range(1, num_epochs + 1):
        train(epoch)
        test(epoch)

In [None]:
torch.save(model,f"/content/drive/MyDrive/meeting_ECG/save_model/temp.pt")

In [None]:
num_labels_dict = {
    'Normal beat': 283,  # N
    'Left bundle branch block beat': 103,  # L
    'Atrial premature beat': 66,  # A
    'Atrial flutter': 20,  # (AFL (aux)
    'Atrial fibrillation': 135,  # (AFIB (aux)
    'Pre-excitation (WPW)': 21,  # (PREX (aux)
    'Premature ventricular contraction': 133,  # V
    'Ventricular bigeminy': 55,  # (B (aux)
    'Ventricular trigeminy': 13,  # (T (aux)
    'Ventricular tachycardia': 10,  # (VT (aux)
    'Idioventricular rhythm': 10,  # (IVR (aux)
    'Ventricular flutter': 10,  # (VFL (aux)
    'Fusion of ventricular and normal beat': 11,  # F
    'Second-degree heart block': 10,  # SHB
    'Pacemaker rhythm': 45,  # P
    'Supraventricular tachyarrhythmia': 13,  # (SVTA (aux)
    'Right bundle branch block beat': 62,  # R
}
num_labels_we_have = {
    'N': 283,  # 'Normal beat'
    'L': 103,  # 'Left bundle branch block beat'
    'A': 66,  # 'Atrial premature beat':
    # 'V': 133, # 'Premature ventricular contraction'
    # 'SHB': 10,  # 'Second-degree heart block'
    # 'F': 11,  # 'Fusion of ventricular and normal beat'
    'R': 62,  # 'Right bundle branch block beat'
    'AFL': 20,  # 'Atrial flutter'
    'AFIB': 135,  # 'Atrial fibrillation'
    'PREX': 21,  # 'Pre-excitation (WPW)'
    'B': 55,  # 'Ventricular bigeminy'
    'T': 13,  # 'Ventricular trigeminy'
    # '(VT'   : 10,  # 'Ventricular tachycardia'
    'IVR': 10,  # 'Idioventricular rhythm'
    'VFL': 10,  # 'Ventricular flutter'
    'P': 45  # 'Pacemaker rhythm'
    # '(SVTA' : 13  # 'Supraventricular tachyarrhythmia'
}
dict_mapping={
    'Normal beat': 'N',
    'Left bundle branch block beat': 'L',
    'Atrial premature beat': 'A',
    'Atrial flutter': 'AFL',
    'Atrial fibrillation': 'AFIB',
    'Pre-excitation (WPW)': 'PREX',
    'Premature ventricular contraction': 'V',
    'Ventricular bigeminy': 'B',
    'Ventricular trigeminy': 'T',
    'Ventricular tachycardia': 'VT',
    'Idioventricular rhythm': 'IVR',
    'Ventricular flutter': 'VFL',
    'Fusion of ventricular and normal beat': 'F',
    'Second-degree heart block': 'SHB',
    'Pacemaker rhythm': 'P',
    'Supraventricular tachyarrhythmia': 'SVTA',
    'Right bundle branch block beat': 'R',
}
label_numeric_dict = {
    'N': 0,
    'A': 1,
    'AFL': 2,
    'AFIB': 3,
    'PREX': 4,
    'B': 5,
    'T': 6,
    'IVR': 7,
    'VFL': 8,
    'L': 9,
    'R': 10,
    'P': 11
}
classes_=len(dict_mapping)




# model=torch.load("/content/drive/MyDrive/meeting_ECG/save_model/lr03_epoch8.pt")
model=torch.load("/content/drive/MyDrive/meeting_ECG/save_model/temp.pt")

test_dataset = CustomDatasetFromCSV('/content/drive/MyDrive/Arrhythmia-CNN-master/data/Arrhythmia_dataset_yystest_noseed1.pkl')
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=1, shuffle=False)
print(test_dataset.df)


# conf_matrix = torch.zeros(classes_,classes_)
dict_truth_prediction_map={}
count_=[0]*16
count_p=[0]*16

model.eval()
test_loss = 0
with torch.no_grad():
    for batch_idx, (data, target) in enumerate(test_loader):
        data, target = data.to(device), target.to(device)
        output = model(data)
        dict_truth_prediction_map[batch_idx]=(target,output)
        # print(int(target))
        count_[int(target)]+=1


        # list_=output[0].tolist()
        # print(list_.index(max(list_)))
        # print(output[0].tolist())
        count_p[int(list_.index(max(list_)))]+=1

# for idx_ in dict_truth_prediction_map:
#   print(idx_,'\t',dict_truth_prediction_map[idx_])


print(count_)
print(count_p)