## Test BinaryBART
<sup>*</sup>Including ProbitBART and LogisticBART

In [1]:
import numpy as np
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import log_loss, accuracy_score, roc_auc_score
from sklearn.calibration import calibration_curve
from sklearn.datasets import make_classification, load_breast_cancer, load_wine, fetch_openml
import pandas as pd
from bart_playground import *
from bart_playground.bart import DefaultBART, ProbitBART
import bartz

In [2]:
# Parameters
N_TREES = 50
NDPOST = 1000
NSKIP = 200
RANDOM_STATE = 42

# If debug then run with only one dataset and record running time
# Otherwise run with all datasets
debug = False

In [3]:
from numpy import indices
from sklearn.preprocessing import OrdinalEncoder, normalize

def load_mushroom():
    X, y = fetch_openml('mushroom', version=1, return_X_y=True)
    for col in X.select_dtypes('category'):
        # -1 in codes indicates NaN by pandas convention
        X[col] = X[col].cat.codes
    X = normalize(X)
    y_array = y.to_numpy().reshape(-1, 1)
    y_arm = OrdinalEncoder(dtype=int).fit_transform(y_array).flatten()
    
    # make the dataset a little bit smaller
    indices = np.random.choice(X.shape[0], size=X.shape[0] // 5, replace=False)
    X = X[indices, :]
    y_arm = y_arm[indices]
    
    return X, y_arm

def load_mushroom_encoded():
    X, y_arm = load_mushroom()
    
    n_arm = np.max(y_arm) + 1
    dim = X.shape[1] * n_arm # total number of encoded covariates (location-encoded for each arm) 
    act_dim = X.shape[1] # number of covariates
    covariates = np.zeros((X.shape[0], dim))
    rewards = np.zeros((X.shape[0], ))
    for cursor in range(X.shape[0]):
        a = np.random.randint(0, n_arm)
        covariates[cursor, a * act_dim:(a * act_dim + act_dim)] = X[cursor]
        if y_arm[cursor] == a:
            rewards[cursor] = 1 # reward is 1 if the true category matches the chosen arm

    return covariates, rewards

In [4]:
# Load datasets
def load_datasets():
    # Synthetic dataset
    X_syn, y_syn = make_classification(n_samples=400, n_features=8, n_informative=6, 
                                       n_redundant=0, n_classes=2, random_state=RANDOM_STATE)
    
    # Breast cancer dataset
    X_bc, y_bc = load_breast_cancer(return_X_y=True)
    
    # Wine dataset (convert to binary: class 0 vs rest)
    X_wine, y_wine = load_wine(return_X_y=True)
    y_wine = (y_wine == 0).astype(int)
    
    X_mushroom, y_mushroom = load_mushroom()
    X_mr_encoded, y_mr_encoded = load_mushroom_encoded()
    
    return {
        "Synthetic": (X_syn, y_syn),
        "Breast Cancer": (X_bc, y_bc),
        "Wine Binary": (X_wine, y_wine),
        "Mushroom": (X_mushroom, y_mushroom),
        "Mushroom Encoded": (X_mr_encoded, y_mr_encoded)
    }

## Utility functions to evaluate model performance

In [5]:
def ece_score(prob_true, prob_pred, y_true, y_prob, n_bins):
    # Compute bin counts for weighting
    bin_edges = np.linspace(0.0, 1.0, n_bins + 1)
    bin_ids = np.digitize(y_prob, bin_edges) - 1
    ece = 0.0
    for i in range(n_bins):
        mask = bin_ids == i
        if not np.any(mask):
            continue
        # weight by fraction of samples in bin
        weight = mask.sum() / len(y_true)
        ece += weight * abs(prob_true[i] - prob_pred[i])
    return ece

def mce_score(prob_true, prob_pred):
    diffs = np.abs(prob_true - prob_pred)
    return np.max(diffs)

import matplotlib.pyplot as plt

def calibration_plot(prob_true, prob_pred, model_name, dataset_name):
    plt.figure(figsize=(8, 6))
    plt.plot(prob_pred, prob_true, marker='o', label='Calibration curve')
    plt.plot([0,1], [0,1], linestyle='--', label='Perfect calibration')
    plt.xlabel('Mean predicted probability')
    plt.ylabel('Fraction of positives')
    plt.title(f'Reliability Diagram ({dataset_name}, {model_name})')
    plt.legend()
    plt.savefig(f'./results/{dataset_name}_{model_name}_calibration.png')
    plt.close()

In [6]:
from math import nan

def evaluate_model(model, model_name, X_train, X_test, y_train, y_test, dataset_name):
    """Evaluate a single model and return metrics"""
    
    if model_name == "Bartz":
        # Bartz regression treating 0/1 as continuous
        fit_result = bartz.BART.gbart(
            x_train=X_train.T, y_train=y_train.astype(float),
            x_test=X_test.T,
            ntree=N_TREES, ndpost=NDPOST, nskip=NSKIP,
            seed=RANDOM_STATE,
            printevery=NDPOST + NSKIP + 100
        )
        btpred_all = fit_result.predict(np.transpose(X_test))
        btpred = np.mean(np.array(btpred_all), axis=0)
        y_pred_prob = np.clip(btpred, 1e-9, 1 - 1e-9)
        y_pred = (y_pred_prob > 0.5).astype(int)
        
    elif model_name == "ProbitBART" or model_name == "LogisticBART":
        # Proper binary BART
        model.fit(X_train, y_train)
        proba_output = model.predict_proba(X_test)
        y_pred_prob = proba_output[:, 1]
        y_pred = model.predict(X_test)
        
    elif model_name == "RandomForestClassifier":
        # Native binary classifier
        model.fit(X_train, y_train)
        y_pred_prob = model.predict_proba(X_test)[:, 1]
        y_pred = model.predict(X_test)
        
    else:
        # Regression methods treating 0/1 as continuous
        model.fit(X_train, y_train)
        raw_pred = model.predict(X_test)
        y_pred_prob = np.clip(raw_pred, 1e-9, 1 - 1e-9)
        y_pred = (y_pred_prob > 0.5).astype(int)
    
    # Calculate metrics
    accuracy = accuracy_score(y_test, y_pred)
    logloss = log_loss(y_test, y_pred_prob)
    auc = roc_auc_score(y_test, y_pred_prob)
    
    n_bins = 10
    prob_true, prob_pred = calibration_curve(y_test, y_pred_prob, n_bins=n_bins, strategy='uniform')
    if(len(prob_true) < n_bins):
        ece = nan
        mce = nan
    else:
        ece = ece_score(prob_true, prob_pred, y_test, y_pred_prob, n_bins=n_bins)
        mce = mce_score(prob_true, prob_pred)
        calibration_plot(prob_true, prob_pred, model_name, dataset_name)

    return {'Accuracy': accuracy, 'LogLoss': logloss, 'AUC': auc, 'ECE': ece, 'MCE': mce}

In [7]:
metrics = None
results = []
 
def record_evaluation_results(dataset_name, X, y):
    global metrics
    
    print(f"\n=== Testing on {dataset_name} ===")
    
    # Split data
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.3, random_state=RANDOM_STATE, stratify=y
    )
    
    # Define models
    models = {
        "RFClassifier": RandomForestClassifier(n_estimators=N_TREES, random_state=RANDOM_STATE),
        "RFRegressor": RandomForestRegressor(n_estimators=N_TREES, random_state=RANDOM_STATE),
        "Bartz": "placeholder",
        "DefaultBART": DefaultBART(n_trees=N_TREES, ndpost=NDPOST, nskip=NSKIP, random_state=RANDOM_STATE),
        "ProbitBART": ProbitBART(n_trees=N_TREES, ndpost=NDPOST, nskip=NSKIP, random_state=RANDOM_STATE),
        "LogisticBART": LogisticBART(n_trees=N_TREES, ndpost=NDPOST, nskip=NSKIP, random_state=RANDOM_STATE)
    }
    
    for model_name, model in models.items():
        print(f"  Training {model_name}...")
        
        X_tr, X_te = X_train, X_test
        
        metrics = evaluate_model(model, model_name, X_tr, X_te, y_train, y_test, dataset_name)
        
        result = {'Dataset': dataset_name, 'Model': model_name, **metrics}
        results.append(result)
        print(f"    Acc: {metrics['Accuracy']:.3f}, LogLoss: {metrics['LogLoss']:.3f}, AUC: {metrics['AUC']:.4f}")
        print(f"    ECE: {metrics['ECE']:.4f}, MCE: {metrics['MCE']:.4f}")

