In [1]:
from __future__ import absolute_import, division, print_function

import argparse
from datetime import datetime
import numpy as np
import os

from keras import backend as K
from keras.callbacks import EarlyStopping, TensorBoard
import numpy as np
from sklearn.metrics import roc_auc_score

from data_handler import DataHandler
from models import create_grud_model, load_grud_model
from nn_utils.callbacks import ModelCheckpointwithBestWeights

import warnings
warnings.filterwarnings("ignore")


Using TensorFlow backend.


In [2]:
# set GPU usage for tensorflow backend
if K.backend() == 'tensorflow':
    import tensorflow as tf
    config = tf.ConfigProto()
    config.gpu_options.per_process_gpu_memory_fraction = .1
    config.gpu_options.allow_growth = True
    K.set_session(tf.Session(config=config))

In [3]:
# parse arguments
## general
arg_parser = argparse.ArgumentParser()
arg_parser.add_argument('--working_path', default='.')

## data
arg_parser.add_argument('dataset_name', default='mimic3',
                        help='The data files should be saved in [working_path]/data/[dataset_name] directory.')
arg_parser.add_argument('label_name', default='mortality')
arg_parser.add_argument('--max_timesteps', type=int, default=200, 
                        help='Time series of at most # time steps are used. Default: 200.')
arg_parser.add_argument('--max_timestamp', type=int, default=48*60*60,
                        help='Time series of at most # seconds are used. Default: 48 (hours).')

## model
arg_parser.add_argument('--recurrent_dim', type=lambda x: x and [int(xx) for xx in x.split(',')] or [], default='64')
arg_parser.add_argument('--hidden_dim', type=lambda x: x and [int(xx) for xx in x.split(',')] or [], default='64')
arg_parser.add_argument('--model', default='GRUD', choices=['GRUD', 'GRUforward', 'GRU0', 'GRUsimple'])
arg_parser.add_argument('--use_bidirectional_rnn', default=False)
                           
## training
arg_parser.add_argument('--pretrained_model_file', default=None,
                        help='If pre-trained model is provided, training will be skipped.') # e.g., [model_name]_[i_fold].h5
arg_parser.add_argument('--epochs', type=int, default=100)
arg_parser.add_argument('--early_stopping_patience', type=int, default=10)
arg_parser.add_argument('--batch_size', type=int, default=32)


## set the actual arguments if running in notebook
if not (__name__ == '__main__' and '__file__' in globals()):
    ARGS = arg_parser.parse_args([
        'sample',
        'taskname',
        '--model', 'GRUD',
        '--hidden_dim', '',
        '--epochs', '100'
    ])
else:
    ARGS = arg_parser.parse_args()

print('Arguments:', ARGS)

Arguments: Namespace(batch_size=32, dataset_name='sample', early_stopping_patience=10, epochs=100, hidden_dim=[], label_name='taskname', max_timestamp=172800, max_timesteps=200, model='GRUD', pretrained_model_file=None, recurrent_dim=[64], use_bidirectional_rnn=False, working_path='.')


In [4]:
# get dataset
dataset = DataHandler(
    data_path=os.path.join(ARGS.working_path, 'data', ARGS.dataset_name), 
    label_name=ARGS.label_name,
    max_steps=ARGS.max_timesteps,
    max_timestamp=ARGS.max_timestamp
)

In [5]:
# k-fold cross-validation
pred_y_list_all = []
auc_score_list_all = []

timestamp = datetime.now().strftime('%Y%m%d_%H%M%S_%f')
print('Timestamp: {}'.format(timestamp))

