## Test BinaryBART

In [1]:
import numpy as np
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import log_loss, accuracy_score, roc_auc_score
from sklearn.datasets import make_classification, load_breast_cancer, load_wine
from sklearn.preprocessing import StandardScaler
import pandas as pd
from bart_playground import *
from bart_playground.bart import DefaultBART, BinaryBART
import bartz

In [2]:
# Parameters
N_TREES = 50
NDPOST = 500
NSKIP = 500
RANDOM_STATE = 42

In [3]:
# Load datasets
def load_datasets():
    # Synthetic dataset
    X_syn, y_syn = make_classification(n_samples=400, n_features=8, n_informative=6, 
                                       n_redundant=0, n_classes=2, random_state=RANDOM_STATE)
    
    # Breast cancer dataset
    X_bc, y_bc = load_breast_cancer(return_X_y=True)
    
    # Wine dataset (convert to binary: class 0 vs rest)
    X_wine, y_wine = load_wine(return_X_y=True)
    y_wine = (y_wine == 0).astype(int)
    
    return {
        "Synthetic": (X_syn, y_syn),
        "Breast Cancer": (X_bc, y_bc),
        "Wine Binary": (X_wine, y_wine)
    }

In [4]:
def evaluate_model(model, model_name, X_train, X_test, y_train, y_test):
    """Evaluate a single model and return metrics"""
    
    if model_name == "Bartz":
        # Bartz regression treating 0/1 as continuous
        fit_result = bartz.BART.gbart(
            x_train=X_train.T, y_train=y_train.astype(float),
            x_test=X_test.T,
            ntree=N_TREES, ndpost=NDPOST, nskip=NSKIP,
            seed=RANDOM_STATE,
            printevery=NDPOST + NSKIP + 100
        )
        btpred_all = fit_result.predict(np.transpose(X_test))
        btpred = np.mean(np.array(btpred_all), axis=0)
        y_pred_prob = np.clip(btpred, 1e-9, 1 - 1e-9)
        y_pred = (y_pred_prob > 0.5).astype(int)
        
    elif model_name == "BinaryBART" or model_name == "LogisticBART":
        # Proper binary BART
        model.fit(X_train, y_train)
        proba_output = model.predict_proba(X_test)
        y_pred_prob = proba_output[:, 1]
        y_pred = model.predict(X_test)
        
    elif model_name == "RandomForestClassifier":
        # Native binary classifier
        model.fit(X_train, y_train)
        y_pred_prob = model.predict_proba(X_test)[:, 1]
        y_pred = model.predict(X_test)
        
    else:
        # Regression methods treating 0/1 as continuous
        model.fit(X_train, y_train)
        raw_pred = model.predict(X_test)
        y_pred_prob = np.clip(raw_pred, 1e-9, 1 - 1e-9)
        y_pred = (y_pred_prob > 0.5).astype(int)
    
    # Calculate metrics
    accuracy = accuracy_score(y_test, y_pred)
    logloss = log_loss(y_test, y_pred_prob)
    auc = roc_auc_score(y_test, y_pred_prob)
    
    return {'Accuracy': accuracy, 'LogLoss': logloss, 'AUC': auc}

In [5]:
# Main evaluation loop
from bart_playground.bart import LogisticBART

datasets = load_datasets()
results = []

