In [1]:
import numpy as np
import pandas as pd
import warnings
import scipy
import itertools
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F

from sklearn.linear_model import LogisticRegression
from IPython.display import display
from prediction_utils.pytorch_utils.metrics import (
    StandardEvaluator,
    FairOVAEvaluator
)
from prediction_utils.pytorch_utils.datasets import ArrayLoaderGenerator
from prediction_utils.pytorch_utils.models import TorchModel, FixedWidthModel
from prediction_utils.util import patient_split
from prediction_utils.pytorch_utils.group_fairness import group_regularized_model
from prediction_utils.pytorch_utils.lagrangian import group_lagrangian_model
from prediction_utils.pytorch_utils.robustness import group_robust_model

In [2]:
def simulation(
    n=10000, 
    num_features=50,
    survival_group_means = [4, 15],
    censoring_group_means = [5, 5],
    binary_event_horizon=10,
    max_weight=100
):
    eps = 1e-6
    # Get randomness
    generator = np.random.default_rng()
    
    # Convert inputs to numpy
    survival_group_means = np.array(survival_group_means)
    censoring_group_means = np.array(censoring_group_means)
    num_groups = len(survival_group_means)
    num_groups = len(survival_group_means)
    assert num_groups == len(censoring_group_means)

    group = generator.integers(num_groups, size=n)
    
    features = generator.standard_normal(size=(n, num_features))
    weights = generator.standard_normal(size=(num_features, num_groups))
    
    survival_mean_shifts = survival_group_means[group]
    survival_means = np.maximum(
        np.dot(features, weights)[np.arange(n), group].reshape(-1) + survival_mean_shifts, eps
    )
    
    survival_times = generator.exponential(survival_means)
    
    censoring_mean_shifts = censoring_group_means[group]
    censoring_means = np.maximum(
        np.dot(features, weights)[np.arange(n), group].reshape(-1) + censoring_mean_shifts, eps
    )
    
    censoring_times = generator.exponential(censoring_means)
    
    # The observed times
    observed_times = np.minimum(survival_times, censoring_times)
    
    # The observed times when also censored by the binary event horizon
    binary_observed_times = np.minimum(observed_times, binary_event_horizon)
    
    # Indicators for whether the observed time is an event or censoring
    event_indicator = survival_times <= censoring_times
    censored_indicator = ~event_indicator
    
    survival_function = np.exp(-observed_times / survival_means)
    
    # The survival function for the censoring process
    censoring_survival_function = np.exp(-observed_times / (censoring_means))
    
    # The survival function for the censoring process when also censored by the binary event horizon
    binary_censoring_survival_function = np.exp(-binary_observed_times / (censoring_means))
    
    # The true value of the binary outcome Y
    true_labels = (survival_times < binary_event_horizon)
    
    labels = true_labels
    
    # An indicator for whether the binary outcome is observed
    binary_event_indicator = ((observed_times < binary_event_horizon) & (event_indicator)) | (binary_event_horizon <= observed_times)
    
    observed_labels = ((observed_times < binary_event_horizon) & (event_indicator))
    
    df = pd.DataFrame(
        {
            'row_id': np.arange(n),
            'survival_times': survival_times,
            'censoring_times': censoring_times,
            'observed_times': observed_times,
            'binary_observed_times': binary_observed_times,
            'group': group,
            'event_indicator': event_indicator,
            'censored_indicator': censored_indicator,
            'binary_event_indicator': binary_event_indicator,
            'true_labels': true_labels,
            'labels': labels,
            'observed_labels': observed_labels,
            'survival_function': survival_function,
            'censoring_survival_function': censoring_survival_function,
            'censoring_weight': np.minimum(1 / censoring_survival_function, max_weight),
            'binary_censoring_weight': np.minimum(1 / binary_censoring_survival_function, max_weight)
        }
    )
    
    return df, features
    
df, features = simulation(
    100000, 
    max_weight=1e6,
    survival_group_means = [5, 12],
    censoring_group_means = [20, 15]
)

fold_id = '1'
df = patient_split(df, patient_col='row_id'
                  )

display(df.groupby('group').agg(
    mean_survival = ('survival_times', 'mean'), 
    prob_y = ('true_labels', 'mean'),
    prob_censored = ('binary_event_indicator', lambda x: 1-x.mean())
))

