In [None]:
import pickle
from itertools import product, combinations

import numpy as np
import pandas as pd
import plotly.express as px
from sklearn.metrics import confusion_matrix, roc_auc_score, average_precision_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

import xgboost


In [None]:
df = pd.read_csv('/content/multiple_imputation.csv')
df

Unnamed: 0,hospital_id,hospital_death,age,bmi,ethnicity,gender,height,hospital_admit_source,icu_admit_source,icu_id,...,aids,cirrhosis,diabetes_mellitus,hepatic_failure,immunosuppression,leukemia,lymphoma,solid_tumor_with_metastasis,apache_3j_bodysystem,apache_2_bodysystem
0,118,0,68.000000,22.730000,Caucasian,M,180.3,Floor,Floor,92,...,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,Sepsis,Cardiovascular
1,81,0,77.000000,27.420000,Caucasian,F,160.0,Floor,Floor,90,...,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,Respiratory,Respiratory
2,118,0,25.000000,31.950000,Caucasian,F,172.7,Emergency Department,Accident & Emergency,93,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,Metabolic,Metabolic
3,118,0,81.000000,22.640000,Caucasian,F,165.1,Operating Room,Operating Room / Recovery,92,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,Cardiovascular,Cardiovascular
4,33,0,19.000000,23.609790,Caucasian,M,188.0,Emergency Department,Accident & Emergency,91,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,Trauma,Trauma
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
91708,30,0,75.000000,23.060250,Caucasian,M,177.8,Acute Care/Floor,Floor,927,...,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,Sepsis,Cardiovascular
91709,121,0,56.000000,47.179671,Caucasian,F,183.0,Emergency Department,Floor,925,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,Sepsis,Cardiovascular
91710,195,0,48.000000,27.236914,Caucasian,M,170.2,Emergency Department,Accident & Emergency,908,...,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,Metabolic,Metabolic
91711,66,0,71.165995,23.297481,Caucasian,F,154.9,Emergency Department,Accident & Emergency,922,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,Respiratory,Respiratory


In [None]:
target_variable = 'hospital_death'
X = df.drop(columns=[target_variable])
y = df[target_variable]

In [None]:
def preprocess_data(X, categorical_cols=None, is_training=True, scaler=None, dummy_cols=None):
    """
    Preprocess data by creating dummy variables and scaling
    """
    X_processed = X.copy()

    # Identify categorical columns if not provided
    if categorical_cols is None:
        categorical_cols = X_processed.select_dtypes(include=['object', 'category']).columns.tolist()

    # Create dummy variables for categorical columns
    if categorical_cols:
        if is_training:
            X_processed = pd.get_dummies(X_processed, columns=categorical_cols, drop_first=True)
            dummy_cols = X_processed.columns.tolist()
        else:
            # For test data, ensure same columns as training
            X_processed = pd.get_dummies(X_processed, columns=categorical_cols, drop_first=True)
            # Align columns with training data
            for col in dummy_cols:
                if col not in X_processed.columns:
                    X_processed[col] = 0
            X_processed = X_processed[dummy_cols]

    # Scale numerical features for LogReg and KNN
    if is_training:
        scaler = StandardScaler()
        X_scaled = pd.DataFrame(
            scaler.fit_transform(X_processed),
            columns=X_processed.columns,
            index=X_processed.index
        )
    else:
        X_scaled = pd.DataFrame(
            scaler.transform(X_processed),
            columns=X_processed.columns,
            index=X_processed.index
        )

    return X_scaled, scaler, dummy_cols

In [None]:
# Identify categorical columns
categorical_cols = X.select_dtypes(include=['object', 'category']).columns.tolist()

In [None]:
# Train-test split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2,
                                                        random_state=42, stratify=df[['hospital_death', 'gender', 'ethnicity']])

In [None]:

# Preprocess data for all models (dummy variables only, no scaling)
def preprocess_data_noscale(X, categorical_cols, is_training=True, dummy_cols=None):
    if categorical_cols:
        X_dummies = pd.get_dummies(X, columns=categorical_cols, drop_first=True)
        if is_training:
            dummy_cols = X_dummies.columns
        else:
            for col in dummy_cols:
                if col not in X_dummies.columns:
                    X_dummies[col] = 0
            X_dummies = X_dummies[dummy_cols]
        return X_dummies, dummy_cols
    else:
        return X, X.columns

