## Early Classification, with pretrained weights.

In [1]:
import numpy as np
import pandas as pd
from os.path import join as osj
from ecg_utilities import *

from torch.nn import functional as Func

from pytorch_sklearn import NeuralNetwork
from pytorch_sklearn.callbacks import WeightCheckpoint, Verbose, LossPlot, EarlyStopping, Callback, CallbackInfo
from pytorch_sklearn.utils.func_utils import to_safe_tensor

In [2]:
# privacy = False

In [3]:
patient_ids = pd.read_csv(osj("..", "files", "patient_ids.csv"), header=None).to_numpy().reshape(-1)
valid_patients = pd.read_csv(osj("..", "files", "valid_patients.csv"), header=None).to_numpy().reshape(-1)

In [4]:
DATASET_PATH = osj("..", "data_single", "dataset_training", "domain_adapted")
TRIO_PATH = osj("..", "data_trio", "dataset_training", "domain_adapted")

# if privacy:
#     # DATASET_PATH = osj(DATASET_PATH,  "domain_adapted_dp_ln")
#     # TRIO_PATH = osj(TRIO_PATH,  "domain_adapted_dp_ln")
#     DATASET_PATH = osj(DATASET_PATH,  "domain_adapted_dp_l")
#     TRIO_PATH = osj(TRIO_PATH,  "domain_adapted_dp_l")
#     DICT_PATH = osj("..", "dictionaries", "5min_sorted_dp")
#     SAVE_PATH = osj("..", "savefolder_dp", "ens")
#     # LOAD_PATH = osj("..", "savefolder_dp", "nets_ln") # self-trained dp
#     LOAD_PATH = osj("..", "savefolder_dp", "nets_l") # self-trained dp
# else:
DICT_PATH = osj("..", "data_single", "dictionaries", "5min_sorted")

SAVE_PATH = osj("..", "savefolder", "en_save")

LOAD_PATH = osj("..", "pretrained", "nets") # pretrained
#LOAD_PATH = osj("..", "savefolder", "nets") # self-trained

### Efficiency vs F1 with DA.

In [5]:
max_epochs = [-1]
batch_sizes = [1024]

In [6]:
%%time
all_patient_cms = []
all_cms = []
repeats = 10

for repeat in range(repeats):
    patient_cms = {percentile:{} for percentile in range(0, 101)}
    cm = {percentile:torch.zeros(2, 2) for percentile in range(0, 101)}
    
    for i, patient_id in enumerate(valid_patients):
        dataset = load_N_channel_dataset(patient_id, DATASET_PATH, TRIO_PATH)
        train_X, train_y, train_ids, val_X, val_y, val_ids, test_X, test_y, test_ids = dataset.values()
        
        # For consulting through error energy.
        D, F = load_dictionary(patient_id, DICT_PATH)
        D, F = torch.Tensor(D), torch.Tensor(F)
        
        E = get_error_one_patient(test_X[:, 0, :].squeeze(), F, as_energy=True)
        
        # Load the neural network.
        model = get_base_model(in_channels=train_X.shape[1])
        model = model.to("cuda")
        crit = nn.CrossEntropyLoss()
        optim = torch.optim.AdamW(params=model.parameters())
        
        net = NeuralNetwork.load_class(osj(LOAD_PATH, f"net_{repeat+1}_{patient_id}"), model, optim, crit)
        weight_checkpoint_val_loss = net.cbmanager.callbacks[1]  # <- this needs to change in case weight checkpoint is not the second callback.
        
        net.load_weights(weight_checkpoint_val_loss)
        pred_y = net.predict(test_X, batch_size=1024, use_cuda=True, fits_gpu=True, decision_func=lambda pred_y: pred_y.argmax(dim=1)).cpu()
        
        for percentile in range(0, 101):
            thresh = np.percentile(E, percentile)
            less_than = E < thresh
            greater_than = E >= thresh

            final_pred_y = torch.Tensor(np.select([less_than, greater_than], [torch.zeros_like(pred_y), pred_y])).long()
            cm_exp = get_confusion_matrix(final_pred_y, test_y, pos_is_zero=False)

            patient_cms[percentile][patient_id] = cm_exp
            cm[percentile] += cm_exp
            
        print_progress(i + 1, len(valid_patients), opt=[f"{patient_id}"])
        
    all_patient_cms.append(patient_cms)
    all_cms.append(cm)