Unnamed: 0_level_0,mean_survival,prob_y,prob_censored
group,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
0,5.708497,0.815856,0.138643
1,12.188879,0.609008,0.319714


In [3]:
def concat_merge(x):
    for i, (key, value) in enumerate(x.items()):
        if i == 0:
            result = value
        else:
            result = result.merge(value)
    return result

def temp_evaluate(df, strata_vars=None, pred_prob_var='pred_probs'):
    evaluator = StandardEvaluator(thresholds=[0.1, 0.5, 0.9])
    result_dict = {
        'unbiased': evaluator.evaluate(
            df, 
            strata_vars=strata_vars, 
            label_var = 'labels', 
            pred_prob_var=pred_prob_var
        ),
        'adjusted': evaluator.evaluate(
            df.query('binary_event_indicator == 1'), 
            strata_vars=strata_vars, 
            result_name='performance_ipcw', 
            weight_var='binary_censoring_weight', 
            label_var = 'observed_labels', 
            pred_prob_var=pred_prob_var
        ),
        'biased_obs': evaluator.evaluate(
            df.query('binary_event_indicator == 1'), 
            strata_vars=strata_vars, 
            result_name='performance_biased_obs',
            label_var = 'observed_labels', 
            pred_prob_var=pred_prob_var
        ),
        'biased_neg': evaluator.evaluate(
            df, 
            label_var = 'observed_labels', 
            strata_vars=strata_vars, 
            result_name='performance_biased_neg', 
            pred_prob_var=pred_prob_var
        )
    }
    return concat_merge(result_dict)


def temp_evaluate_fair(df, group_var_name='group', pred_prob_var='pred_probs'):
    evaluator = FairOVAEvaluator(thresholds=[0.1, 0.5, 0.9])
    
    result_dict = {
        'unbiased': evaluator.evaluate(
            df, 
            group_var_name=group_var_name, 
            label_var = 'labels', 
            pred_prob_var=pred_prob_var
        ),
        'adjusted': evaluator.evaluate(
            df.query('binary_event_indicator == 1'), 
            group_var_name=group_var_name, 
            result_name='performance_ipcw', 
            weight_var='binary_censoring_weight', 
            label_var = 'observed_labels', 
            pred_prob_var=pred_prob_var
        ),
        'biased_obs': evaluator.evaluate(
            df.query('binary_event_indicator == 1'), 
            group_var_name=group_var_name, 
            result_name='performance_biased_obs', 
            label_var = 'observed_labels', 
            pred_prob_var=pred_prob_var
        ),
        'biased_neg': evaluator.evaluate(
            df, 
            label_var = 'observed_labels', 
            group_var_name=group_var_name, 
            result_name='performance_biased_neg', 
            pred_prob_var=pred_prob_var
        )
    }

    return concat_merge(result_dict)

In [4]:
labels_df_dict = {
    'train': df.query('(fold_id != "test") & (fold_id != @fold_id)'),
    'val': df.query('fold_id == @fold_id'),
    'test': df.query('fold_id == "test"')
}
features_dict = {
    key: np.concatenate(
        (
            features[value.row_id, ], 
            value.group.values.reshape(-1, 1)
        ), 
        axis=1
    ) 
    for key, value in labels_df_dict.items()
}

config_dict = {
    'lr': 1e-5, 
    'batch_size': 256, 
    'num_epochs': 20, 
    'print_every': 10,
    'verbose': False,
    'sparse': False,
    'input_dim': features.shape[1] + 1,
    'include_group_in_dataset': True,
    'logging_evaluate_by_group': True,
    'group_var_name': 'group',
    'compute_group_min_max':True,
    'selection_metric': 'auc_min',
    'early_stopping': True,
    'early_stopping_patience': 10,
    'num_groups': 2
}

In [6]:
%%time

# Unweighted Model
loader_generator = ArrayLoaderGenerator(
    features=np.concatenate((features.astype(np.float32), df.group.values.reshape(-1, 1).astype(np.float32)), axis=1), 
    cohort=df.query('binary_event_indicator == 1'), 
    label_col='observed_labels',
    fold_id=fold_id,
    eval_key='val', 
    row_id_col='row_id', 
    num_workers=0,
    **config_dict
)

