In [23]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from tqdm import tqdm
import os
import sklearn.utils
from typing import List, Optional
from sklearn.metrics import roc_auc_score
import collections
import warnings
import pickle
import femr.datasets
import datetime
from hf_ehr.notebooks.ehr_specific_properties import utils
warnings.filterwarnings("ignore", category=FutureWarning, module="sklearn.utils.validation")

PATH_TO_DATABASE: str = '/share/pi/nigam/mwornow/ehrshot-benchmark/EHRSHOT_ASSETS/femr/extract'
PATH_TO_FEATURES_DIR: str = '/share/pi/nigam/mwornow/ehrshot-benchmark/EHRSHOT_ASSETS/features_ehrshot'
PATH_TO_RESULTS_DIR: str = '/share/pi/nigam/migufuen/ehrshot-benchmark/EHRSHOT_ASSETS/results_ehrshot'
PATH_TO_TOKENIZED_TIMELINES_DIR: str = '/share/pi/nigam/mwornow/ehrshot-benchmark/EHRSHOT_ASSETS/tokenized_timelines_ehrshot'
PATH_TO_LABELS_DIR: str = '/share/pi/nigam/mwornow/ehrshot-benchmark/EHRSHOT_ASSETS/benchmark_ehrshot'
PATH_TO_SPLIT_CSV: str = '/share/pi/nigam/mwornow/ehrshot-benchmark/EHRSHOT_ASSETS/splits_ehrshot/person_id_map.csv'
femr_db = femr.datasets.PatientDatabase(PATH_TO_DATABASE)
os.makedirs('../cache', exist_ok=True)

IS_LOAD_FROM_CACHE: bool = False


# Load Data

First, we need to load all of the patient-level predictions / labels for every model and task.
This takes ~30 min.

1. Load list of task names, model names
2. Load patient-level predictions for each (model, task)
3. Load patient-level ground truth labels for each (task)

In [10]:
# Get list of tasks
valid_tasks = os.listdir(PATH_TO_RESULTS_DIR)
valid_tasks.remove('chexpert')

LABELING_FUNCTION_2_PAPER_NAME = {
    # Guo et al. 2023
    "guo_los": "Long LOS",
    "guo_readmission": "30-day Readmission",
    "guo_icu": "ICU Admission",
    # New diagnosis
    "new_pancan": "Pancreatic Cancer",
    "new_celiac": "Celiac",
    "new_lupus": "Lupus",
    "new_acutemi": "Acute MI",
    "new_hypertension": "Hypertension",
    "new_hyperlipidemia": "Hyperlipidemia",
    # Instant lab values
    "lab_thrombocytopenia": "Thrombocytopenia",
    "lab_hyperkalemia": "Hyperkalemia",
    "lab_hypoglycemia": "Hypoglycemia",
    "lab_hyponatremia": "Hyponatremia",
    "lab_anemia": "Anemia",
    # Custom tasks
    "chexpert": "Chest X-ray Findings",
}

# Filter results to only valid models
valid_models = [ 
    'llama-base-512--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last',
    'llama-base-1024--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last',
    'llama-base-2048--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last',
    'llama-base-4096--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last',
    'gpt2-base-512--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last',
    'gpt2-base-1024--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last',
    'gpt2-base-2048--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last',
    'gpt2-base-4096--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last',
    'hyena-large-1024--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last',
    'hyena-large-4096--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last',
    'hyena-large-8192--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last',
    'hyena-large-16384--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last',
    'mamba-tiny-1024--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last',
    'mamba-tiny-4096--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last',
    'mamba-tiny-8192--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last',
    'mamba-tiny-16384--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last', 
    'clmbr',
]
print("Tasks:", valid_tasks)
print("# of valid models:", len(valid_models))

Tasks: ['new_hypertension', 'guo_los', 'lab_hypoglycemia', 'new_lupus', 'lab_hyponatremia', 'new_pancan', 'lab_anemia', 'new_acutemi', 'guo_readmission', 'lab_thrombocytopenia', 'new_hyperlipidemia', 'new_celiac', 'lab_hyperkalemia', 'guo_icu']
# of valid models: 17


In [25]:
# Load patient-level predictions for each (model, task)
predictions = {} # [key] = (model, task), [value] = df_preds

if IS_LOAD_FROM_CACHE and os.path.exists('../cache/predictions.pkl'):
    print("Loading predictions from cache...")
    predictions = pickle.load(open('../cache/predictions.pkl', 'rb'))
else:
    for task in tqdm(valid_tasks):
        for model in valid_models:
            path = os.path.join(PATH_TO_RESULTS_DIR, task, 'models', model, 'lr_lbfgs', f'subtask={task}', 'k=-1', 'preds.csv')
            if not os.path.exists(path):
                print("Missing path for ", model, task)
            df_preds = pd.read_csv(path)
            assert df_preds['replicate'].nunique() == 1, f"Multiple replicates for {model}, {task}"
            predictions[(model, task)] = df_preds
    # Save results to .pkl file
    with open('../cache/predictions.pkl', 'wb') as f:
        pickle.dump(predictions, f)
print("# of (model, task) preds:", len(predictions))

100%|██████████| 14/14 [02:12<00:00,  9.47s/it]


# of (model, task) preds: 238


In [26]:
# Load patient-level labels for each task
# NOTE: Takes ~1 hr
label_data = {} # [key] = task, [value] = {'times': label_times, 'values': label_values, 'patient_ids': patient_ids}

if IS_LOAD_FROM_CACHE and os.path.exists('../cache/label_data.pkl'):
    print("Loading label data from cache...")
    label_data = pickle.load(open('../cache/label_data.pkl', 'rb'))
else:
    for task in tqdm(valid_tasks):
        # Load labeled patients for this task
        LABELING_FUNCTION: str = task
        PATH_TO_LABELED_PATIENTS: str =  os.path.join(PATH_TO_LABELS_DIR, LABELING_FUNCTION, 'labeled_patients.csv')
        labeled_patients = femr.labelers.load_labeled_patients(PATH_TO_LABELED_PATIENTS)
        
        # Get features for patients
        model: str = 'gpt2-base-512--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last'
        patient_ids, label_values, label_times, feature_matrixes = utils.get_labels_and_features(labeled_patients, 
                                                                                            PATH_TO_FEATURES_DIR, 
                                                                                            PATH_TO_TOKENIZED_TIMELINES_DIR,
                                                                                            models_to_keep=[model,])
        train_pids_idx, val_pids_idx, test_pids_idx = utils.get_patient_splits_by_idx(PATH_TO_SPLIT_CSV, patient_ids)
        label_times = label_times[train_pids_idx + val_pids_idx + test_pids_idx]
        label_values = label_values[train_pids_idx + val_pids_idx + test_pids_idx]
        patient_ids = patient_ids[train_pids_idx + val_pids_idx + test_pids_idx]
        label_times = [ x.astype(datetime.datetime) for x in label_times ] # cast to Python datetime
        label_data[task] = {'times': label_times, 'values': label_values, 'patient_ids': patient_ids}

    # Save results to .pkl file
    with open('../cache/label_data.pkl', 'wb') as f:
        pickle.dump(label_data, f)

  0%|          | 0/14 [00:00<?, ?it/s]

