In [1]:
import sys
import os
sys.path.append(os.path.dirname(os.getcwd()))

In [2]:
import numpy as np
from sklearn.linear_model import LinearRegression
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
from bart_playground.bart import DefaultBART
from bart_playground.mcbart import MultiChainBART
from bart_playground.diagnostics import compute_diagnostics
import bartz

INFO:arviz.preview:arviz_base not installed
INFO:arviz.preview:arviz_stats not installed
INFO:arviz.preview:arviz_plots not installed


In [3]:
proposal_probs = {"grow" : 0.25,
                  "prune" : 0.25,
                  "change" : 0.4,
                  "swap" : 0.1}
# generator = DataGenerator(n_samples=3000, n_features=3, noise=0.1, random_seed=42)
from mushroom import load_mushroom
X, y = load_mushroom(1000, rng=np.random.default_rng(42))
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
np.set_printoptions(suppress=True)
print(y_train[:12])

[[0]
 [0]
 [1]
 [1]
 [1]
 [0]
 [0]
 [0]
 [0]
 [0]
 [1]
 [0]]


In [None]:
# initialize numba and run MultiChainBART with 2 chains
mcb = MultiChainBART(
    n_ensembles=4,
    bart_class=DefaultBART,
    random_state=42,
    ndpost=1500,
    nskip=500,
    n_trees=200,
    proposal_probs=proposal_probs,
)
mcb.fit(X_train, y_train)