for i_fold in range(dataset.folds):
    print('{}-th fold...'.format(i_fold))

    # Load or train the model.
    if ARGS.pretrained_model_file is not None:
        model = load_grud_model(os.path.join(ARGS.working_path, 
                                             ARGS.pretrained_model_file.format(i_fold=i_fold)))
    else:
        model = create_grud_model(input_dim=dataset.input_dim,
                                  output_dim=dataset.output_dim,
                                  output_activation=dataset.output_activation,
                                  recurrent_dim=ARGS.recurrent_dim,
                                  hidden_dim=ARGS.hidden_dim,
                                  predefined_model=ARGS.model,
                                  use_bidirectional_rnn=ARGS.use_bidirectional_rnn
                                 )
        if i_fold == 0:
            model.summary()
        model.compile(optimizer='adam', loss=dataset.loss_function)
        model.fit_generator(
            generator=dataset.training_generator(i_fold, batch_size=ARGS.batch_size),
            steps_per_epoch=dataset.training_steps(i_fold, batch_size=ARGS.batch_size),
            epochs=ARGS.epochs,
            verbose=1,
            validation_data=dataset.validation_generator(i_fold, batch_size=ARGS.batch_size),
            validation_steps=dataset.validation_steps(i_fold, batch_size=ARGS.batch_size),
            callbacks=[
                EarlyStopping(patience=ARGS.early_stopping_patience),
                ModelCheckpointwithBestWeights(
                    file_dir=os.path.join(ARGS.working_path, 'model', timestamp + '_' + str(i_fold))
                ),
                TensorBoard(
                    log_dir=os.path.join(ARGS.working_path, 'tb_logs', timestamp + '_' + str(i_fold))
                )
            ]
            )
        model.save(os.path.join(ARGS.working_path, 'model', 
                                timestamp + '_' + str(i_fold), 'model.h5'))

    # Evaluate the model
    true_y_list = [
        dataset.training_y(i_fold), dataset.validation_y(i_fold), dataset.testing_y(i_fold)
    ]
    pred_y_list = [
        model.predict_generator(dataset.training_generator_x(i_fold, batch_size=ARGS.batch_size),
                                steps=dataset.training_steps(i_fold, batch_size=ARGS.batch_size)),
        model.predict_generator(dataset.validation_generator_x(i_fold, batch_size=ARGS.batch_size),
                                steps=dataset.validation_steps(i_fold, batch_size=ARGS.batch_size)),
        model.predict_generator(dataset.testing_generator_x(i_fold, batch_size=ARGS.batch_size),
                                steps=dataset.testing_steps(i_fold, batch_size=ARGS.batch_size)),
    ]
    auc_score_list = [roc_auc_score(ty, py) for ty, py in zip(true_y_list, pred_y_list)] # [3, n_task]
    print('AUC score of this fold: {}'.format(auc_score_list))
    pred_y_list_all.append(pred_y_list)
    auc_score_list_all.append(auc_score_list)

print('Finished!', '='*20)
auc_score_list_all = np.stack(auc_score_list_all, axis=0)
print('Mean AUC score: {}; Std AUC score: {}'.format(
    np.mean(auc_score_list_all, axis=0),
    np.std(auc_score_list_all, axis=0)))

result_path = os.path.join(ARGS.working_path, 'results', timestamp)
if not os.path.exists(result_path):
    os.makedirs(result_path)
np.savez_compressed(os.path.join(result_path, 'predictions.npz'),
                    pred_y_list_all=pred_y_list_all)
np.savez_compressed(os.path.join(result_path, 'auroc_score.npz'),
                    auc_score_list_all=auc_score_list_all)


