In [1]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from tqdm import tqdm
import os
import sklearn.utils
from typing import List, Optional, Dict
from sklearn.metrics import roc_auc_score, brier_score_loss
import collections
import warnings
import pickle
import femr.datasets
import datetime
from hf_ehr.config import EHRSHOT_LABELING_FUNCTION_2_PAPER_NAME
from hf_ehr.notebooks.ehr_specific_properties.utils import (
    get_labels_and_features, 
    get_patient_splits_by_idx
)
import scipy.stats
from femr.labelers import load_labeled_patients, LabeledPatients

warnings.filterwarnings("ignore", category=FutureWarning, module="sklearn.utils.validation")

PATH_TO_DATABASE: str = '/share/pi/nigam/mwornow/ehrshot-benchmark/EHRSHOT_ASSETS/femr/extract'
PATH_TO_FEATURES_DIR: str = '/share/pi/nigam/mwornow/ehrshot-benchmark/EHRSHOT_ASSETS/features_ehrshot'
PATH_TO_RESULTS_DIR: str = '/share/pi/nigam/users/migufuen/ehrshot-benchmark/EHRSHOT_ASSETS/results_ehrshot'
PATH_TO_TOKENIZED_TIMELINES_DIR: str = '/share/pi/nigam/mwornow/ehrshot-benchmark/EHRSHOT_ASSETS/tokenized_timelines_ehrshot'
PATH_TO_LABELS_DIR: str = '/share/pi/nigam/mwornow/ehrshot-benchmark/EHRSHOT_ASSETS/benchmark_ehrshot'
PATH_TO_SPLIT_CSV: str = '/share/pi/nigam/mwornow/ehrshot-benchmark/EHRSHOT_ASSETS/splits_ehrshot/person_id_map.csv'
femr_db = femr.datasets.PatientDatabase(PATH_TO_DATABASE)
os.makedirs('../cache', exist_ok=True)

IS_LOAD_FROM_CACHE: bool = True

# Load Data

First, we need to load all of the patient-level predictions / labels for every model and task.
This takes ~30 min.

1. Load list of task names, model names
2. Load patient-level predictions for each (model, task)
3. Load patient-level ground truth labels for each (task)

In [2]:
# Get list of tasks
valid_tasks = os.listdir(PATH_TO_RESULTS_DIR)
valid_tasks.remove('chexpert')

# Filter results to only valid models
valid_models = [ 
    'llama-base-512--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last',
    'llama-base-1024--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last',
    'llama-base-2048--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last',
    'llama-base-4096--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last',
    'gpt2-base-512--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last',
    'gpt2-base-1024--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last',
    'gpt2-base-2048--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last',
    'gpt2-base-4096--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last',
    'hyena-large-1024--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last',
    'hyena-large-4096--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last',
    'hyena-large-8192--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last',
    'hyena-large-16384--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last',
    'mamba-tiny-1024--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last',
    'mamba-tiny-4096--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last',
    'mamba-tiny-8192--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last',
    'mamba-tiny-16384--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last', 
    'clmbr',
]
print("Tasks:", valid_tasks)
print("# of valid models:", len(valid_models))

Tasks: ['new_hypertension', 'guo_los', 'lab_hypoglycemia', 'new_lupus', 'lab_hyponatremia', 'new_pancan', 'lab_anemia', 'new_acutemi', 'guo_readmission', 'lab_thrombocytopenia', 'new_hyperlipidemia', 'new_celiac', 'lab_hyperkalemia', 'guo_icu']
# of valid models: 17


In [3]:
# Load patient-level predictions for each (model, task)
predictions = {} # [key] = (model, task), [value] = df_preds

if IS_LOAD_FROM_CACHE and os.path.exists('../cache/predictions.pkl'):
    print("Loading predictions from cache...")
    predictions = pickle.load(open('../cache/predictions.pkl', 'rb'))
else:
    for task in tqdm(valid_tasks):
        for model in valid_models:
            path = os.path.join(PATH_TO_RESULTS_DIR, task, 'models', model, 'lr_lbfgs', f'subtask={task}', 'k=-1', 'preds.csv')
            if not os.path.exists(path):
                print("Missing path for ", model, task)
            df_preds = pd.read_csv(path)
            assert df_preds['replicate'].nunique() == 1, f"Multiple replicates for {model}, {task}"
            predictions[(model, task)] = df_preds
    # Save results to .pkl file
    with open('../cache/predictions.pkl', 'wb') as f:
        pickle.dump(predictions, f)
print("# of (model, task) preds:", len(predictions))

Loading predictions from cache...


# of (model, task) preds: 238


In [4]:
# Load patient-level labels for each task
# NOTE: Takes ~1 hr
label_data = {} # [key] = task, [value] = {'times': label_times, 'values': label_values, 'patient_ids': patient_ids}

if IS_LOAD_FROM_CACHE and os.path.exists('../cache/label_data.pkl'):
    print("Loading label data from cache...")
    label_data = pickle.load(open('../cache/label_data.pkl', 'rb'))
else:
    for task in tqdm(valid_tasks):
        # Load labeled patients for this task
        LABELING_FUNCTION: str = task
        PATH_TO_LABELED_PATIENTS: str =  os.path.join(PATH_TO_LABELS_DIR, LABELING_FUNCTION, 'labeled_patients.csv')
        labeled_patients = femr.labelers.load_labeled_patients(PATH_TO_LABELED_PATIENTS)
        
        # Get features for patients
        model: str = 'gpt2-base-512--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last'
        patient_ids, label_values, label_times, feature_matrixes = get_labels_and_features(labeled_patients, 
                                                                                            PATH_TO_FEATURES_DIR, 
                                                                                            PATH_TO_TOKENIZED_TIMELINES_DIR,
                                                                                            models_to_keep=[model,])
        train_pids_idx, val_pids_idx, test_pids_idx = get_patient_splits_by_idx(PATH_TO_SPLIT_CSV, patient_ids)
        label_times = label_times[train_pids_idx + val_pids_idx + test_pids_idx]
        label_values = label_values[train_pids_idx + val_pids_idx + test_pids_idx]
        patient_ids = patient_ids[train_pids_idx + val_pids_idx + test_pids_idx]
        label_times = [ x.astype(datetime.datetime) for x in label_times ] # cast to Python datetime
        label_data[task] = {'times': label_times, 'values': label_values, 'patient_ids': patient_ids}

    # Save results to .pkl file
    with open('../cache/label_data.pkl', 'wb') as f:
        pickle.dump(label_data, f)