Processing features for model: gpt2-base-512--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last


  7%|▋         | 1/14 [00:15<03:27, 15.94s/it]

Processing features for model: gpt2-base-512--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last


 14%|█▍        | 2/14 [00:39<04:05, 20.47s/it]

Processing features for model: gpt2-base-512--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last


 21%|██▏       | 3/14 [14:40<1:12:25, 395.04s/it]

Processing features for model: gpt2-base-512--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last


 29%|██▊       | 4/14 [15:03<41:23, 248.33s/it]  

Processing features for model: gpt2-base-512--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last


 36%|███▌      | 5/14 [24:28<54:23, 362.60s/it]

Processing features for model: gpt2-base-512--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last


 43%|████▎     | 6/14 [24:54<33:03, 247.97s/it]

Processing features for model: gpt2-base-512--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last


 50%|█████     | 7/14 [33:07<38:17, 328.22s/it]

Processing features for model: gpt2-base-512--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last


 57%|█████▋    | 8/14 [33:33<23:12, 232.04s/it]

Processing features for model: gpt2-base-512--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last


 64%|██████▍   | 9/14 [33:58<13:55, 167.05s/it]

Processing features for model: gpt2-base-512--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last


 71%|███████▏  | 10/14 [41:57<17:34, 263.52s/it]

Processing features for model: gpt2-base-512--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last


 79%|███████▊  | 11/14 [42:15<09:25, 188.44s/it]

Processing features for model: gpt2-base-512--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last


 86%|████████▌ | 12/14 [42:39<04:36, 138.33s/it]

Processing features for model: gpt2-base-512--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last


 93%|█████████▎| 13/14 [51:29<04:17, 257.04s/it]

Processing features for model: gpt2-base-512--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last


100%|██████████| 14/14 [51:53<00:00, 222.37s/it]


# Calculate CIs

Calculate bootstrapped 95% CIs over the test set (one sample per patient).

1. Generate 1000 bootstrapped resamples of patient IDs across all splits
2. Loop through every (model, task) preds, limit to test set, loop through 1000 bootstrap weightings and calculate AUROC with patients reweighted accordingly (ignoring non-test patients)
3. Calculate AUROC delta between each (model, task) and a **baseline** model (e.g. CLMBR, gpt-base-512, etc.)
4. Calculate 95% CIs

### Do Bootstrapping of AUROC Deltas

In [13]:
# Get all patient IDs
all_patients = np.sort(np.unique(np.concatenate([v['patient_ids'] for v in label_data.values()])))
print(len(all_patients))

# Patient-level resampling across all patient IDs in train/val/test
# Later, we limit to just test patient IDs per task
bootstrap_weights = []
np.random.seed(342342)
for i in range(1000):
    patient_sample = np.random.choice(list(range(len(all_patients))), len(all_patients), replace=True)
    weights = np.zeros_like(all_patients, dtype=np.float32)
    np.add.at(weights, patient_sample, 1)
    assert np.mean(weights) == 1
    bootstrap_weights.append(weights)

6275


In [27]:
# Calculate AUROC scores for each (model, task) over bootstrapped resamples
# NOTE: Takes ~1 hr
auroc_scores = collections.defaultdict(list) # [key] = (model, task), [value] = auroc's across 1k resamples

if IS_LOAD_FROM_CACHE and os.path.exists('../cache/auroc_scores.pkl'):
    print("Loading auroc scores from cache...")
    auroc_scores = pickle.load(open('../cache/auroc_scores.pkl', 'rb'))
else:
    # Do bootstrap resampling for each (model, task)
    for (model, task), df_preds in tqdm(predictions.items(), total=len(predictions)):
        # Test split
        df_preds['patient_ids'] = label_data[task]['patient_ids']
        df_ = df_preds[df_preds['split'] == 'test']

        # Labels / Preds
        y = df_['y'].values
        pred_proba = df_['pred_proba'].values

        # patient_ids 
        patient_ids = df_['patient_ids'].values
        patient_id_indices = np.searchsorted(all_patients, patient_ids)
        assert np.all(patient_ids == all_patients[patient_id_indices])
        
        for weights in bootstrap_weights:
            auroc = roc_auc_score(y, pred_proba, sample_weight=weights[patient_id_indices])
            auroc_scores[(model, task)].append(auroc)

    # Save results to .pkl file
    with open('../cache/auroc_scores.pkl', 'wb') as f:
        pickle.dump(dict(auroc_scores), f)

100%|██████████| 238/238 [37:50<00:00,  9.54s/it] 


In [28]:
# Calculate delta in AUROC scores for each (model, task) v. baseline
baseline_auroc_deltas = {} # [key] = baseline, [value] = auroc_deltas for all (model, task)
for baseline in [
    'clmbr',
    'gpt2-base-512--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last',
]:
    auroc_deltas = collections.defaultdict(list) # [key] = (model, task), [value] = diff in auroc v. CLMBR across 1k resamples
    if IS_LOAD_FROM_CACHE and os.path.exists(f'../cache/auroc_deltas_{baseline}.pkl'):
        print(f"Loading auroc deltas for {baseline} from cache...")
        auroc_deltas = pickle.load(open(f'../cache/auroc_deltas_{baseline}.pkl', 'rb'))
    else:
        for (model, task), score in tqdm(auroc_scores.items(), total=len(auroc_scores)):
            deltas = np.array(auroc_scores[(model, task)]) - np.array(auroc_scores[(baseline, task)])
            auroc_deltas[(model, task)] = deltas.tolist()

        # Save results to .pkl file
        with open(f'../cache/auroc_deltas_{baseline}.pkl', 'wb') as f:
            pickle.dump(dict(auroc_deltas), f)
    baseline_auroc_deltas[baseline] = auroc_deltas

100%|██████████| 238/238 [00:00<00:00, 9060.53it/s]
100%|██████████| 238/238 [00:00<00:00, 9001.87it/s]


In [41]:
# Formatting
def clean_df_deltas(df: pd.DataFrame) -> pd.DataFrame:
    df['model_name'] = df['model'].apply(lambda x: x.split('--')[0])
    df['model_base'] = df['model'].apply(lambda x: x.split('-')[0])
    df['ctx_length'] = df['model'].apply(lambda x: int(x.split('--')[0].split('-')[-1]) if x != 'clmbr' else 0).astype(int)
    df['formatted_auroc_delta_mean'] = df['auroc_delta_mean'].apply(lambda x: f"{x:.3f}")
    df['formatted_auroc_delta'] = df.apply(lambda x: f"{x['auroc_delta_mean']:.3f} ({x['auroc_delta_ci_025']:.3f}, {x['auroc_delta_ci_975']:.3f})", axis=1)
    df['formatted_auroc_ci'] = df.apply(lambda x: f"({x['auroc_delta_ci_025']:.3f}, {x['auroc_delta_ci_975']:.3f})", axis=1)
    df['is_auroc_win'] = (df['auroc_delta_ci_025'] > 0) & (df['auroc_delta_ci_975'] > 0)
    df['is_stat_sig'] = (df['auroc_delta_ci_025'] > 0) | (df['auroc_delta_ci_975'] < 0)
    df = df[['model', 'model_name', 'model_base', 'ctx_length'] + [col for col in df.columns if col not in ['model', 'model_name', 'model_base', 'ctx_length']]]
    df = df.sort_values(['model_base', 'ctx_length', ] + ([ 'task' ] if 'task' in df.columns else []))
    return df