for dataset_name, (X, y) in datasets.items():
    print(f"\n=== Testing on {dataset_name} ===")
    
    # Split data
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.3, random_state=RANDOM_STATE, stratify=y
    )
    
    # Define models
    models = {
        "RandomForestClassifier": RandomForestClassifier(n_estimators=N_TREES, random_state=RANDOM_STATE),
        "RandomForestRegressor": RandomForestRegressor(n_estimators=N_TREES, random_state=RANDOM_STATE),
        "DefaultBART": DefaultBART(n_trees=N_TREES, ndpost=NDPOST // 5, nskip=NSKIP // 5, random_state=RANDOM_STATE),
        # "BinaryBART": BinaryBART(n_trees=N_TREES, ndpost=NDPOST, nskip=NSKIP, random_state=RANDOM_STATE),
        "LogisticBART": LogisticBART(n_trees=N_TREES // 2, ndpost=NDPOST, nskip=NSKIP // 2, random_state=RANDOM_STATE),
        # "Bartz": "placeholder"
    }
    
    for model_name, model in models.items():
        print(f"  Training {model_name}...")
        
        X_tr, X_te = X_train, X_test
            
        metrics = evaluate_model(model, model_name, X_tr, X_te, y_train, y_test)
        
        result = {'Dataset': dataset_name, 'Model': model_name, **metrics}
        results.append(result)
        
        print(f"    Acc: {metrics['Accuracy']:.3f}, LogLoss: {metrics['LogLoss']:.3f}, AUC: {metrics['AUC']:.3f}")


=== Testing on Synthetic ===
  Training RandomForestClassifier...
    Acc: 0.900, LogLoss: 0.318, AUC: 0.956
  Training RandomForestRegressor...
    Acc: 0.858, LogLoss: 0.341, AUC: 0.932
  Training DefaultBART...


Iterations: 100%|██████████| 200/200 [00:15<00:00, 12.70it/s]


    Acc: 0.875, LogLoss: 0.311, AUC: 0.964
  Training LogisticBART...


Iterations: 100%|██████████| 750/750 [07:12<00:00,  1.73it/s]  


    Acc: 0.525, LogLoss: 0.695, AUC: 0.422

=== Testing on Breast Cancer ===
  Training RandomForestClassifier...
    Acc: 0.924, LogLoss: 0.124, AUC: 0.991
  Training RandomForestRegressor...
    Acc: 0.942, LogLoss: 0.118, AUC: 0.989
  Training DefaultBART...


Iterations: 100%|██████████| 200/200 [00:03<00:00, 65.00it/s]


    Acc: 0.953, LogLoss: 0.123, AUC: 0.992
  Training LogisticBART...


Iterations: 100%|██████████| 750/750 [00:38<00:00, 19.63it/s]


    Acc: 0.380, LogLoss: 0.695, AUC: 0.533

=== Testing on Wine Binary ===
  Training RandomForestClassifier...
    Acc: 0.926, LogLoss: 0.155, AUC: 0.992
  Training RandomForestRegressor...
    Acc: 0.926, LogLoss: 0.519, AUC: 0.950
  Training DefaultBART...


Iterations: 100%|██████████| 200/200 [00:02<00:00, 76.92it/s]


    Acc: 0.944, LogLoss: 0.122, AUC: 0.988
  Training LogisticBART...


Iterations: 100%|██████████| 750/750 [00:28<00:00, 25.96it/s]


    Acc: 0.444, LogLoss: 0.696, AUC: 0.400


In [None]:
# Display results
results_df = pd.DataFrame(results)
print("\n" + "="*60)
print("SUMMARY RESULTS")
print("="*60)

# Pivot tables for easy comparison
for metric in ['Accuracy', 'AUC', 'LogLoss']:
    print(f"\n{metric}:")
    pivot = results_df.pivot_table(index='Dataset', columns='Model', values=metric)
    print(pivot.round(3))


SUMMARY RESULTS

Accuracy:
Model          Bartz  BinaryBART  DefaultBART  RandomForestClassifier  \
Dataset                                                                 
Breast Cancer  0.947       0.942        0.936                   0.924   
Synthetic      0.883       0.775        0.867                   0.900   
Wine Binary    0.944       0.815        0.963                   0.926   

Model          RandomForestRegressor  
Dataset                               
Breast Cancer                  0.942  
Synthetic                      0.858  
Wine Binary                    0.926  

AUC:
Model          Bartz  BinaryBART  DefaultBART  RandomForestClassifier  \
Dataset                                                                 
Breast Cancer  0.991       0.989        0.984                   0.991   
Synthetic      0.949       0.871        0.952                   0.956   
Wine Binary    0.994       0.977        0.985                   0.992   

Model          RandomForestRegressor  


In [None]:
# Best model per dataset
print("\nBest Models per Dataset:")
for dataset in results_df['Dataset'].unique():
    dataset_results = results_df[results_df['Dataset'] == dataset]
    
    best_acc = dataset_results.loc[dataset_results['Accuracy'].idxmax()]
    best_auc = dataset_results.loc[dataset_results['AUC'].idxmax()]
    best_ll = dataset_results.loc[dataset_results['LogLoss'].idxmin()]
    
    print(f"\n{dataset}:")
    print(f"  Best Accuracy: {best_acc['Model']} ({best_acc['Accuracy']:.3f})")
    print(f"  Best AUC: {best_auc['Model']} ({best_auc['AUC']:.3f})")
    print(f"  Best LogLoss: {best_ll['Model']} ({best_ll['LogLoss']:.3f})")

print("\nTesting complete!")



Best Models per Dataset:

Synthetic:
  Best Accuracy: RandomForestClassifier (0.900)
  Best AUC: RandomForestClassifier (0.956)
  Best LogLoss: RandomForestClassifier (0.318)

Breast Cancer:
  Best Accuracy: Bartz (0.947)
  Best AUC: RandomForestClassifier (0.991)
  Best LogLoss: RandomForestRegressor (0.118)

Wine Binary:
  Best Accuracy: DefaultBART (0.963)
  Best AUC: Bartz (0.994)
  Best LogLoss: Bartz (0.117)

Testing complete!
