## Test BinaryBART
<sup>*</sup>Including ProbitBART and LogisticBART

In [1]:
import numpy as np
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import log_loss, accuracy_score, roc_auc_score
from sklearn.calibration import calibration_curve
from sklearn.datasets import make_classification, load_breast_cancer, load_wine, fetch_openml
import pandas as pd
from bart_playground import *
from bart_playground.bart import DefaultBART, ProbitBART
from xgboost import XGBRegressor, XGBClassifier
import bartz

In [2]:
# Parameters
N_TREES = 50
NDPOST = 1000
NSKIP = 200
RANDOM_STATE = 42

# If debug then run with only one dataset and record running time
# Otherwise run with all datasets
debug = True

In [3]:
from sklearn.preprocessing import OrdinalEncoder, normalize

def load_mushroom():
    X, y = fetch_openml('mushroom', version=1, return_X_y=True)
    for col in X.select_dtypes('category'):
        # -1 in codes indicates NaN by pandas convention
        X[col] = X[col].cat.codes
    X = normalize(X)
    y_array = y.to_numpy().reshape(-1, 1)
    y_arm = OrdinalEncoder(dtype=int).fit_transform(y_array).flatten()
    
    # make the dataset a little bit smaller
    indices = np.random.choice(X.shape[0], size=2000, replace=False)
    X = X[indices, :]
    y_arm = y_arm[indices]
    
    return X, y_arm

def load_mushroom_encoded(X, y_arm):
    n_arm = np.max(y_arm) + 1
    dim = X.shape[1] * n_arm # total number of encoded covariates (location-encoded for each arm) 
    act_dim = X.shape[1] # number of covariates
    covariates = np.zeros((X.shape[0], dim))
    rewards = np.zeros((X.shape[0], ))
    for cursor in range(X.shape[0]):
        a = np.random.randint(0, n_arm)
        covariates[cursor, a * act_dim:(a * act_dim + act_dim)] = X[cursor]
        if y_arm[cursor] == a:
            rewards[cursor] = 1 # reward is 1 if the true category matches the chosen arm

    return covariates, rewards

def load_mushroom_encoded_hot(X, y_arm):
    n_arm = np.max(y_arm) + 1
    dim = X.shape[1] + n_arm # total number of encoded covariates (onehot-encoded for each arm) 
    act_dim = X.shape[1] # number of covariates
    covariates = np.zeros((X.shape[0], dim))
    rewards = np.zeros((X.shape[0], ))
    for cursor in range(X.shape[0]):
        a = np.random.randint(0, n_arm)
        covariates[cursor, :act_dim] = X[cursor]
        covariates[cursor, act_dim + a] = 1 # one-hot encoding
        if y_arm[cursor] == a:
            rewards[cursor] = 1 # reward is 1 if the true category matches the chosen arm

    return covariates, rewards

In [4]:
# Load datasets
def load_datasets():
    # Synthetic dataset
    X_syn, y_syn = make_classification(n_samples=400, n_features=8, n_informative=6, 
                                       n_redundant=0, n_classes=2, random_state=RANDOM_STATE)
    
    # Breast cancer dataset
    X_bc, y_bc = load_breast_cancer(return_X_y=True)
    
    # Wine dataset (convert to binary: class 0 vs rest)
    X_wine, y_wine = load_wine(return_X_y=True)
    y_wine = (y_wine == 0).astype(int)
    
    np.random.seed(RANDOM_STATE)
    
    X_mushroom, y_mushroom = load_mushroom()
    X_mr_encoded, y_mr_encoded = load_mushroom_encoded(X_mushroom, y_mushroom)
    
    return {
        # "Synthetic": (X_syn, y_syn),
        # "Breast Cancer": (X_bc, y_bc),
        # "Wine Binary": (X_wine, y_wine),
        # "Mushroom": (X_mushroom, y_mushroom),
        # "Mushroom Encoded": (X_mr_encoded, y_mr_encoded),
        "Mushroom Encoded Hot": load_mushroom_encoded_hot(X_mushroom, y_mushroom),
    }

## Utility functions to evaluate model performance

In [5]:
def ece_score(prob_true, prob_pred, y_true, y_prob, n_bins):
    # Compute bin counts for weighting
    bin_edges = np.linspace(0.0, 1.0, n_bins + 1)
    bin_ids = np.digitize(y_prob, bin_edges) - 1
    ece = 0.0
    for i in range(n_bins):
        mask = bin_ids == i
        if not np.any(mask):
            continue
        # weight by fraction of samples in bin
        weight = mask.sum() / len(y_true)
        ece += weight * abs(prob_true[i] - prob_pred[i])
    return ece

def mce_score(prob_true, prob_pred):
    diffs = np.abs(prob_true - prob_pred)
    return np.max(diffs)

