In [3]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from tqdm import tqdm
import os
import sklearn.utils
from typing import List, Optional, Dict
from sklearn.metrics import roc_auc_score, brier_score_loss
import collections
import warnings
import pickle
import femr.datasets
import datetime
from hf_ehr.config import EHRSHOT_LABELING_FUNCTION_2_PAPER_NAME
from hf_ehr.notebooks.ehr_specific_properties.utils import (
    get_labels_and_features, 
    get_patient_splits_by_idx
)
import scipy.stats
from femr.labelers import load_labeled_patients, LabeledPatients

warnings.filterwarnings("ignore", category=FutureWarning, module="sklearn.utils.validation")

PATH_TO_DATABASE: str = '/share/pi/nigam/mwornow/ehrshot-benchmark/EHRSHOT_ASSETS/femr/extract'
PATH_TO_FEATURES_DIR: str = '/share/pi/nigam/mwornow/ehrshot-benchmark/EHRSHOT_ASSETS/features_ehrshot'
PATH_TO_RESULTS_DIR: str = '/share/pi/nigam/migufuen/ehrshot-benchmark/EHRSHOT_ASSETS/results_ehrshot'
PATH_TO_TOKENIZED_TIMELINES_DIR: str = '/share/pi/nigam/mwornow/ehrshot-benchmark/EHRSHOT_ASSETS/tokenized_timelines_ehrshot'
PATH_TO_LABELS_DIR: str = '/share/pi/nigam/mwornow/ehrshot-benchmark/EHRSHOT_ASSETS/benchmark_ehrshot'
PATH_TO_SPLIT_CSV: str = '/share/pi/nigam/mwornow/ehrshot-benchmark/EHRSHOT_ASSETS/splits_ehrshot/person_id_map.csv'
femr_db = femr.datasets.PatientDatabase(PATH_TO_DATABASE)
os.makedirs('../cache', exist_ok=True)

IS_LOAD_FROM_CACHE: bool = True


# Load Data

First, we need to load all of the patient-level predictions / labels for every model and task.
This takes ~30 min.

1. Load list of task names, model names
2. Load patient-level predictions for each (model, task)
3. Load patient-level ground truth labels for each (task)

In [4]:
# Get list of tasks
valid_tasks = os.listdir(PATH_TO_RESULTS_DIR)
valid_tasks.remove('chexpert')

# Filter results to only valid models
valid_models = [ 
    'llama-base-512--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last',
    'llama-base-1024--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last',
    'llama-base-2048--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last',
    'llama-base-4096--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last',
    'gpt2-base-512--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last',
    'gpt2-base-1024--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last',
    'gpt2-base-2048--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last',
    'gpt2-base-4096--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last',
    'hyena-large-1024--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last',
    'hyena-large-4096--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last',
    'hyena-large-8192--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last',
    'hyena-large-16384--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last',
    'mamba-tiny-1024--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last',
    'mamba-tiny-4096--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last',
    'mamba-tiny-8192--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last',
    'mamba-tiny-16384--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last', 
    'clmbr',
]
print("Tasks:", valid_tasks)
print("# of valid models:", len(valid_models))

Tasks: ['new_hypertension', 'guo_los', 'lab_hypoglycemia', 'new_lupus', 'lab_hyponatremia', 'new_pancan', 'lab_anemia', 'new_acutemi', 'guo_readmission', 'lab_thrombocytopenia', 'new_hyperlipidemia', 'new_celiac', 'lab_hyperkalemia', 'guo_icu']
# of valid models: 17


In [5]:
# Load patient-level predictions for each (model, task)
predictions = {} # [key] = (model, task), [value] = df_preds

if IS_LOAD_FROM_CACHE and os.path.exists('../cache/predictions.pkl'):
    print("Loading predictions from cache...")
    predictions = pickle.load(open('../cache/predictions.pkl', 'rb'))
else:
    for task in tqdm(valid_tasks):
        for model in valid_models:
            path = os.path.join(PATH_TO_RESULTS_DIR, task, 'models', model, 'lr_lbfgs', f'subtask={task}', 'k=-1', 'preds.csv')
            if not os.path.exists(path):
                print("Missing path for ", model, task)
            df_preds = pd.read_csv(path)
            assert df_preds['replicate'].nunique() == 1, f"Multiple replicates for {model}, {task}"
            predictions[(model, task)] = df_preds
    # Save results to .pkl file
    with open('../cache/predictions.pkl', 'wb') as f:
        pickle.dump(predictions, f)
print("# of (model, task) preds:", len(predictions))

Loading predictions from cache...
# of (model, task) preds: 238


In [6]:
# Load patient-level labels for each task
# NOTE: Takes ~1 hr
label_data = {} # [key] = task, [value] = {'times': label_times, 'values': label_values, 'patient_ids': patient_ids}

if IS_LOAD_FROM_CACHE and os.path.exists('../cache/label_data.pkl'):
    print("Loading label data from cache...")
    label_data = pickle.load(open('../cache/label_data.pkl', 'rb'))
else:
    for task in tqdm(valid_tasks):
        # Load labeled patients for this task
        LABELING_FUNCTION: str = task
        PATH_TO_LABELED_PATIENTS: str =  os.path.join(PATH_TO_LABELS_DIR, LABELING_FUNCTION, 'labeled_patients.csv')
        labeled_patients = femr.labelers.load_labeled_patients(PATH_TO_LABELED_PATIENTS)
        
        # Get features for patients
        model: str = 'gpt2-base-512--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last'
        patient_ids, label_values, label_times, feature_matrixes = get_labels_and_features(labeled_patients, 
                                                                                            PATH_TO_FEATURES_DIR, 
                                                                                            PATH_TO_TOKENIZED_TIMELINES_DIR,
                                                                                            models_to_keep=[model,])
        train_pids_idx, val_pids_idx, test_pids_idx = get_patient_splits_by_idx(PATH_TO_SPLIT_CSV, patient_ids)
        label_times = label_times[train_pids_idx + val_pids_idx + test_pids_idx]
        label_values = label_values[train_pids_idx + val_pids_idx + test_pids_idx]
        patient_ids = patient_ids[train_pids_idx + val_pids_idx + test_pids_idx]
        label_times = [ x.astype(datetime.datetime) for x in label_times ] # cast to Python datetime
        label_data[task] = {'times': label_times, 'values': label_values, 'patient_ids': patient_ids}

    # Save results to .pkl file
    with open('../cache/label_data.pkl', 'wb') as f:
        pickle.dump(label_data, f)

Loading label data from cache...


# Calculate CIs

Calculate bootstrapped 95% CIs over the test set (one sample per patient).

1. Generate 1000 bootstrapped resamples of patient IDs across all splits
2. Loop through every (model, task) preds, limit to test set, loop through 1000 bootstrap weightings, loop through each stratification metric, recalculate **terciles** of patients based on their metrics, and calculate Brier score for each **tercile**
3. Calculate 95% CIs

### Do Bootstrapping of Brier Scores across Terciles

In [7]:
# Helpers
strats = {
    'inter_event_times': [ 'std' ],
    'n_gram_count': ['rr_1'], 
    'timeline_lengths': ['n_events'],
}
path_to_ehrshot_metrics_dir: str = '/share/pi/nigam/mwornow/ehrshot-benchmark/ehrshot/stratify/'

