In [1]:
import os
import sys
from datetime import datetime, timedelta

import pandas as pd

root = os.path.abspath('../')
if root not in sys.path:
    sys.path.append(root)
    
from vbridge.data_loader.data import create_entityset
from vbridge.data_loader.utils import load_entityset
from vbridge.featurization.feature import Featurization
from vbridge.modeling.model import Model, ModelManager

## 0. Building EntitySet

In [2]:
# es = create_entityset('pic', verbose=False, save=True)
es = create_entityset('pic')

## 1. Featurization

In [3]:
target_entity = 'ADMISSIONS'
feat = Featurization(es, target_entity)

In [4]:
cutoff_time = es[target_entity].df.loc[:, [es[target_entity].index, es[target_entity].time_index]]
cutoff_time.columns = ['instance_id', 'time']
cutoff_time['time'] += timedelta(hours=48)

In [5]:
fm, fl = feat.generate_features(cutoff_time=cutoff_time)

In [6]:
label = es['ADMISSIONS'].df['HOSPITAL_EXPIRE_FLAG']

## 2. Model Training

In [7]:
fm.index = fm.index.astype('str')

model_manager = ModelManager(fm)
model_manager.add_model(label=label, 'mortality')

In [8]:
model_manager.fit_all()
model_manager.evaluate()

Unnamed: 0,AUROC,Accuracy,Confusion Matrix,F1 Macro,Precision,Recall
mortality,0.983333,0.985972,"[[490, 5], [2, 2]]",0.678272,0.640825,0.744949


## 3. Model Explanation

In [9]:
shap_values = model_manager.explain(target='mortality')

In [10]:
shap_values.sum().sort_values()

MEAN(LABEVENTS.VALUENUM WHERE ITEMID = 5136)                        -1730.733154
MEAN(CHARTEVENTS.VALUENUM WHERE ITEMID = 1004)                      -1470.505981
MEAN(CHARTEVENTS.VALUENUM WHERE ITEMID = 1006)                      -1279.977051
MEAN(LABEVENTS.VALUENUM WHERE ITEMID = 5099)                         -853.524658
MEAN(LABEVENTS.VALUENUM WHERE ITEMID = 5002)                         -408.303589
                                                                        ...     
MEAN(SURGERY_VITAL_SIGNS.VALUE WHERE ITEMID = SV17)                     0.848345
MEAN(LABEVENTS.VALUENUM WHERE ITEMID = 5212)                            1.367452
TREND(SURGERY_VITAL_SIGNS.VALUE, MONITOR_TIME WHERE ITEMID = SV3)       1.819924
MEAN(LABEVENTS.VALUENUM WHERE ITEMID = 5225)                            9.055139
MEAN(LABEVENTS.VALUENUM WHERE ITEMID = 6472)                           42.819359
Length: 105, dtype: float32