## Dataset Loading

In [8]:
from bart_playground.bart import LogisticBART

old_settings = np.seterr(invalid='raise')

datasets = load_datasets()

In [9]:
for name, (X, y) in datasets.items():
    # Print dataset shapes
    print(f"Dataset: {name}\nX shape: {X.shape}, y shape: {y.shape}")
    # Print 0-1 distribution of y
    print(f"y distribution: {pd.Series(y).value_counts(normalize=True).to_dict()}")

Dataset: Synthetic
X shape: (400, 8), y shape: (400,)
y distribution: {0: 0.5, 1: 0.5}
Dataset: Breast Cancer
X shape: (569, 30), y shape: (569,)
y distribution: {1: 0.6274165202108963, 0: 0.37258347978910367}
Dataset: Wine Binary
X shape: (178, 13), y shape: (178,)
y distribution: {0: 0.6685393258426966, 1: 0.33146067415730335}
Dataset: Mushroom
X shape: (1624, 22), y shape: (1624,)
y distribution: {1: 0.5080049261083743, 0: 0.4919950738916256}
Dataset: Mushroom Encoded
X shape: (1624, 44), y shape: (1624,)
y distribution: {1.0: 0.5049261083743842, 0.0: 0.49507389162561577}


## Experiments

In [10]:
if not debug:
    for dataset_name, (X, y) in list(datasets.items()):
        record_evaluation_results(dataset_name, X, y)