# Apply preprocessing
X_train_processed, dummy_cols = preprocess_data_noscale(X_train, categorical_cols, is_training=True)
X_test_processed, _ = preprocess_data_noscale(X_test, categorical_cols, is_training=False, dummy_cols=dummy_cols)

# Use same processed data for all models
X_train_scaled = X_train_processed
X_test_scaled = X_test_processed
X_train_xgb = X_train_processed
X_test_xgb = X_test_processed


In [None]:
knn_path = "/content/best_knn.pkl"
with open(knn_path, 'rb') as f:
    knn = pickle.load(f)
lr_path = "/content/best_logistic_regression.pkl"
with open(lr_path, 'rb') as f:
    lr = pickle.load(f)
xgboost_path = "/content/best_xgboost.pkl"
with open(xgboost_path, 'rb') as f:
    xgboost = pickle.load(f)

In [None]:
def find_empty_subgroups(subpopulation_queries, X):
  """
    Identifies subgroups within a dataset that have no corresponding data points (i.e., empty subgroups).

    Args:
        subpopulation_queries (list): A list of queries defining the subgroups. Each query can be a tuple of conditions or a single condition string.
        X (pandas.DataFrame): The DataFrame containing the data to be analyzed.

    Returns:
        list: A list of queries that resulted in empty subgroups.
  """
  empty_queries = []
  for query in subpopulation_queries:
    S = X.query(" and ".join(query) if isinstance(query, tuple) else query).index

    if len(S) == 0:
      empty_queries.append(query)

  return empty_queries

In [None]:
def generate_subgroups(query_lists):
    """
    Generates all possible subgroups (tuples) formed by combining elements from the input lists.

    Args:
        query_lists (list of lists): A list of lists, where each sublist represents a set of elements to combine.

    Returns:
        list: A list of tuples, where each tuple represents a subgroup formed by combining elements across the input lists.
    """
    subgroups = []
    for list1, list2 in combinations(query_lists, 2):
        subgroups.extend(product(list1, list2))
    for list1, list2, list3 in combinations(query_lists, 3):
        subgroups.extend(product(list1, list2, list3))
    for list1, list2, list3, list4 in combinations(query_lists, 4):
        subgroups.extend(product(list1, list2, list3, list4))
    return subgroups

In [None]:
from itertools import product

# 1. Gender queries
gender_queries = ['gender_M == 1', 'gender_M == 0']

# 2. Age queries (4 bins)
bins = [(0, 30), (31, 50), (51, 70)]
age_queries = [f'age >= {start} and age <= {end}' for start, end in bins]
age_queries.append('age >= 71')  # 71+ group

# 3. Explicit ethnicity columns
explicit_ethnicity_cols = [
    'ethnicity_Asian',
    'ethnicity_Caucasian',
    'ethnicity_Hispanic',
    'ethnicity_Native American',
    'ethnicity_Other/Unknown'
]

# 4. Ethnicity queries + African American
ethnicity_queries = [f'`{col}` == 1' for col in explicit_ethnicity_cols]
african_american_query = ' and '.join([f'`{col}` == 0' for col in explicit_ethnicity_cols])
ethnicity_queries.append(african_american_query)

# --- INTERSECTIONAL QUERIES ---

# 5. 3D: Gender × Age × Ethnicity
intersectional_3d = [
    f'({g}) and ({a}) and ({e})'
    for g, a, e in product(gender_queries, age_queries, ethnicity_queries)
]

# 6. 2D: Gender × Age
intersectional_gender_age = [
    f'({g}) and ({a})' for g, a in product(gender_queries, age_queries)
]

# 7. 2D: Gender × Ethnicity
intersectional_gender_ethnicity = [
    f'({g}) and ({e})' for g, e in product(gender_queries, ethnicity_queries)
]

# 8. 2D: Age × Ethnicity
intersectional_age_ethnicity = [
    f'({a}) and ({e})' for a, e in product(age_queries, ethnicity_queries)
]

# --- MARGINAL 1D QUERIES ---

# 9. Gender only
marginal_gender = gender_queries

# 10. Age only
marginal_age = age_queries

# 11. Ethnicity only
marginal_ethnicity = ethnicity_queries

# --- GLOBAL GROUP ---
global_query = ['age >= 0']