CPU times: user 13min 32s, sys: 13min 41s, total: 27min 14s
Wall time: 45.8 s


In [7]:
config = dict(
    learning_rate=0.001,
    max_epochs=max_epochs[0],
    batch_size=batch_sizes[0],
    optimizer=optim.__class__.__name__,
    loss=crit.__class__.__name__,
    early_stopping="true",
    checkpoint_on=weight_checkpoint_val_loss.tracked,
    dataset="default+trio",
    info="2-channel run, domain adapted, early classification performance based on efficiency."
)

In [8]:
all_cms[0][0] # pre-trained

tensor([[ 7204.,   690.],
        [  759., 55738.]])

In [None]:
all_cms[0][0] # self-trained

tensor([[ 7399.,   676.],
        [  564., 55752.]])

In [9]:
all_cms[9][39] # pre-trained

tensor([[ 6821.,   643.],
        [ 1142., 55785.]])

In [67]:
all_cms[9][39] # as specified in the paper the efficiency threshold of 40 was chosen

tensor([[ 6813.,   596.],
        [ 1150., 55832.]])

In [10]:
all_cms[9][98] # pre-trained

tensor([[  682.,    71.],
        [ 7281., 56357.]])

In [None]:
all_cms[9][98] # self-trained

tensor([[  677.,    77.],
        [ 7286., 56351.]])

In [11]:
get_performance_metrics(all_cms[9][39]) # on pre-trained nets

{'acc': 0.9722787346057679,
 'rec': 0.8565867135501696,
 'spe': 0.9886049478982065,
 'pre': 0.9138531618435155,
 'npv': 0.9799392204050802,
 'f1': 0.8842937706618266}

In [68]:
get_performance_metrics(all_cms[9][39]) # on self-trained nets

{'acc': 0.9728844093118604,
 'rec': 0.8555820670601532,
 'spe': 0.9894378677252428,
 'pre': 0.919557295181536,
 'npv': 0.9798181881997824,
 'f1': 0.8864168618266978}

In [None]:
if False:
    with open(osj(SAVE_PATH, "cms.pkl"), "wb") as f:
        pickle.dump(all_cms, f)
        
    with open(osj(SAVE_PATH, "config.pkl"), "wb") as f:
        pickle.dump(config, f)
        
    with open(osj(SAVE_PATH, "patient_cms.pkl"), "wb") as f:
        pickle.dump(all_patient_cms, f)

### Efficiency vs F1 with DA and Ensemble Classifier.

In [12]:
#CONF_PATH = osj("..", "savefolder", "Ens_val", "confidences.pkl")
CONF_PATH = osj("..", "pretrained", "confidences.pkl")

In [13]:
max_epochs = [-1]
batch_sizes = [1024]

In [14]:
with open(CONF_PATH, "rb") as f:
    confs = pickle.load(f)

In [15]:
#%%time
all_patient_cms = []
all_cms = []
repeats = 10

