In [1]:
import numpy as np 
import pandas as pd 
import os
import pickle
import sys
import torch
from torch import nn
import scipy
import torch.nn.functional as Func
import warnings
from tqdm import tqdm

In [2]:
sys.path.append("/kaggle/input/pytorch-sklearnn")
from pytorch_sklearn import NeuralNetwork
from pytorch_sklearn.callbacks import WeightCheckpoint, Verbose, LossPlot, EarlyStopping, Callback, CallbackInfo
from pytorch_sklearn.utils.func_utils import to_safe_tensor

In [3]:
warnings.filterwarnings('ignore')

In [4]:
def read_data(path):
    with open(path, "rb") as f:
        return pickle.load(f)

In [5]:
VALID_PATIENT_IDS_PATH = '/kaggle/input/arrythmia-valid-patient-ids/patient_ids.pkl'
valid_patients = read_data(VALID_PATIENT_IDS_PATH)

In [6]:
DATASET_PATH = '/kaggle/input/arrythmia-dataset/dataset_training/dataset_beat_single_domain_adapted'
TRIO_PATH = '/kaggle/input/arrythmia-dataset/dataset_training/dataset_beat_trios_domain_adapted'

In [7]:
SAVE_PATH = '/kaggle/working/'
save_dir = os.path.join(SAVE_PATH, "nets")
os.makedirs(save_dir, exist_ok=True)

In [8]:
def load_dataset(patient_id, PATH):
    """ 
    Reads the pickled ECG dataset from the given path for the given patient.
    The file name must be "patient_<patient_id>_dataset.pkl".
    """
    with open(os.path.join(PATH, f"patient_{patient_id}_dataset.pkl"), "rb") as f:
        return pickle.load(f)
    
def dataset_to_tensors(dataset):
    """
    Converts the given dataset to torch tensors in appropriate data types and shapes.
    """
    dataset = dataset.copy()
    train_X, train_y, train_ids, val_X, val_y, val_ids, test_X, test_y, test_ids = dataset.values()
    dataset["train_X"] = torch.Tensor(train_X).float().reshape(-1, 1, train_X.shape[1])
    dataset["train_y"] = torch.Tensor(train_y).long()
    dataset["train_ids"] = torch.Tensor(train_ids).long()
    dataset["val_X"] = torch.Tensor(val_X).float().reshape(-1, 1, val_X.shape[1])
    dataset["val_y"] = torch.Tensor(val_y).long()
    dataset["val_ids"] = torch.Tensor(val_ids).long()
    dataset["test_X"] = torch.Tensor(test_X).float().reshape(-1, 1, test_X.shape[1])
    dataset["test_y"] = torch.Tensor(test_y).long()
    dataset["test_ids"] = torch.Tensor(test_ids).long()
    return dataset

def add_dataset(patient_id, dataset, DATASET_PATH):
    """
    Adds another dataset to an already existing one, increasing the number of channels.
    """
    dataset = dataset.copy()
    dataset_other = load_dataset(patient_id, DATASET_PATH)
    dataset_other = dataset_to_tensors(dataset_other)
    
    assert torch.equal(dataset["train_y"], dataset_other["train_y"]), "Training ground truths are different. Possibly shuffled differently."
    assert torch.equal(dataset["val_y"], dataset_other["val_y"]), "Validation ground truths are different. Possibly shuffled differently."
    assert torch.equal(dataset["test_y"], dataset_other["test_y"]), "Test ground truths are different. Possibly shuffled differently."
    
    train_X, train_y, train_ids, val_X, val_y, val_ids, test_X, test_y, test_ids = dataset.values()
    train_other_X, _, _, val_other_X, _, _, test_other_X, _, _ = dataset_other.values()
    dataset["train_X"] = torch.cat((train_X, train_other_X), dim=1)
    dataset["val_X"] = torch.cat((val_X, val_other_X), dim=1)
    dataset["test_X"] = torch.cat((test_X, test_other_X), dim=1)
    return dataset

def load_N_channel_dataset(patient_id, DEFAULT_PATH, *PATHS):
    """
    Loads the ECG dataset at the given path(s) for the given patient. Each dataset will be added as a new
    channel in the given order.
    """
    default_dataset = load_dataset(patient_id, DEFAULT_PATH)
    default_dataset = dataset_to_tensors(default_dataset)
    for PATH in PATHS:
        default_dataset = add_dataset(patient_id, default_dataset, PATH)
    return default_dataset

In [9]:
def get_performance_metrics(cm):
    """
    Calculates:
        accuracy
        true positive rate (recall, sensitivity)
        specificity (1 - false positive rate)
        positive predictive value (PPV, precision)
        negative predictive value (NPV)
        F1-score
    from the given confusion matrix.
    """
    cm = np.asarray(cm).copy()
    tp, fp, tn, fn = cm[0,0], cm[0,1], cm[1,1], cm[1,0]
    acc = (tp + tn) / (tp + tn + fp + fn)
    rec = tp / (tp + fn)
    spe = tn / (tn + fp)
    pre = tp / (tp + fp)
    npv = tn / (tn + fn)
    f1 = (2 * pre * rec) / (pre + rec)
    metrics = {"acc":acc, "rec":rec, "spe":spe, "pre":pre, "npv":npv, "f1":f1}
    return metrics

def get_confusion_matrix(pred_y, true_y, pos_is_zero=False):
    """
    Calculates the confusion matrix for the given predictions and truth values. 
    
    Set pos_is_zero to True if the positive sample's class index is 0.
    In the case of our ECG work, positive means an abnormal beat, and has a class index of 1.
    """
    pred_y = torch.as_tensor(pred_y, dtype=torch.long)
    true_y = torch.as_tensor(true_y, dtype=torch.long)
    vals = true_y + 2 * pred_y   # 0,0 -> 0    1,0 -> 1    0,1 -> 2    1,1 -> 3
    cm = torch.zeros(4).long()  
    cm += torch.bincount(vals, minlength=4)
    cm = cm.reshape(2, 2)
    
    if not pos_is_zero:
        return cm.flip((0, 1))
    return cm

