In [1]:
import sys
import os
sys.path.append(os.path.dirname(os.getcwd()))

In [2]:
import numpy as np
from sklearn.linear_model import LinearRegression
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
from bart_playground.DataGenerator import DataGenerator
from bart_playground.params import Tree
from bart_playground.bart import DefaultBART
import bartz

In [3]:
proposal_probs = {"grow" : 0.4,
                  "prune" : 0.4,
                  "change" : 0.1,
                  "swap" : 0.1}
generator = DataGenerator(n_samples=1000, n_features=2, noise=0.1, random_seed=42)
X, y = generator.generate(scenario="piecewise_flat")
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
np.set_printoptions(suppress=True)
print(y_train[:12])

[ 0.59965549  0.66000473  0.60195109  0.59759841 -0.48024437 -0.05465378
 -0.09076175 -0.06775115  0.44628062  0.51643207 -0.17564767 -0.18611937]


In [4]:
# initialize numba
bart = DefaultBART(ndpost=100, nskip=0, n_trees=100, proposal_probs=proposal_probs)
bart.fit(X_train, y_train)

Iterations: 100%|██████████| 100/100 [00:07<00:00, 14.27it/s]


In [5]:
import pandas as pd

# Collect move counts
selected = bart.sampler.move_selected_counts
success = bart.sampler.move_success_counts
accepted = bart.sampler.move_accepted_counts

# Combine into a DataFrame for easy viewing
df = pd.DataFrame({
    "selected": pd.Series(selected),
    "success": pd.Series(success),
    "accepted": pd.Series(accepted)
})

# Add success, acceptance and change rates
df["success_rate"] = df["success"] / df["selected"]
df["accept_rate"] = df["accepted"] / df["success"] # The "acceptance rate" reflects the MH acceptance probability
df["change_rate"] = df["accepted"] / df["selected"] # The "change rate" reflects the frequency of tree changes

print(df)

        selected  success  accepted  success_rate  accept_rate  change_rate
grow        4044     4044       558      1.000000     0.137982     0.137982
prune       4030     3330       361      0.826303     0.108408     0.089578
change       961      797        77      0.829344     0.096612     0.080125
swap         965      540        41      0.559585     0.075926     0.042487


In [6]:
bart2 = DefaultBART(ndpost=400, nskip=100, n_trees=100, proposal_probs=proposal_probs)
%prun -s cumtime -D profile_bart.prof -q bart2.fit(X_train, y_train)
!gprof2dot -f pstats profile_bart.prof -o profile_bart.dot
!dot -Tpng profile_bart.dot -o profile_bart.png

Iterations: 100%|██████████| 500/500 [00:41<00:00, 12.14it/s]


 
*** Profile stats marshalled to file 'profile_bart.prof'.


In [7]:
arrays = [tree.vars for tree in bart.trace[-1].trees]
counts = np.array([np.count_nonzero(arr >= 0) for arr in arrays])
print(counts)
deep_trees = np.array([count >= 3 for count in counts])
print(np.where(deep_trees))

[2 3 3 2 2 3 2 1 0 2 4 0 4 1 2 1 2 0 0 4 2 2 1 2 1 0 0 3 5 5 0 3 4 0 0 2 2
 1 3 2 3 3 0 5 2 2 2 2 0 2 3 1 2 2 3 2 3 5 2 0 3 2 5 1 0 2 2 0 3 2 2 2 3 1
 1 2 2 0 0 3 3 2 0 1 2 0 3 2 3 0 3 2 2 2 3 3 2 2 2 4]
(array([ 1,  2,  5, 10, 12, 19, 27, 28, 29, 31, 32, 38, 40, 41, 43, 50, 54,
       56, 57, 60, 62, 68, 72, 79, 80, 86, 88, 90, 94, 95, 99]),)


In [8]:
np.mean([len(arr) for arr in arrays])

np.float64(12.16)

In [9]:
print(bart.sampler.trace[-1].global_params)

{'eps_sigma2': array([0.07568119])}


In [10]:
from bart_playground import visualize_tree
tree_sp : Tree = bart.sampler.trace[-1].trees[50]

print(tree_sp)
print(tree_sp.vars)
print(tree_sp.leaf_vals)

X_1 <= 0.102725416 (split, n = 750)
	Val: -0.028377509 (leaf, n = 60)
	X_0 <= 0.102688663 (split, n = 690)
		Val: 0.072844885 (leaf, n = 69)
		X_0 <= 0.131728828 (split, n = 621)
			Val: -0.027562218 (leaf, n = 25)
			Val: 0.216302931 (leaf, n = 596)
