In [1]:
import sys
import os
source_path = '/home/huypham/Projects/ecg/tmp/physionet2020-submission'
if source_path not in sys.path:
    sys.path.append(source_path)

In [2]:
import time

In [3]:
from train_12ECG_classifier import *

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
input_directory = "/home/huypham/Projects/ecg/tmp/physionet2020-submission/input"
output_directory = "/tmp/prna/"

In [5]:
!rm -rf /tmp/prna/
!mkdir -p /tmp/prna/
!cp -r /home/huypham/Projects/ecg/tmp/physionet2020-submission/output/saved_models /tmp/prna/

In [6]:
src_path = Path(input_directory)
# train_classifier(src_path, output_directory, 'test')

In [7]:
do_train = False
batch_size = 1

In [8]:
patience_count = 0
best_auroc = 0.

torch.manual_seed(1)
if torch.cuda.is_available():
    torch.cuda.manual_seed(1)
device = torch.device('cpu')

# Train, validation and test fold splits
# val_fold = (tst_fold - 1) % 10
# trn_fold = np.delete(np.arange(10), [val_fold, tst_fold])
val_fold = 'val'
trn_fold = 'train'
tst_fold = 'test'

print('trn:', trn_fold)
print('val:', val_fold)
print('tst:', tst_fold)

model = CTN(d_model, nhead, d_ff, num_layers, dropout_rate, deepfeat_sz, nb_feats, nb_demo, classes).to(device)

# Initialize parameters with Glorot / fan_avg.
for p in model.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

# model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))

print(f'Number of params: {sum([p.data.nelement() for p in model.parameters()])}')

trn_df = data_df[data_df.fold == trn_fold]
val_df = data_df[data_df.fold == val_fold]
tst_df = data_df[data_df.fold == tst_fold]

if debug:
    trn_df = trn_df[:5]
    val_df = val_df[:5]
    tst_df = tst_df[:5]

if padding == 'zero':
    trnloader = DataLoader(ECGWindowPaddingDataset(trn_df, window, nb_windows=1, src_path=src_path, cache_dir='cache/train'), batch_size=batch_size, shuffle=True, num_workers=8, collate_fn=collate_fn)
    valloader = DataLoader(ECGWindowPaddingDataset(val_df, window, nb_windows=10, src_path=src_path, cache_dir='cache/val'), batch_size=batch_size, shuffle=False, num_workers=8, collate_fn=collate_fn)
    tstloader = DataLoader(ECGWindowPaddingDataset(tst_df, window, nb_windows=20, src_path=src_path, cache_dir=None), batch_size=batch_size, shuffle=False, num_workers=8, collate_fn=collate_fn)
elif padding == 'qrs':
    trnloader = DataLoader(ECGWindowAlignedDataset(trn_df, window, nb_windows=1, src_path=src_path, cache_dir='cache/train'), batch_size=batch_size, shuffle=True, num_workers=8, collate_fn=collate_fn)
    valloader = DataLoader(ECGWindowAlignedDataset(val_df, window, nb_windows=10, src_path=src_path, cache_dir='cache/val'), batch_size=batch_size, shuffle=False, num_workers=8, collate_fn=collate_fn)
    tstloader = DataLoader(ECGWindowAlignedDataset(tst_df, window, nb_windows=20, src_path=src_path, cache_dir=None), batch_size=batch_size, shuffle=False, num_workers=8, collate_fn=collate_fn)

optimizer = NoamOpt(d_model, 1, 4000, torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))

# Create dir structure and init logs
results_loc, sw = create_experiment_directory(output_directory)
fold_loc = create_fold_dir(results_loc, tst_fold)
start_log(fold_loc, tst_fold)

print(fold_loc)

if do_train:
    for epoch in tqdm(range(100), desc='epoch'):
        trn_loss, trn_auroc = train(epoch, model, trnloader, optimizer)
        val_loss, val_auroc = validate(epoch, model, valloader, optimizer, fold_loc)
        write_log(fold_loc, tst_fold, epoch, trn_loss, trn_auroc, val_loss, val_auroc)
        print(f'Train - loss: {trn_loss}, auroc: {trn_auroc}')
        print(f'Valid - loss: {val_loss}, auroc: {val_auroc}')
        
        sw.add_scalar(f'{tst_fold}/trn/loss', trn_loss, epoch)
        sw.add_scalar(f'{tst_fold}/trn/auroc', trn_auroc, epoch)
        sw.add_scalar(f'{tst_fold}/val/loss', val_loss, epoch)
        sw.add_scalar(f'{tst_fold}/val/auroc', val_auroc, epoch)

        # Early stopping
        if patience_count >= patience:
            print(f'Early stopping invoked at epoch, #{epoch}')
            break
    