loaders = loader_generator.init_loaders()
model = FixedWidthModel(**config_dict)
result_dict = model.train(
    loaders,
#     logging_threshold_metrics=['specificity', 'recall'],
#     logging_thresholds=[0.1, 0.5],
    num_epochs=100
)


[0 1]
cuda
Epoch 0/99
----------
Phase: train:
                 metric  group  performance
0                   auc    0.0     0.480775
1                   auc    1.0     0.511709
2                 auprc    0.0     0.848429
3                 auprc    1.0     0.693107
4                 brier    0.0     0.146412
5                 brier    1.0     0.235342
6              loss_bce    0.0     0.474940
7              loss_bce    1.0     0.695305
8       specificity_0.5    0.0     0.019627
9       specificity_0.5    1.0     0.044608
10        precision_0.5    0.0     0.852800
11        precision_0.5    1.0     0.688599
12           recall_0.5    0.0     0.973419
13           recall_0.5    1.0     0.961825
0                  loss    NaN     0.572400
0               auc_min    NaN     0.480775
1             auprc_min    NaN     0.693107
2             brier_min    NaN     0.146412
3          loss_bce_min    NaN     0.474940
4     precision_0.5_min    NaN     0.688599
5        recall_0.5_min    Na

Phase: val:
                 metric  group  performance
0                   auc    0.0     0.702152
1                   auc    1.0     0.638998
2                 auprc    0.0     0.931025
3                 auprc    1.0     0.781701
4                 brier    0.0     0.111518
5                 brier    1.0     0.214469
6              loss_bce    0.0     0.375776
7              loss_bce    1.0     0.621137
8       specificity_0.5    0.0     0.056311
9       specificity_0.5    1.0     0.080530
10        precision_0.5    0.0     0.874028
11        precision_0.5    1.0     0.683287
12           recall_0.5    0.0     0.992348
13           recall_0.5    1.0     0.964321
0                  loss    NaN     0.482234
0               auc_min    NaN     0.638998
1             auprc_min    NaN     0.781701
2             brier_min    NaN     0.111518
3          loss_bce_min    NaN     0.375776
4     precision_0.5_min    NaN     0.683287
5        recall_0.5_min    NaN     0.964321
6   specificity_0.5_

Epoch 70/99
----------
Phase: train:
                 metric  group  performance
0                   auc    0.0     0.715886
1                   auc    1.0     0.667754
2                 auprc    0.0     0.933804
3                 auprc    1.0     0.806339
4                 brier    0.0     0.117482
5                 brier    1.0     0.203835
6              loss_bce    0.0     0.384087
7              loss_bce    1.0     0.594834
8       specificity_0.5    0.0     0.056992
9       specificity_0.5    1.0     0.170504
10        precision_0.5    0.0     0.858833
11        precision_0.5    1.0     0.703827
12           recall_0.5    0.0     0.987145
13           recall_0.5    1.0     0.930392
0                  loss    NaN     0.477746
0               auc_min    NaN     0.667754
1             auprc_min    NaN     0.806339
2             brier_min    NaN     0.117482
3          loss_bce_min    NaN     0.384087
4     precision_0.5_min    NaN     0.703827
5        recall_0.5_min    NaN     0.93

In [6]:
# Evaluation of unweighted model
loader_generator = ArrayLoaderGenerator(
    features=np.concatenate((features.astype(np.float32), df.group.values.reshape(-1, 1).astype(np.float32)), axis=1), 
    cohort=df, 
    fold_id=fold_id,
    eval_key='test', 
    row_id_col='row_id', 
    label_col='observed_labels',
    **config_dict
)

loaders_predict = loader_generator.init_loaders_predict()

predict_dict = model.predict(loaders_predict, phases=['test'])

output_df_eval, result_df_eval = (
    predict_dict["outputs"],
    predict_dict["performance"],
)

output_df_eval = output_df_eval.drop(columns=['labels']).merge(
    labels_df_dict['test'],
).assign(labels = lambda x: x.labels.astype(np.long))

display(temp_evaluate(df=output_df_eval, strata_vars=None))

display(temp_evaluate(df=output_df_eval, strata_vars=['group']))

display(temp_evaluate_fair(df=output_df_eval))

[0 1]
Evaluating on phase: test


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  ).assign(labels = lambda x: x.labels.astype(np.long))