import matplotlib.pyplot as plt

def calibration_plot(prob_true, prob_pred, model_name, dataset_name):
    plt.figure(figsize=(8, 6))
    plt.plot(prob_pred, prob_true, marker='o', label='Calibration curve')
    plt.plot([0,1], [0,1], linestyle='--', label='Perfect calibration')
    plt.xlabel('Mean predicted probability')
    plt.ylabel('Fraction of positives')
    plt.title(f'Reliability Diagram ({dataset_name}, {model_name})')
    plt.legend()
    plt.savefig(f'./results/{dataset_name}_{model_name}_calibration.png')
    plt.close()

In [6]:
from math import nan

def evaluate_model(model, model_name, X_train, X_test, y_train, y_test, dataset_name):
    """Evaluate a single model and return metrics"""
    if model_name == "Bartz":
        # Bartz regression treating 0/1 as continuous
        fit_result = bartz.BART.gbart(
            x_train=X_train.T, y_train=y_train.astype(float),
            x_test=X_test.T,
            ntree=N_TREES, ndpost=NDPOST, nskip=NSKIP,
            seed=RANDOM_STATE,
            printevery=NDPOST + NSKIP + 100
        )
        btpred_all = fit_result.predict(np.transpose(X_test))
        btpred = np.mean(np.array(btpred_all), axis=0)
        y_pred_prob = np.clip(btpred, 1e-9, 1 - 1e-9)
        y_pred = (y_pred_prob > 0.5).astype(int)
        
    elif model_name == "ProbitBART" or model_name == "LogisticBART":
        # Proper binary BART
        model.fit(X_train, y_train)
        proba_output = model.predict_proba(X_test)
        y_pred_prob = proba_output[:, 1]
        y_pred = np.argmax(proba_output, axis=1)
        
        # y_post = model.posterior_predict(X_test)
        # lower_ci = np.percentile(y_post, 2.5, axis=1)
        # upper_ci = np.percentile(y_post, 97.5, axis=1)
        # coverage = np.mean((y_test >= lower_ci) & (y_test <= upper_ci))

    elif model_name == "RandomForestClassifier":
        # Native binary classifier
        model.fit(X_train, y_train)
        y_pred_prob = model.predict_proba(X_test)[:, 1]
        y_pred = model.predict(X_test)
        
    else:
        # Regression methods treating 0/1 as continuous
        model.fit(X_train, y_train)
        raw_pred = model.predict(X_test)
        y_pred_prob = np.clip(raw_pred, 1e-9, 1 - 1e-9)
        y_pred = (y_pred_prob > 0.5).astype(int)
    
    # Calculate metrics
    accuracy = accuracy_score(y_test, y_pred)
    logloss = log_loss(y_test, y_pred_prob)
    auc = roc_auc_score(y_test, y_pred_prob)
    
    n_bins = 5
    prob_true, prob_pred = calibration_curve(y_test, y_pred_prob, n_bins=n_bins, strategy='uniform')
    if(len(prob_true) < n_bins):
        ece = nan
        mce = nan
    else:
        ece = ece_score(prob_true, prob_pred, y_test, y_pred_prob, n_bins=n_bins)
        mce = mce_score(prob_true, prob_pred)
        calibration_plot(prob_true, prob_pred, model_name, dataset_name)

    return {'Accuracy': accuracy, 'LogLoss': logloss, 'AUC': auc,
            'ECE': ece, 'MCE': mce}

In [7]:
metrics = None
results = []
 
def record_evaluation_results(dataset_name, X, y):
    global metrics
    
    print(f"\n=== Testing on {dataset_name} ===")
    
    if dataset_name == "Mushroom" or dataset_name == "Mushroom Encoded":
        test_size = 0.3
    else:
        test_size = 0.3
    # Split data
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=test_size, random_state=RANDOM_STATE, stratify=y
    )
    
    # Define models
    models = {
        "XGBClassifier": XGBClassifier(n_estimators=N_TREES, random_state=RANDOM_STATE, eval_metric='logloss'),
        "XGBRegressor": XGBRegressor(n_estimators=N_TREES, random_state=RANDOM_STATE, eval_metric='logloss'),
        "RFClassifier": RandomForestClassifier(n_estimators=N_TREES, random_state=RANDOM_STATE),
        "RFRegressor": RandomForestRegressor(n_estimators=N_TREES, random_state=RANDOM_STATE),
        "Bartz": "placeholder",
        "DefaultBART": DefaultBART(n_trees=N_TREES, ndpost=NDPOST, nskip=NSKIP, random_state=RANDOM_STATE),
        # "ProbitBART": ProbitBART(n_trees=N_TREES, ndpost=NDPOST, nskip=NSKIP, random_state=RANDOM_STATE),
        "LogisticBART": LogisticBART(n_trees=N_TREES, ndpost=NDPOST, nskip=NSKIP, random_state=RANDOM_STATE)
    }
    
    for model_name, model in models.items():
        print(f"  Training {model_name}...")
        
        X_tr, X_te = X_train, X_test
        
        metrics = evaluate_model(model, model_name, X_tr, X_te, y_train, y_test, dataset_name)
        
        result = {'Dataset': dataset_name, 'Model': model_name, **metrics}
        results.append(result)
        print(f"    Acc: {metrics['Accuracy']:.3f}, LogLoss: {metrics['LogLoss']:.3f}, AUC: {metrics['AUC']:.4f}")
        print(f"    ECE: {metrics['ECE']:.4f}, MCE: {metrics['MCE']:.4f}")