def format_df_deltas(df: pd.DataFrame, model_start: Optional[str] = None) -> pd.DataFrame:
    if model_start is not None:
        df = df[df['model'].str.startswith(model_start)]
    df = df.drop(columns=['model', 'model_name', 'formatted_auroc_delta', 'auroc_delta_mean', 'auroc_delta_ci_025', 'auroc_delta_ci_500', 'auroc_delta_ci_975', 'is_auroc_win'], errors='ignore') \
        .rename(columns={   'model_base' : 'Model', 
                            'ctx_length' : 'Context Length', 
                            'formatted_auroc_delta_mean' : r'$\Delta$ over baseline', 
                            'formatted_auroc_ci' : '95% CI', 
                            'is_stat_sig' : 'Statistically Significant'} | ({ 'task' : 'Task'} if 'task' in df.columns else {})) \
        .sort_values(['Model', 'Context Length',] + ([ 'Task' ] if 'task' in df.columns else []))
    return df

def format_df_deltas_for_latex(df: pd.DataFrame, model_start: Optional[str] = None) -> str:
    latex = format_df_deltas(df, model_start).to_latex(index=False, escape=False)
    for k, v in LABELING_FUNCTION_2_PAPER_NAME.items():
        latex = latex.replace(k, v)
    latex = latex.replace('False', '').replace('True', '\checkmark')
    return latex

### Model-Level CIs

In [30]:
# Calculate model-level delta CIs
baseline: str = 'clmbr'

df_model_deltas = []
for model in valid_models:
    # We care about the actual delta between model and baseline
    aurocs = np.zeros(len(bootstrap_weights)) 
    for task in valid_tasks:
        aurocs += baseline_auroc_deltas[baseline][(model, task)]
    aurocs /= len(valid_tasks)
    
    # We only care about BINARY of whether model BEATS the baseline
    win_aurocs = np.zeros(len(bootstrap_weights))
    for task in valid_tasks:
        win_aurocs += np.sign(baseline_auroc_deltas[baseline][(model, task)])
    win_aurocs /= len(valid_tasks)

    df_model_deltas.append({
        'model' : model,
        'win_mean' : np.mean(win_aurocs),
        'win_ci_025' : np.percentile(win_aurocs, 2.5),
        'win_ci_500' : np.percentile(win_aurocs, 50),
        'win_ci_975' : np.percentile(win_aurocs, 97.5),
        'auroc_delta_mean' : np.mean(aurocs),
        'auroc_delta_ci_025' : np.percentile(aurocs, 2.5),
        'auroc_delta_ci_500' : np.percentile(aurocs, 50),
        'auroc_delta_ci_975' : np.percentile(aurocs, 97.5),
    })
df_model_deltas = pd.DataFrame(df_model_deltas)
df_model_deltas = clean_df_deltas(df_model_deltas)
df_model_deltas

Unnamed: 0,model,model_name,model_base,ctx_length,win_mean,win_ci_025,win_ci_500,win_ci_975,auroc_delta_mean,auroc_delta_ci_025,auroc_delta_ci_500,auroc_delta_ci_975,formatted_auroc_delta_mean,formatted_auroc_delta,formatted_auroc_ci,is_auroc_win,is_stat_sig
16,clmbr,clmbr,clmbr,0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,"0.000 (0.000, 0.000)","(0.000, 0.000)",False,False
4,gpt2-base-512--clmbr_train-tokens-total_nonPAD...,gpt2-base-512,gpt2,512,0.196143,-0.142857,0.142857,0.571429,0.017038,0.005149,0.017045,0.028389,0.017,"0.017 (0.005, 0.028)","(0.005, 0.028)",True,True
5,gpt2-base-1024--clmbr_train-tokens-total_nonPA...,gpt2-base-1024,gpt2,1024,0.005714,-0.285714,0.0,0.285714,0.008263,-0.002871,0.007939,0.021149,0.008,"0.008 (-0.003, 0.021)","(-0.003, 0.021)",False,False
6,gpt2-base-2048--clmbr_train-tokens-total_nonPA...,gpt2-base-2048,gpt2,2048,0.2165,-0.142857,0.285714,0.571429,0.024723,0.008462,0.02503,0.041175,0.025,"0.025 (0.008, 0.041)","(0.008, 0.041)",True,True
7,gpt2-base-4096--clmbr_train-tokens-total_nonPA...,gpt2-base-4096,gpt2,4096,0.423929,0.0,0.428571,0.714286,0.019307,0.007289,0.019119,0.03244,0.019,"0.019 (0.007, 0.032)","(0.007, 0.032)",True,True
8,hyena-large-1024--clmbr_train-tokens-total_non...,hyena-large-1024,hyena,1024,0.1315,-0.142857,0.142857,0.428571,0.019351,0.006742,0.019035,0.033037,0.019,"0.019 (0.007, 0.033)","(0.007, 0.033)",True,True
9,hyena-large-4096--clmbr_train-tokens-total_non...,hyena-large-4096,hyena,4096,0.112429,-0.142857,0.142857,0.428571,0.021339,0.007195,0.021063,0.035929,0.021,"0.021 (0.007, 0.036)","(0.007, 0.036)",True,True
10,hyena-large-8192--clmbr_train-tokens-total_non...,hyena-large-8192,hyena,8192,-0.266714,-0.428571,-0.285714,0.0,-0.008174,-0.024471,-0.008385,0.009969,-0.008,"-0.008 (-0.024, 0.010)","(-0.024, 0.010)",False,False
11,hyena-large-16384--clmbr_train-tokens-total_no...,hyena-large-16384,hyena,16384,-0.571929,-0.714286,-0.571429,-0.428571,-0.042903,-0.059552,-0.042844,-0.025433,-0.043,"-0.043 (-0.060, -0.025)","(-0.060, -0.025)",False,True
0,llama-base-512--clmbr_train-tokens-total_nonPA...,llama-base-512,llama,512,-0.054214,-0.428571,0.0,0.285714,0.013981,0.001354,0.013903,0.026618,0.014,"0.014 (0.001, 0.027)","(0.001, 0.027)",True,True


In [44]:
format_df_deltas(df_model_deltas)