Unnamed: 0,metric,performance,performance_ipcw,performance_biased_obs,performance_biased_neg
0,auc,0.622285,0.621265,0.634921,0.603093
1,auprc,0.798715,0.801957,0.8565,0.683574
2,loss_bce,0.590732,0.587509,0.508092,0.741713
3,ace_rmse_logistic_log,0.061973,0.060066,0.041414,0.169586
4,ace_abs_logistic_log,0.051935,0.048313,0.035597,0.166235
5,specificity_0.1,0.0,0.0,0.0,0.0
6,specificity_0.5,0.03331,0.033888,0.036927,0.028035
7,specificity_0.9,0.95021,0.952326,0.954735,0.936421
8,precision_0.1,0.7148,0.717697,0.781494,0.6005
9,precision_0.5,0.718731,0.72158,0.78543,0.603856


Unnamed: 0,metric,group,performance,performance_ipcw,performance_biased_obs,performance_biased_neg
0,auc,0,0.645581,0.648762,0.659302,0.623265
1,auc,1,0.574092,0.569253,0.578054,0.545134
2,auprc,0,0.879789,0.88333,0.913797,0.809174
3,auprc,1,0.68666,0.687934,0.753744,0.510867
4,loss_bce,0,0.465846,0.461858,0.408307,0.572183
5,loss_bce,1,0.713145,0.712078,0.632043,0.907886
6,ace_rmse_logistic_log,0,0.040859,0.041428,0.074163,0.059499
7,ace_rmse_logistic_log,1,0.139105,0.138604,0.084505,0.288727
8,ace_abs_logistic_log,0,0.035496,0.03608,0.065228,0.053629
9,ace_abs_logistic_log,1,0.12565,0.122753,0.069958,0.276613


Unnamed: 0,metric,phase,group,performance,performance_ipcw,performance_biased_obs,performance_biased_neg
0,auc_ova,test,1,-0.048192,-0.052012,-0.056867,-0.057959
1,auc_ova,test,0,0.023296,0.027497,0.024382,0.020172
2,auprc_ova,test,1,-0.112055,-0.114023,-0.102756,-0.172706
3,auprc_ova,test,0,0.081074,0.081373,0.057297,0.125601
4,loss_bce_ova,test,1,0.122413,0.124568,0.123951,0.166173
5,loss_bce_ova,test,0,-0.124886,-0.125651,-0.099784,-0.16953
6,ace_rmse_logistic_log_ova,test,1,0.077132,0.078538,0.043091,0.119141
7,ace_rmse_logistic_log_ova,test,0,-0.021114,-0.018638,0.032749,-0.110087
8,ace_abs_logistic_log_ova,test,1,0.073715,0.07444,0.034361,0.110378
9,ace_abs_logistic_log_ova,test,0,-0.016438,-0.012233,0.02963,-0.112606


In [6]:
# Unweighted Model - training for fairness
loader_generator = ArrayLoaderGenerator(
    features=np.concatenate((features.astype(np.float32), df.group.values.reshape(-1, 1).astype(np.float32)), axis=1), 
    cohort=df.query('binary_event_indicator == 1'), 
    label_col='observed_labels',
    fold_id=fold_id,
    eval_key='val', 
    row_id_col='row_id', 
    num_workers=0,
    **config_dict
)

loaders = loader_generator.init_loaders()


lagrangian_config={
    'lr_lambda': 1e-1,
    'constraint_slack': 0.01,
    'multiplier_bound': 1,
    'additive_update': False,
    'use_exact_constraints': True,
    'thresholds': [0.075, 0.2],
    'constraint_metrics': ['tpr', 'fpr']
}

# thresholds = [0.075, 0.2]

model_class = group_lagrangian_model(
    'multi'
)

model = model_class(
    **config_dict, 
    **lagrangian_config,
    logging_thresholds=lagrangian_config.get('thresholds'), 
)

result_dict = model.train(
    loaders,
    lr=1e-5,
    num_epochs=20,
    early_stopping=True,
    early_stopping_patience=10,
    print_debug=False,
)

[0 1]
cuda
Epoch 0/19
----------
Phase: train:
               metric  group  performance