2025-09-13 20:25:12,582	INFO worker.py:1942 -- Started a local Ray instance. View the dashboard at [1m[32mhttp://127.0.0.1:8265 [39m[22m


Created 4 BARTActor(s) using BART class: DefaultBART


Iterations:   0%|          | 0/2000 [00:00<?, ?it/s]
Iterations:   0%|          | 1/2000 [00:00<12:40,  2.63it/s]
Iterations:   0%|          | 7/2000 [00:00<01:53, 17.63it/s]
Iterations:   4%|▍         | 82/2000 [00:02<00:44, 43.35it/s]
Iterations:   0%|          | 0/2000 [00:00<?, ?it/s][32m [repeated 3x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/user-guides/configure-logging.html#log-deduplication for more options.)[0m
Iterations:  13%|█▎        | 251/2000 [00:05<00:29, 59.40it/s][32m [repeated 169x across cluster][0m
Iterations:  13%|█▎        | 264/2000 [00:05<00:29, 59.14it/s]
Iterations:  21%|██        | 419/2000 [00:08<00:26, 59.48it/s]
Iterations:  27%|██▋       | 536/2000 [00:10<00:41, 35.59it/s][32m [repeated 172x across cluster][0m
Iterations:  30%|███       | 605/2000 [00:11<00:24, 56.37it/s][32m [repeated 2x across cluster][0m
Iterations:  40%|███▉     

Iterations: 100%|██████████| 2000/2000 [00:51<00:00, 38.67it/s]
[36m(BARTActor pid=3347523)[0m INFO:arviz.preview:arviz_base not installed
[36m(BARTActor pid=3347523)[0m INFO:arviz.preview:arviz_stats not installed
[36m(BARTActor pid=3347523)[0m INFO:arviz.preview:arviz_plots not installed
[36m(BARTActor pid=3347536)[0m INFO:arviz.preview:arviz_base not installed[32m [repeated 3x across cluster][0m
[36m(BARTActor pid=3347536)[0m INFO:arviz.preview:arviz_stats not installed[32m [repeated 3x across cluster][0m
[36m(BARTActor pid=3347536)[0m INFO:arviz.preview:arviz_plots not installed[32m [repeated 3x across cluster][0m


In [10]:
# Compute diagnostics for MultiChainBART
diag = compute_diagnostics(mcb, key="eps_sigma2")
print({
    "n_chains": diag["n_chains"],
    "n_draws": diag["n_draws"],
    "rhat": diag["rhat"],
    "ess_bulk": diag["ess_bulk"],
    "mcse_over_sd": diag["mcse_over_sd"],
})

{'n_chains': 4, 'n_draws': 1500, 'rhat': array([1.16637192]), 'ess_bulk': array([15.91862345]), 'mcse_over_sd': array([0.26195221])}


In [11]:
import pandas as pd

# Show acceptance statistics
acc_df = pd.DataFrame(diag["acceptance"]).T
acc_df

Unnamed: 0,selected,proposed,accepted,acc_rate,prop_rate
change,640607.0,624781.0,39782.0,0.063674,0.975295
grow,399820.0,399820.0,34028.0,0.085108,1.0
prune,400048.0,390191.0,32331.0,0.082859,0.97536
swap,159525.0,88772.0,14054.0,0.158316,0.556477
overall,1600000.0,1503564.0,120195.0,0.07994,0.939728


In [6]:
diag_X = compute_diagnostics(mcb, X=X_train[1:10, :])
print(diag_X)

{'n_chains': 4, 'n_draws': 1500, 'rhat': array([1.0320045 , 1.08132425, 1.10381339, 1.09144862, 1.0694225 ,
       1.06234806, 1.13254428, 1.05725985, 1.08361328]), 'ess_bulk': array([73.93899791, 41.10839446, 25.50458238, 40.67654167, 60.82698775,
       66.26420919, 24.94915777, 59.07153281, 42.65487627]), 'mcse_mean': array([0.00042219, 0.00074193, 0.00102948, 0.00050698, 0.00024395,
       0.00038115, 0.00059923, 0.00032602, 0.00046678]), 'mcse_over_sd': array([0.11624268, 0.1619003 , 0.19777783, 0.15659377, 0.12864192,
       0.12850829, 0.20020352, 0.13351641, 0.15156721]), 'acceptance': {'change': {'selected': 640607.0, 'proposed': 624781.0, 'accepted': 39782.0, 'acc_rate': 0.06367351119832389, 'prop_rate': 0.97529530585835}, 'grow': {'selected': 399820.0, 'proposed': 399820.0, 'accepted': 34028.0, 'acc_rate': 0.08510829873443049, 'prop_rate': 1.0}, 'prune': {'selected': 400048.0, 'proposed': 390191.0, 'accepted': 32331.0, 'acc_rate': 0.08285942013014139, 'prop_rate': 0.97536045

In [7]:
# Show one actor's last global params (for curiosity)
print(mcb.collect(lambda x: x.sampler.trace[-1].global_params)[0])

{'eps_sigma2': array([0.00007939])}


In [8]:
rf = RandomForestRegressor(random_state=42)
lr = LinearRegression()
rf.fit(X_train, y_train)
lr.fit(X_train, y_train)

btz = bartz.BART.gbart(np.transpose(X_train), y_train.flatten(), ntree=100, ndpost=200, nskip=100)
btpred_all = btz.predict(np.transpose(X_test))
btpred = np.mean(np.array(btpred_all), axis=0)

  return fit_method(estimator, *args, **kwargs)
INFO:2025-09-13 20:26:28,431:jax._src.xla_bridge:867: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory


Iteration 100/300 P_grow=0.53 P_prune=0.47 A_grow=0.04 A_prune=0.02 (burnin)
Iteration 200/300 P_grow=0.51 P_prune=0.49 A_grow=0.02 A_prune=0.02
Iteration 300/300 P_grow=0.50 P_prune=0.50 A_grow=0.00 A_prune=0.00


In [9]:
models = {"mcb" : mcb, 
          "rf" : rf, 
          "lr" : lr,
          "btz" : btz}
results = {}
for model_name, model in models.items():
    if model_name == "btz":
        results[model_name] = mean_squared_error(y_test, btpred)
    elif model_name == "mcb":
        results[model_name] = mean_squared_error(y_test, model.predict(X_test))
    else:
        results[model_name] = mean_squared_error(y_test, model.predict(X_test))
results

{'mcb': 0.00013731324753101315,
 'rf': 0.0018815999999999998,
 'lr': 0.05785262134863105,
 'btz': 4.5180382585385814e-05}