=== Testing on Synthetic ===
  Training RFClassifier...
    Acc: 0.900, LogLoss: 2.072, AUC: 0.9000
    ECE: nan, MCE: nan
  Training RFRegressor...


INFO:2025-06-22 13:08:01,977:jax._src.xla_bridge:867: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory


    Acc: 0.858, LogLoss: 0.341, AUC: 0.9315
    ECE: 0.0970, MCE: 0.5640
  Training Bartz...
    Acc: 0.883, LogLoss: 0.327, AUC: 0.9486
    ECE: 0.0977, MCE: 0.3013
  Training DefaultBART...


Iterations: 100%|██████████| 1200/1200 [00:12<00:00, 94.19it/s] 


    Acc: 0.858, LogLoss: 0.318, AUC: 0.9525
    ECE: 0.1077, MCE: 0.2558
  Training ProbitBART...


Iterations: 100%|██████████| 1200/1200 [00:09<00:00, 130.15it/s]


    Acc: 0.817, LogLoss: 0.552, AUC: 0.8764
    ECE: nan, MCE: nan
  Training LogisticBART...


Iterations: 100%|██████████| 1200/1200 [00:16<00:00, 70.84it/s]


    Acc: 0.883, LogLoss: 0.312, AUC: 0.9508
    ECE: 0.0792, MCE: 0.4606

=== Testing on Breast Cancer ===
  Training RFClassifier...
    Acc: 0.924, LogLoss: 1.575, AUC: 0.9204
    ECE: nan, MCE: nan
  Training RFRegressor...
    Acc: 0.942, LogLoss: 0.118, AUC: 0.9892
    ECE: 0.0266, MCE: 0.6000
  Training Bartz...
    Acc: 0.947, LogLoss: 0.126, AUC: 0.9915
    ECE: 0.0393, MCE: 0.5586
  Training DefaultBART...


Iterations: 100%|██████████| 1200/1200 [00:07<00:00, 162.85it/s]


    Acc: 0.942, LogLoss: 0.123, AUC: 0.9920
    ECE: 0.0430, MCE: 0.3258
  Training ProbitBART...


Iterations: 100%|██████████| 1200/1200 [00:08<00:00, 139.65it/s]


    Acc: 0.936, LogLoss: 0.265, AUC: 0.9892
    ECE: nan, MCE: nan
  Training LogisticBART...


Iterations: 100%|██████████| 1200/1200 [00:16<00:00, 74.54it/s]


    Acc: 0.971, LogLoss: 0.103, AUC: 0.9940
    ECE: 0.0495, MCE: 0.4175

=== Testing on Wine Binary ===
  Training RFClassifier...
    Acc: 0.926, LogLoss: 1.535, AUC: 0.9028
    ECE: nan, MCE: nan
  Training RFRegressor...
    Acc: 0.926, LogLoss: 0.519, AUC: 0.9498
    ECE: nan, MCE: nan
  Training Bartz...
    Acc: 0.944, LogLoss: 0.111, AUC: 0.9923
    ECE: nan, MCE: nan
  Training DefaultBART...


Iterations: 100%|██████████| 1200/1200 [00:05<00:00, 226.61it/s]


    Acc: 0.944, LogLoss: 0.137, AUC: 0.9846
    ECE: nan, MCE: nan
  Training ProbitBART...


Iterations: 100%|██████████| 1200/1200 [00:07<00:00, 167.70it/s]


    Acc: 0.815, LogLoss: 0.463, AUC: 0.9861
    ECE: nan, MCE: nan
  Training LogisticBART...