def weighted_quantile(values, quantiles, sample_weight=None, values_sorted=False):
    values = np.array(values)
    quantiles = np.array(quantiles)
    if sample_weight is None:
        sample_weight = np.ones(len(values))
    sample_weight = np.array(sample_weight)
    assert np.all(quantiles >= 0) and np.all(quantiles <= 1), \
        'quantiles should be in [0, 1]'

    if not values_sorted:
        sorter = np.argsort(values)
        values = values[sorter]
        sample_weight = sample_weight[sorter]

    weighted_quantiles = np.cumsum(sample_weight) - 0.5 * sample_weight
    weighted_quantiles /= np.sum(sample_weight)
    return np.interp(quantiles, weighted_quantiles, values)

In [8]:
# Get all patient IDs
all_patients = np.sort(np.unique(np.concatenate([v['patient_ids'] for v in label_data.values()])))
print("# of unique patients:", len(all_patients))

# Patient-level resampling across all patient IDs in train/val/test
# Later, we limit to just test patient IDs per task
bootstrap_weights = []
np.random.seed(342342)
for i in range(1000):
    patient_sample = np.random.choice(list(range(len(all_patients))), len(all_patients), replace=True)
    weights = np.zeros_like(all_patients, dtype=np.float32)
    np.add.at(weights, patient_sample, 1)
    assert np.mean(weights) == 1
    bootstrap_weights.append(weights)

# of unique patients: 6275


In [36]:
# Formatting
def model_to_base(model: str) -> str:
    return model.split('-')[0]
def model_to_ctx(model: str) -> str:
    return int(model.split('--')[0].split('-')[-1]) if model != 'clmbr' else 0
def model_to_name(model: str) -> str:
    return model.split('--')[0]

