In [1]:
import sys
import os
sys.path.append(os.path.dirname(os.getcwd()))

In [2]:
import bartz

In [3]:
import numpy as np
import sys
from pyinstrument import Profiler
from sklearn.linear_model import LinearRegression
from sklearn.tree import DecisionTreeRegressor
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
from bart_playground import *
import arviz as az

INFO:arviz.preview:arviz_base not installed
INFO:arviz.preview:arviz_stats not installed
INFO:arviz.preview:arviz_plots not installed


In [4]:
from sklearn.datasets import fetch_california_housing

data = fetch_california_housing(as_frame=True)

X = data.data
y = data.target

X = X.values.astype(float)
y = np.array(y).reshape(-1)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)

In [5]:
proposal_probs = {"multi_grow": 0.25, "multi_prune": 0.25, "multi_change": 0.4, "multi_swap": 0.1}
#proposal_probs = {"multi_grow": 0.5, "multi_prune": 0.5}
bart1 = MultiBART(ndpost=50, nskip=20, n_trees=100, proposal_probs=proposal_probs, multi_tries=1)
bart1.fit(X_train, y_train)

Iterations: 100%|██████████| 70/70 [00:06<00:00, 11.48it/s]


In [6]:
bart = MultiBART(ndpost=200, nskip=100, n_trees=100, proposal_probs=proposal_probs, multi_tries=10)
%prun -s cumtime -D Cal_profile_multi.prof -q bart.fit(X_train, y_train)
!gprof2dot -f pstats Cal_profile_multi.prof -o Cal_profile_multi.dot
!dot -Tpng Cal_profile_multi.dot -o Cal_profile_multi.png

Iterations: 100%|██████████| 300/300 [01:23<00:00,  3.57it/s]

 
*** Profile stats marshalled to file 'Cal_profile_multi.prof'.





In [7]:
rf = RandomForestRegressor(random_state=42)
lr = LinearRegression()
rf.fit(X_train, y_train)
lr.fit(X_train, y_train)

btz = bartz.BART.gbart(np.transpose(X_train), y_train, ntree=100, ndpost=200, nskip=100)
btpred_all = btz.predict(np.transpose(X_test))
btpred = np.mean(np.array(btpred_all), axis=0)

INFO:2025-08-18 17:11:23,564:jax._src.xla_bridge:752: Unable to initialize backend 'tpu': UNIMPLEMENTED: LoadPjrtPlugin is not implemented on windows yet.
INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': UNIMPLEMENTED: LoadPjrtPlugin is not implemented on windows yet.


....................................................................................................
It 100/300 grow P=58% A=9%, prune P=42% A=12%, fill=8% (burnin)
....................................................................................................
It 200/300 grow P=57% A=12%, prune P=43% A=12%, fill=9%
....................................................................................................
It 300/300 grow P=53% A=15%, prune P=47% A=11%, fill=10%


In [8]:
default_proposal_probs = {"grow": 0.25, "prune": 0.25, "change": 0.4, "swap": 0.1}
bart_default = DefaultBART(ndpost=200, nskip=100, n_trees=100, proposal_probs=default_proposal_probs)
%prun -s cumtime -D Cal_profile_bart.prof -q bart_default.fit(X_train, y_train)
!gprof2dot -f pstats Cal_profile_bart.prof -o Cal_profile_bart.dot
!dot -Tpng Cal_profile_bart.dot -o Cal_profile_bart.png

Iterations: 100%|██████████| 300/300 [00:09<00:00, 31.95it/s]

 
*** Profile stats marshalled to file 'Cal_profile_bart.prof'.





In [9]:
models = {"bart" : bart, 
          "rf" : rf, 
          "lr" : lr,
          "btz" : btz,
          "bart_default" : bart_default}
results = {}
for model_name, model in models.items():
    if model_name == "btz":
        results[model_name] = mean_squared_error(y_test, btpred)
    else:
        results[model_name] = mean_squared_error(y_test, model.predict(X_test))
results

{'bart': 0.221192917849968,
 'rf': 0.2542358390056568,
 'lr': 0.5411287478470684,
 'btz': 0.3176016792627629,
 'bart_default': 0.25522702814372716}