In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from tqdm import tqdm
import os
import sklearn.utils
from typing import List, Optional
from sklearn.metrics import roc_auc_score
import collections
import warnings
import pickle
import femr.datasets
import datetime
from hf_ehr.notebooks.ehr_specific_properties import utils
from hf_ehr.utils import load_tokenizer_from_path, load_model_from_path
from starr_eda import calc_n_gram_count, calc_inter_event_times
warnings.filterwarnings("ignore", category=FutureWarning, module="sklearn.utils.validation")

PATH_TO_DATABASE: str = '/share/pi/nigam/mwornow/ehrshot-benchmark/EHRSHOT_ASSETS/femr/extract'
PATH_TO_FEATURES_DIR: str = '/share/pi/nigam/mwornow/ehrshot-benchmark/EHRSHOT_ASSETS/features_ehrshot'
PATH_TO_RESULTS_DIR: str = '/share/pi/nigam/migufuen/ehrshot-benchmark/EHRSHOT_ASSETS/results_ehrshot'
PATH_TO_TOKENIZED_TIMELINES_DIR: str = '/share/pi/nigam/mwornow/ehrshot-benchmark/EHRSHOT_ASSETS/tokenized_timelines_ehrshot'
PATH_TO_LABELS_DIR: str = '/share/pi/nigam/mwornow/ehrshot-benchmark/EHRSHOT_ASSETS/benchmark_ehrshot'
PATH_TO_SPLIT_CSV: str = '/share/pi/nigam/mwornow/ehrshot-benchmark/EHRSHOT_ASSETS/splits_ehrshot/person_id_map.csv'
femr_db = femr.datasets.PatientDatabase(PATH_TO_DATABASE)
os.makedirs('../cache', exist_ok=True)

# Output directory
path_to_output_dir: str = '/share/pi/nigam/mwornow/ehrshot-benchmark/ehrshot/stratify/'
path_to_output_file: str = os.path.join(path_to_output_dir, f'metrics__{args.model}__per_patient__all_tasks.csv')
os.makedirs(path_to_output_dir, exist_ok=True)

IS_LOAD_FROM_CACHE: bool = False

# Get list of tasks
valid_tasks = os.listdir(PATH_TO_RESULTS_DIR)
valid_tasks.remove('chexpert')

LABELING_FUNCTION_2_PAPER_NAME = {
    # Guo et al. 2023
    "guo_los": "Long LOS",
    "guo_readmission": "30-day Readmission",
    "guo_icu": "ICU Admission",
    # New diagnosis
    "new_pancan": "Pancreatic Cancer",
    "new_celiac": "Celiac",
    "new_lupus": "Lupus",
    "new_acutemi": "Acute MI",
    "new_hypertension": "Hypertension",
    "new_hyperlipidemia": "Hyperlipidemia",
    # Instant lab values
    "lab_thrombocytopenia": "Thrombocytopenia",
    "lab_hyperkalemia": "Hyperkalemia",
    "lab_hypoglycemia": "Hypoglycemia",
    "lab_hyponatremia": "Hyponatremia",
    "lab_anemia": "Anemia",
    # Custom tasks
    "chexpert": "Chest X-ray Findings",
}

# Filter results to only valid models
valid_models = [ 
    'llama-base-512--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last',
    'llama-base-1024--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last',
    'llama-base-2048--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last',
    'llama-base-4096--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last',
    'gpt2-base-512--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last',
    'gpt2-base-1024--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last',
    'gpt2-base-2048--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last',
    'gpt2-base-4096--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last',
    'hyena-large-1024--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last',
    'hyena-large-4096--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last',
    'hyena-large-8192--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last',
    'hyena-large-16384--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last',
    'mamba-tiny-1024--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last',
    'mamba-tiny-4096--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last',
    'mamba-tiny-8192--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last',
    'mamba-tiny-16384--clmbr_train-tokens-total_nonPAD-ckpt_val=2000000000-persist_chunk:last_embed:last', 
    'clmbr',
]
print("Tasks:", valid_tasks)
print("# of valid models:", len(valid_models))

