In [1]:
%load_ext autoreload
%autoreload 2

import pandas as pd
import matplotlib.pyplot as plt
import os
import json
from typing import List, Dict, Tuple, Optional
from tqdm import tqdm
import seaborn as sns
tqdm.pandas()
sns.set_context("paper")

In [2]:
path_to_results_dir: str = '/share/pi/nigam/mwornow/ehrshot-benchmark/EHRSHOT_ASSETS/results_ehrshot'
tasks = os.listdir(path_to_results_dir)
print("Tasks: ", tasks)

Tasks:  ['new_hypertension', 'guo_los', 'lab_hypoglycemia', 'new_lupus', 'lab_hyponatremia', 'new_pancan', 'lab_anemia', 'new_acutemi', 'chexpert', 'guo_readmission', 'lab_thrombocytopenia', 'new_hyperlipidemia', 'new_celiac', 'lab_hyperkalemia', 'guo_icu']


In [3]:
paths = []
for task in tqdm(tasks, desc='Finding paths...'):
    path_to_task_dir: str = os.path.join(path_to_results_dir, task, 'models')
    if not os.path.exists(path_to_task_dir): 
        continue
    models = os.listdir(path_to_task_dir)
    for model in models:
        path_to_model_dir: str = os.path.join(path_to_task_dir, model)
        heads = os.listdir(path_to_model_dir)
        for head in heads:
            path_to_head_dir: str = os.path.join(path_to_model_dir, head)
            subtasks = os.listdir(path_to_head_dir)
            for subtask in subtasks:
                path_to_subtask_dir: str = os.path.join(path_to_head_dir, subtask)
                ks = os.listdir(path_to_subtask_dir)
                for k in ks:
                    path_to_k_dir: str = os.path.join(path_to_subtask_dir, k)
                    paths.append({
                        'path' : path_to_k_dir,
                        'task' : task,
                        'model' : model,
                        'head' : head,
                        'k' : k,
                        'subtask' : subtask,
                    })
print("Found {} paths".format(len(paths)))

Finding paths...: 100%|██████████| 15/15 [01:19<00:00,  5.28s/it]

Found 1585 paths





In [4]:
print("# of expected tasks (just counting 'count' + 'clmbr' models):", 4 * 12 * (15 + 13))

# of expected tasks (just counting 'count' + 'clmbr' models): 1344


In [5]:
results = []
for p in tqdm(paths, desc='Loading results...'):
    path, task, model, head, k, subtask = p['path'], p['task'], p['model'], p['head'], p['k'], p['subtask']
    path_to_hparams_json: str = os.path.join(path, 'model_hparams.json')
    hparams: dict = json.load(open(path_to_hparams_json, 'r')).get("model_hparams", {})
    auroc: float = json.load(open(path_to_hparams_json, 'r')).get("scores", {}).get("auroc", {}).get("score")
    hparams = { f"hparam_{k}": v for k, v in hparams.items() }
    results.append({
        **p,
        **hparams
    })
df = pd.DataFrame(results)

Loading results...:   0%|          | 0/1585 [00:00<?, ?it/s]

Loading results...: 100%|██████████| 1585/1585 [14:45<00:00,  1.79it/s]


In [6]:
df = df[df['model'].isin(['count', 'clmbr'])]
df.shape

(1333, 49)

### LogReg

In [7]:
df[df['head'] == 'lr_lbfgs']['hparam_C'].value_counts()

hparam_C
1.000000e-08    194
1.000000e-02    131
1.000000e-01     89
1.000000e-03     63
1.000000e+04     33
1.000000e+02     31
1.000000e+00     28
1.000000e+06     25
1.000000e-04     25
1.000000e+03     23
1.000000e+05     19
1.000000e-05      7
1.000000e-06      4
Name: count, dtype: int64

### Random Forest

In [8]:
df[df['head'] == 'rf']['hparam_n_estimators'].value_counts()

hparam_n_estimators
300.0    112
100.0     62
10.0      52
50.0      51
20.0      48
Name: count, dtype: int64

In [9]:
df[df['head'] == 'rf']['hparam_max_depth'].value_counts()

hparam_max_depth
3.0     162
5.0      72
10.0     48
20.0     33
50.0     10
Name: count, dtype: int64

### GBM

In [10]:
df[df['head'] == 'gbm']['hparam_max_depth'].value_counts()

hparam_max_depth
 3.0    207
 6.0     89
-1.0     40
Name: count, dtype: int64

In [11]:
df[df['head'] == 'gbm']['hparam_learning_rate'].value_counts()

hparam_learning_rate
0.02    128
0.50    120
0.10     88
Name: count, dtype: int64

In [12]:
df[df['head'] == 'gbm']['hparam_num_leaves'].value_counts()

hparam_num_leaves
10.0     287
25.0      32
100.0     17
Name: count, dtype: int64