## Dataset Loading

In [8]:
from bart_playground.bart import LogisticBART

old_settings = np.seterr(invalid='raise')

datasets = load_datasets()

In [9]:
for name, (X, y) in datasets.items():
    # Print dataset shapes
    print(f"Dataset: {name}\nX shape: {X.shape}, y shape: {y.shape}")
    # Print 0-1 distribution of y
    print(f"y distribution: {pd.Series(y).value_counts(normalize=True).to_dict()}")

Dataset: Mushroom Encoded Hot
X shape: (2000, 24), y shape: (2000,)
y distribution: {0.0: 0.5045, 1.0: 0.4955}


## Experiments

In [10]:
if not debug:
    for dataset_name, (X, y) in list(datasets.items()):
        record_evaluation_results(dataset_name, X, y)

In [11]:
if debug == True:
    dataset_name, (X, y) = list(datasets.items())[-1]  # Last dataset for debugging
    
    profile = False
    if not profile:
        record_evaluation_results(dataset_name, X, y)
    else:
        %prun -s cumtime -D temp_profile.prof -q record_evaluation_results(dataset_name, X, y)

        fname = "profile_logisticbart"

        !mv temp_profile.prof {fname}.prof
        !gprof2dot -f pstats {fname}.prof -o {fname}.dot
        !dot -Tpng {fname}.dot -o {fname}.png


=== Testing on Mushroom Encoded Hot ===
  Training XGBClassifier...
    Acc: 0.970, LogLoss: 0.622, AUC: 0.9699
    ECE: nan, MCE: nan
  Training XGBRegressor...
    Acc: 0.968, LogLoss: 0.141, AUC: 0.9899
    ECE: 0.0152, MCE: 0.1041
  Training RFClassifier...
    Acc: 0.973, LogLoss: 0.553, AUC: 0.9732
    ECE: nan, MCE: nan
  Training RFRegressor...


INFO:2025-07-07 11:14:15,776:jax._src.xla_bridge:867: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory


    Acc: 0.982, LogLoss: 0.112, AUC: 0.9989
    ECE: 0.0751, MCE: 0.2422
  Training Bartz...
    Acc: 0.998, LogLoss: 0.029, AUC: 0.9999
    ECE: 0.0183, MCE: 0.3060
  Training DefaultBART...


Iterations: 100%|██████████| 1200/1200 [00:06<00:00, 199.62it/s]


    Acc: 0.985, LogLoss: 0.077, AUC: 0.9977
    ECE: 0.0403, MCE: 0.3252
  Training LogisticBART...


Iterations: 100%|██████████| 1200/1200 [00:11<00:00, 104.34it/s]


    Acc: 0.922, LogLoss: 0.213, AUC: 0.9768
    ECE: 0.0560, MCE: 0.1184


In [12]:
# Display results
results_df = pd.DataFrame(results)
print("\n" + "="*60)
print("SUMMARY RESULTS")
print("="*60)

# Pivot tables for easy comparison
for metric in ['Accuracy', 'AUC', 'LogLoss', 'ECE', 'MCE']:
    print(f"\n{metric}:")
    pivot = results_df.pivot_table(index='Dataset', columns='Model', values=metric)
    print(pivot.round(3))


SUMMARY RESULTS

Accuracy:
Model                 Bartz  DefaultBART  LogisticBART  RFClassifier  \
Dataset                                                                
Mushroom Encoded Hot  0.998        0.985         0.922         0.973   

Model                 RFRegressor  XGBClassifier  XGBRegressor  
Dataset                                                         
Mushroom Encoded Hot        0.982           0.97         0.968  

AUC:
Model                 Bartz  DefaultBART  LogisticBART  RFClassifier  \
Dataset                                                                
Mushroom Encoded Hot    1.0        0.998         0.977         0.973   

Model                 RFRegressor  XGBClassifier  XGBRegressor  
Dataset                                                         
Mushroom Encoded Hot        0.999           0.97          0.99  

LogLoss:
Model                 Bartz  DefaultBART  LogisticBART  RFClassifier  \
Dataset                                                     