In [1]:
import numpy as np
import sys
from sklearn.linear_model import LinearRegression
from sklearn.tree import DecisionTreeRegressor
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
from bart_playground import *
import bartz

In [2]:
proposal_probs = {"grow" : 0.5,
                  "prune" : 0.5}
generator = DataGenerator(n_samples=360, n_features=2, noise=0.1, random_seed=42)
X, y = generator.generate(scenario="piecewise_flat")
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
np.set_printoptions(suppress=True)
print(y_train[:12])

[ 0.50114513  0.62873311  0.53828857  0.41549547  0.00699196  0.31192016
  0.0558549   0.4105654   0.478931   -0.12503155  0.06700792  0.30414086]


In [3]:
# profiler = Profiler()
# profiler.start()
bart = DefaultBART(ndpost=50, nskip=100, n_trees=100, proposal_probs=proposal_probs)
bart.fit(X_train, y_train)
# profiler.stop()
# profiler.print()

Iterations: 100%|██████████| 150/150 [00:06<00:00, 22.98it/s]


In [4]:
from tqdm import tqdm

# profiler = Profiler()
# profiler.start()
test_n = X_test.shape[0]
for i in tqdm(range(test_n - 1)):
    # use i:(i+1) to avoid dimensional reduction
    bart.update_fit(X_test[i:(i+1), :], y_test[i:(i+1)], add_ndpost=10, add_nskip=5, quietly=True)
# profiler.stop()
# profiler.print()
bart.update_fit(X_test[-1:, :], y_test[-1:], add_ndpost=40, add_nskip=20)

100%|██████████| 89/89 [00:18<00:00,  4.90it/s]
Iterations: 100%|██████████| 60/60 [00:00<00:00, 75.63it/s]


<bart_playground.bart.DefaultBART at 0x7c05c0d60610>

In [5]:
arrays = [tree.vars for tree in bart.trace[-1].trees]
counts = np.array([np.count_nonzero(arr >= 0) for arr in arrays])
print(counts)
deep_trees = np.array([count >= 3 for count in counts])
print(np.where(deep_trees))

[2 1 1 2 2 2 2 3 1 1 1 0 3 2 2 1 2 1 1 3 1 1 2 3 1 1 3 2 1 1 1 1 2 1 1 0 1
 2 1 1 0 2 1 2 1 1 2 2 2 1 3 1 1 2 2 2 1 2 1 2 1 1 3 1 1 2 0 1 1 0 1 2 1 1
 2 4 1 1 1 2 4 2 2 1 1 1 2 3 1 1 0 1 1 2 1 1 1 3 2 2]
(array([ 7, 12, 19, 23, 26, 50, 62, 75, 80, 87, 97]),)


In [6]:
print(bart.trace[-1].global_params)

{'eps_sigma2': array([0.00727223])}


In [7]:
tree_sp : Tree = bart.trace[-1].trees[72]

print(tree_sp)
print(tree_sp.vars)
print(tree_sp.leaf_vals)

X_0 <= 0.644340694 (split, n = 360)
	Val: -0.011254027 (leaf, n = 224)
	Val: 0.001365768 (leaf, n = 136)
[ 0 -1 -1 -2 -2 -2 -2 -2]
[        nan -0.01125403  0.00136577         nan         nan         nan
         nan         nan]


In [8]:
rf = RandomForestRegressor(random_state=42)
lr = LinearRegression()
rf.fit(X_train, y_train)
lr.fit(X_train, y_train)

btz = bartz.BART.gbart(np.transpose(X_train), y_train, ntree=100, ndpost=200, nskip=100)
btpred_all = btz.predict(np.transpose(X_test))
btpred = np.mean(np.array(btpred_all), axis=0)

INFO:2025-06-20 10:31:22,431:jax._src.xla_bridge:867: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory


Iteration 100/300 P_grow=0.58 P_prune=0.42 A_grow=0.41 A_prune=0.33 (burnin)
Iteration 200/300 P_grow=0.52 P_prune=0.48 A_grow=0.35 A_prune=0.31
Iteration 300/300 P_grow=0.53 P_prune=0.47 A_grow=0.26 A_prune=0.28


In [9]:
models = {"bart" : bart, 
          "rf" : rf, 
          "lr" : lr,
          "btz" : btz}
results = {}
for model_name, model in models.items():
    if model_name == "btz":
        results[model_name] = mean_squared_error(y_test, btpred)
    else:
        results[model_name] = mean_squared_error(y_test, model.predict(X_test))
results

{'bart': 0.020800551823822983,
 'rf': 0.022721420713122355,
 'lr': 0.04590326679925266,
 'btz': 0.02072978001474411}

In [10]:
print(bart.sampler.trace[-1].evaluate(X_train)[:12])
print(bart.preprocessor.transform_y(y_train)[:12])

[ 0.40655979  0.34013918  0.36989591  0.30850492 -0.02861053  0.32757756
  0.00287076  0.33968681  0.33602279  0.05486739 -0.02842319  0.34891531]
[ 0.35283538  0.44447931  0.37951478  0.29131493 -0.00210506  0.21691883
  0.03299222  0.28777375  0.33687941 -0.09693493  0.04100322  0.21133111]


In [11]:
if all([(bart.trace[-1].trees[i].evaluate()[range(X_train.shape[0]), ] == bart.trace[-1].trees[i].evaluate(X_train)).all()
            for i in range(100)]):
    print("True")
else:
    print("False")

False


In [12]:
if np.allclose(bart.trace[-1].evaluate()[range(X_train.shape[0]), ], bart.trace[-1].evaluate(X_train)):
    print("True")
else:
    print("False")

False