# --- COMBINE ALL ---
all_subgroups = (
    intersectional_3d +
    intersectional_gender_age +
    intersectional_gender_ethnicity +
    intersectional_age_ethnicity +
    marginal_gender +
    marginal_age +
    marginal_ethnicity +
    global_query
)


# Define your filtering function
def find_empty_subgroups(subgroup_queries, X_df):
    empty = []
    for query in subgroup_queries:
        try:
            if X_df.query(query).shape[0] == 0:
                empty.append(query)
        except Exception as e:
            print(f"Query error: {query} -> {e}")
            empty.append(query)
    return empty

# Assuming your data is in X_test_xgb or similar
empty_subgroups = find_empty_subgroups(all_subgroups, X_test_xgb)

# Remove empty ones
valid_subgroup_queries = sorted(set(all_subgroups) - set(empty_subgroups), key=len)


# 9. Sort for readability
sorted_queries = sorted(valid_subgroup_queries, key=len)
print(f'Valid subgroups = {len(sorted_queries)}')


Valid subgroups = 105


In [None]:
class CalibratedPredictor:
    def __init__(self, model, subgroups, y_test):
        self.model = model
        self.subgroups = subgroups
        self.y_test = y_test
        self.proba = None

    def predict(self, X, y=None):
        if self.proba is None:
            self.predict_proba(X, y)
        p = self.proba[:, 1]
        return np.where(p >= 0.2, 1, 0)

    def predict_proba(self, X, y=None):
        return self.multi_calibrate_predictor(X, y or self.y_test, self.subgroups)

    def multi_calibrate_predictor(
        self,
        X,
        y,
        subpopulation_queries,
        alpha=0.001,
        max_iter=1000,
        learning_rate=0.6,
        normalize_by_size=True,
        decay=False
    ):
        y = pd.Series(y, index=X.index)
        p = pd.Series(self.model.predict_proba(X)[:, 1], index=X.index)

        for iteration in range(max_iter):
            done = True
            total_deltas = []
            base_p = p.copy()

            for query in subpopulation_queries:
                try:
                    combined_query = " and ".join(query) if isinstance(query, tuple) else query
                    S = X.query(combined_query).index

                    if len(S) == 0:
                        continue

                    delta_S = (y.loc[S] - p.loc[S]).mean()
                    total_deltas.append(abs(delta_S))

                    if abs(delta_S) > alpha:
                        # Decay learning rate per iteration (optional)
                        lr = learning_rate / (1 + 0.01 * iteration) if decay else learning_rate
                        adj = lr * delta_S

                        if normalize_by_size:
                            adj *= len(S) / len(p)

                        p.loc[S] += adj
                        done = False

                except Exception as e:
                    print(f"⚠️ Error in subgroup {query}: {e}")

            # Clip once at the end
            p = p.clip(0, 1)

            # Summary print
            mean_pred = p.mean()
            mean_true = y.mean()
            mean_abs_delta = np.mean(total_deltas) if total_deltas else 0

            if done:
                print("✅ Calibration completed!")
                break

        self.proba = np.column_stack([1 - p, p])
        return self.proba