In [None]:

        # For each model...
        for model in os.listdir(path_to_results_dir):
            print(f'  Processing model: {model}')
            path_to_output_file: str = os.path.join(path_to_output_dir, f'df_buckets__model={model}__task={LABELING_FUNCTION}.csv')
            os.makedirs(os.path.dirname(path_to_output_dir), exist_ok=True)

            # Load EHRSHOT results
            ## Each model has its own results; Let's only examine results for k = -1
            head: str = 'lr_lbfgs'
            path_to_results_dir: str = os.path.join(PATH_TO_RESULTS_DIR, LABELING_FUNCTION, 'models')
            path_to_results_file: str = os.path.join(path_to_results_dir, 
                                                        model, 
                                                        head, 
                                                        f'subtask={LABELING_FUNCTION}', 
                                                        'k=-1', # always use k=-1
                                                        'preds.csv')
            assert os.path.exists(path_to_results_file), f'Path to results file does not exist: {path_to_results_file}'
            df_preds = pd.read_csv(path_to_results_file)

            ## Map each result to proper patient id / label time
            df_preds['pid_idx'] = train_pids_idx + val_pids_idx + test_pids_idx
            df_preds['pid'] = patient_ids[train_pids_idx + val_pids_idx + test_pids_idx]
            df_preds['label_time'] = label_times[train_pids_idx + val_pids_idx + test_pids_idx]
            df_preds['label_value'] = label_values[train_pids_idx + val_pids_idx + test_pids_idx]
            df_preds = df_preds[df_preds['pid_idx'].isin(test_pids_idx)]
            
            # Calculate quartiles based on each stratification metric
            df_results = []
            for strat, metrics in tqdm(strats.items(), desc=f'Stratifying {LABELING_FUNCTION}'):
                # Load metric for each patient
                strat_cols = metrics['strat_cols']
                df_metrics = pd.read_parquet(os.path.join(path_to_output_dir, f'df__{LABELING_FUNCTION}__{strat}__metrics.parquet'))
                
                # For every metric, calculate quartiles
                for strat_col in strat_cols:
                    if strat_col not in df_metrics.columns:
                        print(f'{strat_col} not in df_metrics columns. Skipping...')
                        continue
                    
                    # If stratifying by inter-event times, need to pivot table since 'time' and 'metric' are separate columns
                    if strat == 'inter_event_times':
                        df_metrics = df_metrics.pivot_table(index=['pid', 'pid_idx', 'label_time', 'sub_task'], columns='metric', values='time').reset_index()
                    assert df_metrics.shape[0] == df_preds.shape[0], f'Number of rows in df_metrics does not match number of rows in df_preds: {df_metrics.shape[0]} != {df_preds.shape[0]}'

                    # Merge the predictions with the stratification metric
                    df_merged = pd.merge(df_preds, df_metrics[['pid', 'label_time', strat_col]], on=['pid', 'label_time',])
                    assert df_merged.shape[0] == df_preds.shape[0], f'Number of rows in merged DataFrame does not match number of rows in df_preds: {df_merged.shape[0]} != {df_preds.shape[0]}'
                    
                    # Create quartiles
                    df_merged['metric_name'] = f'{strat_col}'
                    df_merged['quartile'] = pd.qcut(df_merged[strat_col].rank(method='min'), 4, labels=False)
                    df_merged = df_merged.rename(columns={strat_col: 'metric_value'})

                    # Save results
                    df_results.append(df_merged)

            df_results = pd.concat(df_results, ignore_index=True)
            df_results.to_csv(path_to_output_file, index=False)
            print('    Saved to:', path_to_output_file)


In [None]:
import pandas as pd
import numpy as np
from tqdm import tqdm
from sklearn.metrics import brier_score_loss