def clean_df_briers(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()
    df['model_name'] = df['model'].apply(model_to_name)
    df['model_base'] = df['model'].apply(model_to_base)
    df['ctx_length'] = df['model'].apply(model_to_ctx).astype(int)
    if 'brier_true' in df.columns:
        df['formatted_brier_ci_mean'] = df['brier_true'].apply(lambda x: f"{x:.4f}")
        df['formatted_brier'] = df.apply(lambda x: f"{x['brier_true']:.4f} ({x['brier_ci_025']:.4f}, {x['brier_ci_975']:.4f})", axis=1)
        df['formatted_brier_ci'] = df.apply(lambda x: f"({x['brier_ci_025']:.4f}, {x['brier_ci_975']:.4f})", axis=1)
        df['is_brier_win'] = (df['brier_ci_025'] > 0) & (df['brier_ci_975'] > 0)
        df['is_brier_stat_sig'] = (df['brier_ci_025'] > 0) | (df['brier_ci_975'] < 0)
    if 'win_rate_true' in df.columns:
        df['formatted_win_rate_ci_mean'] = df['win_rate_true'].apply(lambda x: f"{x:.4f}")
        df['formatted_win_rate'] = df.apply(lambda x: f"{x['win_rate_true']:.4f} ({x['win_rate_ci_025']:.4f}, {x['win_rate_ci_975']:.4f})", axis=1)
        df['formatted_win_rate'] = df.apply(lambda x: f"{x['win_rate_true']:.4f} ({x['win_rate_ci_025']:.4f}, {x['win_rate_ci_975']:.4f})", axis=1)
        df['is_win_rate_win'] = (df['win_rate_ci_025'] > 0) & (df['win_rate_ci_975'] > 0)
        df['is_win_rate_stat_sig'] = (df['win_rate_ci_025'] > 0) | (df['win_rate_ci_975'] < 0)
    df = df[['model', 'model_name', 'model_base', 'ctx_length'] + [col for col in df.columns if col not in ['model', 'model_name', 'model_base', 'ctx_length']]]
    df = df.sort_values(['model_base', 'ctx_length', ] + ([ 'task' ] if 'task' in df.columns else []))
    return df

def format_df_briers(df: pd.DataFrame, model_start: Optional[str] = None) -> pd.DataFrame:
    if model_start is not None:
        df = df[df['model'].str.startswith(model_start)]
    df = df.drop(columns=['model', 'model_name', 'formatted_brier', 'brier_ci_mean', 'brier_ci_025', 'brier_ci_500', 'brier_ci_975', 'is_brier_win'], errors='ignore') \
        .rename(columns={   'model_base' : 'Model', 
                            'ctx_length' : 'Context Length', 
                            'formatted_brier_ci_mean' : r'$\Delta$ over baseline', 
                            'formatted_brier_ci' : '95% CI', 
                            'formatted_win_rate_ci_mean' : r'Win Rate over baseline', 
                            'formatted_win_rate_ci' : 'Win Rate 95% CI', 
                            'is_brier_stat_sig' : 'Statistically Significant',
                            'is_win_rate_stat_sig' : 'Win Rate Statistically Significant',
                            'task' : 'Task'
                }, errors='ignore') \
        .sort_values(['Model', 'Context Length',] + ([ 'Task' ] if 'task' in df.columns else []))
    return df

def format_df_briers_for_latex(df: pd.DataFrame, model_start: Optional[str] = None) -> str:
    latex = format_df_briers(df, model_start).to_latex(index=False, escape=False)
    for k, v in EHRSHOT_LABELING_FUNCTION_2_PAPER_NAME.items():
        latex = latex.replace(k, v)
    latex = latex.replace('False', '').replace('True', '\checkmark')
    return latex

In [10]:
brier_scores_per_tercile = collections.defaultdict(list) # [key] = (model, task, strat, strat_col, tercile), [value] = brier's across 1k resamples for this tercile
brier_scores_metadata_per_tercile = collections.defaultdict(list) # [key] = (model, task, strat, strat_col, tercile), [value] = df_terciles
true_brier_scores_per_tercile = collections.defaultdict(int) # [key] = (model, task, strat, strat_col, tercile), [value] = True brier score for this model using original raw (non-bootstrapped) dataset

In [None]:
# Calculate Brier scores for each (model, task, tercile) across 1k resamples
if IS_LOAD_FROM_CACHE and os.path.exists('../cache/brier_scores_per_tercile.pkl') and os.path.exists('../cache/true_brier_scores_per_tercile.pkl') and os.path.exists('../cache/brier_scores_metadata_per_tercile.pkl'):
    print("Loading brier_scores_per_tercile from cache...")
    brier_scores_per_tercile = pickle.load(open('../cache/brier_scores_per_tercile.pkl', 'rb'))
    print("Loading true_brier_scores_per_tercile from cache...")
    true_brier_scores_per_tercile = pickle.load(open('../cache/true_brier_scores_per_tercile.pkl', 'rb'))
    print("Loading brier_scores_metadata_per_tercile from cache...")
    brier_scores_metadata_per_tercile = pickle.load(open('../cache/brier_scores_metadata_per_tercile.pkl', 'rb'))
else:
    # Add metrics to each (model, task) predictions
    for (model, task), df_preds in tqdm(predictions.items(), total=len(predictions)):
        # Add (patient ID, label time) to df_preds to align with terciles
        df_preds['pid'] = label_data[task]['patient_ids']
        df_preds['label_time'] = label_data[task]['times']

        # Test split
        df_preds = df_preds[df_preds['split'] == 'test']

        # Merge patient IDs
        for strat, strat_cols in strats.items():
            df_metrics = pd.read_parquet(os.path.join(path_to_ehrshot_metrics_dir, f'df__{task}__{strat}__metrics.parquet'))

            # If stratifying by inter-event times, need to pivot table since 'time' and 'metric' are separate columns
            if strat == 'inter_event_times':
                df_metrics = df_metrics.pivot_table(index=['pid', 'pid_idx', 'label_time', 'sub_task'], columns='metric', values='time').reset_index()

            # Merge metrics with predictions
            df_ = pd.merge(df_preds, df_metrics, on=['pid', 'label_time'])
            if df_.shape[0] != df_preds.shape[0]:
                print(f'{model} | {task} | {strat} | Number of rows in df does not match number of rows in df_preds: {df_.shape[0]} != {df_preds.shape[0]}')
                
            for strat_col in strat_cols:
                if strat_col not in df_metrics.columns:
                    raise ValueError(f'col={strat_col} not in df_metrics columns for strat={strat}.')

                # Metric values
                metric_values = df_[strat_col]

                # Get test patient IDs
                patient_ids = df_['pid'].values
                patient_id_indices = np.searchsorted(all_patients, patient_ids)
                assert np.all(patient_ids == all_patients[patient_id_indices])
                
                # Calculate "true" Brier scores on non-bootstrapped dataset
                df_['tercile'] = pd.qcut(df_[strat_col].fillna(0).rank(method='min'), 3, labels=False)
                assert set(df_['tercile'].unique()) == {0, 1, 2}, f'terciles not 0, 1, 2: {set(df_["tercile"].unique())}'
                for tercile in range(3):
                    # Limit to this tercile
                    df_tercile = df_[df_['tercile'] == tercile]
                    # Labels / Preds
                    y = df_tercile['y'].values.astype(int)
                    pred_proba = df_tercile['pred_proba'].values
                    # Calculate Brier score
                    brier = brier_score_loss(y, pred_proba)
                    true_brier_scores_per_tercile[(model, task, strat, strat_col, tercile)] = brier

                # Do bootstraps
                for weights in bootstrap_weights:
                    weights = weights[patient_id_indices]
                    assert weights.shape[0] == df_.shape[0] and weights.shape[0] == metric_values.shape[0], f"Error - weights shape: {weights.shape[0]}, df shape: {df_.shape[0]}, metric_values shape: {metric_values.shape[0]}"

                    # Calculate terciles
                    quantile_cutoffs = weighted_quantile(metric_values, [0.33, .67, 1], weights)
                    df_['tercile'] = np.searchsorted(quantile_cutoffs, metric_values)
                    df_['weight'] = weights
                    assert set(df_['tercile'].unique()) == {0, 1, 2}, f'terciles not 0, 1, 2: {set(df_metrics["tercile"].unique())}'

                    # Calculate Brier scores
                    for tercile in range(3):
                        # Limit to this tercile
                        df_tercile = df_[df_['tercile'] == tercile]
                        # Labels / Preds
                        y = df_tercile['y'].values.astype(int)
                        pred_proba = df_tercile['pred_proba'].values
                        sample_weight = df_tercile['weight'].values
                        # Calculate Brier score
                        brier = brier_score_loss(y, pred_proba, sample_weight=sample_weight)
                        brier_scores_per_tercile[(model, task, strat, strat_col, tercile)].append(brier)
                        brier_scores_metadata_per_tercile[(model, task, strat, strat_col, tercile)].append({
                            'brier' : ((y - pred_proba) ** 2).astype(np.float32),
                            'sample_weight' : sample_weight.astype(np.int8),
                        })

    # Save results to .pkl file
    with open('../cache/brier_scores_per_tercile.pkl', 'wb') as f:
        pickle.dump(brier_scores_per_tercile, f)
    with open('../cache/true_brier_scores_per_tercile.pkl', 'wb') as f:
        pickle.dump(true_brier_scores_per_tercile, f)
    with open('../cache/brier_scores_metadata_per_tercile.pkl', 'wb') as f:
        pickle.dump(brier_scores_metadata_per_tercile, f)

### Task-Level CIs

In [13]:
# Calculate task-level Brier CIs
df_brier_cis = []
for key in brier_scores_per_tercile.keys():
    model, task, strat, strat_col, tercile = key
    scores = brier_scores_per_tercile[key]
    df_brier_cis.append({
        'model' : model,
        'task' : task,
        'tercile' : tercile,
        'strat' : strat,
        'strat_col' : strat_col,
        'brier_true' : true_brier_scores_per_tercile[key],
        'brier_ci_mean' : np.mean(scores),
        'brier_ci_025' : np.percentile(scores, 2.5),
        'brier_ci_500' : np.percentile(scores, 50),
        'brier_ci_975' : np.percentile(scores, 97.5),
    })
df_brier_cis = pd.DataFrame(df_brier_cis).sort_values(['model', 'task', 'tercile'])
df_brier_cis

Unnamed: 0,model,task,tercile,strat,strat_col,brier_true,brier_ci_mean,brier_ci_025,brier_ci_500,brier_ci_975
2133,clmbr,guo_icu,0,inter_event_times,std,0.031065,0.031765,0.020401,0.031514,0.045442
2136,clmbr,guo_icu,0,n_gram_count,rr_1,0.036430,0.037051,0.024866,0.036749,0.050059
2139,clmbr,guo_icu,0,timeline_lengths,n_events,0.039529,0.039948,0.028213,0.040080,0.053327
2134,clmbr,guo_icu,1,inter_event_times,std,0.029109,0.028930,0.018663,0.028980,0.039850
2137,clmbr,guo_icu,1,n_gram_count,rr_1,0.033230,0.031551,0.020631,0.031499,0.043136
...,...,...,...,...,...,...,...,...,...,...
895,mamba-tiny-8192--clmbr_train-tokens-total_nonP...,new_pancan,1,n_gram_count,rr_1,0.035790,0.034799,0.022949,0.034370,0.048056
898,mamba-tiny-8192--clmbr_train-tokens-total_nonP...,new_pancan,1,timeline_lengths,n_events,0.027915,0.027906,0.018306,0.027600,0.040117
893,mamba-tiny-8192--clmbr_train-tokens-total_nonP...,new_pancan,2,inter_event_times,std,0.026984,0.027351,0.017607,0.026637,0.040013
896,mamba-tiny-8192--clmbr_train-tokens-total_nonP...,new_pancan,2,n_gram_count,rr_1,0.015759,0.015716,0.008298,0.015370,0.025473


### Model-level CIs

In [14]:
# Get test PIDs for each task
task_2_test_pids = {}
for (model, task), df_preds in predictions.items():
    if task in task_2_test_pids:
        continue
    assert label_data[task]['patient_ids'].shape[0] == len(label_data[task]['times'])
    df_preds['pid'] = label_data[task]['patient_ids']
    df_preds['label_time'] = label_data[task]['times']
    test_pids = df_preds[df_preds['split'] == 'test']
    task_2_test_pids[task] = test_pids['pid']
    if len(task_2_test_pids[task]) == len(valid_tasks):
        break
print([ len(task_2_test_pids[task]) for task in valid_tasks ])

[1258, 2195, 100568, 2243, 67028, 2220, 58155, 2127, 2189, 56338, 1317, 2222, 63653, 2037]


In [29]:
# Reweight each bootstrap by the number of labels in the task
n_labels_per_task = []
for weights in bootstrap_weights:
    weights_per_task = [
        weights[np.searchsorted(all_patients, task_2_test_pids[task])] 
        for task in valid_tasks
    ]
    n_labels_per_task.append([w.sum() for w in weights_per_task])
n_labels_per_task = np.array(n_labels_per_task).T
assert n_labels_per_task.shape == (14, 1000)

# Calculate model-level Brier CIs
df_brier_model_cis = []
for strat, strat_cols in strats.items():
    for strat_col in strat_cols:
        for model in valid_models:
            for tercile in range(3):
                ci_scores = [] # scores for this tercile, averaged across all tasks
                true_scores = [] # true scores for this tercile, averaged across all tasks
                for task in valid_tasks:
                    ci_scores.append(brier_scores_per_tercile[(model, task, strat, strat_col, tercile)]) # each is 1 x 1000
                    true_scores.append(true_brier_scores_per_tercile[(model, task, strat, strat_col, tercile)]) # each is 1 x 1
                # Raw scores
                scores = np.vstack(ci_scores) # 14 x 1000
                assert scores.shape == (14, 1000)
                # Macro-average across tasks
                scores = np.mean(scores, axis=0)
                # Micro-average across tasks
                # scores = np.average(scores, axis=0, weights=n_labels_per_task)
                # Win rates
                # t-test
                df_brier_model_cis.append({
                    'model' : model,
                    'task' : 'all',
                    'strat' : strat,
                    'strat_col' : strat_col,
                    'tercile' : tercile,
                    'brier_true' : np.mean(true_scores),
                    'brier_ci_mean' : np.mean(scores),
                    'brier_ci_025' : np.percentile(scores, 2.5),
                    'brier_ci_500' : np.percentile(scores, 50),
                    'brier_ci_975' : np.percentile(scores, 97.5),
                })
df_brier_model_cis = pd.DataFrame(df_brier_model_cis).sort_values(['model', 'tercile'])
df_brier_model_cis

Unnamed: 0,model,task,strat,strat_col,tercile,brier_true,brier_ci_mean,brier_ci_025,brier_ci_500,brier_ci_975
48,clmbr,all,inter_event_times,std,0,0.069993,0.070055,0.063221,0.070332,0.076010
99,clmbr,all,n_gram_count,rr_1,0,0.067175,0.067434,0.063413,0.067454,0.071231
150,clmbr,all,timeline_lengths,n_events,0,0.072197,0.072107,0.068240,0.072184,0.076160
49,clmbr,all,inter_event_times,std,1,0.072023,0.072004,0.067915,0.071931,0.076225
100,clmbr,all,n_gram_count,rr_1,1,0.073277,0.073251,0.068984,0.073308,0.077590
...,...,...,...,...,...,...,...,...,...,...
94,mamba-tiny-8192--clmbr_train-tokens-total_nonP...,all,n_gram_count,rr_1,1,0.069956,0.069877,0.065870,0.069895,0.074072
145,mamba-tiny-8192--clmbr_train-tokens-total_nonP...,all,timeline_lengths,n_events,1,0.069587,0.069711,0.065651,0.069711,0.074168
44,mamba-tiny-8192--clmbr_train-tokens-total_nonP...,all,inter_event_times,std,2,0.072616,0.073025,0.068764,0.073041,0.077507
95,mamba-tiny-8192--clmbr_train-tokens-total_nonP...,all,n_gram_count,rr_1,2,0.073496,0.073672,0.067804,0.073706,0.079436


In [30]:
df_irregularity = clean_df_briers(df_brier_model_cis[df_brier_model_cis['strat_col'] == 'std'])
# Take the 'tercile' column and give each unique value its own column. Use the `formatted_brier` column as the values.
df_irregularity = df_irregularity.pivot(index=['model', 'model_base', 'ctx_length'], columns='tercile', values='formatted_brier').reset_index().sort_values(['model_base', 'ctx_length'])

# Limit to first / last for clarity
df_model_latex = df_irregularity[
    (
        (df_irregularity['model_base'].isin(['mamba', 'hyena']) & df_irregularity['ctx_length'].isin([1024, 16384]))
        | (df_irregularity['model_base'].isin(['gpt2', 'llama']) & df_irregularity['ctx_length'].isin([512, 4096]))
    )
]
df_model_latex

tercile,model,model_base,ctx_length,0,1,2
4,gpt2-base-512--clmbr_train-tokens-total_nonPAD...,gpt2,512,"0.0664 (0.0603, 0.0720)","0.0700 (0.0661, 0.0742)","0.0726 (0.0688, 0.0771)"
3,gpt2-base-4096--clmbr_train-tokens-total_nonPA...,gpt2,4096,"0.0665 (0.0604, 0.0723)","0.0692 (0.0655, 0.0734)","0.0752 (0.0713, 0.0801)"
5,hyena-large-1024--clmbr_train-tokens-total_non...,hyena,1024,"0.0682 (0.0616, 0.0737)","0.0683 (0.0646, 0.0729)","0.0742 (0.0699, 0.0790)"
6,hyena-large-16384--clmbr_train-tokens-total_no...,hyena,16384,"0.0715 (0.0654, 0.0770)","0.0763 (0.0723, 0.0807)","0.0842 (0.0802, 0.0890)"
12,llama-base-512--clmbr_train-tokens-total_nonPA...,llama,512,"0.0710 (0.0646, 0.0762)","0.0706 (0.0668, 0.0753)","0.0747 (0.0706, 0.0793)"
11,llama-base-4096--clmbr_train-tokens-total_nonP...,llama,4096,"0.0678 (0.0616, 0.0736)","0.0691 (0.0654, 0.0732)","0.0734 (0.0695, 0.0782)"
13,mamba-tiny-1024--clmbr_train-tokens-total_nonP...,mamba,1024,"0.0708 (0.0649, 0.0765)","0.0717 (0.0679, 0.0759)","0.0761 (0.0719, 0.0808)"
14,mamba-tiny-16384--clmbr_train-tokens-total_non...,mamba,16384,"0.0651 (0.0597, 0.0703)","0.0676 (0.0637, 0.0716)","0.0714 (0.0677, 0.0761)"


In [28]:
df_repetition = clean_df_briers(df_brier_model_cis[df_brier_model_cis['strat_col'] == 'rr_1'])
# Take the 'tercile' column and give each unique value its own column. Use the `formatted_brier` column as the values.
df_repetition = df_repetition.pivot(index=['model', 'model_base', 'ctx_length'], columns='tercile', values='formatted_brier').reset_index().sort_values(['model_base', 'ctx_length'])

# Limit to first / last for clarity
df_model_latex = df_repetition[
    (
        (df_repetition['model_base'].isin(['mamba', 'hyena']) & df_repetition['ctx_length'].isin([1024, 16384]))
        | (df_repetition['model_base'].isin(['gpt2', 'llama']) & df_repetition['ctx_length'].isin([512, 4096]))
    )
]
df_model_latex

tercile,model,model_base,ctx_length,0,1,2
4,gpt2-base-512--clmbr_train-tokens-total_nonPAD...,gpt2,512,"0.0641 (0.0606, 0.0679)","0.0702 (0.0662, 0.0742)","0.0746 (0.0685, 0.0802)"
3,gpt2-base-4096--clmbr_train-tokens-total_nonPA...,gpt2,4096,"0.0665 (0.0628, 0.0707)","0.0698 (0.0656, 0.0738)","0.0745 (0.0690, 0.0804)"
5,hyena-large-1024--clmbr_train-tokens-total_non...,hyena,1024,"0.0652 (0.0616, 0.0691)","0.0703 (0.0661, 0.0744)","0.0753 (0.0695, 0.0809)"
6,hyena-large-16384--clmbr_train-tokens-total_no...,hyena,16384,"0.0746 (0.0708, 0.0786)","0.0771 (0.0730, 0.0812)","0.0803 (0.0747, 0.0864)"
12,llama-base-512--clmbr_train-tokens-total_nonPA...,llama,512,"0.0661 (0.0626, 0.0700)","0.0733 (0.0691, 0.0774)","0.0770 (0.0712, 0.0826)"
11,llama-base-4096--clmbr_train-tokens-total_nonP...,llama,4096,"0.0646 (0.0611, 0.0687)","0.0707 (0.0667, 0.0747)","0.0750 (0.0695, 0.0807)"
13,mamba-tiny-1024--clmbr_train-tokens-total_nonP...,mamba,1024,"0.0675 (0.0637, 0.0713)","0.0743 (0.0702, 0.0785)","0.0768 (0.0710, 0.0828)"
14,mamba-tiny-16384--clmbr_train-tokens-total_non...,mamba,16384,"0.0626 (0.0592, 0.0666)","0.0685 (0.0643, 0.0724)","0.0729 (0.0672, 0.0785)"


### LaTeX Results

In [20]:
# Model-level delta CIs for LaTeX
latex: str = clean_df_briers(df_brier_model_cis)
print(latex)

                                                model        model_name  \
64                                              clmbr             clmbr   
65                                              clmbr             clmbr   
66                                              clmbr             clmbr   
67                                              clmbr             clmbr   
16  gpt2-base-512--clmbr_train-tokens-total_nonPAD...     gpt2-base-512   
..                                                ...               ...   
59  mamba-tiny-8192--clmbr_train-tokens-total_nonP...   mamba-tiny-8192   
60  mamba-tiny-16384--clmbr_train-tokens-total_non...  mamba-tiny-16384   
61  mamba-tiny-16384--clmbr_train-tokens-total_non...  mamba-tiny-16384   
62  mamba-tiny-16384--clmbr_train-tokens-total_non...  mamba-tiny-16384   
63  mamba-tiny-16384--clmbr_train-tokens-total_non...  mamba-tiny-16384   

   model_base  ctx_length  quartile  brier_mean  brier_ci_025  brier_ci_500  \
64      clmbr       

In [35]:
# Task-level delta CIs for LaTeX
latex: str = format_df_deltas_for_latex(df_task_deltas)
print(latex)

\begin{tabular}{lrlllr}
\toprule
Model & Context Length & Task & $\Delta$ over baseline & 95% CI & Statistically Significant \\
\midrule
clmbr & 0 & ICU Admission & 0.000 & (0.000, 0.000) &  \\
clmbr & 0 & Long LOS & 0.000 & (0.000, 0.000) &  \\
clmbr & 0 & 30-day Readmission & 0.000 & (0.000, 0.000) &  \\
clmbr & 0 & Anemia & 0.000 & (0.000, 0.000) &  \\
clmbr & 0 & Hyperkalemia & 0.000 & (0.000, 0.000) &  \\
clmbr & 0 & Hypoglycemia & 0.000 & (0.000, 0.000) &  \\
clmbr & 0 & Hyponatremia & 0.000 & (0.000, 0.000) &  \\
clmbr & 0 & Thrombocytopenia & 0.000 & (0.000, 0.000) &  \\
clmbr & 0 & Acute MI & 0.000 & (0.000, 0.000) &  \\
clmbr & 0 & Celiac & 0.000 & (0.000, 0.000) &  \\
clmbr & 0 & Hyperlipidemia & 0.000 & (0.000, 0.000) &  \\
clmbr & 0 & Hypertension & 0.000 & (0.000, 0.000) &  \\
clmbr & 0 & Lupus & 0.000 & (0.000, 0.000) &  \\
clmbr & 0 & Pancreatic Cancer & 0.000 & (0.000, 0.000) &  \\
gpt2 & 512 & ICU Admission & 0.022 & (-0.005, 0.050) &  \\
gpt2 & 512 & Long LOS & -0.00

## Win Rates

In [37]:
# Calculate model-level win rates between longest and shortest context lengths for each (model, tercile)
df_win_rate_cis = []
for strat, strat_cols in strats.items():
    for strat_col in strat_cols:
        for model in valid_models:

            # Only keep shortest and longest context lengths for ease of comparison
            if model_to_base(model) in ['mamba', 'hyena' ]:
                if model_to_ctx(model) not in [1024, 16384]:
                    continue
            if model_to_base(model) in ['gpt2', 'llama' ]:
                if model_to_ctx(model) not in [512, 4096]:
                    continue
            baseline_model_name: str = model.replace(str(model_to_ctx(model)), str(1024 if model_to_base(model) in ['mamba', 'hyena' ] else 512))

            for tercile in range(3):
                win_rates = [] # scores for this tercile, averaged across all tasks
                true_win_rate = [] # true scores for this tercile, averaged across all tasks
                for task in valid_tasks:
                    ## NOTE: Comparison is '>' b/c lower Brier is better
                    # True win rate
                    true_this_model_score = np.array(true_brier_scores_per_tercile[(model, task, strat, strat_col, tercile)]) # each is 1 x 1
                    true_baseline_model_score = np.array(true_brier_scores_per_tercile[(baseline_model_name, task, strat, strat_col, tercile)]) # each is 1 x 1
                    assert true_baseline_model_score.shape == true_this_model_score.shape
                    true_win_rate.append((true_baseline_model_score > true_this_model_score).astype(bool))
                    # Bootstrap win rates
                    this_model_scores = np.array(brier_scores_per_tercile[(model, task, strat, strat_col, tercile)]) # each is 1 x 1000
                    baseline_model_scores = np.array(brier_scores_per_tercile[(
                        baseline_model_name, 
                        task, strat, strat_col, tercile
                    )]) # each is 1 x 1000
                    assert baseline_model_scores.shape == this_model_scores.shape
                    win_rates.append((baseline_model_scores > this_model_scores).astype(bool))
                # Win rates
                win_rates = np.vstack(win_rates) # 14 x 1000
                assert win_rates.shape == (14, 1000), f"win_rates.shape={win_rates.shape}"
                win_rates = np.mean(win_rates, axis=0)
                assert len(true_win_rate) == 14
                df_win_rate_cis.append({
                    'model' : model,
                    'task' : 'all',
                    'strat' : strat,
                    'strat_col' : strat_col,
                    'tercile' : tercile,
                    'win_rate_true' : np.mean(true_win_rate),
                    'win_rate_ci_mean' : np.mean(win_rates),
                    'win_rate_ci_025' : np.percentile(win_rates, 2.5),
                    'win_rate_ci_500' : np.percentile(win_rates, 50),
                    'win_rate_ci_975' : np.percentile(win_rates, 97.5),
                })
df_win_rate_cis = pd.DataFrame(df_win_rate_cis).sort_values(['model', 'tercile'])
df_win_rate_cis

Unnamed: 0,model,task,strat,strat_col,tercile,win_rate_true,win_rate_ci_mean,win_rate_ci_025,win_rate_ci_500,win_rate_ci_975
24,clmbr,all,inter_event_times,std,0,0.000000,0.000000,0.000000,0.000000,0.000000
51,clmbr,all,n_gram_count,rr_1,0,0.000000,0.000000,0.000000,0.000000,0.000000
78,clmbr,all,timeline_lengths,n_events,0,0.000000,0.000000,0.000000,0.000000,0.000000
25,clmbr,all,inter_event_times,std,1,0.000000,0.000000,0.000000,0.000000,0.000000
52,clmbr,all,n_gram_count,rr_1,1,0.000000,0.000000,0.000000,0.000000,0.000000
...,...,...,...,...,...,...,...,...,...,...
49,mamba-tiny-16384--clmbr_train-tokens-total_non...,all,n_gram_count,rr_1,1,0.785714,0.790500,0.642857,0.785714,0.928571
76,mamba-tiny-16384--clmbr_train-tokens-total_non...,all,timeline_lengths,n_events,1,0.928571,0.821714,0.642857,0.857143,0.928571
23,mamba-tiny-16384--clmbr_train-tokens-total_non...,all,inter_event_times,std,2,0.857143,0.751571,0.642857,0.785714,0.857143
50,mamba-tiny-16384--clmbr_train-tokens-total_non...,all,n_gram_count,rr_1,2,0.642857,0.668786,0.500000,0.642857,0.857143


In [38]:
df_repetition = clean_df_briers(df_win_rate_cis[df_win_rate_cis['strat_col'] == 'rr_1'])
df_repetition = df_repetition.pivot(index=['model', 'model_base', 'ctx_length'], columns='tercile', values='formatted_win_rate').reset_index().sort_values(['model_base', 'ctx_length'])
df_repetition

tercile,model,model_base,ctx_length,0,1,2
0,clmbr,clmbr,0,"0.0000 (0.0000, 0.0000)","0.0000 (0.0000, 0.0000)","0.0000 (0.0000, 0.0000)"
2,gpt2-base-512--clmbr_train-tokens-total_nonPAD...,gpt2,512,"0.0000 (0.0000, 0.0000)","0.0000 (0.0000, 0.0000)","0.0000 (0.0000, 0.0000)"
1,gpt2-base-4096--clmbr_train-tokens-total_nonPA...,gpt2,4096,"0.4286 (0.1429, 0.5714)","0.5000 (0.2857, 0.6429)","0.5000 (0.2857, 0.7143)"
3,hyena-large-1024--clmbr_train-tokens-total_non...,hyena,1024,"0.0000 (0.0000, 0.0000)","0.0000 (0.0000, 0.0000)","0.0000 (0.0000, 0.0000)"
4,hyena-large-16384--clmbr_train-tokens-total_no...,hyena,16384,"0.1429 (0.0714, 0.2143)","0.0714 (0.0000, 0.2857)","0.0714 (0.0000, 0.2857)"
6,llama-base-512--clmbr_train-tokens-total_nonPA...,llama,512,"0.0000 (0.0000, 0.0000)","0.0000 (0.0000, 0.0000)","0.0000 (0.0000, 0.0000)"
5,llama-base-4096--clmbr_train-tokens-total_nonP...,llama,4096,"0.5000 (0.5000, 0.7857)","0.6429 (0.5000, 0.8571)","0.7857 (0.5714, 0.9286)"
7,mamba-tiny-1024--clmbr_train-tokens-total_nonP...,mamba,1024,"0.0000 (0.0000, 0.0000)","0.0000 (0.0000, 0.0000)","0.0000 (0.0000, 0.0000)"
8,mamba-tiny-16384--clmbr_train-tokens-total_non...,mamba,16384,"0.8571 (0.6429, 0.9286)","0.7857 (0.6429, 0.9286)","0.6429 (0.5000, 0.8571)"


In [39]:
df_irregularity = clean_df_briers(df_win_rate_cis[df_win_rate_cis['strat_col'] == 'std'])
# Take the 'quartile' column and give each unique value its own column. Use the `formatted_brier` column as the values.
df_irregularity = df_irregularity.pivot(index=['model', 'model_base', 'ctx_length'], columns='tercile', values='formatted_win_rate').reset_index().sort_values(['model_base', 'ctx_length'])
df_irregularity

tercile,model,model_base,ctx_length,0,1,2
0,clmbr,clmbr,0,"0.0000 (0.0000, 0.0000)","0.0000 (0.0000, 0.0000)","0.0000 (0.0000, 0.0000)"
2,gpt2-base-512--clmbr_train-tokens-total_nonPAD...,gpt2,512,"0.0000 (0.0000, 0.0000)","0.0000 (0.0000, 0.0000)","0.0000 (0.0000, 0.0000)"
1,gpt2-base-4096--clmbr_train-tokens-total_nonPA...,gpt2,4096,"0.5000 (0.2143, 0.7143)","0.5000 (0.3571, 0.7857)","0.4286 (0.2143, 0.5714)"
3,hyena-large-1024--clmbr_train-tokens-total_non...,hyena,1024,"0.0000 (0.0000, 0.0000)","0.0000 (0.0000, 0.0000)","0.0000 (0.0000, 0.0000)"
4,hyena-large-16384--clmbr_train-tokens-total_no...,hyena,16384,"0.1429 (0.0714, 0.3571)","0.0000 (0.0000, 0.2143)","0.1429 (0.0714, 0.2143)"
6,llama-base-512--clmbr_train-tokens-total_nonPA...,llama,512,"0.0000 (0.0000, 0.0000)","0.0000 (0.0000, 0.0000)","0.0000 (0.0000, 0.0000)"
5,llama-base-4096--clmbr_train-tokens-total_nonP...,llama,4096,"0.7143 (0.5000, 0.9286)","0.7857 (0.5000, 0.8571)","0.7143 (0.5000, 0.8571)"
7,mamba-tiny-1024--clmbr_train-tokens-total_nonP...,mamba,1024,"0.0000 (0.0000, 0.0000)","0.0000 (0.0000, 0.0000)","0.0000 (0.0000, 0.0000)"
8,mamba-tiny-16384--clmbr_train-tokens-total_non...,mamba,16384,"0.8571 (0.5714, 0.8571)","0.8571 (0.6429, 0.9286)","0.8571 (0.6429, 0.8571)"


## T-test

## T-test

In [None]:
# Calculate t-test scores for each (model, task, tercile)
raw_brier_scores_per_tercile = collections.defaultdict(list) # [key] = (model, task, strat, strat_col, tercile), [value] = brier scores across all labels for this task and quaratile
if IS_LOAD_FROM_CACHE and os.path.exists('../cache/raw_brier_scores_per_tercile.pkl'):
    print("Loading raw_brier_scores_per_tercile from cache...")
    raw_brier_scores_per_tercile = pickle.load(open('../cache/raw_brier_scores_per_tercile.pkl', 'rb'))
else:
    # Add metrics to each (model, task) predictions
    for (model, task), df_preds in tqdm(predictions.items(), total=len(predictions)):
        # Add (patient ID, label time) to df_preds to align with terciles
        df_preds['pid'] = label_data[task]['patient_ids']
        df_preds['label_time'] = label_data[task]['times']

        # Test split
        df_preds = df_preds[df_preds['split'] == 'test']

        # Merge patient IDs
        for strat, strat_cols in strats.items():
            df_metrics = pd.read_parquet(os.path.join(path_to_ehrshot_metrics_dir, f'df__{task}__{strat}__metrics.parquet'))

            # If stratifying by inter-event times, need to pivot table since 'time' and 'metric' are separate columns
            if strat == 'inter_event_times':
                df_metrics = df_metrics.pivot_table(index=['pid', 'pid_idx', 'label_time', 'sub_task'], columns='metric', values='time').reset_index()

            # Merge metrics with predictions
            df_ = pd.merge(df_preds, df_metrics, on=['pid', 'label_time'])
            if df_.shape[0] != df_preds.shape[0]:
                print(f'{model} | {task} | {strat} | Number of rows in df does not match number of rows in df_preds: {df_.shape[0]} != {df_preds.shape[0]}')

            for strat_col in strat_cols:
                if strat_col not in df_metrics.columns:
                    raise ValueError(f'col={strat_col} not in df_metrics columns for strat={strat}.')

                # Metric values
                metric_values = df_[strat_col]

                # Calculate t-test values against each model for this task
                df_['tercile'] = pd.qcut(df_[strat_col].fillna(0).rank(method='min'), 3, labels=False)
                assert set(df_['tercile'].unique()) == {0, 1, 2}, f'terciles not 0, 1, 2: {set(df_["tercile"].unique())}'
                for tercile in range(3):
                    # Limit to this tercile
                    df_tercile = df_[df_['tercile'] == tercile]
                    # Labels / Preds
                    y = df_tercile['y'].values.astype(int)
                    pred_proba = df_tercile['pred_proba'].values
                    # Calculate brier score per label
                    brier_scores = (y - pred_proba) ** 2
                    raw_brier_scores_per_tercile[(model, task, strat, strat_col, tercile)] = brier_scores

    # Save results to .pkl file
    with open('../cache/raw_brier_scores_per_tercile.pkl', 'wb') as f:
        pickle.dump(raw_brier_scores_per_tercile, f)

In [40]:
# Run t-test for each (model, task, strat, strat_col tercile) against every other model for that task and tercile
df_t_test_results = []
if IS_LOAD_FROM_CACHE and os.path.exists('../cache/df_t_test_results.pkl'):
    print("Loading df_t_test_results from cache...")
    df_t_test_results = pickle.load(open('../cache/df_t_test_results.pkl', 'rb'))
else:
    for (model1, task, strat, strat_col, tercile), brier_scores in tqdm(raw_brier_scores_per_tercile.items(), total=len(raw_brier_scores_per_tercile)):
        for model2 in valid_models:
            if model1 == model2:
                continue
            if model1[:3] != model2[:3]:
                # Only compare models of the same base
                continue
            brier_scores2 = raw_brier_scores_per_tercile[(model2, task, strat, strat_col, tercile)]
            t_test, p_value = scipy.stats.ttest_ind(brier_scores, brier_scores2)
            df_t_test_results.append({
                'model1': model1,
                'model2': model2,
                'task': task,
                'strat': strat,
                'strat_col': strat_col,
                'tercile': tercile,
                't_test': t_test,
                'p_value': p_value,
            })
df_t_test_results = pd.DataFrame(df_t_test_results)
df_t_test_results['model1_name'] = df_t_test_results['model1'].apply(model_to_name)
df_t_test_results['model1_base'] = df_t_test_results['model1'].apply(model_to_base)
df_t_test_results['model1_ctx_length'] = df_t_test_results['model1'].apply(model_to_ctx).astype(int)
df_t_test_results['model2_name'] = df_t_test_results['model2'].apply(model_to_name)
df_t_test_results['model2_base'] = df_t_test_results['model2'].apply(model_to_base)
df_t_test_results['model2_ctx_length'] = df_t_test_results['model2'].apply(model_to_ctx).astype(int)
df_t_test_results

100%|██████████| 2856/2856 [00:14<00:00, 199.02it/s]


Unnamed: 0,model1,model2,task,strat,strat_col,quartile,t_test,p_value,model1_name,model1_base,model1_ctx_length,model2_name,model2_base,model2_ctx_length
0,llama-base-512--clmbr_train-tokens-total_nonPA...,llama-base-1024--clmbr_train-tokens-total_nonP...,new_hypertension,inter_event_times,std,0,0.514707,0.606939,llama-base-512,llama,512,llama-base-1024,llama,1024
1,llama-base-512--clmbr_train-tokens-total_nonPA...,llama-base-2048--clmbr_train-tokens-total_nonP...,new_hypertension,inter_event_times,std,0,0.772041,0.440381,llama-base-512,llama,512,llama-base-2048,llama,2048
2,llama-base-512--clmbr_train-tokens-total_nonPA...,llama-base-4096--clmbr_train-tokens-total_nonP...,new_hypertension,inter_event_times,std,0,0.631369,0.528029,llama-base-512,llama,512,llama-base-4096,llama,4096
3,llama-base-512--clmbr_train-tokens-total_nonPA...,llama-base-1024--clmbr_train-tokens-total_nonP...,new_hypertension,inter_event_times,std,1,0.516850,0.605443,llama-base-512,llama,512,llama-base-1024,llama,1024
4,llama-base-512--clmbr_train-tokens-total_nonPA...,llama-base-2048--clmbr_train-tokens-total_nonP...,new_hypertension,inter_event_times,std,1,0.429320,0.667838,llama-base-512,llama,512,llama-base-2048,llama,2048
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
8059,mamba-tiny-16384--clmbr_train-tokens-total_non...,mamba-tiny-4096--clmbr_train-tokens-total_nonP...,guo_icu,timeline_lengths,n_events,2,0.049866,0.960239,mamba-tiny-16384,mamba,16384,mamba-tiny-4096,mamba,4096
8060,mamba-tiny-16384--clmbr_train-tokens-total_non...,mamba-tiny-8192--clmbr_train-tokens-total_nonP...,guo_icu,timeline_lengths,n_events,2,0.077202,0.938478,mamba-tiny-16384,mamba,16384,mamba-tiny-8192,mamba,8192
8061,mamba-tiny-16384--clmbr_train-tokens-total_non...,mamba-tiny-1024--clmbr_train-tokens-total_nonP...,guo_icu,timeline_lengths,n_events,3,-0.326128,0.744395,mamba-tiny-16384,mamba,16384,mamba-tiny-1024,mamba,1024
8062,mamba-tiny-16384--clmbr_train-tokens-total_non...,mamba-tiny-4096--clmbr_train-tokens-total_nonP...,guo_icu,timeline_lengths,n_events,3,-0.053720,0.957168,mamba-tiny-16384,mamba,16384,mamba-tiny-4096,mamba,4096


In [44]:
df_t_test_results_high_low = df_t_test_results[
    (
        (df_t_test_results['model1_base'].isin(['mamba', 'hyena']) & df_t_test_results['model1_ctx_length'].isin([16384]))
        | (df_t_test_results['model1_base'].isin(['gpt2', 'llama']) & df_t_test_results['model1_ctx_length'].isin([4096]))
    )
    & 
    (
        (df_t_test_results['model2_base'].isin(['mamba', 'hyena']) & df_t_test_results['model2_ctx_length'].isin([1024,]))
        | (df_t_test_results['model2_base'].isin(['gpt2', 'llama']) & df_t_test_results['model2_ctx_length'].isin([512,]))
    )
]
df_t_test_results_high_low

Unnamed: 0,model1,model2,task,strat,strat_col,quartile,t_test,p_value,model1_name,model1_base,model1_ctx_length,model2_name,model2_base,model2_ctx_length
108,llama-base-4096--clmbr_train-tokens-total_nonP...,llama-base-512--clmbr_train-tokens-total_nonPA...,new_hypertension,inter_event_times,std,0,-0.631369,0.528029,llama-base-4096,llama,4096,llama-base-512,llama,512
111,llama-base-4096--clmbr_train-tokens-total_nonP...,llama-base-512--clmbr_train-tokens-total_nonPA...,new_hypertension,inter_event_times,std,1,-0.457260,0.647643,llama-base-4096,llama,4096,llama-base-512,llama,512
114,llama-base-4096--clmbr_train-tokens-total_nonP...,llama-base-512--clmbr_train-tokens-total_nonPA...,new_hypertension,inter_event_times,std,2,-0.254339,0.799317,llama-base-4096,llama,4096,llama-base-512,llama,512
117,llama-base-4096--clmbr_train-tokens-total_nonP...,llama-base-512--clmbr_train-tokens-total_nonPA...,new_hypertension,inter_event_times,std,3,-0.193799,0.846396,llama-base-4096,llama,4096,llama-base-512,llama,512
120,llama-base-4096--clmbr_train-tokens-total_nonP...,llama-base-512--clmbr_train-tokens-total_nonPA...,new_hypertension,n_gram_count,rr_1,0,-0.416442,0.677229,llama-base-4096,llama,4096,llama-base-512,llama,512
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
8049,mamba-tiny-16384--clmbr_train-tokens-total_non...,mamba-tiny-1024--clmbr_train-tokens-total_nonP...,guo_icu,n_gram_count,rr_1,3,-0.127728,0.898390,mamba-tiny-16384,mamba,16384,mamba-tiny-1024,mamba,1024
8052,mamba-tiny-16384--clmbr_train-tokens-total_non...,mamba-tiny-1024--clmbr_train-tokens-total_nonP...,guo_icu,timeline_lengths,n_events,0,-0.446571,0.655280,mamba-tiny-16384,mamba,16384,mamba-tiny-1024,mamba,1024
8055,mamba-tiny-16384--clmbr_train-tokens-total_non...,mamba-tiny-1024--clmbr_train-tokens-total_nonP...,guo_icu,timeline_lengths,n_events,1,-0.225352,0.821751,mamba-tiny-16384,mamba,16384,mamba-tiny-1024,mamba,1024
8058,mamba-tiny-16384--clmbr_train-tokens-total_non...,mamba-tiny-1024--clmbr_train-tokens-total_nonP...,guo_icu,timeline_lengths,n_events,2,-0.016608,0.986752,mamba-tiny-16384,mamba,16384,mamba-tiny-1024,mamba,1024


In [45]:
df_t_test_results_high_low[df_t_test_results_high_low['p_value'] < 0.05]

Unnamed: 0,model1,model2,task,strat,strat_col,quartile,t_test,p_value,model1_name,model1_base,model1_ctx_length,model2_name,model2_base,model2_ctx_length
2268,mamba-tiny-16384--clmbr_train-tokens-total_non...,mamba-tiny-1024--clmbr_train-tokens-total_nonP...,new_lupus,inter_event_times,std,0,2.175077,2.983300e-02,mamba-tiny-16384,mamba,16384,mamba-tiny-1024,mamba,1024
2301,mamba-tiny-16384--clmbr_train-tokens-total_non...,mamba-tiny-1024--clmbr_train-tokens-total_nonP...,new_lupus,timeline_lengths,n_events,3,2.715803,6.713266e-03,mamba-tiny-16384,mamba,16384,mamba-tiny-1024,mamba,1024
2412,llama-base-4096--clmbr_train-tokens-total_nonP...,llama-base-512--clmbr_train-tokens-total_nonPA...,lab_hyponatremia,inter_event_times,std,0,-10.210958,1.914469e-24,llama-base-4096,llama,4096,llama-base-512,llama,512
2415,llama-base-4096--clmbr_train-tokens-total_nonP...,llama-base-512--clmbr_train-tokens-total_nonPA...,lab_hyponatremia,inter_event_times,std,1,-7.703955,1.353396e-14,llama-base-4096,llama,4096,llama-base-512,llama,512
2418,llama-base-4096--clmbr_train-tokens-total_nonP...,llama-base-512--clmbr_train-tokens-total_nonPA...,lab_hyponatremia,inter_event_times,std,2,-7.391177,1.487361e-13,llama-base-4096,llama,4096,llama-base-512,llama,512
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
6897,mamba-tiny-16384--clmbr_train-tokens-total_non...,mamba-tiny-1024--clmbr_train-tokens-total_nonP...,new_celiac,n_gram_count,rr_1,3,-2.369788,1.796855e-02,mamba-tiny-16384,mamba,16384,mamba-tiny-1024,mamba,1024
6900,mamba-tiny-16384--clmbr_train-tokens-total_non...,mamba-tiny-1024--clmbr_train-tokens-total_nonP...,new_celiac,timeline_lengths,n_events,0,-3.865597,1.172455e-04,mamba-tiny-16384,mamba,16384,mamba-tiny-1024,mamba,1024
6903,mamba-tiny-16384--clmbr_train-tokens-total_non...,mamba-tiny-1024--clmbr_train-tokens-total_nonP...,new_celiac,timeline_lengths,n_events,1,-2.989941,2.851927e-03,mamba-tiny-16384,mamba,16384,mamba-tiny-1024,mamba,1024
6906,mamba-tiny-16384--clmbr_train-tokens-total_non...,mamba-tiny-1024--clmbr_train-tokens-total_nonP...,new_celiac,timeline_lengths,n_events,2,-3.657793,2.663531e-04,mamba-tiny-16384,mamba,16384,mamba-tiny-1024,mamba,1024
