In [1]:
import os
import sys
from datetime import datetime, timedelta

import pandas as pd

root = os.path.abspath('../')
if root not in sys.path:
    sys.path.append(root)
    
from vbridge.data_loader.data import create_entityset
from vbridge.data_loader.utils import load_entityset
from vbridge.featurization.feature import Featurization
from vbridge.modeling.model import Model, ModelManager

## 0. Building EntitySet

In [2]:
# es = create_entityset('pic', verbose=False, save=True)
es = load_entityset()

## 1. Featurization

In [3]:
target_entity = 'ADMISSIONS'
feat = Featurization(es, target_entity)

In [4]:
cutoff_time = es[target_entity].df.loc[:, [es[target_entity].index, es[target_entity].time_index]]
cutoff_time.columns = ['instance_id', 'time']
cutoff_time['time'] += timedelta(hours=48)

In [5]:
fm, fl = feat.generate_features(cutoff_time=cutoff_time)

Remove:  ['MEAN(CHARTEVENTS.VALUENUM)', 'STD(CHARTEVENTS.VALUENUM)', 'TREND(CHARTEVENTS.VALUENUM, CHARTTIME)', 'MEAN(SURGERY_VITAL_SIGNS.VALUE)', 'STD(SURGERY_VITAL_SIGNS.VALUE)', 'TREND(SURGERY_VITAL_SIGNS.VALUE, MONITOR_TIME)', 'MEAN(LABEVENTS.VALUENUM)']


In [6]:
label = es['ADMISSIONS'].df['HOSPITAL_EXPIRE_FLAG']

## 2. Model Training

In [7]:
fm.index = fm.index.astype('str')

model_manager = ModelManager(fm)
model = Model()
model_manager.add_model(model, label, 'mortality')

In [8]:
model_manager.fit_all()
model_manager.evaluate()

Unnamed: 0,AUROC,Accuracy,Confusion Matrix,F1 Macro,Precision,Recall
mortality,0.881766,0.961924,"[[478, 8], [11, 2]]",0.577213,0.588753,0.568693


## 3. Model Explanation

In [9]:
shap_values = model.SHAP(fm)

In [10]:
shap_values.sum().sort_values()

MEAN(CHARTEVENTS.VALUENUM WHERE ITEMID = 1006)                      -2757.764648
MEAN(LABEVENTS.VALUENUM WHERE ITEMID = 5136)                        -1180.693115
MEAN(LABEVENTS.VALUENUM WHERE ITEMID = 5099)                         -739.795776
MEAN(LABEVENTS.VALUENUM WHERE ITEMID = 5002)                         -717.263916
MEAN(LABEVENTS.VALUENUM WHERE ITEMID = 5141)                         -661.361267
                                                                        ...     
STD(SURGERY_VITAL_SIGNS.VALUE WHERE ITEMID = SV6)                       0.000000
TREND(SURGERY_VITAL_SIGNS.VALUE, MONITOR_TIME WHERE ITEMID = SV1)       0.000000
MEAN(LABEVENTS.VALUENUM WHERE ITEMID = 5218)                            1.311347
MEAN(LABEVENTS.VALUENUM WHERE ITEMID = 5212)                           13.428649
MEAN(LABEVENTS.VALUENUM WHERE ITEMID = 5257)                           17.537210
Length: 105, dtype: float32