Loading label data from cache...


# Calculate CIs

Calculate bootstrapped 95% CIs over the test set (one sample per patient).

1. Generate 1000 bootstrapped resamples of patient IDs across all splits
2. Loop through every (model, task) preds, limit to test set, loop through 1000 bootstrap weightings, loop through each stratification metric, recalculate quartiles of patients based on their metrics, and calculate Brier score for each quartile
3. Calculate 95% CIs

### Do Bootstrapping of Brier Scores across Quartiles

In [5]:
# Helpers
strats = {
    'inter_event_times': [ 'std' ],
    'n_gram_count': ['rr_1'], 
    'timeline_lengths': ['n_events'],
}
path_to_ehrshot_metrics_dir: str = '/share/pi/nigam/mwornow/ehrshot-benchmark/ehrshot/stratify/'

def weighted_quantile(values, quantiles, sample_weight=None, values_sorted=False):
    values = np.array(values)
    quantiles = np.array(quantiles)
    if sample_weight is None:
        sample_weight = np.ones(len(values))
    sample_weight = np.array(sample_weight)
    assert np.all(quantiles >= 0) and np.all(quantiles <= 1), \
        'quantiles should be in [0, 1]'

    if not values_sorted:
        sorter = np.argsort(values)
        values = values[sorter]
        sample_weight = sample_weight[sorter]

    weighted_quantiles = np.cumsum(sample_weight) - 0.5 * sample_weight
    weighted_quantiles /= np.sum(sample_weight)
    return np.interp(quantiles, weighted_quantiles, values)

In [6]:
# Get all patient IDs
all_patients = np.sort(np.unique(np.concatenate([v['patient_ids'] for v in label_data.values()])))
print("# of unique patients:", len(all_patients))

# Patient-level resampling across all patient IDs in train/val/test
# Later, we limit to just test patient IDs per task
bootstrap_weights = []
np.random.seed(342342)
for i in range(1000):
    patient_sample = np.random.choice(list(range(len(all_patients))), len(all_patients), replace=True)
    weights = np.zeros_like(all_patients, dtype=np.float32)
    np.add.at(weights, patient_sample, 1)
    assert np.mean(weights) == 1
    bootstrap_weights.append(weights)

# of unique patients: 6275


In [7]:
# Formatting
def model_to_base(model: str) -> str:
    return model.split('-')[0]
def model_to_ctx(model: str) -> str:
    return int(model.split('--')[0].split('-')[-1]) if model != 'clmbr' else 0
def model_to_name(model: str) -> str:
    return model.split('--')[0]

