In [1]:
import sys
import os
sys.path.append(os.path.dirname(os.getcwd()))

In [2]:
import numpy as np
from sklearn.linear_model import LinearRegression
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
from bart_playground.bart import DefaultBART
from bart_playground.mcbart import MultiChainBART
from bart_playground.diagnostics import compute_diagnostics
import bartz

INFO:arviz.preview:arviz_base not installed
INFO:arviz.preview:arviz_stats not installed
INFO:arviz.preview:arviz_plots not installed


In [3]:
proposal_probs = {"grow" : 0.25,
                  "prune" : 0.25,
                  "change" : 0.4,
                  "swap" : 0.1}
# generator = DataGenerator(n_samples=3000, n_features=3, noise=0.1, random_seed=42)
from mushroom import load_mushroom
X, y = load_mushroom(1000)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
np.set_printoptions(suppress=True)
print(y_train[:12])

[[1]
 [0]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [0]
 [0]
 [0]
 [1]]


In [None]:
# initialize numba and run MultiChainBART with 2 chains
mcb = MultiChainBART(
    n_ensembles=4,
    bart_class=DefaultBART,
    random_state=42,
    ndpost=1000,
    nskip=500,
    n_trees=200,
    proposal_probs=proposal_probs,
)
mcb.fit(X_train, y_train)

2025-09-11 15:08:13,318	INFO worker.py:1942 -- Started a local Ray instance. View the dashboard at [1m[32mhttp://127.0.0.1:8265 [39m[22m


Created 4 BARTActor(s) using BART class: DefaultBART


Iterations:   0%|          | 0/1500 [00:00<?, ?it/s]
Iterations:   0%|          | 1/1500 [00:01<40:55,  1.64s/it]
Iterations:   1%|          | 13/1500 [00:01<02:18, 10.71it/s]
Iterations:   0%|          | 0/1500 [00:00<?, ?it/s][32m [repeated 3x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/user-guides/configure-logging.html#log-deduplication for more options.)[0m
Iterations:  13%|█▎        | 194/1500 [00:06<00:33, 38.95it/s][32m [repeated 167x across cluster][0m
Iterations:  13%|█▎        | 190/1500 [00:06<00:33, 39.69it/s][32m [repeated 3x across cluster][0m
Iterations:  28%|██▊       | 413/1500 [00:11<00:25, 42.61it/s][32m [repeated 173x across cluster][0m
Iterations:  23%|██▎       | 352/1500 [00:10<00:29, 38.93it/s][32m [repeated 4x across cluster][0m
Iterations:  43%|████▎     | 642/1500 [00:16<00:15, 56.03it/s][32m [repeated 166x across cluster][0m
Iteration

Iterations: 100%|██████████| 1500/1500 [00:33<00:00, 44.84it/s]


In [5]:
import pandas as pd

# Compute diagnostics for MultiChainBART
diag = compute_diagnostics(mcb, key="eps_sigma2")
print({
    "n_chains": diag["n_chains"],
    "n_draws": diag["n_draws"],
    "rhat": diag["rhat"],
    "ess_bulk": diag["ess_bulk"],
    "mcse_over_sd": diag["mcse_over_sd"],
})

# Show acceptance statistics
acc_df = pd.DataFrame(diag["acceptance"]).T
acc_df

{'n_chains': 4, 'n_draws': 1000, 'rhat': 1.437292345888082, 'ess_bulk': 7.9294524021810355, 'mcse_over_sd': 0.3521534605795039}


Unnamed: 0,selected,proposed,accepted,acc_rate,prop_rate
change,480274.0,470214.0,29222.0,0.062146,0.979054
grow,300004.0,300004.0,25430.0,0.084766,1.0
prune,299676.0,293476.0,23700.0,0.080756,0.979311
swap,120046.0,68464.0,11121.0,0.162436,0.570315
overall,1200000.0,1132158.0,89473.0,0.079029,0.943465


In [6]:
# Show one actor's last global params (for curiosity)
models = mcb.collect_model_states()
print(models[0].sampler.trace[-1].global_params)

{'eps_sigma2': array([0.00006835])}


In [7]:
rf = RandomForestRegressor(random_state=42)
lr = LinearRegression()
rf.fit(X_train, y_train)
lr.fit(X_train, y_train)

btz = bartz.BART.gbart(np.transpose(X_train), y_train.flatten(), ntree=100, ndpost=200, nskip=100)
btpred_all = btz.predict(np.transpose(X_test))
btpred = np.mean(np.array(btpred_all), axis=0)

  return fit_method(estimator, *args, **kwargs)
INFO:2025-09-11 15:10:33,374:jax._src.xla_bridge:867: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory


Iteration 100/300 P_grow=0.53 P_prune=0.47 A_grow=0.06 A_prune=0.04 (burnin)
Iteration 200/300 P_grow=0.50 P_prune=0.50 A_grow=0.00 A_prune=0.00
Iteration 300/300 P_grow=0.53 P_prune=0.47 A_grow=0.06 A_prune=0.04


In [8]:
models = {"mcb" : mcb, 
          "rf" : rf, 
          "lr" : lr,
          "btz" : btz}
results = {}
for model_name, model in models.items():
    if model_name == "btz":
        results[model_name] = mean_squared_error(y_test, btpred)
    elif model_name == "mcb":
        results[model_name] = mean_squared_error(y_test, model.predict(X_test))
    else:
        results[model_name] = mean_squared_error(y_test, model.predict(X_test))
results

{'mcb': 5.731185110844544e-05,
 'rf': 0.0013759999999999998,
 'lr': 0.051489656277622635,
 'btz': 7.91522252256982e-05}