Unnamed: 0,Model,Context Length,win_mean,win_ci_025,win_ci_500,win_ci_975,$\Delta$ over baseline,95% CI,Statistically Significant
16,clmbr,0,0.0,0.0,0.0,0.0,0.0,"(0.000, 0.000)",False
4,gpt2,512,0.196143,-0.142857,0.142857,0.571429,0.017,"(0.005, 0.028)",True
5,gpt2,1024,0.005714,-0.285714,0.0,0.285714,0.008,"(-0.003, 0.021)",False
6,gpt2,2048,0.2165,-0.142857,0.285714,0.571429,0.025,"(0.008, 0.041)",True
7,gpt2,4096,0.423929,0.0,0.428571,0.714286,0.019,"(0.007, 0.032)",True
8,hyena,1024,0.1315,-0.142857,0.142857,0.428571,0.019,"(0.007, 0.033)",True
9,hyena,4096,0.112429,-0.142857,0.142857,0.428571,0.021,"(0.007, 0.036)",True
10,hyena,8192,-0.266714,-0.428571,-0.285714,0.0,-0.008,"(-0.024, 0.010)",False
11,hyena,16384,-0.571929,-0.714286,-0.571429,-0.428571,-0.043,"(-0.060, -0.025)",True
0,llama,512,-0.054214,-0.428571,0.0,0.285714,0.014,"(0.001, 0.027)",True


### Task-Level CIs

In [45]:
# Calculate task-level delta CIs
baseline: str = 'clmbr'

df_task_deltas = []
for model in valid_models:
    for task in valid_tasks:
        aurocs = baseline_auroc_deltas[baseline][(model, task)]
        wins = np.sign(baseline_auroc_deltas[baseline][(model, task)])
        df_task_deltas.append({
            'model' : model,
            'task' : task,
            'auroc_delta_mean' : np.mean(aurocs),
            'auroc_delta_ci_025' : np.percentile(aurocs, 2.5),
            'auroc_delta_ci_500' : np.percentile(aurocs, 50),
            'auroc_delta_ci_975' : np.percentile(aurocs, 97.5),
        })
df_task_deltas = pd.DataFrame(df_task_deltas).sort_values(['model', 'task'])
df_task_deltas = clean_df_deltas(df_task_deltas)
df_task_deltas

Unnamed: 0,model,model_name,model_base,ctx_length,task,auroc_delta_mean,auroc_delta_ci_025,auroc_delta_ci_500,auroc_delta_ci_975,formatted_auroc_delta_mean,formatted_auroc_delta,formatted_auroc_ci,is_auroc_win,is_stat_sig
237,clmbr,clmbr,clmbr,0,guo_icu,0.000000,0.000000,0.000000,0.000000,0.000,"0.000 (0.000, 0.000)","(0.000, 0.000)",False,False
225,clmbr,clmbr,clmbr,0,guo_los,0.000000,0.000000,0.000000,0.000000,0.000,"0.000 (0.000, 0.000)","(0.000, 0.000)",False,False
232,clmbr,clmbr,clmbr,0,guo_readmission,0.000000,0.000000,0.000000,0.000000,0.000,"0.000 (0.000, 0.000)","(0.000, 0.000)",False,False
230,clmbr,clmbr,clmbr,0,lab_anemia,0.000000,0.000000,0.000000,0.000000,0.000,"0.000 (0.000, 0.000)","(0.000, 0.000)",False,False
236,clmbr,clmbr,clmbr,0,lab_hyperkalemia,0.000000,0.000000,0.000000,0.000000,0.000,"0.000 (0.000, 0.000)","(0.000, 0.000)",False,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
221,mamba-tiny-16384--clmbr_train-tokens-total_non...,mamba-tiny-16384,mamba,16384,new_celiac,0.193925,0.108023,0.184678,0.333350,0.194,"0.194 (0.108, 0.333)","(0.108, 0.333)",True,True
220,mamba-tiny-16384--clmbr_train-tokens-total_non...,mamba-tiny-16384,mamba,16384,new_hyperlipidemia,0.022537,-0.012559,0.022141,0.057557,0.023,"0.023 (-0.013, 0.058)","(-0.013, 0.058)",False,False
210,mamba-tiny-16384--clmbr_train-tokens-total_non...,mamba-tiny-16384,mamba,16384,new_hypertension,0.003078,-0.017625,0.002916,0.022696,0.003,"0.003 (-0.018, 0.023)","(-0.018, 0.023)",False,False
213,mamba-tiny-16384--clmbr_train-tokens-total_non...,mamba-tiny-16384,mamba,16384,new_lupus,0.036555,-0.056182,0.035704,0.132338,0.037,"0.037 (-0.056, 0.132)","(-0.056, 0.132)",False,False


In [46]:
format_df_deltas(df_task_deltas, model_start='mamba-tiny-16384')

Unnamed: 0,Model,Context Length,Task,$\Delta$ over baseline,95% CI,Statistically Significant
223,mamba,16384,guo_icu,0.007,"(-0.028, 0.040)",False
211,mamba,16384,guo_los,0.013,"(-0.005, 0.029)",False
218,mamba,16384,guo_readmission,0.005,"(-0.008, 0.017)",False
216,mamba,16384,lab_anemia,0.002,"(0.001, 0.003)",True
222,mamba,16384,lab_hyperkalemia,0.03,"(0.019, 0.042)",True
212,mamba,16384,lab_hypoglycemia,0.006,"(-0.006, 0.019)",False
214,mamba,16384,lab_hyponatremia,0.07,"(0.061, 0.079)",True
219,mamba,16384,lab_thrombocytopenia,0.008,"(0.004, 0.013)",True
217,mamba,16384,new_acutemi,0.016,"(-0.005, 0.036)",False
221,mamba,16384,new_celiac,0.194,"(0.108, 0.333)",True


### LaTeX Results

In [34]:
# Model-level delta CIs for LaTeX
latex: str = format_df_deltas_for_latex(df_model_deltas)
print(latex)

\begin{tabular}{lrrrrrllr}
\toprule
Model & Context Length & win_mean & win_ci_025 & win_ci_500 & win_ci_975 & $\Delta$ over baseline & 95% CI & Statistically Significant \\
\midrule
clmbr & 0 & 0.000000 & 0.000000 & 0.000000 & 0.000000 & 0.000 & (0.000, 0.000) &  \\
gpt2 & 512 & 0.196143 & -0.142857 & 0.142857 & 0.571429 & 0.017 & (0.005, 0.028) & \checkmark \\
gpt2 & 1024 & 0.005714 & -0.285714 & 0.000000 & 0.285714 & 0.008 & (-0.003, 0.021) &  \\
gpt2 & 2048 & 0.216500 & -0.142857 & 0.285714 & 0.571429 & 0.025 & (0.008, 0.041) & \checkmark \\
gpt2 & 4096 & 0.423929 & 0.000000 & 0.428571 & 0.714286 & 0.019 & (0.007, 0.032) & \checkmark \\
hyena & 1024 & 0.131500 & -0.142857 & 0.142857 & 0.428571 & 0.019 & (0.007, 0.033) & \checkmark \\
hyena & 4096 & 0.112429 & -0.142857 & 0.142857 & 0.428571 & 0.021 & (0.007, 0.036) & \checkmark \\
hyena & 8192 & -0.266714 & -0.428571 & -0.285714 & 0.000000 & -0.008 & (-0.024, 0.010) &  \\
hyena & 16384 & -0.571929 & -0.714286 & -0.571429 & -0.42857