# Read in the patient splits to get all test patient IDs
print("Loading patient splits data...")
df_splits = pd.read_csv('/share/pi/nigam/mwornow/ehrshot-benchmark/EHRSHOT_ASSETS/splits_ehrshot/person_id_map.csv')

# Read in the metrics data
print("Loading metrics data...")
dtypes = {
    'sub_task': 'category',
    'model': 'category',
    'y': 'int8',
    'pred_proba': 'float32',
    'pid': 'int64',
    'pid_idx': 'int32',
    'metric_value': 'float32',
    'brier_score': 'float32',
    'metric_name': 'category',
    'quartile': 'float32'
}
df = pd.read_csv('/share/pi/nigam/mwornow/ehrshot-benchmark/ehrshot/stratify/metrics__llama-base-512--clmbr_train-tokens-total_nonPAD-ckpt_val=1000000000-persist_chunk:last_embed:last__per_patient__all_tasks.csv',
                 dtype=dtypes)

# Filter the dataframe to include only test patients
all_patients = df['pid'].unique()

print("Metrics data loaded and filtered.")

# Patient-level resampling across all patient IDs in the test split
print("Generating bootstrap weights...")
bootstrap_weights = []
np.random.seed(342342)
for i in range(1000):
    patient_sample = np.random.choice(len(all_patients), len(all_patients), replace=True)
    weights = np.zeros_like(all_patients, dtype=np.float32)
    np.add.at(weights, patient_sample, 1)
    bootstrap_weights.append(weights)
print("Bootstrap weights generated.")

# Calculate Brier scores for each (model, task, metric_name) over bootstrapped resamples, stratified by quartile
print("Calculating Brier scores...")
brier_scores = {}
grouped = df.groupby(['model', 'sub_task', 'metric_name'])

for (model, task, metric_name), df_preds in tqdm(grouped, total=len(grouped)):
    # Labels / predictions / quartiles
    y = df_preds['y'].values
    pred_proba = df_preds['pred_proba'].values
    quartiles = df_preds['quartile'].values
    patient_ids = df_preds['pid'].values

    # Map patient IDs to indices in the all_patients array
    patient_id_indices = np.searchsorted(all_patients, patient_ids)

    # Filter out any indices that are out of bounds (safety check)
    valid_mask = (patient_id_indices >= 0) & (patient_id_indices < len(all_patients))
    patient_id_indices = patient_id_indices[valid_mask]
    y = y[valid_mask]
    pred_proba = pred_proba[valid_mask]
    quartiles = quartiles[valid_mask]

    # Initialize results storage
    brier_scores[(model, task, metric_name)] = {quartile: [] for quartile in np.unique(quartiles)}

    # Calculate Brier score for each bootstrap sample and each quartile
    for weights in bootstrap_weights:
        for quartile in np.unique(quartiles):
            # Filter data for the current quartile
            mask = quartiles == quartile
            y_quartile = y[mask]
            pred_proba_quartile = pred_proba[mask]
            weights_quartile = weights[patient_id_indices[mask]]

            # Calculate Brier score using the weighted samples
            if len(y_quartile) > 0:
                brier = brier_score_loss(y_quartile, pred_proba_quartile, sample_weight=weights_quartile)
                brier_scores[(model, task, metric_name)][quartile].append(brier)

# Display Brier score results
print("Brier score calculation completed. Results:")
for (model, task, metric_name), quartile_results in brier_scores.items():
    print(f'Results for model {model}, task {task}, metric {metric_name}:')
    for quartile, scores in quartile_results.items():
        mean_brier = np.mean(scores)
        ci_lower = np.percentile(scores, 2.5)
        ci_upper = np.percentile(scores, 97.5)
        print(f'  Quartile {quartile}: Mean Brier = {mean_brier:.4f}, 95% CI = ({ci_lower:.4f}, {ci_upper:.4f})')