Iterations: 100%|██████████| 1200/1200 [00:16<00:00, 72.14it/s]


    Acc: 0.963, LogLoss: 0.128, AUC: 0.9907
    ECE: nan, MCE: nan

=== Testing on Mushroom ===
  Training RFClassifier...
    Acc: 0.992, LogLoss: 0.170, AUC: 0.9919
    ECE: nan, MCE: nan
  Training RFRegressor...
    Acc: 0.986, LogLoss: 0.068, AUC: 0.9975
    ECE: nan, MCE: nan
  Training Bartz...
    Acc: 0.990, LogLoss: 0.048, AUC: 0.9968
    ECE: nan, MCE: nan
  Training DefaultBART...


Iterations: 100%|██████████| 1200/1200 [00:15<00:00, 77.57it/s]


    Acc: 0.990, LogLoss: 0.041, AUC: 0.9996
    ECE: nan, MCE: nan
  Training ProbitBART...


Iterations: 100%|██████████| 1200/1200 [00:12<00:00, 95.04it/s]


    Acc: 0.955, LogLoss: 0.226, AUC: 0.9872
    ECE: 0.1250, MCE: 0.4625
  Training LogisticBART...


Iterations: 100%|██████████| 1200/1200 [00:17<00:00, 67.84it/s]


    Acc: 0.986, LogLoss: 0.078, AUC: 0.9983
    ECE: 0.0401, MCE: 0.4411

=== Testing on Mushroom Encoded ===
  Training RFClassifier...
    Acc: 0.988, LogLoss: 0.255, AUC: 0.9877
    ECE: nan, MCE: nan
  Training RFRegressor...
    Acc: 0.988, LogLoss: 0.090, AUC: 0.9970
    ECE: nan, MCE: nan
  Training Bartz...
    Acc: 0.994, LogLoss: 0.032, AUC: 0.9998
    ECE: nan, MCE: nan
  Training DefaultBART...


Iterations: 100%|██████████| 1200/1200 [00:15<00:00, 78.38it/s]


    Acc: 0.996, LogLoss: 0.057, AUC: 0.9998
    ECE: 0.0449, MCE: 0.4833
  Training ProbitBART...


Iterations: 100%|██████████| 1200/1200 [00:13<00:00, 91.14it/s]


    Acc: 0.898, LogLoss: 0.388, AUC: 0.9701
    ECE: nan, MCE: nan
  Training LogisticBART...


Iterations: 100%|██████████| 1200/1200 [00:19<00:00, 62.97it/s]


    Acc: 0.953, LogLoss: 0.141, AUC: 0.9933
    ECE: 0.0502, MCE: 0.1732


In [11]:
if debug == True:
    dataset_name, (X, y) = list(datasets.items())[-1]  # Last dataset for debugging
    
    profile = False
    if not profile:
        record_evaluation_results(dataset_name, X, y)
    else:
        %prun -s cumtime -D temp_profile.prof -q record_evaluation_results(dataset_name, X, y)

        fname = "profile_logisticbart"

        !mv temp_profile.prof {fname}.prof
        !gprof2dot -f pstats {fname}.prof -o {fname}.dot
        !dot -Tpng {fname}.dot -o {fname}.png

In [12]:
# Display results
results_df = pd.DataFrame(results)
print("\n" + "="*60)
print("SUMMARY RESULTS")
print("="*60)

# Pivot tables for easy comparison
for metric in ['Accuracy', 'AUC', 'LogLoss', 'ECE', 'MCE']:
    print(f"\n{metric}:")
    pivot = results_df.pivot_table(index='Dataset', columns='Model', values=metric)
    print(pivot.round(3))


SUMMARY RESULTS

Accuracy:
Model             Bartz  DefaultBART  LogisticBART  ProbitBART  RFClassifier  \
Dataset                                                                        
Breast Cancer     0.947        0.942         0.971       0.936         0.924   
Mushroom          0.990        0.990         0.986       0.955         0.992   
Mushroom Encoded  0.994        0.996         0.953       0.898         0.988   
Synthetic         0.883        0.858         0.883       0.817         0.900   
Wine Binary       0.944        0.944         0.963       0.815         0.926   

Model             RFRegressor  
Dataset                        
Breast Cancer           0.942  
Mushroom                0.986  
Mushroom Encoded        0.988  
Synthetic               0.858  
Wine Binary             0.926  

AUC:
Model             Bartz  DefaultBART  LogisticBART  ProbitBART  RFClassifier  \
Dataset                                                                        
Breast Cancer     0.9