In [35]:
# Task-level delta CIs for LaTeX
latex: str = format_df_deltas_for_latex(df_task_deltas)
print(latex)

\begin{tabular}{lrlllr}
\toprule
Model & Context Length & Task & $\Delta$ over baseline & 95% CI & Statistically Significant \\
\midrule
clmbr & 0 & ICU Admission & 0.000 & (0.000, 0.000) &  \\
clmbr & 0 & Long LOS & 0.000 & (0.000, 0.000) &  \\
clmbr & 0 & 30-day Readmission & 0.000 & (0.000, 0.000) &  \\
clmbr & 0 & Anemia & 0.000 & (0.000, 0.000) &  \\
clmbr & 0 & Hyperkalemia & 0.000 & (0.000, 0.000) &  \\
clmbr & 0 & Hypoglycemia & 0.000 & (0.000, 0.000) &  \\
clmbr & 0 & Hyponatremia & 0.000 & (0.000, 0.000) &  \\
clmbr & 0 & Thrombocytopenia & 0.000 & (0.000, 0.000) &  \\
clmbr & 0 & Acute MI & 0.000 & (0.000, 0.000) &  \\
clmbr & 0 & Celiac & 0.000 & (0.000, 0.000) &  \\
clmbr & 0 & Hyperlipidemia & 0.000 & (0.000, 0.000) &  \\
clmbr & 0 & Hypertension & 0.000 & (0.000, 0.000) &  \\
clmbr & 0 & Lupus & 0.000 & (0.000, 0.000) &  \\
clmbr & 0 & Pancreatic Cancer & 0.000 & (0.000, 0.000) &  \\
gpt2 & 512 & ICU Admission & 0.022 & (-0.005, 0.050) &  \\
gpt2 & 512 & Long LOS & -0.00

# Raw Win Rates of Models

Count how many times a model beats CLMBR / best baseline. No bootstrapping or CIs.

In [36]:
df_ctx_orig = pd.read_csv('../cache/results_nonchexpert.csv')
df_chexpert_orig = pd.read_csv('../cache/results_chexpert.csv')

def preprocess_df(df, is_keep_max_tokens=False, drop_tasks=[]):
    assert df.drop_duplicates(subset=['model', 'task_name']).shape == df.shape, "Duplicate rows found"
    df['Baseline_AUROC'] = df[['Logistic_AUROC', 'GBM_AUROC', 'Random_Forest_AUROC']].max(axis=1)
    df['delta_clmbr'] = df['value_mean'] - df['CLMBR_AUROC']
    df['delta_baseline'] = df['value_mean'] - df['Baseline_AUROC']
    df['is_max_tokens'] = df['model'].apply(lambda x : x.endswith('max'))
    df['model_name'] = df['model'].apply(lambda x : x.replace('-max', '').replace("-" + x.replace('-max', '').split("-")[-1], "").strip())
    df['ctx_length'] = df['model'].apply(lambda x : int(x.replace('-max', '').split("-")[-1]))
    
    
    # Best GPT
    for task in df['task_name'].unique():
        best_ = df[(df['task_name'] == task) & (df['model'].str.contains('gpt2-base'))]['value_mean'].max()
        df.loc[df['task_name'] == task, 'BestGPT2_AUROC'] = best_
    for task in df['task_name'].unique():
        best_ = df[(df['task_name'] == task) & (df['model'].str.contains('mamba-tiny'))]['value_mean'].max()
        df.loc[df['task_name'] == task, 'BestMamba_AUROC'] = best_
    for task in df['task_name'].unique():
        best_ = df[(df['task_name'] == task) & (df['model'].str.contains('llama-base'))]['value_mean'].max()
        df.loc[df['task_name'] == task, 'BestLlama_AUROC'] = best_
    for task in df['task_name'].unique():
        best_ = df[(df['task_name'] == task) & (df['model'].str.contains('hyena-large'))]['value_mean'].max()
        df.loc[df['task_name'] == task, 'BestHyena_AUROC'] = best_
    df['delta_bestGPT2'] = df['value_mean'] - df['BestGPT2_AUROC']
    df['delta_bestMamba'] = df['value_mean'] - df['BestMamba_AUROC']
    df['delta_bestLlama'] = df['value_mean'] - df['BestLlama_AUROC']
    df['delta_bestHyena'] = df['value_mean'] - df['BestHyena_AUROC']
    
    # Count win rates v. CLMBR / Baseline / Models
    df['is_beat_clmbr'] = df['delta_clmbr'] > 0
    df['is_beat_baseline'] = df['delta_baseline'] > 0
    df['is_beat_bestGPT2'] = df['delta_bestGPT2'] > 0
    df['is_beat_bestMamba'] = df['delta_bestMamba'] > 0
    df['is_beat_bestLlama'] = df['delta_bestLlama'] > 0
    df['is_beat_bestHyena'] = df['delta_bestHyena'] > 0
    
    # Filtering
    ## Only keep 2B models
    if not is_keep_max_tokens:
        df = df[~df['is_max_tokens']]
    ## Drop tasks (usually Lupus and Celiac b/c so low n)
    if len(drop_tasks) > 0:
        df = df[~df['task_name'].isin(drop_tasks)]
    
    # Sanity check baseline / CLMBR numbers
    for task in df['task_name'].unique():
        df_ = df[df['task_name'] == task]
        assert all(x == df_['Logistic_AUROC'].tolist()[0] for x in df_['Logistic_AUROC'].tolist()), "Logistic_AUROC is not uniform"
        assert all(x == df_['GBM_AUROC'].tolist()[0] for x in df_['GBM_AUROC'].tolist()), "GBM_AUROC is not uniform"
        assert all(x == df_['Random_Forest_AUROC'].tolist()[0] for x in df_['Random_Forest_AUROC'].tolist()), "Random_Forest_AUROC is not uniform"
        assert all(x == df_['Baseline_AUROC'].tolist()[0] for x in df_['Baseline_AUROC'].tolist()), "Baseline Model AUROC is not uniform"
        assert all(x == df_['CLMBR_AUROC'].tolist()[0] for x in df_['CLMBR_AUROC'].tolist()), "CLMBR AUROC is not uniform"
    return df

df_ctx = preprocess_df(df_ctx_orig, drop_tasks=['Lupus', 'Celiac'])
df_chexpert = preprocess_df(df_chexpert_orig, drop_tasks=['Lupus', 'Celiac'])

In [37]:
# NOTE: We're missing gpt2-base-1024 for some tasks...
set(df_ctx['task_name'].unique())- set(df_ctx[df_ctx['model'] == 'gpt2-base-1024']['task_name'].tolist()) 

{'Anemia', 'Hypoglycemia', 'Hyponatremia'}

