In [1]:
from os import getcwd
from os.path import join
import sys
sys.path.insert(0, join(getcwd(), "../module_code"))

import pandas as pd
from mlflow.tracking import MlflowClient

from models.static_models import CRRTStaticPredictor
from data.sklearn_loaders import SklearnCRRTDataModule
from main import get_mlflow_model_uri
from cli_utils import init_cli_args, load_cli_args
from data.argparse_utils import string_list_to_list, string_dict_to_dict
from data.subpopulation_utils import generate_filters, combine_filters

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
client = MlflowClient("/home/davina/Private/repos/CRRT/mlruns/")
run_id = "cdcbd1d31e104b448da2f3020aaae4e8"
run = client.get_run(run_id)
logged_args = run.data.tags

sys.argv = [sys.argv[0]]
load_cli_args("../options.yml")
for k,v in logged_args.items():
    if v != "None" and f"--{k}" not in sys.argv:
        sys.argv += [f"--{k}", v]
args = init_cli_args()

path = get_mlflow_model_uri(run)

model = CRRTStaticPredictor.load(args, path)
data = SklearnCRRTDataModule.from_argparse_args(args)
data.setup(args)

In [4]:
filters = generate_filters()
dz_sex = combine_filters(filters["disease_indicator"], filters["sex"])

In [5]:
X, y = data.test
X = pd.DataFrame(X, columns=data.columns, index=y.index)
masks = {k: data.get_filter(X, *args) for k, args in dz_sex.items()}
data.test_filters = masks
model.evaluate(data, "test")

{'heart_female': IP_PATIENT_ID                     Start Date
 00000064C94859EB973942C862550ABA  2020-08-10    False
 000008E26F8F14E5099662DDB778C698  2017-09-23    False
 00000AC44FF1345FEE4B73057620EC0F  2017-10-02    False
 000014471A50BA85E236FC7DDA593FE6  2017-11-26    False
 000015CA913525E268FA991FD4AED560  2012-05-03    False
                                                 ...  
 0117DEA6B296A4BF14116169C72C9248  2013-04-25    False
 02E1D0B786923B55CCD9684C0BB305D4  2009-03-05    False
 045A0C05D0C3B66C71A4B8EC722A31EF  2010-05-04    False
 16FC1EF68F3CAD586FA58E7D8912555D  2011-09-24    False
 1FA1377CDE80E07206C12D0C4613ABAE  2010-08-24    False
 Length: 3560, dtype: bool,
 'heart_male': IP_PATIENT_ID                     Start Date
 00000064C94859EB973942C862550ABA  2020-08-10     True
 000008E26F8F14E5099662DDB778C698  2017-09-23    False
 00000AC44FF1345FEE4B73057620EC0F  2017-10-02     True
 000014471A50BA85E236FC7DDA593FE6  2017-11-26     True
 000015CA913525E268FA991F

In [6]:
data.test

(array([[6.30000000e+01, 5.77634358e+01, 2.02000000e+03, ...,
         0.00000000e+00, 1.00000000e+00, 1.00000000e+00],
        [3.00000000e+01, 5.77634358e+01, 2.01700000e+03, ...,
         0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        [4.70000000e+01, 5.77634358e+01, 2.01700000e+03, ...,
         0.00000000e+00, 0.00000000e+00, 1.00000000e+00],
        ...,
        [9.10000000e+01, 5.77634358e+01, 2.01000000e+03, ...,
         0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        [6.50000000e+01, 5.77634358e+01, 2.01100000e+03, ...,
         0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
        [6.60000000e+01, 5.77634358e+01, 2.01000000e+03, ...,
         0.00000000e+00, 0.00000000e+00, 0.00000000e+00]]),
 IP_PATIENT_ID                     Start Date
 00000064C94859EB973942C862550ABA  2020-08-10    0
 000008E26F8F14E5099662DDB778C698  2017-09-23    0
 00000AC44FF1345FEE4B73057620EC0F  2017-10-02    0
 000014471A50BA85E236FC7DDA593FE6  2017-11-26    0
 000015CA913525