Timestamp: 20230505_184652_394531
0-th fold...
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            (None, None, 7)      0                                            
__________________________________________________________________________________________________
input_2 (InputLayer)            (None, None, 7)      0                                            
__________________________________________________________________________________________________
input_3 (InputLayer)            (None, None, 1)      0                                            
__________________________________________________________________________________________________
external_masking_1 (ExternalMas (None, None, 7)      0           input_1[0][0]                    
                                                              

Epoch 54/100
Epoch 55/100
Epoch 56/100
Epoch 57/100
Epoch 58/100
Epoch 59/100
Epoch 60/100
Epoch 61/100
Epoch 62/100
Epoch 63/100
Epoch 64/100
Epoch 65/100
Epoch 66/100
Epoch 67/100
Epoch 68/100
Epoch 69/100
Epoch 70/100
Epoch 71/100
Epoch 72/100
Epoch 73/100
Epoch 74/100
Epoch 75/100
Epoch 76/100
Epoch 77/100
Epoch 78/100
Epoch 79/100
Epoch 80/100
Epoch 81/100
Epoch 82/100
Epoch 83/100
Epoch 84/100
Epoch 85/100
Epoch 86/100
Epoch 87/100
Epoch 88/100
Epoch 89/100
Epoch 90/100
Epoch 91/100
Epoch 92/100
Epoch 93/100
Epoch 94/100
Epoch 95/100
Epoch 96/100
Epoch 97/100
Epoch 98/100
Epoch 99/100
Epoch 100/100
..........................AUC score of this fold: [0.8884166666666666, 0.8310439560439561, 0.8583453583453583]
1-th fold...
.Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100

Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 45/100
Epoch 46/100
Epoch 47/100
Epoch 48/100
Epoch 49/100
Epoch 50/100
Epoch 51/100
Epoch 52/100
Epoch 53/100
Epoch 54/100
Epoch 55/100
Epoch 56/100
Epoch 57/100
Epoch 58/100
Epoch 59/100
Epoch 60/100
Epoch 61/100
Epoch 62/100
Epoch 63/100
Epoch 64/100
Epoch 65/100
Epoch 66/100
Epoch 67/100
Epoch 68/100
Epoch 69/100
Epoch 70/100
Epoch 71/100
Epoch 72/100
Epoch 73/100
Epoch 74/100
Epoch 75/100
Epoch 76/100
Epoch 77/100
Epoch 78/100
Epoch 79/100
Epoch 80/100
Epoch 81/100
Epoch 82/100
Epoch 83/100
Epoch 84/100
Epoch 85/100
Epoch 86/100
Epoch 87/100
Epoch 88/100
Epoch 89/100
Epoch 90/100
Epoch 91/100
Epoch 92/100
Epoch 93/100
Epoch 94/100
Epoch 95/100
Epoch 96/100
Epoch 97/100
Epoch 98/100
Epoch 99/100
Epoch 100/100
..............................AUC score of this fold: [0.8989487467265245, 0.6987179487179488, 0.7529189560439561]
2-th fold...
.Epoch 1/100


Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 45/100
Epoch 46/100
Epoch 47/100
Epoch 48/100
Epoch 49/100
Epoch 50/100
Epoch 51/100
Epoch 52/100
Epoch 53/100
Epoch 54/100
Epoch 55/100
Epoch 56/100
Epoch 57/100
Epoch 58/100
Epoch 59/100
Epoch 60/100
Epoch 61/100
Epoch 62/100
Epoch 63/100
Epoch 64/100
Epoch 65/100
Epoch 66/100
Epoch 67/100
Epoch 68/100
Epoch 69/100
Epoch 70/100
Epoch 71/100
Epoch 72/100
Epoch 73/100
Epoch 74/100
Epoch 75/100
Epoch 76/100
Epoch 77/100
Epoch 78/100
Epoch 79/100
Epoch 80/100
Epoch 81/100
Epoch 82/100
Epoch 83/100
Epoch 84/100
Epoch 85/100
Epoch 86/100
Epoch 87/100
Epoch 88/100
Epoch 89/100
Epoch 90/100
Epoch 91/100
Epoch 92/100

Epoch 98/100
Epoch 99/100
Epoch 100/100
..............................AUC score of this fold: [0.8761281799476244, 0.7290140415140415, 0.7393162393162394]
3-th fold...
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 45/100
Epoch 46/100
Epoch 47/100
Epoch 48/100
Epoch 49/100
Epoch 50/100
Epoch 51/100
Epoch 52/100
Epoch 53/100
Epoch 54/100
Epoch 55/100
Epoch 56/100
Epoch 57/100
Epoch 58/100
Epoch 59/100
Epoch 60/100
Epoch 61/100
Epoch 62/100
Epoch 63/100
Epoch 64/100
Epoch 65/

Epoch 80/100
Epoch 81/100
Epoch 82/100
Epoch 83/100
Epoch 84/100
Epoch 85/100
Epoch 86/100
Epoch 87/100
Epoch 88/100
Epoch 89/100
Epoch 90/100
Epoch 91/100
Epoch 92/100
Epoch 93/100
Epoch 94/100
Epoch 95/100
Epoch 96/100
Epoch 97/100
Epoch 98/100
Epoch 99/100
Epoch 100/100
...............................AUC score of this fold: [0.9136755332407507, 0.7151274651274652, 0.8125381562881563]
4-th fold...
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 45/100
Epoch 46/100
Epoch 47

Epoch 61/100
Epoch 62/100
Epoch 63/100
Epoch 64/100
Epoch 65/100
Epoch 66/100
Epoch 67/100
Epoch 68/100
Epoch 69/100
Epoch 70/100
Epoch 71/100
Epoch 72/100
Epoch 73/100
Epoch 74/100
Epoch 75/100
Epoch 76/100
Epoch 77/100
Epoch 78/100
Epoch 79/100
Epoch 80/100
Epoch 81/100
Epoch 82/100
Epoch 83/100
Epoch 84/100
Epoch 85/100
Epoch 86/100
Epoch 87/100
Epoch 88/100
Epoch 89/100
Epoch 90/100
Epoch 91/100
Epoch 92/100
Epoch 93/100
Epoch 94/100
Epoch 95/100
Epoch 96/100
Epoch 97/100
Epoch 98/100
Epoch 99/100
Epoch 100/100
...............................AUC score of this fold: [0.8559557726224393, 0.8742183742183741, 0.731000481000481]
Mean AUC score: [0.88662498 0.76962436 0.77882384]; Std AUC score: [0.01968553 0.06979843 0.04894755]