In [38]:
1+1

2

### Save Results Tables

In [136]:
df_ctx['delta_clmbr']
print(df_ctx[df_ctx['task_name'] == '30-Day Readmission'][['model_name', 'ctx_length', 'delta_clmbr', 'delta_baseline']].round({'delta_clmbr': 4, 'delta_baseline': 4}).sort_values(['model_name', 'ctx_length'], ascending=[True, True]).to_markdown(tablefmt="grid", index=False))

+--------------+--------------+---------------+------------------+
| model_name   |   ctx_length |   delta_clmbr |   delta_baseline |
| gpt2-base    |          512 |       -0.0016 |           0.0334 |
+--------------+--------------+---------------+------------------+
| gpt2-base    |         1024 |        0.004  |           0.039  |
+--------------+--------------+---------------+------------------+
| gpt2-base    |         2048 |        0.0016 |           0.0366 |
+--------------+--------------+---------------+------------------+
| gpt2-base    |         4096 |        0.0038 |           0.0388 |
+--------------+--------------+---------------+------------------+
| hyena-large  |         1024 |       -0.0007 |           0.0343 |
+--------------+--------------+---------------+------------------+
| hyena-large  |         4096 |        0.0018 |           0.0368 |
+--------------+--------------+---------------+------------------+
| hyena-large  |         8192 |       -0.0163 |           0.01

In [164]:
# Non-Chexpert
os.makedirs('../cache/tables_ehrshot/chexpert/', exist_ok=True)
for task in df_ctx['task_name'].unique():
    df_task_ = df_ctx[df_ctx['task_name'] == task][['model_name', 'ctx_length', 'delta_clmbr', 'delta_baseline']].round({'delta_clmbr': 4, 'delta_baseline': 4})
    df_delta_ordered = df_task_.sort_values(['delta_clmbr'], ascending=False)
    df_model_ordered = df_task_.sort_values(['model_name', 'ctx_length'], ascending=[True, False])
    with open(f'../cache/tables_ehrshot/{task}.md', 'w') as fd:
        fd.write(f"## {task}\n\n")
        fd.write("### Delta Ordered\n")
        fd.write(df_delta_ordered.to_markdown(tablefmt="grid", index=False))
        fd.write("\n\n### Model Ordered\n")
        fd.write(df_model_ordered.to_markdown(tablefmt="grid", index=False))
# Chexpert
for task in df_chexpert['task_name'].unique():
    df_task_ = df_chexpert[df_chexpert['task_name'] == task][['model_name', 'ctx_length', 'delta_clmbr', 'delta_baseline']].round({'delta_clmbr': 4, 'delta_baseline': 4})
    df_delta_ordered = df_task_.sort_values(['delta_clmbr'], ascending=False)
    df_model_ordered = df_task_.sort_values(['model_name', 'ctx_length'], ascending=[True, False])
    with open(f'../cache/tables_ehrshot/chexpert/{task}.md', 'w') as fd:
        fd.write(f"## {task}\n\n")
        fd.write("### Delta Ordered\n")
        fd.write(df_delta_ordered.to_markdown(tablefmt="grid", index=False))
        fd.write("\n\n### Model Ordered\n")
        fd.write(df_model_ordered.to_markdown(tablefmt="grid", index=False))

### Win Counts v. Baseline

Count how many times a model beats CLMBR / best baseline. No bootstrapping or CIs.

In [139]:
def count_win_rates(df):
    # Count # of times a model beats CLMBR
    df_clmbr_beats = df.groupby('model')['is_beat_clmbr'].sum().reset_index().sort_values('is_beat_clmbr', ascending=False)
    df_mean_clmbr_delta = df.groupby('model')['delta_clmbr'].mean().reset_index().sort_values('delta_clmbr', ascending=False)

    # Count # of times a model beats best baseline (max(LR, RF, XGB))
    df_baseline_beats = df.groupby('model')['is_beat_baseline'].sum().reset_index().sort_values('is_beat_baseline', ascending=False)
    df_mean_baseline_delta = df.groupby('model')['delta_baseline'].mean().reset_index().sort_values('delta_baseline', ascending=False)

    # Count # of times a model beats best GPT2-base
    df_gpt2_beats = df.groupby('model')['is_beat_bestGPT2'].sum().reset_index().sort_values('is_beat_bestGPT2', ascending=False)
    df_mean_gpt2_delta = df.groupby('model')['delta_bestGPT2'].mean().reset_index().sort_values('delta_bestGPT2', ascending=False)

    # Count # of times a model beats best Llama-base
    df_llama_beats = df.groupby('model')['is_beat_bestLlama'].sum().reset_index().sort_values('is_beat_bestLlama', ascending=False)
    df_mean_llama_delta = df.groupby('model')['delta_bestLlama'].mean().reset_index().sort_values('delta_bestLlama', ascending=False)
    return {
        'clmbr_beats' : df_clmbr_beats,
        'mean_clmbr_delta' : df_mean_clmbr_delta,
        'baseline_beats' : df_baseline_beats,
        'mean_baseline_delta' : df_mean_baseline_delta,
        'gpt2_beats' : df_gpt2_beats,
        'mean_gpt2_delta' : df_mean_gpt2_delta,
        'llama_beats' : df_llama_beats,
        'mean_llama_delta' : df_mean_llama_delta,
    }

#### Non-Chexpert tasks

In [140]:
df_ctx_win_rates = count_win_rates(df_ctx)

Win rate v. CLMBR

In [141]:
print("# of tasks that model achieves higher AUROC than CLMBR (higher is better)")
df_ctx_win_rates['clmbr_beats']

# of tasks that model achieves higher AUROC than CLMBR (higher is better)


Unnamed: 0,model,is_beat_clmbr
13,mamba-tiny-16384,12
14,mamba-tiny-4096,11
2,gpt2-base-4096,10
9,llama-base-2048,9
15,mamba-tiny-8192,9
10,llama-base-4096,8
12,mamba-tiny-1024,8
6,hyena-large-4096,7
1,gpt2-base-2048,6
3,gpt2-base-512,6


In [142]:
print("Mean AUROC diff between model and CLMBR (higher is better)")
df_ctx_win_rates['mean_clmbr_delta']

Mean AUROC diff between model and CLMBR (higher is better)


Unnamed: 0,model,delta_clmbr
13,mamba-tiny-16384,0.018845
14,mamba-tiny-4096,0.014324
15,mamba-tiny-8192,0.011398
2,gpt2-base-4096,0.009947
9,llama-base-2048,0.008483
10,llama-base-4096,0.007948
6,hyena-large-4096,0.007604
3,gpt2-base-512,0.006575
4,hyena-large-1024,0.005338
12,mamba-tiny-1024,0.004107


Win rate v. GPT2

In [143]:
df_ctx[df_ctx['model'] == 'mamba-tiny-16384'][['task_name', 'delta_bestGPT2']]