for repeat in range(repeats):
    patient_cms = {percentile:{} for percentile in range(0, 101)}
    cm = {percentile:torch.zeros(2, 2) for percentile in range(0, 101)}
    
    for i, patient_id in enumerate(valid_patients):
        dataset = load_N_channel_dataset(patient_id, DATASET_PATH, TRIO_PATH)
        train_X, train_y, train_ids, val_X, val_y, val_ids, test_X, test_y, test_ids = dataset.values()
        
        # For consulting through error energy.
        D, F = load_dictionary(patient_id, DICT_PATH)
        D, F = torch.Tensor(D), torch.Tensor(F)
        
        ## Consulting Exponential - Gaussian.
        BF = BayesianFit()
        EF = ExponentialFit()
        GF = GaussianFit()

        # Train error.
        train_E, E_healthy, E_arrhyth = get_error_one_patient(train_X[:, 0, :].squeeze(), F, y=train_y, as_energy=True)
        
        # Test Error
        test_E = get_error_one_patient(test_X[:, 0, :].squeeze(), F, as_energy=True)
        
        EF.fit(E_healthy)
        GF.fit(E_arrhyth)
        consult_test_y = torch.Tensor(BF.predict(test_E, EF, GF) <= 0.5).long()
        ##
        
        # Load the neural network.
        model = get_base_model(in_channels=train_X.shape[1])
        model = model.to("cuda")
        crit = nn.CrossEntropyLoss()
        optim = torch.optim.AdamW(params=model.parameters())
        
        net = NeuralNetwork.load_class(osj(LOAD_PATH, f"net_{repeat+1}_{patient_id}"), model, optim, crit)
        weight_checkpoint_val_loss = net.cbmanager.callbacks[1]  # <- this needs to change in case weight checkpoint is not the second callback.
        
        net.load_weights(weight_checkpoint_val_loss)
        
        # Test predictions and probabilities.
        pred_y = net.predict(test_X, batch_size=1024, use_cuda=True, fits_gpu=True, decision_func=lambda pred_y: pred_y.argmax(dim=1)).cpu()
        prob_y = net.predict_proba(test_X).cpu()
        softmax_prob_y = Func.softmax(prob_y, dim=1).max(dim=1).values
        
        for percentile in range(0, 101):
            thresh = np.percentile(test_E, percentile)
            less_than = test_E < thresh
            greater_than = test_E >= thresh
            
            conf = confs[repeat][i]
            low_confidence = softmax_prob_y < conf
            high_confidence = softmax_prob_y >= conf

            # These are neural network vs probabilistic model predictions based on confidence.
            final_pred_y = torch.Tensor(np.select([low_confidence, high_confidence], [consult_test_y, pred_y])).long()
            
            # Signals below threshold are early classified, the rest are classified with confidence.
            final_pred_y = torch.Tensor(np.select([less_than, greater_than], [torch.zeros_like(final_pred_y), final_pred_y])).long()
            cm_exp = get_confusion_matrix(final_pred_y, test_y, pos_is_zero=False)

            patient_cms[percentile][patient_id] = cm_exp
            cm[percentile] += cm_exp
            
        print_progress(i + 1, len(valid_patients), opt=[f"{patient_id}"])
        
    all_patient_cms.append(patient_cms)
    all_cms.append(cm)



In [16]:
config = dict(
    learning_rate=0.001,
    max_epochs=max_epochs[0],
    batch_size=batch_sizes[0],
    optimizer=optim.__class__.__name__,
    loss=crit.__class__.__name__,
    early_stopping="true",
    checkpoint_on=weight_checkpoint_val_loss.tracked,
    dataset="default+trio",
    info="2-channel run, domain adapted, consulted, early classification performance based on efficiency."
)

In [17]:
all_cms[9][39] # pre-trained

tensor([[ 7013.,   665.],
        [  950., 55763.]])

In [None]:
all_cms[9][39] # self-trained

tensor([[ 6887.,   618.],
        [ 1076., 55810.]])

In [18]:
get_performance_metrics(all_cms[9][39]) # on pre-trained nets

{'acc': 0.9749188551195043,
 'rec': 0.8806982293105613,
 'spe': 0.9882150705323598,
 'pre': 0.91338890336025,
 'npv': 0.9832489905312715,
 'f1': 0.896745732370053}

In [None]:
get_performance_metrics(all_cms[9][39]) # on self-trained nets

{'acc': 0.9736919755866503,
 'rec': 0.8648750470928043,
 'spe': 0.9890479903593961,
 'pre': 0.9176548967355097,
 'npv': 0.9810849769714869,
 'f1': 0.8904835790018102}

In [None]:
if False:
    with open(osj(SAVE_PATH, "cms.pkl"), "wb") as f:
        pickle.dump(all_cms, f)
        
    with open(osj(SAVE_PATH, "config.pkl"), "wb") as f:
        pickle.dump(config, f)
        
    with open(osj(SAVE_PATH, "patient_cms.pkl"), "wb") as f:
        pickle.dump(all_patient_cms, f)