0                 auc    0.0     0.528574
1                 auc    1.0     0.573145
2               auprc    0.0     0.858299
3               auprc    1.0     0.728893
4               brier    0.0     0.255799
5               brier    1.0     0.256318
6            loss_bce    0.0     0.711067
7            loss_bce    1.0     0.713486
8   specificity_0.075    0.0     0.000449
9   specificity_0.075    1.0     0.001088
10    specificity_0.2    0.0     0.033693
11    specificity_0.2    1.0     0.055223
12    precision_0.075    0.0     0.844199
13    precision_0.075    1.0     0.675015
14      precision_0.2    0.0     0.844457
15      precision_0.2    1.0     0.680350
16       recall_0.075    0.0     0.999503
17       recall_0.075    1.0     0.998821
18         recall_0.2    0.0     0.968164
19         recall_0.2    1.0     0.968046
0                loss    NaN     0.703229
1          supervised    NaN 

KeyboardInterrupt: 

In [None]:
model.get_surrogate_fn()(2)

In [8]:
# Evaluation of unweighted fair model
loader_generator = ArrayLoaderGenerator(
    features=np.concatenate((features.astype(np.float32), df.group.values.reshape(-1, 1).astype(np.float32)), axis=1), 
    cohort=df, 
    fold_id=fold_id,
    eval_key='test', 
    row_id_col='row_id', 
    label_col='true_labels',
    include_group_in_dataset=True,
    group_var_name='group'
)

loaders_predict = loader_generator.init_loaders_predict()

predict_dict = model.predict(loaders_predict, phases=['test'])

output_df_eval, result_df_eval = (
    predict_dict["outputs"],
    predict_dict["performance"],
)

output_df_eval = output_df_eval.drop(columns=['labels']).merge(
    labels_df_dict['test'], 
).assign(labels = lambda x: x.labels.astype(np.long))

display(temp_evaluate(df=output_df_eval, strata_vars=None))

display(temp_evaluate(df=output_df_eval, strata_vars=['group']))

display(temp_evaluate_fair(df=output_df_eval))


[0 1]
Evaluating on phase: test


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  ).assign(labels = lambda x: x.labels.astype(np.long))


Unnamed: 0,metric,performance,performance_ipcw,performance_biased_obs,performance_biased_neg
0,auc,0.568866,0.563097,0.574014,0.551908
1,auprc,0.756919,0.757462,0.820254,0.638077
2,loss_bce,0.614556,0.614901,0.562608,0.711465
3,ace_rmse_logistic_log,0.100523,0.105608,0.142514,0.123773
4,ace_abs_logistic_log,0.081527,0.084911,0.115682,0.103484
5,specificity_0.1,0.0,0.0,0.0,0.0
6,specificity_0.5,0.159537,0.155032,0.162597,0.142678
7,specificity_0.9,0.980715,0.982175,0.983919,0.979474
8,precision_0.1,0.7148,0.717697,0.781494,0.6005
9,precision_0.5,0.727676,0.729151,0.792717,0.610884


Unnamed: 0,metric,group,performance,performance_ipcw,performance_biased_obs,performance_biased_neg
0,auc,0,0.581981,0.586873,0.592107,0.578773
1,auc,1,0.566123,0.553334,0.56362,0.539201
2,auprc,0,0.85011,0.855302,0.889768,0.779831
3,auprc,1,0.667155,0.66282,0.734163,0.50146
4,loss_bce,0,0.538362,0.534493,0.503177,0.595323
5,loss_bce,1,0.689241,0.694617,0.636432,0.825306
6,ace_rmse_logistic_log,0,0.166284,0.166406,0.201516,0.106406
7,ace_rmse_logistic_log,1,0.106028,0.114823,0.099568,0.233933
8,ace_abs_logistic_log,0,0.140305,0.141002,0.176404,0.086189
9,ace_abs_logistic_log,1,0.088231,0.096233,0.081879,0.209766


Unnamed: 0,metric,phase,group,performance,performance_ipcw,performance_biased_obs,performance_biased_neg
0,auc_ova,test,1,-0.002742,-0.009763,-0.010394,-0.012707
1,auc_ova,test,0,0.013116,0.023776,0.018093,0.026865
2,auprc_ova,test,1,-0.089764,-0.094643,-0.08609,-0.136617
3,auprc_ova,test,0,0.093191,0.09784,0.069514,0.141755
4,loss_bce_ova,test,1,0.074685,0.079715,0.073825,0.113842
5,loss_bce_ova,test,0,-0.076194,-0.080408,-0.059431,-0.116142
6,ace_rmse_logistic_log_ova,test,1,0.005505,0.009214,-0.042946,0.11016
7,ace_rmse_logistic_log_ova,test,0,0.065761,0.060798,0.059001,-0.017367
8,ace_abs_logistic_log_ova,test,1,0.006705,0.011321,-0.033804,0.106281
9,ace_abs_logistic_log_ova,test,0,0.058778,0.056091,0.060721,-0.017295


