In [1]:
import numpy as np
import os
import torch
from sklearn.externals import joblib

from pytorch_utils.datasets import ArrayDataset
from pytorch_utils.models import SparseModel, SparseModelEmbed
import pytorch_utils

In [2]:
data_path = 'data/'
features_path = os.path.join(data_path, 'features', str(0))
label_path = os.path.join(data_path, 'labels')

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
features_dict = joblib.load(os.path.join(features_path, 'features.pkl'))
label_dict = joblib.load(os.path.join(label_path, 'label_dict.pkl'))

In [5]:
outcome = 'mortality'

In [6]:
data_dict = {split: features_dict[split]['features'] for split in features_dict.keys()}
outcome_dict = {split : label_dict[split][outcome] for split in label_dict.keys()}

In [7]:
config_dict = {
    'input_dim' : data_dict['train'].shape[1],
    'output_dim' : 2,
    'lr' : 1e-3,
    'num_epochs' : 3,
    'batch_size' : 256,
    'iters_per_epoch' : None
}

In [8]:
num_samples = 1000
small_data_dict = {key: data_dict[key][:num_samples] if key == 'train' else data_dict[key] for key in data_dict.keys()}
small_outcome_dict = {key: outcome_dict[key][:num_samples] if key == 'train' else outcome_dict[key] for key in data_dict.keys()}
small_data_dict

{'train': <1000x368117 sparse matrix of type '<class 'numpy.float32'>'
 	with 285517 stored elements in Compressed Sparse Row format>,
 'test': <12963x368117 sparse matrix of type '<class 'numpy.float32'>'
 	with 3694263 stored elements in Compressed Sparse Row format>,
 'val': <12964x368117 sparse matrix of type '<class 'numpy.float32'>'
 	with 3668580 stored elements in Compressed Sparse Row format>}

In [9]:
%%time
torch.manual_seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
model1 = SparseModelEmbed(config_dict)
result1 = model1.train(data_dict, outcome_dict)
# result = model.train(small_data_dict, small_outcome_dict)
print(model1.predict(data_dict, outcome_dict, keys = ['test']))

Epoch 0/2
----------
Phase: train:
 loss: 0.101979,
 auc: 0.790042, auprc: 0.121907, brier: 0.022461,
Phase: val:
 loss: 0.085918,
 auc: 0.854748, auprc: 0.173890, brier: 0.020091,
Best model updated
Epoch 1/2
----------
Phase: train:
 loss: 0.059117,
 auc: 0.953227, auprc: 0.513295, brier: 0.015136,
Phase: val:
 loss: 0.084807,
 auc: 0.854411, auprc: 0.197855, brier: 0.019558,
Best model updated
Epoch 2/2
----------
Phase: train:
 loss: 0.045614,
 auc: 0.977480, auprc: 0.698895, brier: 0.011768,
Phase: val:
 loss: 0.086255,
 auc: 0.852349, auprc: 0.209351, brier: 0.019425,
Best val performance: 0.084807
({'test': {'outputs': array([[ 1.9271259 , -1.6779692 ],
       [ 3.256986  , -2.981594  ],
       [ 3.3569183 , -3.0938375 ],
       ...,
       [ 1.4305469 , -1.184189  ],
       [ 0.64315784, -0.35853004],
       [ 3.1437624 , -2.893929  ]], dtype=float32), 'pred_probs': array([[0.9735346 , 0.0264654 ],
       [0.99805117, 0.00194882],
       [0.9984232 , 0.00157684],
       ...,
  

In [10]:
%%time
torch.manual_seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

model2 = SparseModel(config_dict)
result2 = model2.train(data_dict, outcome_dict)
# result = model.train(small_data_dict, small_outcome_dict)
print(model2.predict(data_dict, outcome_dict, keys = ['test']))

Epoch 0/2
----------
Phase: train:
 loss: 0.116151,
 auc: 0.755395, auprc: 0.088173, brier: 0.025765,
Phase: val:
 loss: 0.090624,
 auc: 0.835400, auprc: 0.145769, brier: 0.020806,
Best model updated
Epoch 1/2
----------
Phase: train:
 loss: 0.063015,
 auc: 0.942981, auprc: 0.462766, brier: 0.015970,
Phase: val:
 loss: 0.087119,
 auc: 0.846306, auprc: 0.178966, brier: 0.020002,
Best model updated
Epoch 2/2
----------
Phase: train:
 loss: 0.047944,
 auc: 0.973369, auprc: 0.669487, brier: 0.012277,
Phase: val:
 loss: 0.087486,
 auc: 0.847393, auprc: 0.197522, brier: 0.019704,
Best val performance: 0.087119
({'test': {'outputs': array([[ 1.8706224 , -1.24862   ],
       [ 3.4919553 , -2.7763348 ],
       [ 3.6205487 , -2.983833  ],
       ...,
       [ 1.6864917 , -1.115145  ],
       [ 0.69924116, -0.07837975],
       [ 3.276644  , -2.7013903 ]], dtype=float32), 'pred_probs': array([[0.9576795 , 0.04232046],
       [0.99810815, 0.00189188],
       [0.9986474 , 0.00135259],
       ...,
  