Unnamed: 0,task_name,delta_bestGPT2
5,30-Day Readmission,0.00072
19,Acute MI,0.010173
32,Anemia,0.005698
63,Hyperkalemia,0.007539
81,Hyperlipidemia,0.010968
97,Hypertension,-0.001608
113,Hypoglycemia,0.004088
126,Hyponatremia,0.02435
143,ICU Prediction,-0.014959
158,Long LOS,0.013017


In [144]:
print("# of tasks that model achieves higher AUROC than best GPT2-base (higher is better)")
df_ctx_win_rates['gpt2_beats']

# of tasks that model achieves higher AUROC than best GPT2-base (higher is better)


Unnamed: 0,model,is_beat_bestGPT2
13,mamba-tiny-16384,9
14,mamba-tiny-4096,8
10,llama-base-4096,6
9,llama-base-2048,6
15,mamba-tiny-8192,6
6,hyena-large-4096,4
12,mamba-tiny-1024,4
4,hyena-large-1024,4
11,llama-base-512,2
7,hyena-large-8192,2


In [145]:
print("Mean AUROC diff between model and best GPT2 (higher is better)")
df_ctx_win_rates['mean_gpt2_delta']

Mean AUROC diff between model and best GPT2 (higher is better)


Unnamed: 0,model,delta_bestGPT2
13,mamba-tiny-16384,0.006164
14,mamba-tiny-4096,0.001642
15,mamba-tiny-8192,-0.001284
2,gpt2-base-4096,-0.002734
9,llama-base-2048,-0.004198
10,llama-base-4096,-0.004733
6,hyena-large-4096,-0.005077
3,gpt2-base-512,-0.006106
4,hyena-large-1024,-0.007343
12,mamba-tiny-1024,-0.008574


Win rate v. Best Llama

In [146]:
df_ctx[df_ctx['model'] == 'mamba-tiny-16384'][['task_name', 'delta_bestLlama']]

Unnamed: 0,task_name,delta_bestLlama
5,30-Day Readmission,-0.008647
19,Acute MI,-0.007046
32,Anemia,0.001165
63,Hyperkalemia,0.005238
81,Hyperlipidemia,0.001584
97,Hypertension,-0.000847
113,Hypoglycemia,-0.005177
126,Hyponatremia,0.033643
143,ICU Prediction,0.001561
158,Long LOS,-0.001396


In [147]:
print("# of tasks that model achieves higher AUROC than best Llama-base (higher is better)")
df_ctx_win_rates['llama_beats']

# of tasks that model achieves higher AUROC than best Llama-base (higher is better)


Unnamed: 0,model,is_beat_bestLlama
13,mamba-tiny-16384,7
2,gpt2-base-4096,5
6,hyena-large-4096,4
3,gpt2-base-512,4
14,mamba-tiny-4096,4
15,mamba-tiny-8192,4
4,hyena-large-1024,4
7,hyena-large-8192,2
0,gpt2-base-1024,2
1,gpt2-base-2048,1


In [148]:
print("Mean AUROC diff between model and best Llama (higher is better)")
df_ctx_win_rates['mean_llama_delta']

Mean AUROC diff between model and best Llama (higher is better)


Unnamed: 0,model,delta_bestLlama
13,mamba-tiny-16384,0.006305
14,mamba-tiny-4096,0.001784
15,mamba-tiny-8192,-0.001142
2,gpt2-base-4096,-0.002593
9,llama-base-2048,-0.004056
10,llama-base-4096,-0.004592
6,hyena-large-4096,-0.004936
3,gpt2-base-512,-0.005965
4,hyena-large-1024,-0.007202
12,mamba-tiny-1024,-0.008433


Win rate v. Best Baseline

In [149]:
print("# of tasks that model achieves higher AUROC than best baseline (higher is better)")
df_ctx_win_rates['baseline_beats']

# of tasks that model achieves higher AUROC than best baseline (higher is better)


Unnamed: 0,model,is_beat_baseline
6,hyena-large-4096,10
15,mamba-tiny-8192,10
13,mamba-tiny-16384,10
14,mamba-tiny-4096,10
12,mamba-tiny-1024,10
10,llama-base-4096,10
9,llama-base-2048,10
8,llama-base-1024,10
2,gpt2-base-4096,9
3,gpt2-base-512,9


In [150]:
print("Mean AUROC diff between model and best baseline (higher is better)")
df_ctx_win_rates['mean_baseline_delta']

Mean AUROC diff between model and best baseline (higher is better)


Unnamed: 0,model,delta_baseline
13,mamba-tiny-16384,0.063262
14,mamba-tiny-4096,0.05874
15,mamba-tiny-8192,0.055814
2,gpt2-base-4096,0.054364
9,llama-base-2048,0.0529
10,llama-base-4096,0.052365
6,hyena-large-4096,0.052021
3,gpt2-base-512,0.050992
4,hyena-large-1024,0.049755
12,mamba-tiny-1024,0.048524


Best model win rate v. best other model

In [151]:
df_ctx

Unnamed: 0,task_name,model,value_mean,lower_bound,upper_bound,Logistic_AUROC,GBM_AUROC,Random_Forest_AUROC,CLMBR_AUROC,Baseline_AUROC,...,delta_bestGPT2,delta_bestMamba,delta_bestLlama,delta_bestHyena,is_beat_clmbr,is_beat_baseline,is_beat_bestGPT2,is_beat_bestMamba,is_beat_bestLlama,is_beat_bestHyena
0,30-Day Readmission,llama-base-4096,0.823392,0.788580,0.852653,0.751,0.741,0.775,0.810,0.775,...,0.009367,0.007630,0.000000,0.009046,True,True,True,True,False,True
1,30-Day Readmission,llama-base-2048,0.820290,0.787215,0.849638,0.751,0.741,0.775,0.810,0.775,...,0.006265,0.004528,-0.003102,0.005944,True,True,True,True,False,True
2,30-Day Readmission,llama-base-512,0.817426,0.782813,0.846699,0.751,0.741,0.775,0.810,0.775,...,0.003402,0.001665,-0.005966,0.003081,True,True,True,True,False,True
3,30-Day Readmission,llama-base-1024,0.816984,0.781272,0.848693,0.751,0.741,0.775,0.810,0.775,...,0.002959,0.001222,-0.006408,0.002638,True,True,True,True,False,True
4,30-Day Readmission,mamba-tiny-4096,0.815761,0.779548,0.847507,0.751,0.741,0.775,0.810,0.775,...,0.001737,0.000000,-0.007630,0.001416,True,True,True,False,False,True
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
216,Thrombocytopenia,llama-base-4096,0.851733,0.834264,0.867966,0.753,0.815,0.811,0.852,0.815,...,-0.021453,-0.008391,0.000000,-0.018525,False,True,False,False,False,False
217,Thrombocytopenia,llama-base-2048,0.851180,0.833910,0.867854,0.753,0.815,0.811,0.852,0.815,...,-0.022006,-0.008943,-0.000552,-0.019078,False,True,False,False,False,False
218,Thrombocytopenia,llama-base-1024,0.848140,0.831247,0.864922,0.753,0.815,0.811,0.852,0.815,...,-0.025046,-0.011984,-0.003593,-0.022118,False,True,False,False,False,False
219,Thrombocytopenia,llama-base-512,0.846590,0.829281,0.863533,0.753,0.815,0.811,0.852,0.815,...,-0.026595,-0.013533,-0.005142,-0.023668,False,True,False,False,False,False