In [10]:
def get_base_model(in_channels):
    """
    Returns the model from paper: Personalized Monitoring and Advance Warning System for Cardiac Arrhythmias.
    """
    # Input size: 128x1
    # 128x1 -> 122x32 -> 40x32 -> 34x16 -> 11x16 -> 5x16 -> 1x16
    model = nn.Sequential(
        nn.Conv1d(in_channels, 32, kernel_size=7, padding=0, bias=True),
        nn.MaxPool1d(3),
        nn.Tanh(),
        
        nn.Conv1d(32, 16, kernel_size=7, padding=0, bias=True),
        nn.MaxPool1d(3),
        nn.Tanh(),
        
        nn.Conv1d(16, 16, kernel_size=7, padding=0, bias=True),
        nn.MaxPool1d(3),
        nn.Tanh(),
        
        nn.Flatten(),
        
        nn.Linear(16, 32, bias=True),
        nn.ReLU(),
        
        nn.Linear(32, 2, bias=True),
    )
    return model

In [11]:
max_epochs = [-1]
batch_sizes = [1024]

In [12]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [13]:

all_patient_cms = []
all_cms = []
repeats = 10

for repeat in range(repeats):
    patient_cms = {}
    cm = torch.zeros(2, 2)

    for i, patient_id in tqdm(enumerate(valid_patients), total=len(valid_patients), desc=f"Repeat {repeat+1} Progress"):
        dataset = load_N_channel_dataset(patient_id, DATASET_PATH, TRIO_PATH)
        train_X, train_y, train_ids, val_X, val_y, val_ids, test_X, test_y, test_ids = dataset.values()

        # TRAIN THE NEURAL NETWORK
        model = get_base_model(in_channels=train_X.shape[1])
        model = model.to("cuda")
        crit = nn.CrossEntropyLoss()
        optim = torch.optim.AdamW(params=model.parameters())

        net = NeuralNetwork(model, optim, crit)
        weight_checkpoint_val_loss = WeightCheckpoint(tracked="val_loss", mode="min")
        early_stopping = EarlyStopping(tracked="val_loss", mode="min", patience=15)

        net.fit(
            train_X=train_X,
            train_y=train_y,
            validate=True,
            val_X=val_X,
            val_y=val_y,
            max_epochs=max_epochs[0],
            batch_size=batch_sizes[0],
            use_cuda=True,
            fits_gpu=True,
            callbacks=[weight_checkpoint_val_loss, early_stopping],
        )

        net.load_weights(weight_checkpoint_val_loss)
        pred_y = net.predict(test_X, decision_func=lambda pred_y: pred_y.argmax(dim=1)).cpu()

        # SAVE TRAINED WEIGHTS
        NeuralNetwork.save_class(net, os.path.join(save_dir, f"net_{repeat+1}_{patient_id}"))

        cur_cm = get_confusion_matrix(pred_y, test_y, pos_is_zero=False)
        patient_cms[patient_id] = cur_cm
        cm += cur_cm

    all_patient_cms.append(patient_cms)
    all_cms.append(cm)

Repeat 1 Progress: 100%|██████████| 34/34 [10:48<00:00, 19.08s/it]
Repeat 2 Progress: 100%|██████████| 34/34 [10:03<00:00, 17.76s/it]
Repeat 3 Progress: 100%|██████████| 34/34 [10:14<00:00, 18.06s/it]
Repeat 4 Progress: 100%|██████████| 34/34 [10:35<00:00, 18.69s/it]
Repeat 5 Progress: 100%|██████████| 34/34 [10:45<00:00, 19.00s/it]
Repeat 6 Progress: 100%|██████████| 34/34 [10:27<00:00, 18.47s/it]
Repeat 7 Progress: 100%|██████████| 34/34 [10:24<00:00, 18.38s/it]
Repeat 8 Progress: 100%|██████████| 34/34 [10:29<00:00, 18.53s/it]
Repeat 9 Progress: 100%|██████████| 34/34 [10:50<00:00, 19.14s/it]
Repeat 10 Progress: 100%|██████████| 34/34 [10:48<00:00, 19.07s/it]


In [14]:
config = dict(
    learning_rate=0.001,
    max_epochs=max_epochs[0],
    batch_size=batch_sizes[0],
    optimizer=optim.__class__.__name__,
    loss=crit.__class__.__name__,
    early_stopping="true",
    checkpoint_on=weight_checkpoint_val_loss.tracked,
    dataset="default+trio",
    info="2-channel run, domain adapted, consulting with default dictionary, and trying all thresholds, saves weights"
)

In [15]:
all_cms = np.stack(all_cms).astype(int)

In [16]:
get_performance_metrics(all_cms.sum(axis=0))

{'acc': 0.9906384432607042,
 'rec': 0.9872912219012935,
 'spe': 0.9911107960586942,
 'pre': 0.9400243919936868,
 'npv': 0.998193747367369,
 'f1': 0.963078204625637}

In [17]:
with open(os.path.join(SAVE_PATH, "cms.pkl"), "wb") as f:
    pickle.dump(all_cms, f)
    
with open(os.path.join(SAVE_PATH, "config.pkl"), "wb") as f:
    pickle.dump(config, f)
    
with open(os.path.join(SAVE_PATH, "patient_cms.pkl"), "wb") as f:
    pickle.dump(all_patient_cms, f)