In [None]:
def compare_performance_measures(X_test, y_test, model, subgroup_conds, model_name):
    """
    Function to compare all performance measures across subgroups for a single model.
    :param X_test: DataFrame containing the test features.
    :param y_test: Series containing the test labels.
    :param model: A single model to evaluate.
    :param subgroup_conds: List of conditions defining the subgroups.
    :return: None.
    """
    metrics_data = []
    metrics_types = ['AUC', 'Sensitivity', 'PPV', 'Specificity', 'NPV']

    # Iterate over each subgroup condition
    for cond, label in subgroup_conds.items():
        query_str = " and ".join(cond) if isinstance(cond, tuple) else cond
        sub_df = X_test.query(query_str)
        if sub_df.empty:
            continue
        true_labels = y_test.loc[sub_df.index]

        if len(true_labels) == 0: # Additional check in case query results in empty true_labels after loc
            continue

        predictions = model.predict(sub_df)

        pred_proba = model.predict_proba(sub_df)[:, 1] if hasattr(model, "predict_proba") else None

        # Ensure confusion matrix is 2x2 by specifying labels
        cm = confusion_matrix(true_labels, predictions, labels=[0, 1])
        tn, fp, fn, tp = cm.ravel()


        # Calculate metrics
        auc_value = None
        if hasattr(model, "predict_proba") and len(true_labels.value_counts()) > 1:
            try:
                auc_value = roc_auc_score(true_labels, pred_proba)
            except ValueError: # Handles cases where roc_auc_score can't be computed
                auc_value = 0

        metrics_data.append({
            'Subgroup': label,
            'Metric': 'AUC',
            'Value': auc_value
        })

        sensitivity_value = tp / (tp + fn) if (tp + fn) > 0 else 0
        ppv_value = tp / (tp + fp) if (tp + fp) > 0 else 0
        specificity_value = tn / (tn + fp) if (tn + fp) > 0 else 0
        npv_value = tn / (tn + fn) if (tn + fn) > 0 else 0

        metrics_data.append({'Subgroup': label, 'Metric': 'Sensitivity', 'Value': sensitivity_value})
        metrics_data.append({'Subgroup': label, 'Metric': 'PPV', 'Value': ppv_value})
        metrics_data.append({'Subgroup': label, 'Metric': 'Specificity', 'Value': specificity_value})
        metrics_data.append({'Subgroup': label, 'Metric': 'NPV', 'Value': npv_value})

    df_metrics = pd.DataFrame(metrics_data)

    # Plotting each metric for each subgroup using Plotly Express
    if not df_metrics.empty:
        pastel_colors = ['#6BAED6', '#9ECAE1', '#BCBDDC', '#FDD0A2', '#FDBB84']

        fig = px.bar(
            df_metrics,
            x='Subgroup',
            y='Value',
            color='Metric',
            barmode='group',
            title=f"Performance Metrics by Subgroup with {model_name}",
            color_discrete_sequence=pastel_colors,
        )

        fig.update_xaxes(
            tickfont=dict(size=12),
            tickangle=35
        )

        fig.update_layout(
            xaxis_title='Subgroup',
            yaxis_title='Metric Value',
            width=1400,
            height=500
        )

        fig.show()


