In [1]:
import sys
import os
sys.path.append(os.path.dirname(os.getcwd()))

In [2]:
import bartz

In [3]:
import numpy as np
import sys
from pyinstrument import Profiler
from sklearn.linear_model import LinearRegression
from sklearn.tree import DecisionTreeRegressor
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
from bart_playground import *
import arviz as az

INFO:arviz.preview:arviz_base not installed
INFO:arviz.preview:arviz_stats not installed
INFO:arviz.preview:arviz_plots not installed


In [4]:
proposal_probs = {"multi_grow": 0.25, "multi_prune": 0.25, "multi_change": 0.4, "multi_swap": 0.1}
#proposal_probs = {"multi_grow": 0.5, "multi_prune": 0.5}
generator = DataGenerator(n_samples=160, n_features=2, noise=0.1, random_seed=42)
X, y = generator.generate(scenario="piecewise_flat")
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)

In [5]:
bart1 = MultiBART(ndpost=50, nskip=20, n_trees=100, proposal_probs=proposal_probs, multi_tries=1)
bart1.fit(X_train, y_train)

Iterations: 100%|██████████| 70/70 [00:04<00:00, 16.82it/s]


In [6]:
bart = MultiBART(ndpost=200, nskip=100, n_trees=100, proposal_probs=proposal_probs, multi_tries=10)
%prun -s cumtime -D profile_multi.prof -q bart.fit(X_train, y_train)
!gprof2dot -f pstats profile_multi.prof -o profile_multi.dot
!dot -Tpng profile_multi.dot -o profile_multi.png

Iterations: 100%|██████████| 300/300 [00:14<00:00, 20.21it/s]

 
*** Profile stats marshalled to file 'profile_multi.prof'.





In [7]:
rf = RandomForestRegressor(random_state=42)
lr = LinearRegression()
rf.fit(X_train, y_train)
lr.fit(X_train, y_train)

btz = bartz.BART.gbart(np.transpose(X_train), y_train, ntree=100, ndpost=200, nskip=100)
btpred_all = btz.predict(np.transpose(X_test))
btpred = np.mean(np.array(btpred_all), axis=0)

INFO:2025-09-14 23:35:46,729:jax._src.xla_bridge:822: Unable to initialize backend 'tpu': UNIMPLEMENTED: LoadPjrtPlugin is not implemented on windows yet.
INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': UNIMPLEMENTED: LoadPjrtPlugin is not implemented on windows yet.


....................................................................................................
It 100/300 grow P=60% A=40%, prune P=40% A=38%, fill=6% (burnin)
....................................................................................................
It 200/300 grow P=60% A=35%, prune P=40% A=38%, fill=6%
....................................................................................................
It 300/300 grow P=55% A=25%, prune P=45% A=29%, fill=6%


In [8]:
default_proposal_probs = {"grow": 0.25, "prune": 0.25, "change": 0.4, "swap": 0.1}
bart_default = DefaultBART(ndpost=200, nskip=100, n_trees=100, proposal_probs=default_proposal_probs)
%prun -s cumtime -D profile_bart.prof -q bart_default.fit(X_train, y_train)
!gprof2dot -f pstats profile_bart.prof -o profile_bart.dot
!dot -Tpng profile_bart.dot -o profile_bart.png

Iterations: 100%|██████████| 300/300 [00:02<00:00, 118.54it/s]


 
*** Profile stats marshalled to file 'profile_bart.prof'.


In [9]:
models = {"bart" : bart, 
          "rf" : rf, 
          "lr" : lr,
          "btz" : btz,
          "bart_default" : bart_default}
results = {}
for model_name, model in models.items():
    if model_name == "btz":
        results[model_name] = mean_squared_error(y_test, btpred)
    else:
        results[model_name] = mean_squared_error(y_test, model.predict(X_test))
results

{'bart': 0.021602058730366257,
 'rf': 0.022139023845392215,
 'lr': 0.048045521328019404,
 'btz': 0.02515917299472511,
 'bart_default': 0.0211558149339335}