In [9]:
# Weighted Model
loader_generator_weighted = ArrayLoaderGenerator(
    features=np.concatenate((features.astype(np.float32), df.group.values.reshape(-1, 1).astype(np.float32)), axis=1), 
    cohort=df.query('binary_event_indicator == 1'), 
    label_col='observed_labels',
    fold_id=fold_id,
    eval_key='val', 
    row_id_col='row_id', 
    weight_var_name='binary_censoring_weight',
    num_workers=0,
    **config_dict
)

loaders_weighted = loader_generator_weighted.init_loaders()
model_weighted = FixedWidthModel(weighted_loss=True, weighted_evaluation=True, **config_dict)
result_dict_weighted = model_weighted.train(
    loaders_weighted,
    num_epochs=20,
)

[0 1]
cuda
Epoch 0/19
----------


KeyError: 'weights'

In [None]:
# Evaluation of weighted model
loader_generator_predict = ArrayLoaderGenerator(
    features=np.concatenate((features.astype(np.float32), df.group.values.reshape(-1, 1).astype(np.float32)), axis=1), 
    cohort=df, 
    fold_id=fold_id,
    eval_key='test', 
    row_id_col='row_id', 
    label_col='observed_labels',
    weight_var='binary_censoring_weight',
    **config_dict
)

loaders_predict = loader_generator_predict.init_loaders_predict()

predict_dict = model_weighted.predict(loaders_predict, phases=['test'])

output_df_eval, result_df_eval = (
    predict_dict["outputs"],
    predict_dict["performance"],
)

output_df_eval = output_df_eval.drop(columns=['labels']).merge(
    labels_df_dict['test'], 
).assign(labels = lambda x: x.labels.astype(np.long))

display(temp_evaluate(df=output_df_eval, strata_vars=None))

display(temp_evaluate(df=output_df_eval, strata_vars=['group']))

display(temp_evaluate_fair(df=output_df_eval))

In [None]:
# Weighted Model - fair
loader_generator_weighted = ArrayLoaderGenerator(
    features=np.concatenate((features.astype(np.float32), df.group.values.reshape(-1, 1).astype(np.float32)), axis=1), 
    cohort=df.query('binary_event_indicator == 1'), 
    fold_id=fold_id,
    eval_key='val', 
    row_id_col='row_id', 
    label_col='observed_labels',
    weight_var='binary_censoring_weight',
    num_workers=0,
    **config_dict
)

loaders_weighted = loader_generator_weighted.init_loaders()

model_class = group_lagrangian_model(
    'multi'
)

model_weighted = model_class(
    **config_dict, 
    **lagrangian_config,
    logging_thresholds=lagrangian_config.get('thresholds'), 
    weighted_loss=True,
    weighted_evaluation=True
)
result_dict_weighted = model_weighted.train(
    loaders_weighted,
    lr=1e-5,
    num_epochs=20,
    print_debug=False,
)

In [None]:
# Evaluation of weighted fair model
loader_generator_predict = ArrayLoaderGenerator(
    features=np.concatenate((features.astype(np.float32), df.group.values.reshape(-1, 1).astype(np.float32)), axis=1), 
    cohort=df, 
    fold_id=fold_id,
    eval_key='test', 
    row_id_col='row_id', 
    label_col='observed_labels',
    weight_var='binary_censoring_weight',
    **config_dict
)

loaders_predict = loader_generator_predict.init_loaders_predict()

predict_dict = model_weighted.predict(loaders_predict, phases=['test'])

output_df_eval, result_df_eval = (
    predict_dict["outputs"],
    predict_dict["performance"],
)

output_df_eval = output_df_eval.drop(columns=['labels']).merge(
    labels_df_dict['test'], 
).assign(labels = lambda x: x.labels.astype(np.long))

display(temp_evaluate(df=output_df_eval, strata_vars=None))

display(temp_evaluate(df=output_df_eval, strata_vars=['group']))

display(temp_evaluate_fair(df=output_df_eval))