def model_comparison(model, model_name):
  print("Let's compare basic models")

  # Define the compound query for African American subgroup
  _other_ethnicities_for_aa_definition = ['Caucasian', 'Native American', 'Other/Unknown', 'Asian']
  african_american_compound_query = ' and '.join([f'`ethnicity_{e}` == 0' for e in _other_ethnicities_for_aa_definition])

  males_vs_females = {
    'age > 0' : 'All data',
    'gender_M == 1': 'Males',
    'gender_M == 0': 'Females',
  }
  compare_performance_measures(X_test_processed, y_test, model, males_vs_females, model_name)

  age_subgroups = {
    'age > 0': 'All data',
    'age <= 30': 'Aged 0–30',
    'age > 30 and age <= 50': 'Aged 31–50',
    'age > 50 and age <= 70': 'Aged 51–70',
    'age > 70': 'Aged 71+',
  }
  compare_performance_measures(X_test_processed, y_test, model, age_subgroups, model_name)

  ethnicity_subgroups = {
    'age > 0' : 'All data',
    '`ethnicity_Caucasian` == 1': 'Caucasian',
    '`ethnicity_Native American` == 1': 'Native American',
    '`ethnicity_Other/Unknown` == 1': 'Other/Unknown Ethnicity',
    african_american_compound_query: 'African American',
    '`ethnicity_Asian` == 1': 'Asian',
  }
  compare_performance_measures(X_test_processed, y_test, model, ethnicity_subgroups, model_name)
  intersecting_female_subgroups_dict = {
    'age > 0' : 'All data',
    ('gender_M == 0'): 'Females',
    ('gender_M == 0', african_american_compound_query): "African American Females",
    ('`ethnicity_Native American` == 1', 'gender_M == 0'): 'Native Americans Females',
    ('`ethnicity_Caucasian` == 1', 'gender_M == 0'): 'Caucasian Females',
    ('`ethnicity_Asian` == 1', 'gender_M == 0'): 'Asian Females',
    ('`ethnicity_Other/Unknown` == 1', 'gender_M == 0'): 'Other/Unknown ethnicity Females',
  }
  compare_performance_measures(X_test_processed, y_test, model, intersecting_female_subgroups_dict, model_name)

  intersecting_female_and_age_subgroups_dict = {
    'age > 0' : 'All data',
    ('`ethnicity_Asian` == 1', 'gender_M == 0', 'age <= 30'): 'Asian Females aged 0-30',
    ('`ethnicity_Asian` == 1', 'gender_M == 0', 'age > 30 and age <= 50'): 'Asian Females aged 31-50',
    ('`ethnicity_Asian` == 1', 'gender_M == 0', 'age > 50 and age <= 70'): 'Asian Females aged 51-70',
    ('`ethnicity_Asian` == 1', 'gender_M == 0', 'age > 70'): 'Asian Females aged 71+',
  }
  compare_performance_measures(X_test_processed, y_test, model, intersecting_female_and_age_subgroups_dict, model_name)

  print("Let's go deeper")

  intersecting_2_subgroups_dict = {
    'age > 0' : 'All data',
    ('age >= 71', 'gender_M == 0'): 'Females Aged 71+',
    ('`ethnicity_Caucasian` == 1', 'gender_M == 1'): 'Caucasian Males',
    ('`ethnicity_Asian` == 1', 'gender_M == 0'): 'Asian Females',
    ('gender_M == 0', '`ethnicity_Caucasian` == 1'): 'Caucasian Females',
  }
  compare_performance_measures(X_test_processed, y_test, model, intersecting_2_subgroups_dict, model_name)

  intersecting_male_subgroups_dict = {
    'age > 0' : 'All data',
    ('gender_M == 1'): 'Males',
    ('gender_M == 1', african_american_compound_query): "African American Males",
    ('`ethnicity_Native American` == 1', 'gender_M == 1'): 'Native Americans Males',
    ('`ethnicity_Caucasian` == 1', 'gender_M == 1'): 'Caucasian Males',
    ('`ethnicity_Asian` == 1', 'gender_M == 1'): 'Asian Males',
    ('`ethnicity_Other/Unknown` == 1', 'gender_M == 1'): 'Other/Unknown ethnicity Males',
  }
  compare_performance_measures(X_test_processed, y_test, model, intersecting_male_subgroups_dict, model_name)

  intersecting_male_and_age_subgroups_dict = {
    'age > 0' : 'All data',
    ('`ethnicity_Asian` == 1', 'gender_M == 1', 'age <= 30'): 'Asian Males aged 0-30',
    ('`ethnicity_Asian` == 1', 'gender_M == 1', 'age > 30 and age <= 50'): 'Asian Males aged 31-50',
    ('`ethnicity_Asian` == 1', 'gender_M == 1', 'age > 50 and age <= 70'): 'Asian Males aged 51-70',
    ('`ethnicity_Asian` == 1', 'gender_M == 1', 'age > 70'): 'Asian Males aged 71+',
  }
  compare_performance_measures(X_test_processed, y_test, model, intersecting_male_and_age_subgroups_dict, model_name)

  print("Even deeper")
  intersecting_3_subgroups_dict = {
    'age > 0' : 'All data',
    ('gender_M == 1', 'age <= 30'): 'Young Males (<= 30)',
    ('gender_M == 0', '`ethnicity_Caucasian` == 1'): 'Caucasian Females',
    ('age >= 71', '`ethnicity_Native American` == 1'): 'Native American Aged 71+',
    ('gender_M == 0', african_american_compound_query, 'age > 50 and age <= 70'): 'African American Females Aged 51 to 70',
    ('gender_M == 0', 'age > 70', '`ethnicity_Caucasian` == 1'): 'Caucasian Females Aged 71+',
  }
  compare_performance_measures(X_test_processed, y_test, model, intersecting_3_subgroups_dict, model_name)

  intersecting_4_subgroups_dict = {
    'age > 0' : 'All data',
    ('age > 30 and age <= 50', 'gender_M == 1', '`ethnicity_Caucasian` == 1'): "Caucasian Males Aged 31 to 50",
    ('age > 50 and age <= 70', 'gender_M == 0', african_american_compound_query): "African American Females Aged 51 to 70",
    ('age > 30 and age <= 50', 'gender_M == 0', '`ethnicity_Caucasian` == 1'): "Caucasian Females Aged 31 to 50",
  }
  compare_performance_measures(X_test_processed, y_test, model, intersecting_4_subgroups_dict, model_name)

  intersecting_5_subgroups_dict={
      ('gender_M == 1', '`ethnicity_Asian` == 1'): 'Asian Males',
      ('gender_M == 0', '`ethnicity_Native American` == 1'): 'Native American Females',
      ('age <= 30', '`ethnicity_Other/Unknown` == 1'): 'Young Patients, Unknown Ethnicity',
      ('age > 70', '`ethnicity_Caucasian` == 1'): 'Caucasian Elderly (71+)',
  }
  compare_performance_measures(X_test_processed, y_test, model, intersecting_5_subgroups_dict, model_name)


  intersecting_10_subgroups_dict = {
    'age > 0' : 'All data',
    ('age < 70', 'gender_M == 1', '`ethnicity_Asian` == 1'): "Asian Males Aged 71+",
    ('age < 70', 'gender_M == 1', '`ethnicity_Native American` == 1'): "Native American Males Aged 71+",
    ('age < 70', 'gender_M == 1',african_american_compound_query): "african american Males Aged 71+",
  }
  compare_performance_measures(X_test_processed, y_test, model, intersecting_10_subgroups_dict, model_name)

  intersecting_6_subgroups_dict={
      ('gender_M == 1', 'age > 70', '`ethnicity_Asian` == 1'): 'Asian Males Aged 71+',
      ('gender_M == 0', '`ethnicity_Caucasian` == 1', 'age > 50 and age <= 70'): 'Caucasian Females Aged 51–70',
      ('gender_M == 1', '`ethnicity_Other/Unknown` == 1', 'age <= 30'): 'Young Males, Unknown Ethnicity',
      ('gender_M == 0', 'age > 70', '`ethnicity_Native American` == 1'): 'Native American Females Aged 71+',
  }
  compare_performance_measures(X_test_processed, y_test, model, intersecting_6_subgroups_dict, model_name)