[ 1 -1  0 -2 -2 -1  0 -2 -2 -2 -2 -2 -2 -1 -1 -2]
[        nan -0.02837751         nan         nan         nan  0.07284489
         nan         nan         nan         nan         nan         nan
         nan -0.02756222  0.21630293         nan]


In [11]:
rf = RandomForestRegressor(random_state=42)
lr = LinearRegression()
rf.fit(X_train, y_train)
lr.fit(X_train, y_train)

btz = bartz.BART.gbart(np.transpose(X_train), y_train, ntree=100, ndpost=200, nskip=100)
btpred_all = btz.predict(np.transpose(X_test))
btpred = np.mean(np.array(btpred_all), axis=0)

INFO:2025-09-24 19:13:54,568:jax._src.xla_bridge:752: Unable to initialize backend 'tpu': UNIMPLEMENTED: LoadPjrtPlugin is not implemented on windows yet.
INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': UNIMPLEMENTED: LoadPjrtPlugin is not implemented on windows yet.


....................................................................................................
It 100/300 grow P=59% A=24%, prune P=41% A=39%, fill=7% (burnin)
....................................................................................................
It 200/300 grow P=58% A=29%, prune P=42% A=31%, fill=7%
....................................................................................................
It 300/300 grow P=54% A=35%, prune P=46% A=30%, fill=7%


In [12]:
models = {"bart" : bart, 
          "rf" : rf, 
          "lr" : lr,
          "btz" : btz}
results = {}
for model_name, model in models.items():
    if model_name == "btz":
        results[model_name] = mean_squared_error(y_test, btpred)
    else:
        results[model_name] = mean_squared_error(y_test, model.predict(X_test))
results

{'bart': 2762.2057622077195,
 'rf': 0.01570244066836391,
 'lr': 0.04058944845865094,
 'btz': 0.015576111635784784}

In [13]:
print(bart.sampler.trace[-1].evaluate(X_train)[:12])
print(bart.preprocessor.transform_y(y_train)[:12])

[49.32259  42.338333 53.702686 36.95055  11.455135 36.798134 28.875034
 32.902878 43.501865 48.10456  32.40367  38.513184]
[ 0.36207604  0.4024236   0.3636108   0.36070073 -0.35991025 -0.07537408
 -0.09951471 -0.08413056  0.25953454  0.30643553 -0.1562667  -0.16326775]


In [14]:
mean_squared_error(y_test, np.ones_like(y_test) * y_test.mean())

0.08625514042334204

In [15]:
if all([(bart.sampler.trace[-1].trees[i].evaluate() == bart.sampler.trace[-1].trees[i].evaluate(X_train)).all()
            for i in range(100)]):
    print("True")
else:
    print("False")

True


In [16]:
if np.allclose(bart.trace[-1].evaluate(), bart.trace[-1].evaluate(X_train), atol=1e-6):
    print("True")
else:
    print("False")

False


In [17]:
total_output = np.zeros(X_train.shape[0])
for i in range(bart.trace[-1].n_trees):
    total_output += bart.trace[-1].trees[i].evaluate(X_train)  # Add the tree's output to the total

In [18]:
total_output

array([49.32259324, 42.33833756, 53.70268771, 36.95056385, 11.45513523,
       36.79813394, 28.8750402 , 32.9028855 , 43.50186713, 48.10456579,
       32.40367432, 38.5131855 , 17.54897558, 48.95436662, 41.03805939,
       41.74073364, 18.9389603 , 38.81797611, 47.12058791, 10.2887535 ,
       53.70268771, 51.72792299, 53.83663927, 53.70268771, 44.57669884,
       53.70268771, 53.83663927, 53.43361105, 18.04786699,  9.00180127,
       53.22589366, 53.70268771, 50.90326113, 52.14013481,  8.75909105,
       44.57669884, 53.70268771, 43.42691373, 13.69933312, 37.30730517,
       36.35301114, 50.68538438, 15.52623656, 53.83663927, 10.63032992,
       36.75513431, 43.91703002, 53.70268771, 34.71284282, 48.43777646,
       53.70268771, 49.63903639, 53.2157343 , 40.13541559, 53.70268771,
        8.57513858, 41.19445306, 11.6098658 , 53.45136593, 46.66114728,
       43.76282653, 53.70268771, 39.63592641, 42.95714798, 24.57726281,
       33.24387073, 52.38964032, 46.07070429, 41.26839058, 53.83