In [152]:
base_models = ['gpt2-base', 'hyena-large', 'llama-base', 'mamba-tiny']
df_base_model_max = df_ctx.groupby(['model_name', 'task_name']).agg({ 'value_mean': 'max' }).reset_index().pivot(index='model_name', columns='task_name', values='value_mean').T.reset_index()
for model in base_models:
    df_base_model_max[f'{model}_rank'] = df_base_model_max[base_models].rank(axis=1, method='min', ascending=False)[model].astype(int)
df_base_model_max

model_name,task_name,gpt2-base,hyena-large,llama-base,mamba-tiny,gpt2-base_rank,hyena-large_rank,llama-base_rank,mamba-tiny_rank
0,30-Day Readmission,0.814025,0.811818,0.823392,0.815761,3,4,1,2
1,Acute MI,0.735222,0.742005,0.752441,0.746761,4,3,1,2
2,Anemia,0.959098,0.959973,0.963632,0.964796,4,3,2,1
3,Hyperkalemia,0.814182,0.817473,0.816482,0.821721,4,2,3,1
4,Hyperlipidemia,0.678603,0.689662,0.687986,0.697883,4,2,3,1
5,Hypertension,0.721966,0.694439,0.721205,0.720358,1,4,2,3
6,Hypoglycemia,0.796039,0.790018,0.805305,0.800128,3,4,1,2
7,Hyponatremia,0.805744,0.826314,0.79645,0.830094,3,2,4,1
8,ICU Prediction,0.868563,0.82239,0.852043,0.853604,1,4,3,2
9,Long LOS,0.813097,0.808156,0.82751,0.826115,3,4,1,2


In [153]:
for model in base_models:
    print(f"Mean rank of {model}:", df_base_model_max[model + '_rank'].mean())

Mean rank of gpt2-base: 2.8333333333333335
Mean rank of hyena-large: 3.0
Mean rank of llama-base: 2.4166666666666665
Mean rank of mamba-tiny: 1.75


#### Chexpert tasks

In [154]:
df_chexpert_win_rates = count_win_rates(df_chexpert)

Win rate v. CLMBR

In [155]:
print("# of tasks that model achieves higher AUROC than CLMBR (higher is better)")
df_chexpert_win_rates['clmbr_beats']

# of tasks that model achieves higher AUROC than CLMBR (higher is better)


Unnamed: 0,model,is_beat_clmbr
1,gpt2-base-2048,13
3,gpt2-base-512,13
6,hyena-large-4096,13
11,llama-base-512,13
13,mamba-tiny-16384,13
14,mamba-tiny-4096,13
2,gpt2-base-4096,12
0,gpt2-base-1024,12
12,mamba-tiny-1024,12
15,mamba-tiny-8192,12


In [156]:
print("Mean AUROC diff between model and CLMBR (higher is better)")
df_chexpert_win_rates['mean_clmbr_delta']

Mean AUROC diff between model and CLMBR (higher is better)


Unnamed: 0,model,delta_clmbr
1,gpt2-base-2048,0.017374
14,mamba-tiny-4096,0.015817
15,mamba-tiny-8192,0.015726
13,mamba-tiny-16384,0.015642
3,gpt2-base-512,0.013601
2,gpt2-base-4096,0.013119
12,mamba-tiny-1024,0.012177
9,llama-base-2048,0.011759
10,llama-base-4096,0.010881
6,hyena-large-4096,0.010164


Win rate v. GPT2

In [159]:
df_chexpert[df_chexpert['model'] == 'mamba-tiny-16384'][['task_name', 'delta_bestGPT2']]

Unnamed: 0,task_name,delta_bestGPT2
6,Atelectasis,-0.005638
16,Cardiomegaly,0.001889
34,Consolidation,-0.001827
48,Edema,0.006886
73,Enlarged Cardiomediastinum,-0.028122
93,Fracture,-0.069331
99,Lung Lesion,-0.001337
115,Lung Opacity,0.002343
135,No Finding,-0.006456
147,Pleural Effusion,0.001405


In [160]:
print("# of tasks that model achieves higher AUROC than best GPT2-base (higher is better)")
df_chexpert_win_rates['gpt2_beats']

# of tasks that model achieves higher AUROC than best GPT2-base (higher is better)


Unnamed: 0,model,is_beat_bestGPT2
14,mamba-tiny-4096,8
13,mamba-tiny-16384,7
15,mamba-tiny-8192,7
6,hyena-large-4096,6
4,hyena-large-1024,4
12,mamba-tiny-1024,3
7,hyena-large-8192,2
10,llama-base-4096,2
9,llama-base-2048,2
8,llama-base-1024,1


In [161]:
print("Mean AUROC diff between model and best GPT2 (higher is better)")
df_chexpert_win_rates['mean_gpt2_delta']

Mean AUROC diff between model and best GPT2 (higher is better)


Unnamed: 0,model,delta_bestGPT2
1,gpt2-base-2048,-0.005717
14,mamba-tiny-4096,-0.007273
15,mamba-tiny-8192,-0.007364
13,mamba-tiny-16384,-0.007449
3,gpt2-base-512,-0.009489
2,gpt2-base-4096,-0.009972
12,mamba-tiny-1024,-0.010913
9,llama-base-2048,-0.011332
10,llama-base-4096,-0.012209
6,hyena-large-4096,-0.012926


Win rate v. Llama

In [162]:
print("# of tasks that model achieves higher AUROC than best llama-base (higher is better)")
df_chexpert_win_rates['llama_beats']

# of tasks that model achieves higher AUROC than best llama-base (higher is better)


Unnamed: 0,model,is_beat_bestLlama
15,mamba-tiny-8192,10
14,mamba-tiny-4096,8
13,mamba-tiny-16384,7
12,mamba-tiny-1024,7
1,gpt2-base-2048,6
2,gpt2-base-4096,5
0,gpt2-base-1024,3
6,hyena-large-4096,3
4,hyena-large-1024,2
7,hyena-large-8192,2


In [163]:
print("Mean AUROC diff between model and best Llama (higher is better)")
df_chexpert_win_rates['mean_llama_delta']

Mean AUROC diff between model and best Llama (higher is better)


Unnamed: 0,model,delta_bestLlama
1,gpt2-base-2048,0.000882
14,mamba-tiny-4096,-0.000675
15,mamba-tiny-8192,-0.000766
13,mamba-tiny-16384,-0.00085
3,gpt2-base-512,-0.00289
2,gpt2-base-4096,-0.003373
12,mamba-tiny-1024,-0.004315
9,llama-base-2048,-0.004733
10,llama-base-4096,-0.00561
6,hyena-large-4096,-0.006328