def calc_calibration_itl(model, X, y, subgroup_queries):
  """
    Calculates the calibration-in-the-large (ITL) for each subgroup defined by the provided queries.

    Calibration-in-the-large measures how well the average predicted probability for a subgroup
    matches the observed event rate within that subgroup.

    Args:
        model: A fitted model with a `predict_proba` method.
        X (pandas.DataFrame): The feature DataFrame.
        y (pandas.Series): The target variable.
        subgroup_queries (list): A list of queries defining the subgroups. Each query can be
                                 a tuple of conditions or a single condition string.

    Returns:
        list:  A list of calibration-in-the-large values (floats), one for each subgroup query.
  """

  y = pd.Series(y, index=X.index)
  p = pd.Series(model.predict_proba(X)[:, 1], index=X.index)

  calibrations_list = []

  for query in subgroup_queries:
    S = X.query(" and ".join(query) if isinstance(query, tuple) else query).index
    avg_predicted_proba = np.mean(p[S])
    overall_event_rate = np.mean(y[S])
    calibration_intercept = avg_predicted_proba - overall_event_rate
    calibrations_list.append(calibration_intercept)

  return calibrations_list


def test_mc_across_models(model, sorted_queries, y_test, model_name):
  """
  Evaluates the calibration-in-the-large (ITL) of a model across subgroups, both before and after calibration. Plots the results for comparison.

  Args:
      model: The fitted model to evaluate.
      sorted_queries: A list of queries defining the subgroups, sorted by subgroup size.
      y_test: The true labels (target values) for the test set.
      model_name (str): The name of the model for the plot title.

  Returns:
      tuple:
         - The calibrated model object (ClibratedPredictor)
         - The original ITL calibration values
         - The ITL calibration values after calibration
  """
  if model_name == 'KNN' or model_name == 'LR':
      calibrated_model = CalibratedPredictor(model, sorted_queries, y_test)
      calibrations_list = calc_calibration_itl(model, X_test_scaled, y_test, sorted_queries)
      calibrated_calibrations_list = calc_calibration_itl(calibrated_model, X_test_scaled, y_test, sorted_queries)
  else:
      calibrated_model = CalibratedPredictor(model, sorted_queries, y_test)
      calibrations_list = calc_calibration_itl(model, X_test_xgb, y_test, sorted_queries)
      calibrated_calibrations_list = calc_calibration_itl(calibrated_model, X_test_xgb, y_test, sorted_queries)

  data = {
    'Original': calibrations_list,
    'Calibrated': calibrated_calibrations_list,
    'Index': range(len(calibrations_list))
  }

  df = pd.DataFrame(data)

  # Melt the DataFrame so that it's in a tidy format Plotly can use to automatically generate a legend
  df_melted = df.melt(id_vars=['Index'], value_vars=['Original', 'Calibrated'],
                      var_name='Type', value_name='Value')

  # Plot using the melted DataFrame
  fig = px.line(df_melted,
                x='Index',
                y='Value',
                color='Type', # This will automatically create a legend based on the 'Type' column
                title=f' Multi Calibration Across Subgroups ({model_name} model)')

  fig.update_layout(
      xaxis_title="Subgroup Index (Widest to Narrowest)",
      yaxis_title="Calibration"
  )

  fig.show()

  return calibrated_model, calibrations_list, calibrated_calibrations_list


