## Test DefaultBART, MultiChainBART for Regression

In [2]:
import numpy as np
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
import pandas as pd
from bart_playground import *
from bart_playground.bart import DefaultBART
from bart_playground.mcbart import MultiChainBART
from xgboost import XGBRegressor
import bartz

In [3]:
# Parameters
N_TREES = 50
NDPOST = 1000
NSKIP = 200
RANDOM_STATE = 42

# If debug then run with only one dataset and record running time
# Otherwise run with all datasets
debug = True

In [4]:
# Load datasets
def load_datasets():
    # Friedman regression dataset
    generator = DataGenerator(n_samples=1000, n_features=5, noise=1, random_seed=RANDOM_STATE)
    X_friedman, y_friedman = generator.generate(scenario="friedman1")

    return {
        "Friedman1": (X_friedman, y_friedman)
    }

## Utility functions to evaluate model performance

In [5]:
import time


def evaluate_model(model, model_name, X_train, X_test, y_train, y_test, dataset_name):
    """Evaluate a single model and return regression metrics"""
    time_start = time.time()
    if model_name == "Bartz":
        # Bartz regression
        fit_result = bartz.BART.gbart(
            x_train=X_train.T, y_train=y_train.astype(float),
            x_test=X_test.T,
            ntree=N_TREES, ndpost=NDPOST, nskip=NSKIP,
            seed=RANDOM_STATE,
            printevery=NDPOST + NSKIP + 100
        )
        btpred_all = fit_result.predict(np.transpose(X_test))
        y_pred = np.mean(np.array(btpred_all), axis=0)
    else:
        # Standard regression models
        model.fit(X_train, y_train)
        y_pred = model.predict(X_test)
    time_end = time.time()
    
    # Calculate regression metrics
    mse = mean_squared_error(y_test, y_pred)
    mae = mean_absolute_error(y_test, y_pred)
    r2 = r2_score(y_test, y_pred)

    return {'MSE': mse, 'MAE': mae, 'R2': r2, 'Time': time_end - time_start}

In [6]:
metrics = None
results = []
 