# Training done, choose threshold...
# model = load_best_model(str(f'{fold_loc}/{model_name}.tar'), model)
ckpt = '/tmp/prna/saved_models/ctn/fold_test/ctn.tar'
checkpoint = torch.load(ckpt, map_location=torch.device('cpu'))

state_dict = checkpoint['model_state_dict']

module_name = list(state_dict.keys())
for mod in module_name:
    weight = state_dict.pop(mod)
    mod = mod.replace('module.', '')
    state_dict[mod] = weight

model.load_state_dict(state_dict)

if padding == 'zero':
    valloader = DataLoader(ECGWindowPaddingDataset(val_df, window, nb_windows=20, src_path=src_path, cache_dir='cache/val'), batch_size=batch_size, shuffle=False, num_workers=8, collate_fn=collate_fn)
    tstloader = DataLoader(ECGWindowPaddingDataset(tst_df, window, nb_windows=20, src_path=src_path, cache_dir=None), batch_size=batch_size, shuffle=False, num_workers=8, collate_fn=collate_fn)
elif padding == 'qrs':
    valloader = DataLoader(ECGWindowAlignedDataset(val_df, window, nb_windows=20, src_path=src_path), batch_size=batch_size, shuffle=False, num_workers=8, collate_fn=collate_fn)
    tstloader = DataLoader(ECGWindowAlignedDataset(tst_df, window, nb_windows=20, src_path=src_path), batch_size=batch_size, shuffle=False, num_workers=8, collate_fn=collate_fn)

trn: train
val: val
tst: test
Number of params: 13643885
/tmp/prna/saved_models/ctn/fold_test


In [9]:
dataloader = tstloader
total = len(dataloader)

In [10]:
''' Return probs and lbls given model and dataloader '''
model.eval()
probs, lbls = [], []
processing_time = []
predicting_time = []

for i, (inp_windows_t, feats_t, lbl_t, hdr, filename) in tqdm(enumerate(dataloader), total=len(dataloader), desc='get_probs', disable=False):
    # print(i, '/', total)
    start_processing = time.time()
    # Get normalized data
    inp_windows_t, lbl_t = inp_windows_t.float().to(device), lbl_t.float().to(device)
    
    # Get (normalized) demographic data and append to top (normalized) features
    # Be careful not to double count Age/Gender in future
    # age_t = torch.FloatTensor((get_age(hdr[13])[None].T - data_df.Age.mean()) / data_df.Age.std())
    # sex_t = torch.FloatTensor([1. if h.find('Female') >= 0. else 0 for h in hdr[14]])[None].T

    ages = [i[13] for i in hdr]
    age_t = torch.FloatTensor((get_age(ages)[None].T - data_df.Age.mean()) / data_df.Age.std())
    sex_t = torch.FloatTensor(
        # [1. if h.find('Female') >= 0. else 0 for h in hdr[14]]
        [1. if h[14].find('Female') >= 0. else 0 for h in hdr]
    )[None].T


    wide_feats = torch.cat([age_t, sex_t, feats_t.squeeze(1).float()], dim=1).to(device)

    stop_processing = time.time()

    processing_time.append(stop_processing - start_processing)

    # Predict
    start_predicting = time.time()
    outs = []
    with torch.no_grad():
        # Loop over nb_windows
        for inp_t in inp_windows_t.transpose(1, 0):
            out = model(inp_t, wide_feats)
            outs.append(out)
        out = torch.stack(outs).mean(dim=0)   # take the average of the sequence windows

    # Collect probs and labels
    probs.append(out.sigmoid().data.cpu().numpy())
    lbls.append(lbl_t.data.cpu().numpy())
    stop_predicting = time.time()
    predicting_time.append(stop_predicting - start_predicting)

# Consolidate probs and labels
lbls = np.concatenate(lbls)
probs = np.concatenate(probs)

get_probs: 100%|█████████████████████████████████████████████████████████████████████████████████████| 8621/8621 [2:01:04<00:00,  1.19it/s]


In [12]:
print('Processing:', np.mean(processing_time)*1000)
print('Predicting:', np.mean(predicting_time)*1000)
print('Total:', 1000*(np.mean(processing_time) + np.mean(predicting_time)))

Processing: 4.921544885541347
Predicting: 833.5824819469021
Total: 838.5040268324434