def calc_calibration_itl(model, X, y, subgroup_queries):
  """
    Calculates the calibration-in-the-large (ITL) for each subgroup defined by the provided queries.
    Calibration-in-the-large measures how well the average predicted probability for a subgroup
    matches the observed event rate within that subgroup.

    Args:
        model: A fitted model with a `predict_proba` method.
        X (pandas.DataFrame): The feature DataFrame.
        y (pandas.Series): The target variable.
        subgroup_queries (list): A list of queries defining the subgroups. Each query can be
                                 a tuple of conditions or a single condition string.

    Returns:
        list:  A list of calibration-in-the-large values (floats), one for each subgroup query.
  """

  y = pd.Series(y, index=X.index)
  p = pd.Series(model.predict_proba(X)[:, 1], index=X.index)

  calibrations_list = []

  for query in subgroup_queries:
    S = X.query(" and ".join(query) if isinstance(query, tuple) else query).index
    avg_predicted_proba = np.mean(p[S])
    overall_event_rate = np.mean(y[S])
    calibration_intercept = avg_predicted_proba - overall_event_rate
    calibrations_list.append(calibration_intercept)

  return calibrations_list


def test_mc_across_models(model, sorted_queries, y_test, model_name):
    """
    Evaluates the calibration-in-the-large (ITL) of a model across subgroups, both before and after calibration. Plots the results for comparison.
    """
    if model_name == 'KNN' or model_name == 'LR':
        calibrated_model = CalibratedPredictor(model, sorted_queries, y_test)
        calibrations_list = calc_calibration_itl(model, X_test_scaled, y_test, sorted_queries)
        calibrated_calibrations_list = calc_calibration_itl(calibrated_model, X_test_scaled, y_test, sorted_queries)
    else:
        calibrated_model = CalibratedPredictor(model, sorted_queries, y_test)
        calibrations_list = calc_calibration_itl(model, X_test_xgb, y_test, sorted_queries)
        calibrated_calibrations_list = calc_calibration_itl(calibrated_model, X_test_xgb, y_test, sorted_queries)

    data = {
        'Original': calibrations_list,
        'Calibrated': calibrated_calibrations_list,
        'Index': range(len(calibrations_list))
    }

    df = pd.DataFrame(data)
    df_melted = df.melt(id_vars=['Index'], value_vars=['Original', 'Calibrated'],
                        var_name='Type', value_name='Value')

    # Compute max absolute value for symmetric y-axis
    y_abs_max = max(abs(df_melted['Value'].min()), abs(df_melted['Value'].max()))

    # Plot
    fig = px.line(
        df_melted,
        x='Index',
        y='Value',
        color='Type',
        title=f'Mulit calibration Across Subgroups ({model_name} model)'
    )

    # Update layout: y-axis centered around zero
    fig.update_layout(
        xaxis_title="Subgroup Index (Widest to Narrowest)",
        yaxis_title="Mean Predicted Probability − Actual Outcome Rate",
        yaxis=dict(range=[-y_abs_max, y_abs_max], zeroline=True, zerolinewidth=2, zerolinecolor='gray')
    )

    fig.show()

    return calibrated_model, calibrations_list, calibrated_calibrations_list







In [None]:
calibrated_dict = {
    'KNN': {'model': knn},
    'LR': {'model': lr},
    'XGB': {'model': xgboost}
}

for model_name, model_dict in calibrated_dict.items():
  model_dict['calibrated-model'], model_dict['calibrations_list'], model_dict['calibrated-calibrations-list'] = \
  test_mc_across_models(model_dict['model'], sorted_queries, y_test, model_name)

In [None]:
model_comparison(xgboost, 'XGboost')

Let's compare basic models


Let's go deeper


Even deeper