def record_evaluation_results(dataset_name, X, y):
    global metrics
    
    print(f"\n=== Testing on {dataset_name} ====")
    
    # Split data (no stratification needed for regression)
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.3, random_state=RANDOM_STATE
    )
    
    # Define regression models only
    models = {
        "XGBRegressor": XGBRegressor(n_estimators=N_TREES, random_state=RANDOM_STATE),
        "RFRegressor": RandomForestRegressor(n_estimators=N_TREES, random_state=RANDOM_STATE),
        "MultiChainBART": MultiChainBART(n_ensembles=4, bart_class = DefaultBART, random_state=RANDOM_STATE, n_trees=N_TREES, ndpost=NDPOST//4, nskip=NSKIP//4),
        "Bartz": "placeholder",
        "DefaultBART": DefaultBART(n_trees=N_TREES, ndpost=NDPOST, nskip=NSKIP, random_state=RANDOM_STATE)
    }
    
    for model_name, model in models.items():
        print(f"  Training {model_name}...")
        
        X_tr, X_te = X_train, X_test
        
        metrics = evaluate_model(model, model_name, X_tr, X_te, y_train, y_test, dataset_name)
        
        result = {'Dataset': dataset_name, 'Model': model_name, **metrics}
        results.append(result)
        print(f"    MSE: {metrics['MSE']:.3f}, MAE: {metrics['MAE']:.3f}, "
              f"R2: {metrics['R2']:.4f}, Time: {metrics['Time']:.2f}s")

## Dataset Loading

In [7]:
old_settings = np.seterr(invalid='raise')

datasets = load_datasets()

In [8]:
for name, (X, y) in datasets.items():
    # Print dataset shapes
    print(f"Dataset: {name}\nX shape: {X.shape}, y shape: {y.shape}")
    # Print target variable statistics for regression
    print(f"y statistics:")
    print(f"  Mean: {y.mean():.3f}, Std: {y.std():.3f}")
    print(f"  Min: {y.min():.3f}, Max: {y.max():.3f}")
    print()

Dataset: Friedman1
X shape: (1000, 5), y shape: (1000,)
y statistics:
  Mean: 14.342, Std: 5.099
  Min: 0.354, Max: 28.031



## Experiments

In [9]:
if not debug:
    for dataset_name, (X, y) in list(datasets.items()):
        record_evaluation_results(dataset_name, X, y)

In [9]:
if debug == True:
    dataset_name, (X, y) = list(datasets.items())[-1]  # Last dataset for debugging
    
    profile = False
    if not profile:
        record_evaluation_results(dataset_name, X, y)
    else:
        %prun -s cumtime -D temp_profile.prof -q record_evaluation_results(dataset_name, X, y)

        fname = "profile_mcbart"

        !mv temp_profile.prof {fname}.prof
        !gprof2dot -f pstats {fname}.prof -o {fname}.dot
        !dot -Tpng {fname}.dot -o {fname}.png


=== Testing on Friedman1 ====


2025-06-13 14:48:39,021	INFO worker.py:1908 -- Started a local Ray instance. View the dashboard at [1m[32mhttp://127.0.0.1:8265 [39m[22m


Created 4 BARTActor(s).
  Training XGBRegressor...
    MSE: 2.663, MAE: 1.337, R2: 0.8878, Time: 2.11s
  Training RFRegressor...
    MSE: 3.356, MAE: 1.469, R2: 0.8586, Time: 0.14s
  Training MultiChainBART...
    MSE: 2.663, MAE: 1.337, R2: 0.8878, Time: 2.11s
  Training RFRegressor...
    MSE: 3.356, MAE: 1.469, R2: 0.8586, Time: 0.14s
  Training MultiChainBART...


Iterations:   0%|          | 0/300 [00:00<?, ?it/s]
Iterations:   0%|          | 1/300 [00:05<25:13,  5.06s/it]
Iterations:   0%|          | 0/300 [00:00<?, ?it/s][32m [repeated 3x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/user-guides/configure-logging.html#log-deduplication for more options.)[0m
Iterations:   0%|          | 1/300 [00:05<25:13,  5.06s/it]
Iterations:   0%|          | 0/300 [00:00<?, ?it/s][32m [repeated 3x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/user-guides/configure-logging.html#log-deduplication for more options.)[0m
Iterations:  94%|█████████▎| 281/300 [00:09<00:00, 88.92it/s]
Iterations:  94%|█████████▎| 281/300 [00:09<00:00, 88.92it/s]
Iterations: 100%|██████████| 300/300 [00:09<00:00, 32.10it/s]
Iterations: 100%|██████████| 300/300

    MSE: 1.251, MAE: 0.916, R2: 0.9473, Time: 10.76s
  Training Bartz...
    MSE: 1.792, MAE: 1.062, R2: 0.9245, Time: 8.03s
  Training DefaultBART...
    MSE: 1.792, MAE: 1.062, R2: 0.9245, Time: 8.03s
  Training DefaultBART...


Iterations: 100%|██████████| 1200/1200 [00:18<00:00, 64.18it/s] 



    MSE: 1.799, MAE: 1.058, R2: 0.9242, Time: 20.49s


In [10]:
# Display results
results_df = pd.DataFrame(results)
print("\n" + "="*60)
print("SUMMARY RESULTS")
print("="*60)

# Pivot tables for easy comparison
for metric in ['MSE', 'MAE', 'R2']:
    print(f"\n{metric}:")
    pivot = results_df.pivot_table(index='Dataset', columns='Model', values=metric)
    print(pivot.round(3))


SUMMARY RESULTS

MSE:
Model      Bartz  DefaultBART  MultiChainBART  RFRegressor  XGBRegressor
Dataset                                                                 
Friedman1  1.792        1.799           1.251        3.356         2.663

MAE:
Model      Bartz  DefaultBART  MultiChainBART  RFRegressor  XGBRegressor
Dataset                                                                 
Friedman1  1.062        1.058           0.916        1.469         1.337

R2:
Model      Bartz  DefaultBART  MultiChainBART  RFRegressor  XGBRegressor
Dataset                                                                 
Friedman1  0.924        0.924           0.947        0.859         0.888


## Test MCBART update_fit

In [8]:
# Test MCBART update_fit with one-by-one updates
print("Testing MCBART update_fit with incremental samples...")

# Use the Friedman dataset
X, y = datasets["Friedman1"]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=RANDOM_STATE)

# Initialize MCBART with initial training data (first 10% of training set)
initial_size = int(0.1 * len(X_train))
X_initial = X_train[:initial_size]
y_initial = y_train[:initial_size]
X_update = X_train[initial_size:]
y_update = y_train[initial_size:]

# record total time
import time
start_time = time.time()
# Fit MCBART initially
mcbart_update = MultiChainBART(n_ensembles=4, bart_class=DefaultBART, 
                               random_state=RANDOM_STATE, 
                               n_trees=N_TREES, ndpost=NDPOST, nskip=NSKIP)
mcbart_update.fit(X_initial, y_initial, quietly=True)

# Update one sample at a time
for i in range(len(X_update)):
    mcbart_update.update_fit(X_update[i:i+1], y_update[i:i+1], add_ndpost=3, quietly=True)
end_time = time.time()
print(f"Time taken for incremental updates: {end_time - start_time:.2f} seconds")

# Evaluate final model
y_pred_update = mcbart_update.predict(X_test)
mse_update = mean_squared_error(y_test, y_pred_update)
mae_update = mean_absolute_error(y_test, y_pred_update)
r2_update = r2_score(y_test, y_pred_update)

print(f"MCBART update_fit results:")
print(f"  MSE: {mse_update:.3f}, MAE: {mae_update:.3f}, R2: {r2_update:.4f}")

# mcbart_update.shutdown()

Testing MCBART update_fit with incremental samples...


2025-06-13 15:30:08,877	INFO worker.py:1908 -- Started a local Ray instance. View the dashboard at [1m[32mhttp://127.0.0.1:8266 [39m[22m


Created 4 BARTActor(s).
Time taken for incremental updates: 50.97 seconds
MCBART update_fit results:
  MSE: 1.316, MAE: 0.917, R2: 0.9445
Time taken for incremental updates: 50.97 seconds
MCBART update_fit results:
  MSE: 1.316, MAE: 0.917, R2: 0.9445


In [9]:
def uniform_schedule(i): 
    return 1.0/mcbart_update._trace_length
posterior_sample = mcbart_update.posterior_sample(X_test, schedule=uniform_schedule)
print(f"Posterior sample shape: {posterior_sample.shape}")

Posterior sample shape: (300,)


In [10]:
mcbart_update.shutdown()

Ray has been shut down.


In [11]:
# compare with sequential update_fit
from tqdm import tqdm

bart_seq = DefaultBART(random_state=RANDOM_STATE, 
                               n_trees=N_TREES, ndpost=NDPOST, nskip=NSKIP)
start_time_seq = time.time()
bart_seq.fit(X_initial, y_initial, quietly=False)
for i in tqdm(range(len(X_update))):
    bart_seq.update_fit(X_update[i:i+1], y_update[i:i+1], add_ndpost=5, quietly=True)
end_time_seq = time.time()
y_pred_seq = bart_seq.predict(X_test)
mse_seq = mean_squared_error(y_test, y_pred_seq)
mae_seq = mean_absolute_error(y_test, y_pred_seq)
r2_seq = r2_score(y_test, y_pred_seq)
print(f"BART sequential update_fit results:")
print(f"  MSE: {mse_seq:.3f}, MAE: {mae_seq:.3f}, R2: {r2_seq:.4f}")
print(f"Time taken for sequential updates: {end_time_seq - start_time_seq:.2f} seconds")

Iterations: 100%|██████████| 1200/1200 [00:14<00:00, 83.25it/s] 
Iterations: 100%|██████████| 1200/1200 [00:14<00:00, 83.25it/s] 
100%|██████████| 630/630 [00:31<00:00, 19.96it/s]

BART sequential update_fit results:
  MSE: 1.968, MAE: 1.128, R2: 0.9170
Time taken for sequential updates: 45.98 seconds





## Test Logistic MCBART

In [11]:
# By Copilot

# Test MCBART with LogisticBART for classification
import numpy as np
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from xgboost import XGBClassifier
from bart_playground.bart import LogisticBART
from bart_playground.mcbart import MultiChainBART
import time

def load_classification_datasets():
    """Load classification datasets for testing"""
    # Simple binary classification
    X_binary, y_binary = make_classification(
        n_samples=1000, n_features=10, n_informative=8, n_redundant=2,
        n_classes=2, random_state=RANDOM_STATE
    )
    
    # Multiclass classification
    X_multi, y_multi = make_classification(
        n_samples=1000, n_features=10, n_informative=8, n_redundant=2,
        n_classes=3, random_state=RANDOM_STATE
    )
    
    return {
        "Binary": (X_binary, y_binary),
        "Multiclass": (X_multi, y_multi)
    }

def evaluate_classification_model(model, model_name, X_train, X_test, y_train, y_test, dataset_name):
    """Evaluate a classification model and return metrics"""
    time_start = time.time()
    
    # Fit the model
    model.fit(X_train, y_train, quietly=True)
    
    # Make predictions
    y_pred = model.predict(X_test)
    
    # Get probabilities if available
    if hasattr(model, 'predict_proba'):
        y_proba = model.predict_proba(X_test)
    else:
        y_proba = None
    
    time_end = time.time()
    y_pred = y_pred.astype(int)
    # Calculate metrics
    accuracy = accuracy_score(y_test, y_pred)
    
    return {
        'Accuracy': accuracy,
        'Time': time_end - time_start,
        'y_pred': y_pred,
        'y_proba': y_proba
    }

def test_logistic_mcbart():
    """Test MultiChainBART with LogisticBART for classification"""
    print("\n" + "="*60)
    print("TESTING MCBART WITH LOGISTICBART")
    print("="*60)
    
    # Load classification datasets
    datasets = load_classification_datasets()
    classification_results = []
    
    for dataset_name, (X, y) in datasets.items():
        print(f"\n=== Testing on {dataset_name} Classification ====")
        print(f"Dataset shape: X={X.shape}, y={y.shape}")
        print(f"Classes: {np.unique(y)}")
        
        # Split data
        X_train, X_test, y_train, y_test = train_test_split(
            X, y, test_size=0.3, random_state=RANDOM_STATE, stratify=y
        )
        
        # Define models for comparison
        n_trees_logistic = 25  # LogisticBART typically uses fewer trees
        ndpost_logistic = 500
        nskip_logistic = 100
        
        models = {
            "LogisticRegression": LogisticRegression(random_state=RANDOM_STATE, max_iter=1000),
            "RandomForest": RandomForestClassifier(n_estimators=n_trees_logistic, random_state=RANDOM_STATE),
            "XGBClassifier": XGBClassifier(n_estimators=n_trees_logistic, random_state=RANDOM_STATE),
            "LogisticBART": LogisticBART(
                n_trees=n_trees_logistic, 
                ndpost=ndpost_logistic, 
                nskip=nskip_logistic, 
                random_state=RANDOM_STATE
            ),
            "MCLogisticBART": MultiChainBART(
                n_ensembles=3,
                bart_class=LogisticBART,
                random_state=RANDOM_STATE,
                n_trees=n_trees_logistic,
                ndpost=ndpost_logistic//2,
                nskip=nskip_logistic//2
            ),
        }
        
        for model_name, model in models.items():
            print(f"  Training {model_name}...")
            
            if model_name in ["LogisticBART", "MCLogisticBART"]:
                metrics = evaluate_classification_model(
                    model, model_name, X_train, X_test, y_train, y_test, dataset_name
                )
            else:
                # Standard sklearn models
                time_start = time.time()
                model.fit(X_train, y_train)
                y_pred = model.predict(X_test)
                y_proba = model.predict_proba(X_test) if hasattr(model, 'predict_proba') else None
                time_end = time.time()
                
                metrics = {
                    'Accuracy': accuracy_score(y_test, y_pred),
                    'Time': time_end - time_start,
                    'y_pred': y_pred,
                    'y_proba': y_proba
                }
            
            result = {
                'Dataset': dataset_name, 
                'Model': model_name, 
                'Accuracy': metrics['Accuracy'],
                'Time': metrics['Time']
            }
            classification_results.append(result)
            
            print(f"    Accuracy: {metrics['Accuracy']:.4f}, Time: {metrics['Time']:.2f}s")
        
            # Shutdown MultiChainBART if needed
            if model_name == "MCLogisticBART" and hasattr(model, 'shutdown'):
                model.shutdown()
    
    # Display results
    results_df = pd.DataFrame(classification_results)
    print("\n" + "="*60)
    print("CLASSIFICATION SUMMARY RESULTS")
    print("="*60)
    
    pivot = results_df.pivot_table(index='Dataset', columns='Model', values='Accuracy')
    print("\nAccuracy:")
    print(pivot.round(4))
    
    pivot_time = results_df.pivot_table(index='Dataset', columns='Model', values='Time')
    print("\nTime (seconds):")
    print(pivot_time.round(2))

def test_logistic_mcbart_update_fit():
    """Test LogisticBART update_fit functionality"""
    print("\n" + "="*60)
    print("TESTING LOGISTIC MCBART UPDATE_FIT")
    print("="*60)
    
    # Use binary classification dataset
    X, y = make_classification(
        n_samples=500, n_features=8, n_informative=6, n_redundant=2,
        n_classes=2, random_state=RANDOM_STATE
    )
    
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.3, random_state=RANDOM_STATE, stratify=y
    )
    
    # Initialize with small subset
    initial_size = int(0.2 * len(X_train))
    X_initial = X_train[:initial_size]
    y_initial = y_train[:initial_size]
    X_update = X_train[initial_size:]
    y_update = y_train[initial_size:]
    
    print(f"Initial training size: {len(X_initial)}")
    print(f"Update size: {len(X_update)}")
    print(f"Test size: {len(X_test)}")
    
    # Test MultiChain LogisticBART update
    start_time = time.time()
    
    mcbart_logistic = MultiChainBART(
        n_ensembles=2, 
        bart_class=LogisticBART,
        random_state=RANDOM_STATE,
        n_trees=15,
        ndpost=200,
        nskip=50
    )
    
    # Initial fit
    mcbart_logistic.fit(X_initial, y_initial, quietly=True)
    initial_accuracy = accuracy_score(y_test, mcbart_logistic.predict(X_test).astype(int))
    
    # Update in batches
    batch_size = 10
    update_accuracies = []
    
    for i in range(0, len(X_update), batch_size):
        end_idx = min(i + batch_size, len(X_update))
        X_batch = X_update[i:end_idx]
        y_batch = y_update[i:end_idx]
        
        mcbart_logistic.update_fit(X_batch, y_batch, add_ndpost=5, quietly=True)
        
        # Evaluate after update
        y_pred = mcbart_logistic.predict(X_test).astype(int)
        accuracy = accuracy_score(y_test, y_pred)
        update_accuracies.append(accuracy)
    
    end_time = time.time()
    
    print(f"\nUpdate Results:")
    print(f"Initial accuracy: {initial_accuracy:.4f}")
    print(f"Final accuracy: {update_accuracies[-1]:.4f}")
    print(f"Accuracy improvement: {update_accuracies[-1] - initial_accuracy:.4f}")
    print(f"Total update time: {end_time - start_time:.2f} seconds")
    
    # Test posterior sampling
    print(f"\nTesting posterior sampling...")
    
    def uniform_schedule(i):
        return 1.0 / mcbart_logistic._trace_length
    
    # Get posterior samples (probabilities)
    posterior_proba_sample = mcbart_logistic.posterior_sample(X_test[:5], schedule=uniform_schedule)
    print(f"Posterior probability sample shape: {posterior_proba_sample.shape}")
    
    # Test posterior prediction probabilities
    posterior_proba = mcbart_logistic.posterior_f(X_test[:5])
    print(f"Posterior f(x) shape: {posterior_proba.shape}")
    
    mcbart_logistic.shutdown()

# Run the tests
test_logistic_mcbart()
test_logistic_mcbart_update_fit()


TESTING MCBART WITH LOGISTICBART

=== Testing on Binary Classification ====
Dataset shape: X=(1000, 10), y=(1000,)
Classes: [0 1]


2025-06-13 16:31:34,426	INFO worker.py:1908 -- Started a local Ray instance. View the dashboard at [1m[32mhttp://127.0.0.1:8266 [39m[22m


Created 3 BARTActor(s) using BART class: LogisticBART
  Training LogisticRegression...
    Accuracy: 0.6600, Time: 0.06s
  Training RandomForest...
    Accuracy: 0.8267, Time: 0.08s
  Training XGBClassifier...
    Accuracy: 0.8900, Time: 0.76s
  Training LogisticBART...
    Accuracy: 0.8300, Time: 14.25s
  Training MCLogisticBART...
    Accuracy: 0.8267, Time: 7.39s
Ray has been shut down.

=== Testing on Multiclass Classification ====
Dataset shape: X=(1000, 10), y=(1000,)
Classes: [0 1 2]


2025-06-13 16:32:01,497	INFO worker.py:1908 -- Started a local Ray instance. View the dashboard at [1m[32mhttp://127.0.0.1:8266 [39m[22m


Created 3 BARTActor(s) using BART class: LogisticBART
  Training LogisticRegression...
    Accuracy: 0.7267, Time: 0.16s
  Training RandomForest...
    Accuracy: 0.8767, Time: 0.11s
  Training XGBClassifier...
    Accuracy: 0.8867, Time: 0.05s
  Training LogisticBART...
    Accuracy: 0.8467, Time: 11.67s
  Training MCLogisticBART...
    Accuracy: 0.8033, Time: 7.88s
Ray has been shut down.

CLASSIFICATION SUMMARY RESULTS

Accuracy:
Model       LogisticBART  LogisticRegression  MCLogisticBART  RandomForest  \
Dataset                                                                      
Binary            0.8300              0.6600          0.8267        0.8267   
Multiclass        0.8467              0.7267          0.8033        0.8767   

Model       XGBClassifier  
Dataset                    
Binary             0.8900  
Multiclass         0.8867  

Time (seconds):
Model       LogisticBART  LogisticRegression  MCLogisticBART  RandomForest  \
Dataset                                     

2025-06-13 16:32:25,487	INFO worker.py:1908 -- Started a local Ray instance. View the dashboard at [1m[32mhttp://127.0.0.1:8266 [39m[22m


Created 2 BARTActor(s) using BART class: LogisticBART

Update Results:
Initial accuracy: 0.7867
Final accuracy: 0.8333
Accuracy improvement: 0.0467
Total update time: 8.53 seconds

Testing posterior sampling...
Posterior probability sample shape: (5, 2)
Posterior f(x) shape: (5, 10, 2)
Ray has been shut down.