def clean_df_briers(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()
    df['model_name'] = df['model'].apply(model_to_name)
    df['model_base'] = df['model'].apply(model_to_base)
    df['ctx_length'] = df['model'].apply(model_to_ctx).astype(int)
    if 'brier_true' in df.columns:
        df['formatted_brier_ci_mean'] = df['brier_true'].apply(lambda x: f"{x:.4f}")
        df['formatted_brier'] = df.apply(lambda x: f"{x['brier_true']:.4f} ({x['brier_ci_025']:.4f}, {x['brier_ci_975']:.4f})", axis=1)
        df['formatted_brier_ci'] = df.apply(lambda x: f"({x['brier_ci_025']:.4f}, {x['brier_ci_975']:.4f})", axis=1)
        df['is_brier_win'] = (df['brier_ci_025'] > 0) & (df['brier_ci_975'] > 0)
        df['is_brier_stat_sig'] = (df['brier_ci_025'] > 0) | (df['brier_ci_975'] < 0)
    if 'win_rate_true' in df.columns:
        df['formatted_win_rate_ci_mean'] = df['win_rate_true'].apply(lambda x: f"{x:.4f}")
        df['formatted_win_rate'] = df.apply(lambda x: f"{x['win_rate_true']:.4f} ({x['win_rate_ci_025']:.4f}, {x['win_rate_ci_975']:.4f})", axis=1)
        df['formatted_win_rate'] = df.apply(lambda x: f"{x['win_rate_true']:.4f} ({x['win_rate_ci_025']:.4f}, {x['win_rate_ci_975']:.4f})", axis=1)
        df['is_win_rate_win'] = (df['win_rate_ci_025'] > 0) & (df['win_rate_ci_975'] > 0)
        df['is_win_rate_stat_sig'] = (df['win_rate_ci_025'] > 0) | (df['win_rate_ci_975'] < 0)
    df = df[['model', 'model_name', 'model_base', 'ctx_length'] + [col for col in df.columns if col not in ['model', 'model_name', 'model_base', 'ctx_length']]]
    df = df.sort_values(['model_base', 'ctx_length', ] + ([ 'task' ] if 'task' in df.columns else []))
    return df

def format_df_briers(df: pd.DataFrame, model_start: Optional[str] = None) -> pd.DataFrame:
    if model_start is not None:
        df = df[df['model'].str.startswith(model_start)]
    df = df.drop(columns=['model', 'model_name', 'formatted_brier', 'brier_ci_mean', 'brier_ci_025', 'brier_ci_500', 'brier_ci_975', 'is_brier_win'], errors='ignore') \
        .rename(columns={   'model_base' : 'Model', 
                            'ctx_length' : 'Context Length', 
                            'formatted_brier_ci_mean' : r'$\Delta$ over baseline', 
                            'formatted_brier_ci' : '95% CI', 
                            'formatted_win_rate_ci_mean' : r'Win Rate over baseline', 
                            'formatted_win_rate_ci' : 'Win Rate 95% CI', 
                            'is_brier_stat_sig' : 'Statistically Significant',
                            'is_win_rate_stat_sig' : 'Win Rate Statistically Significant',
                            'task' : 'Task'
                }, errors='ignore') \
        .sort_values(['Model', 'Context Length',] + ([ 'Task' ] if 'task' in df.columns else []))
    return df

def format_df_briers_for_latex(df: pd.DataFrame, model_start: Optional[str] = None) -> str:
    latex = format_df_briers(df, model_start).to_latex(index=False, escape=False)
    for k, v in EHRSHOT_LABELING_FUNCTION_2_PAPER_NAME.items():
        latex = latex.replace(k, v)
    latex = latex.replace('False', '').replace('True', '\checkmark')
    return latex

In [8]:
brier_scores_per_quartile = collections.defaultdict(list) # [key] = (model, task, strat, strat_col, quartile), [value] = brier's across 1k resamples for this quartile
brier_scores_metadata_per_quartile = collections.defaultdict(list) # [key] = (model, task, strat, strat_col, quartile), [value] = metadata
true_brier_scores_per_quartile = collections.defaultdict(int) # [key] = (model, task, strat, strat_col, quartile), [value] = True brier score for this model using original raw (non-bootstrapped) dataset

In [9]:
# Calculate Brier scores for each (model, task, quartile) across 1k resamples
if IS_LOAD_FROM_CACHE and os.path.exists('../cache/brier_scores_per_quartile.pkl') and os.path.exists('../cache/true_brier_scores_per_quartile.pkl') and os.path.exists('../cache/brier_scores_metadata_per_quartile.pkl'):
    print("Loading brier_scores_per_quartile from cache...")
    brier_scores_per_quartile = pickle.load(open('../cache/brier_scores_per_quartile.pkl', 'rb'))
    # print("Loading brier_scores_metadata_per_quartile from cache...")
    # brier_scores_metadata_per_quartile = pickle.load(open('../cache/brier_scores_metadata_per_quartile.pkl', 'rb'))
    print("Loading true_brier_scores_per_quartile from cache...")
    true_brier_scores_per_quartile = pickle.load(open('../cache/true_brier_scores_per_quartile.pkl', 'rb'))
else:
    # Add metrics to each (model, task) predictions
    for (model, task), df_preds in tqdm(predictions.items(), total=len(predictions)):
        # Add (patient ID, label time) to df_preds to align with quartiles
        df_preds['pid'] = label_data[task]['patient_ids']
        df_preds['label_time'] = label_data[task]['times']

        # Test split
        df_preds = df_preds[df_preds['split'] == 'test']

        # Merge patient IDs
        for strat, strat_cols in strats.items():
            df_metrics = pd.read_parquet(os.path.join(path_to_ehrshot_metrics_dir, f'df__{task}__{strat}__metrics.parquet'))

            # If stratifying by inter-event times, need to pivot table since 'time' and 'metric' are separate columns
            if strat == 'inter_event_times':
                df_metrics = df_metrics.pivot_table(index=['pid', 'pid_idx', 'label_time', 'sub_task'], columns='metric', values='time').reset_index()

            # Merge metrics with predictions
            df_ = pd.merge(df_preds, df_metrics, on=['pid', 'label_time'])
            if df_.shape[0] != df_preds.shape[0]:
                print(f'{model} | {task} | {strat} | Number of rows in df does not match number of rows in df_preds: {df_.shape[0]} != {df_preds.shape[0]}')
                
            for strat_col in strat_cols:
                if strat_col not in df_metrics.columns:
                    raise ValueError(f'col={strat_col} not in df_metrics columns for strat={strat}.')

                # Metric values
                metric_values = df_[strat_col]

                # Get test patient IDs
                patient_ids = df_['pid'].values
                patient_id_indices = np.searchsorted(all_patients, patient_ids)
                assert np.all(patient_ids == all_patients[patient_id_indices])
                
                # Calculate "true" Brier scores on non-bootstrapped dataset
                df_['quartile'] = pd.qcut(df_[strat_col].fillna(0).rank(method='min'), 4, labels=False)
                assert set(df_['quartile'].unique()) == {0, 1, 2, 3}, f'Quartiles not 0, 1, 2, 3: {set(df_["quartile"].unique())}'
                for quartile in range(4):
                    # Limit to this quartile
                    df_quartile = df_[df_['quartile'] == quartile]
                    # Labels / Preds
                    y = df_quartile['y'].values.astype(int)
                    pred_proba = df_quartile['pred_proba'].values
                    # Calculate Brier score
                    brier = brier_score_loss(y, pred_proba)
                    true_brier_scores_per_quartile[(model, task, strat, strat_col, quartile)] = brier

                # Do bootstraps
                for weights in bootstrap_weights:
                    weights = weights[patient_id_indices]
                    assert weights.shape[0] == df_.shape[0] and weights.shape[0] == metric_values.shape[0], f"Error - weights shape: {weights.shape[0]}, df shape: {df_.shape[0]}, metric_values shape: {metric_values.shape[0]}"

                    # Calculate quartiles
                    quantile_cutoffs = weighted_quantile(metric_values, [0.25, .5, .75, 1], weights)
                    df_['quartile'] = np.searchsorted(quantile_cutoffs, metric_values)
                    df_['weight'] = weights
                    assert set(df_['quartile'].unique()) == {0, 1, 2, 3}, f'Quartiles not 0, 1, 2, 3: {set(df_metrics["quartile"].unique())}'

                    # Calculate Brier scores
                    for quartile in range(4):
                        # Limit to this quartile
                        df_quartile = df_[df_['quartile'] == quartile]
                        # Labels / Preds
                        y = df_quartile['y'].values.astype(int)
                        pred_proba = df_quartile['pred_proba'].values
                        sample_weight = df_quartile['weight'].values
                        # Calculate Brier score
                        brier = brier_score_loss(y, pred_proba, sample_weight=sample_weight)
                        brier_scores_per_quartile[(model, task, strat, strat_col, quartile)].append(brier)
                        brier_scores_metadata_per_quartile[(model, task, strat, strat_col, quartile)].append({
                            'brier' : ((y - pred_proba) ** 2).astype(np.float32),
                            'sample_weight' : sample_weight.astype(np.int8),
                        })
    # Save results to .pkl file
    with open('../cache/brier_scores_per_quartile.pkl', 'wb') as f:
        pickle.dump(brier_scores_per_quartile, f)
    with open('../cache/true_brier_scores_per_quartile.pkl', 'wb') as f:
        pickle.dump(true_brier_scores_per_quartile, f)
    with open('../cache/brier_scores_metadata_per_quartile.pkl', 'wb') as f:
        pickle.dump(brier_scores_metadata_per_quartile, f)

Loading brier_scores_per_quartile from cache...
Loading true_brier_scores_per_quartile from cache...


### Task-Level CIs

In [10]:
# Calculate task-level Brier CIs
df_brier_cis = []
for key in brier_scores_per_quartile.keys():
    model, task, strat, strat_col, quartile = key
    scores = brier_scores_per_quartile[key]
    df_brier_cis.append({
        'model' : model,
        'task' : task,
        'quartile' : quartile,
        'strat' : strat,
        'strat_col' : strat_col,
        'brier_true' : true_brier_scores_per_quartile[key],
        'brier_ci_mean' : np.mean(scores),
        'brier_ci_025' : np.percentile(scores, 2.5),
        'brier_ci_500' : np.percentile(scores, 50),
        'brier_ci_975' : np.percentile(scores, 97.5),
    })
df_brier_cis = pd.DataFrame(df_brier_cis).sort_values(['model', 'task', 'quartile'])
df_brier_cis

Unnamed: 0,model,task,quartile,strat,strat_col,brier_true,brier_ci_mean,brier_ci_025,brier_ci_500,brier_ci_975
2844,clmbr,guo_icu,0,inter_event_times,std,0.035984,0.035149,0.019668,0.035239,0.052189
2848,clmbr,guo_icu,0,n_gram_count,rr_1,0.040230,0.039553,0.026063,0.039179,0.053784
2852,clmbr,guo_icu,0,timeline_lengths,n_events,0.044137,0.044539,0.030298,0.044418,0.059930
2845,clmbr,guo_icu,1,inter_event_times,std,0.023934,0.024645,0.014234,0.024329,0.037353
2849,clmbr,guo_icu,1,n_gram_count,rr_1,0.032956,0.033599,0.020306,0.033245,0.048665
...,...,...,...,...,...,...,...,...,...,...
1194,mamba-tiny-8192--clmbr_train-tokens-total_nonP...,new_pancan,2,n_gram_count,rr_1,0.026120,0.026648,0.013950,0.026183,0.041951
1198,mamba-tiny-8192--clmbr_train-tokens-total_nonP...,new_pancan,2,timeline_lengths,n_events,0.024836,0.023678,0.014339,0.023457,0.035831
1191,mamba-tiny-8192--clmbr_train-tokens-total_nonP...,new_pancan,3,inter_event_times,std,0.026012,0.025991,0.015477,0.025312,0.040845
1195,mamba-tiny-8192--clmbr_train-tokens-total_nonP...,new_pancan,3,n_gram_count,rr_1,0.014126,0.014282,0.005617,0.013576,0.026770


### Model-level CIs

In [11]:
# Get test PIDs for each task
task_2_test_pids = {}
for (model, task), df_preds in predictions.items():
    if task in task_2_test_pids:
        continue
    assert label_data[task]['patient_ids'].shape[0] == len(label_data[task]['times'])
    df_preds['pid'] = label_data[task]['patient_ids']
    df_preds['label_time'] = label_data[task]['times']
    test_pids = df_preds[df_preds['split'] == 'test']
    task_2_test_pids[task] = test_pids['pid']
    if len(task_2_test_pids[task]) == len(valid_tasks):
        break
print([ len(task_2_test_pids[task]) for task in valid_tasks ])

[1258, 2195, 100568, 2243, 67028, 2220, 58155, 2127, 2189, 56338, 1317, 2222, 63653, 2037]


In [12]:
# Reweight each bootstrap by the number of labels in the task
n_labels_per_task = []
for weights in bootstrap_weights:
    weights_per_task = [
        weights[np.searchsorted(all_patients, task_2_test_pids[task])] 
        for task in valid_tasks
    ]
    n_labels_per_task.append([w.sum() for w in weights_per_task])
n_labels_per_task = np.array(n_labels_per_task).T
assert n_labels_per_task.shape == (14, 1000)

# Calculate model-level Brier CIs
df_brier_model_cis = []
for strat, strat_cols in strats.items():
    for strat_col in strat_cols:
        for model in valid_models:
            for quartile in range(4):
                ci_scores = [] # scores for this quartile, averaged across all tasks
                true_scores = [] # true scores for this quartile, averaged across all tasks
                for task in valid_tasks:
                    ci_scores.append(brier_scores_per_quartile[(model, task, strat, strat_col, quartile)]) # each is 1 x 1000
                    true_scores.append(true_brier_scores_per_quartile[(model, task, strat, strat_col, quartile)]) # each is 1 x 1
                # Raw scores
                scores = np.vstack(ci_scores) # 14 x 1000
                assert scores.shape == (14, 1000)
                # Macro-average across tasks
                scores = np.mean(scores, axis=0)
                # Micro-average across tasks
                # scores = np.average(scores, axis=0, weights=n_labels_per_task)
                df_brier_model_cis.append({
                    'model' : model,
                    'task' : 'all',
                    'strat' : strat,
                    'strat_col' : strat_col,
                    'quartile' : quartile,
                    'brier_true' : np.mean(true_scores),
                    'brier_ci_mean' : np.mean(scores),
                    'brier_ci_025' : np.percentile(scores, 2.5),
                    'brier_ci_500' : np.percentile(scores, 50),
                    'brier_ci_975' : np.percentile(scores, 97.5),
                })
df_brier_model_cis = pd.DataFrame(df_brier_model_cis).sort_values(['model', 'quartile'])
df_brier_model_cis

Unnamed: 0,model,task,strat,strat_col,quartile,brier_true,brier_ci_mean,brier_ci_025,brier_ci_500,brier_ci_975
64,clmbr,all,inter_event_times,std,0,0.068110,0.068768,0.061678,0.068846,0.075391
132,clmbr,all,n_gram_count,rr_1,0,0.064749,0.064870,0.060491,0.064838,0.069401
200,clmbr,all,timeline_lengths,n_events,0,0.070006,0.069885,0.065220,0.069992,0.074372
65,clmbr,all,inter_event_times,std,1,0.073972,0.073590,0.068396,0.073618,0.078932
133,clmbr,all,n_gram_count,rr_1,1,0.071764,0.071972,0.067029,0.071968,0.076659
...,...,...,...,...,...,...,...,...,...,...
126,mamba-tiny-8192--clmbr_train-tokens-total_nonP...,all,n_gram_count,rr_1,2,0.070774,0.071371,0.065286,0.071519,0.077010
194,mamba-tiny-8192--clmbr_train-tokens-total_nonP...,all,timeline_lengths,n_events,2,0.073094,0.072354,0.067307,0.072333,0.077191
59,mamba-tiny-8192--clmbr_train-tokens-total_nonP...,all,inter_event_times,std,3,0.073400,0.073705,0.069028,0.073706,0.078748
127,mamba-tiny-8192--clmbr_train-tokens-total_nonP...,all,n_gram_count,rr_1,3,0.075551,0.075227,0.069464,0.075280,0.081099


In [13]:
df_irregularity = clean_df_briers(df_brier_model_cis[df_brier_model_cis['strat_col'] == 'std'])
# Take the 'quartile' column and give each unique value its own column. Use the `formatted_brier` column as the values.
df_irregularity = df_irregularity.pivot(index=['model', 'model_base', 'ctx_length'], columns='quartile', values='formatted_brier').reset_index().sort_values(['model_base', 'ctx_length'])

# Limit to first / last for clarity
df_model_latex = df_irregularity[
    (
        (df_irregularity['model_base'].isin(['mamba', 'hyena']) & df_irregularity['ctx_length'].isin([1024, 16384]))
        | (df_irregularity['model_base'].isin(['gpt2', 'llama']) & df_irregularity['ctx_length'].isin([512, 4096]))
    )
]
df_model_latex

quartile,model,model_base,ctx_length,0,1,2,3
4,gpt2-base-512--clmbr_train-tokens-total_nonPAD...,gpt2,512,"0.0654 (0.0591, 0.0723)","0.0693 (0.0641, 0.0743)","0.0703 (0.0656, 0.0752)","0.0736 (0.0692, 0.0786)"
3,gpt2-base-4096--clmbr_train-tokens-total_nonPA...,gpt2,4096,"0.0653 (0.0594, 0.0722)","0.0699 (0.0646, 0.0747)","0.0701 (0.0656, 0.0749)","0.0759 (0.0715, 0.0812)"
5,hyena-large-1024--clmbr_train-tokens-total_non...,hyena,1024,"0.0666 (0.0604, 0.0736)","0.0702 (0.0649, 0.0748)","0.0692 (0.0646, 0.0739)","0.0751 (0.0704, 0.0804)"
6,hyena-large-16384--clmbr_train-tokens-total_no...,hyena,16384,"0.0698 (0.0641, 0.0767)","0.0755 (0.0702, 0.0803)","0.0788 (0.0740, 0.0837)","0.0853 (0.0804, 0.0907)"
12,llama-base-512--clmbr_train-tokens-total_nonPA...,llama,512,"0.0694 (0.0644, 0.0757)","0.0730 (0.0679, 0.0775)","0.0713 (0.0669, 0.0757)","0.0749 (0.0708, 0.0797)"
11,llama-base-4096--clmbr_train-tokens-total_nonP...,llama,4096,"0.0664 (0.0604, 0.0733)","0.0705 (0.0655, 0.0753)","0.0694 (0.0650, 0.0744)","0.0740 (0.0695, 0.0792)"
13,mamba-tiny-1024--clmbr_train-tokens-total_nonP...,mamba,1024,"0.0693 (0.0632, 0.0759)","0.0729 (0.0678, 0.0779)","0.0731 (0.0685, 0.0779)","0.0764 (0.0718, 0.0816)"
14,mamba-tiny-16384--clmbr_train-tokens-total_non...,mamba,16384,"0.0641 (0.0588, 0.0704)","0.0678 (0.0628, 0.0728)","0.0679 (0.0635, 0.0729)","0.0723 (0.0680, 0.0773)"


In [14]:
df_repetition = clean_df_briers(df_brier_model_cis[df_brier_model_cis['strat_col'] == 'rr_1'])
# Take the 'quartile' column and give each unique value its own column. Use the `formatted_brier` column as the values.
df_repetition = df_repetition.pivot(index=['model', 'model_base', 'ctx_length'], columns='quartile', values='formatted_brier').reset_index().sort_values(['model_base', 'ctx_length'])

# Limit to first / last for clarity
df_model_latex = df_repetition[
    (
        (df_repetition['model_base'].isin(['mamba', 'hyena']) & df_repetition['ctx_length'].isin([1024, 16384]))
        | (df_repetition['model_base'].isin(['gpt2', 'llama']) & df_repetition['ctx_length'].isin([512, 4096]))
    )
]
df_model_latex

quartile,model,model_base,ctx_length,0,1,2,3
4,gpt2-base-512--clmbr_train-tokens-total_nonPAD...,gpt2,512,"0.0619 (0.0579, 0.0660)","0.0691 (0.0646, 0.0741)","0.0710 (0.0655, 0.0775)","0.0765 (0.0706, 0.0818)"
3,gpt2-base-4096--clmbr_train-tokens-total_nonPA...,gpt2,4096,"0.0643 (0.0601, 0.0691)","0.0692 (0.0648, 0.0740)","0.0711 (0.0658, 0.0774)","0.0765 (0.0704, 0.0817)"
5,hyena-large-1024--clmbr_train-tokens-total_non...,hyena,1024,"0.0636 (0.0594, 0.0680)","0.0681 (0.0636, 0.0728)","0.0718 (0.0661, 0.0786)","0.0776 (0.0716, 0.0829)"
6,hyena-large-16384--clmbr_train-tokens-total_no...,hyena,16384,"0.0733 (0.0688, 0.0777)","0.0759 (0.0715, 0.0812)","0.0780 (0.0724, 0.0839)","0.0822 (0.0762, 0.0879)"
12,llama-base-512--clmbr_train-tokens-total_nonPA...,llama,512,"0.0640 (0.0600, 0.0684)","0.0710 (0.0665, 0.0758)","0.0743 (0.0688, 0.0807)","0.0792 (0.0730, 0.0846)"
11,llama-base-4096--clmbr_train-tokens-total_nonP...,llama,4096,"0.0627 (0.0587, 0.0670)","0.0687 (0.0644, 0.0735)","0.0721 (0.0666, 0.0785)","0.0770 (0.0711, 0.0824)"
13,mamba-tiny-1024--clmbr_train-tokens-total_nonP...,mamba,1024,"0.0644 (0.0601, 0.0688)","0.0737 (0.0691, 0.0785)","0.0744 (0.0687, 0.0811)","0.0790 (0.0732, 0.0847)"
14,mamba-tiny-16384--clmbr_train-tokens-total_non...,mamba,16384,"0.0605 (0.0567, 0.0648)","0.0670 (0.0625, 0.0720)","0.0700 (0.0650, 0.0757)","0.0746 (0.0687, 0.0801)"


In [15]:
df_repetition = clean_df_briers(df_brier_model_cis[df_brier_model_cis['strat_col'] == 'n_events'])
# Take the 'quartile' column and give each unique value its own column. Use the `formatted_brier` column as the values.
df_repetition = df_repetition.pivot(index=['model', 'model_base', 'ctx_length'], columns='quartile', values='formatted_brier').reset_index().sort_values(['model_base', 'ctx_length'])

# Limit to first / last for clarity
df_model_latex = df_repetition[
    (
        (df_repetition['model_base'].isin(['mamba', 'hyena']) & df_repetition['ctx_length'].isin([1024, 16384]))
        | (df_repetition['model_base'].isin(['gpt2', 'llama']) & df_repetition['ctx_length'].isin([512, 4096]))
    )
]
df_model_latex

quartile,model,model_base,ctx_length,0,1,2,3
4,gpt2-base-512--clmbr_train-tokens-total_nonPAD...,gpt2,512,"0.0645 (0.0603, 0.0691)","0.0705 (0.0658, 0.0751)","0.0745 (0.0691, 0.0785)","0.0692 (0.0629, 0.0772)"
3,gpt2-base-4096--clmbr_train-tokens-total_nonPA...,gpt2,4096,"0.0671 (0.0629, 0.0721)","0.0720 (0.0673, 0.0765)","0.0736 (0.0680, 0.0776)","0.0685 (0.0624, 0.0764)"
5,hyena-large-1024--clmbr_train-tokens-total_non...,hyena,1024,"0.0670 (0.0628, 0.0718)","0.0706 (0.0658, 0.0753)","0.0730 (0.0673, 0.0768)","0.0703 (0.0635, 0.0790)"
6,hyena-large-16384--clmbr_train-tokens-total_no...,hyena,16384,"0.0759 (0.0713, 0.0807)","0.0815 (0.0768, 0.0862)","0.0792 (0.0733, 0.0830)","0.0728 (0.0674, 0.0808)"
12,llama-base-512--clmbr_train-tokens-total_nonPA...,llama,512,"0.0666 (0.0623, 0.0715)","0.0728 (0.0682, 0.0771)","0.0758 (0.0700, 0.0799)","0.0733 (0.0670, 0.0819)"
11,llama-base-4096--clmbr_train-tokens-total_nonP...,llama,4096,"0.0651 (0.0608, 0.0698)","0.0715 (0.0669, 0.0759)","0.0738 (0.0682, 0.0779)","0.0700 (0.0639, 0.0785)"
13,mamba-tiny-1024--clmbr_train-tokens-total_nonP...,mamba,1024,"0.0676 (0.0633, 0.0724)","0.0742 (0.0696, 0.0790)","0.0772 (0.0719, 0.0811)","0.0726 (0.0662, 0.0811)"
14,mamba-tiny-16384--clmbr_train-tokens-total_non...,mamba,16384,"0.0633 (0.0592, 0.0680)","0.0689 (0.0643, 0.0735)","0.0718 (0.0665, 0.0758)","0.0681 (0.0628, 0.0756)"


### LaTeX Results

In [16]:
# Model-level delta CIs for LaTeX
latex: str = clean_df_briers(df_model_latex)
print(latex.to_latex(index=False, escape=False))

\begin{tabular}{lllrllll}
\toprule
model & model_name & model_base & ctx_length & 0 & 1 & 2 & 3 \\
\midrule
gpt2-base-512--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last & gpt2-base-512 & gpt2 & 512 & 0.0645 (0.0603, 0.0691) & 0.0705 (0.0658, 0.0751) & 0.0745 (0.0691, 0.0785) & 0.0692 (0.0629, 0.0772) \\
gpt2-base-4096--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last & gpt2-base-4096 & gpt2 & 4096 & 0.0671 (0.0629, 0.0721) & 0.0720 (0.0673, 0.0765) & 0.0736 (0.0680, 0.0776) & 0.0685 (0.0624, 0.0764) \\
hyena-large-1024--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last & hyena-large-1024 & hyena & 1024 & 0.0670 (0.0628, 0.0718) & 0.0706 (0.0658, 0.0753) & 0.0730 (0.0673, 0.0768) & 0.0703 (0.0635, 0.0790) \\
hyena-large-16384--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last & hyena-large-16384 & hyena & 16384 & 0.0759 (0.0713, 0.0807) & 0.0815 (0.0768, 0

## Win Rates

In [17]:
# Calculate model-level win rates between longest and shortest context lengths for each (model, quartile)
df_win_rate_cis = []
for strat, strat_cols in strats.items():
    for strat_col in strat_cols:
        for model in valid_models:

            # Only keep shortest and longest context lengths for ease of comparison
            if model_to_base(model) in ['mamba', 'hyena' ]:
                if model_to_ctx(model) not in [1024, 16384]:
                    continue
            if model_to_base(model) in ['gpt2', 'llama' ]:
                if model_to_ctx(model) not in [512, 4096]:
                    continue
            baseline_model_name: str = model.replace(str(model_to_ctx(model)), str(1024 if model_to_base(model) in ['mamba', 'hyena' ] else 512))

            for quartile in range(4):
                win_rates = [] # scores for this quartile, averaged across all tasks
                true_win_rate = [] # true scores for this quartile, averaged across all tasks
                for task in valid_tasks:
                    ## NOTE: Comparison is '>' b/c lower Brier is better
                    # True win rate
                    true_this_model_score = np.array(true_brier_scores_per_quartile[(model, task, strat, strat_col, quartile)]) # each is 1 x 1
                    true_baseline_model_score = np.array(true_brier_scores_per_quartile[(baseline_model_name, task, strat, strat_col, quartile)]) # each is 1 x 1
                    assert true_baseline_model_score.shape == true_this_model_score.shape
                    true_win_rate.append((true_baseline_model_score > true_this_model_score).astype(bool))
                    # Bootstrap win rates
                    this_model_scores = np.array(brier_scores_per_quartile[(model, task, strat, strat_col, quartile)]) # each is 1 x 1000
                    baseline_model_scores = np.array(brier_scores_per_quartile[(
                        baseline_model_name, 
                        task, strat, strat_col, quartile
                    )]) # each is 1 x 1000
                    assert baseline_model_scores.shape == this_model_scores.shape
                    win_rates.append((baseline_model_scores > this_model_scores).astype(bool))
                # Win rates
                win_rates = np.vstack(win_rates) # 14 x 1000
                assert win_rates.shape == (14, 1000), f"win_rates.shape={win_rates.shape}"
                win_rates = np.mean(win_rates, axis=0)
                assert len(true_win_rate) == 14
                df_win_rate_cis.append({
                    'model' : model,
                    'task' : 'all',
                    'strat' : strat,
                    'strat_col' : strat_col,
                    'quartile' : quartile,
                    'win_rate_true' : np.mean(true_win_rate),
                    'win_rate_ci_mean' : np.mean(win_rates),
                    'win_rate_ci_025' : np.percentile(win_rates, 2.5),
                    'win_rate_ci_500' : np.percentile(win_rates, 50),
                    'win_rate_ci_975' : np.percentile(win_rates, 97.5),
                })
df_win_rate_cis = pd.DataFrame(df_win_rate_cis).sort_values(['model', 'quartile'])
df_win_rate_cis

Unnamed: 0,model,task,strat,strat_col,quartile,win_rate_true,win_rate_ci_mean,win_rate_ci_025,win_rate_ci_500,win_rate_ci_975
32,clmbr,all,inter_event_times,std,0,0.000000,0.000000,0.000000,0.000000,0.000000
68,clmbr,all,n_gram_count,rr_1,0,0.000000,0.000000,0.000000,0.000000,0.000000
104,clmbr,all,timeline_lengths,n_events,0,0.000000,0.000000,0.000000,0.000000,0.000000
33,clmbr,all,inter_event_times,std,1,0.000000,0.000000,0.000000,0.000000,0.000000
69,clmbr,all,n_gram_count,rr_1,1,0.000000,0.000000,0.000000,0.000000,0.000000
...,...,...,...,...,...,...,...,...,...,...
66,mamba-tiny-16384--clmbr_train-tokens-total_non...,all,n_gram_count,rr_1,2,0.785714,0.756143,0.571429,0.785714,0.858929
102,mamba-tiny-16384--clmbr_train-tokens-total_non...,all,timeline_lengths,n_events,2,0.857143,0.783500,0.642857,0.785714,0.928571
31,mamba-tiny-16384--clmbr_train-tokens-total_non...,all,inter_event_times,std,3,0.714286,0.713500,0.571429,0.714286,0.857143
67,mamba-tiny-16384--clmbr_train-tokens-total_non...,all,n_gram_count,rr_1,3,0.714286,0.704357,0.500000,0.714286,0.857143


In [18]:
df_repetition = clean_df_briers(df_win_rate_cis[df_win_rate_cis['strat_col'] == 'rr_1'])
df_repetition = df_repetition.pivot(index=['model', 'model_base', 'ctx_length'], columns='quartile', values='formatted_win_rate').reset_index().sort_values(['model_base', 'ctx_length'])
df_repetition

quartile,model,model_base,ctx_length,0,1,2,3
0,clmbr,clmbr,0,"0.0000 (0.0000, 0.0000)","0.0000 (0.0000, 0.0000)","0.0000 (0.0000, 0.0000)","0.0000 (0.0000, 0.0000)"
2,gpt2-base-512--clmbr_train-tokens-total_nonPAD...,gpt2,512,"0.0000 (0.0000, 0.0000)","0.0000 (0.0000, 0.0000)","0.0000 (0.0000, 0.0000)","0.0000 (0.0000, 0.0000)"
1,gpt2-base-4096--clmbr_train-tokens-total_nonPA...,gpt2,4096,"0.4286 (0.1429, 0.5018)","0.5000 (0.3571, 0.6429)","0.5714 (0.2143, 0.6429)","0.5000 (0.2857, 0.6429)"
3,hyena-large-1024--clmbr_train-tokens-total_non...,hyena,1024,"0.0000 (0.0000, 0.0000)","0.0000 (0.0000, 0.0000)","0.0000 (0.0000, 0.0000)","0.0000 (0.0000, 0.0000)"
4,hyena-large-16384--clmbr_train-tokens-total_no...,hyena,16384,"0.0714 (0.0714, 0.2143)","0.1429 (0.0000, 0.2857)","0.0714 (0.0000, 0.2143)","0.1429 (0.0696, 0.2857)"
6,llama-base-512--clmbr_train-tokens-total_nonPA...,llama,512,"0.0000 (0.0000, 0.0000)","0.0000 (0.0000, 0.0000)","0.0000 (0.0000, 0.0000)","0.0000 (0.0000, 0.0000)"
5,llama-base-4096--clmbr_train-tokens-total_nonP...,llama,4096,"0.6429 (0.5000, 0.8571)","0.7857 (0.5000, 0.8571)","0.6429 (0.5000, 0.8571)","0.8571 (0.5714, 0.9286)"
7,mamba-tiny-1024--clmbr_train-tokens-total_nonP...,mamba,1024,"0.0000 (0.0000, 0.0000)","0.0000 (0.0000, 0.0000)","0.0000 (0.0000, 0.0000)","0.0000 (0.0000, 0.0000)"
8,mamba-tiny-16384--clmbr_train-tokens-total_non...,mamba,16384,"0.7857 (0.5714, 0.8571)","0.8571 (0.6429, 0.9286)","0.7857 (0.5714, 0.8589)","0.7143 (0.5000, 0.8571)"


In [19]:
df_irregularity = clean_df_briers(df_win_rate_cis[df_win_rate_cis['strat_col'] == 'std'])
df_irregularity = df_irregularity.pivot(index=['model', 'model_base', 'ctx_length'], columns='quartile', values='formatted_win_rate').reset_index().sort_values(['model_base', 'ctx_length'])
df_irregularity

quartile,model,model_base,ctx_length,0,1,2,3
0,clmbr,clmbr,0,"0.0000 (0.0000, 0.0000)","0.0000 (0.0000, 0.0000)","0.0000 (0.0000, 0.0000)","0.0000 (0.0000, 0.0000)"
2,gpt2-base-512--clmbr_train-tokens-total_nonPAD...,gpt2,512,"0.0000 (0.0000, 0.0000)","0.0000 (0.0000, 0.0000)","0.0000 (0.0000, 0.0000)","0.0000 (0.0000, 0.0000)"
1,gpt2-base-4096--clmbr_train-tokens-total_nonPA...,gpt2,4096,"0.5000 (0.2857, 0.7143)","0.4286 (0.2857, 0.6429)","0.5714 (0.3571, 0.7857)","0.4286 (0.2143, 0.5714)"
3,hyena-large-1024--clmbr_train-tokens-total_non...,hyena,1024,"0.0000 (0.0000, 0.0000)","0.0000 (0.0000, 0.0000)","0.0000 (0.0000, 0.0000)","0.0000 (0.0000, 0.0000)"
4,hyena-large-16384--clmbr_train-tokens-total_no...,hyena,16384,"0.1429 (0.0714, 0.3571)","0.1429 (0.0000, 0.3571)","0.0000 (0.0000, 0.2143)","0.1429 (0.0714, 0.2143)"
6,llama-base-512--clmbr_train-tokens-total_nonPA...,llama,512,"0.0000 (0.0000, 0.0000)","0.0000 (0.0000, 0.0000)","0.0000 (0.0000, 0.0000)","0.0000 (0.0000, 0.0000)"
5,llama-base-4096--clmbr_train-tokens-total_nonP...,llama,4096,"0.7143 (0.5000, 0.8571)","0.7143 (0.5000, 0.8571)","0.8571 (0.5714, 0.9286)","0.4286 (0.4286, 0.7857)"
7,mamba-tiny-1024--clmbr_train-tokens-total_nonP...,mamba,1024,"0.0000 (0.0000, 0.0000)","0.0000 (0.0000, 0.0000)","0.0000 (0.0000, 0.0000)","0.0000 (0.0000, 0.0000)"
8,mamba-tiny-16384--clmbr_train-tokens-total_non...,mamba,16384,"0.8571 (0.5696, 0.8571)","0.8571 (0.5714, 0.9286)","0.8571 (0.6429, 0.9286)","0.7143 (0.5714, 0.8571)"


In [20]:
df_length = clean_df_briers(df_win_rate_cis[df_win_rate_cis['strat_col'] == 'n_events'])
df_length = df_length.pivot(index=['model', 'model_base', 'ctx_length'], columns='quartile', values='formatted_win_rate').reset_index().sort_values(['model_base', 'ctx_length'])
df_length

quartile,model,model_base,ctx_length,0,1,2,3
0,clmbr,clmbr,0,"0.0000 (0.0000, 0.0000)","0.0000 (0.0000, 0.0000)","0.0000 (0.0000, 0.0000)","0.0000 (0.0000, 0.0000)"
2,gpt2-base-512--clmbr_train-tokens-total_nonPAD...,gpt2,512,"0.0000 (0.0000, 0.0000)","0.0000 (0.0000, 0.0000)","0.0000 (0.0000, 0.0000)","0.0000 (0.0000, 0.0000)"
1,gpt2-base-4096--clmbr_train-tokens-total_nonPA...,gpt2,4096,"0.3571 (0.2143, 0.5714)","0.3571 (0.2839, 0.6429)","0.3571 (0.3571, 0.7857)","0.5000 (0.2857, 0.7857)"
3,hyena-large-1024--clmbr_train-tokens-total_non...,hyena,1024,"0.0000 (0.0000, 0.0000)","0.0000 (0.0000, 0.0000)","0.0000 (0.0000, 0.0000)","0.0000 (0.0000, 0.0000)"
4,hyena-large-16384--clmbr_train-tokens-total_no...,hyena,16384,"0.1429 (0.0714, 0.2857)","0.0714 (0.0000, 0.2143)","0.0714 (0.0000, 0.2857)","0.2143 (0.1429, 0.4286)"
6,llama-base-512--clmbr_train-tokens-total_nonPA...,llama,512,"0.0000 (0.0000, 0.0000)","0.0000 (0.0000, 0.0000)","0.0000 (0.0000, 0.0000)","0.0000 (0.0000, 0.0000)"
5,llama-base-4096--clmbr_train-tokens-total_nonP...,llama,4096,"0.7143 (0.5714, 0.8571)","0.7143 (0.5000, 0.7857)","0.7143 (0.5000, 0.8571)","0.7857 (0.5000, 0.8571)"
7,mamba-tiny-1024--clmbr_train-tokens-total_nonP...,mamba,1024,"0.0000 (0.0000, 0.0000)","0.0000 (0.0000, 0.0000)","0.0000 (0.0000, 0.0000)","0.0000 (0.0000, 0.0000)"
8,mamba-tiny-16384--clmbr_train-tokens-total_non...,mamba,16384,"0.7143 (0.5000, 0.8571)","0.8571 (0.6429, 0.9286)","0.8571 (0.6429, 0.9286)","0.7857 (0.5000, 0.8571)"
