In [1]:
import sys
import os
sys.path.append(os.path.dirname(os.getcwd()))

In [2]:
import numpy as np
import sys
from pyinstrument import Profiler
from sklearn.linear_model import LinearRegression
from sklearn.tree import DecisionTreeRegressor
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
from bart_playground import *
import bartz
import arviz as az

INFO:arviz.preview:arviz_base not installed
INFO:arviz.preview:arviz_stats not installed
INFO:arviz.preview:arviz_plots not installed


In [3]:
proposal_probs = {"multi_grow": 0.25, "multi_prune": 0.25, "multi_change": 0.4, "multi_swap": 0.1}
#proposal_probs = {"multi_grow": 0.5, "multi_prune": 0.5}
generator = DataGenerator(n_samples=160, n_features=2, noise=0.1, random_seed=42)
X, y = generator.generate(scenario="piecewise_flat")
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)

In [4]:
bart1 = DefaultBART(ndpost=50, nskip=20, n_trees=100)
bart1.fit(X_train, y_train)

Iterations: 100%|██████████| 70/70 [00:08<00:00,  7.80it/s]


In [5]:
from pyinstrument import Profiler

profiler = Profiler()
profiler.start()

bart = MultiBART(ndpost=50, nskip=20, n_trees=100, proposal_probs=proposal_probs, multi_tries=1)
bart.fit(X_train, y_train)

profiler.stop()
profiler.print()

Iterations: 100%|██████████| 70/70 [00:05<00:00, 11.74it/s]



  _     ._   __/__   _ _  _  _ _/_   Recorded: 14:16:57  Samples:  5857
 /_//_/// /_\ / //_// / //_'/ //     Duration: 5.968     CPU time: 5.891
/   _/                      v5.0.1

Profile at C:\Windows\Temp\ipykernel_23776\2353391020.py:4

5.970 <module>  C:\Windows\Temp\ipykernel_23776\2353391020.py:1
└─ 5.970 MultiBART.fit  bart_playground\bart.py:27
   └─ 5.969 MultiSampler.run  bart_playground\samplers.py:83
      └─ 5.908 MultiSampler.one_iter  bart_playground\samplers.py:358
         ├─ 5.228 MultiChange.propose  bart_playground\moves.py:33
         │  ├─ 2.326 MultiChange.try_propose  bart_playground\moves.py:401
         │  │  ├─ 1.038 Tree.change_split  bart_playground\params.py:391
         │  │  │  └─ 1.015 Tree.update_n  bart_playground\params.py:405
         │  │  │     ├─ 0.472 [self]  bart_playground\params.py
         │  │  │     ├─ 0.272 Tree.update_n  bart_playground\params.py:405
         │  │  │     └─ 0.232 sum  numpy\core\fromnumeric.py:2177
         │  │  │    

In [6]:
rf = RandomForestRegressor(random_state=42)
lr = LinearRegression()
rf.fit(X_train, y_train)
lr.fit(X_train, y_train)

btz = bartz.BART.gbart(np.transpose(X_train), y_train, ntree=100, ndpost=50, nskip=20)
btpred_all = btz.predict(np.transpose(X_test))
btpred = np.mean(np.array(btpred_all), axis=0)

INFO:2025-06-28 14:17:04,430:jax._src.xla_bridge:927: Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
INFO:2025-06-28 14:17:04,437:jax._src.xla_bridge:927: Unable to initialize backend 'tpu': UNIMPLEMENTED: LoadPjrtPlugin is not implemented on windows yet.
INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': UNIMPLEMENTED: LoadPjrtPlugin is not implemented on windows yet.


In [7]:
profiler = Profiler()
profiler.start()

default_proposal_probs = {"grow": 0.25, "prune": 0.25, "change": 0.4, "swap": 0.1}
bart_default = DefaultBART(ndpost=50, nskip=20, n_trees=100, proposal_probs=default_proposal_probs)
bart_default.fit(X_train, y_train)

profiler.stop()
profiler.print()

Iterations: 100%|██████████| 70/70 [00:03<00:00, 23.06it/s]



  _     ._   __/__   _ _  _  _ _/_   Recorded: 14:17:12  Samples:  3015
 /_//_/// /_\ / //_// / //_'/ //     Duration: 3.045     CPU time: 3.031
/   _/                      v5.0.1

Profile at C:\Windows\Temp\ipykernel_23776\3672608610.py:2

3.044 ZMQInteractiveShell.run_code  IPython\core\interactiveshell.py:3543
└─ 3.043 <module>  C:\Windows\Temp\ipykernel_23776\3672608610.py:1
   └─ 3.043 DefaultBART.fit  bart_playground\bart.py:27
      └─ 3.042 DefaultSampler.run  bart_playground\samplers.py:83
         └─ 3.004 DefaultSampler.one_iter  bart_playground\samplers.py:280
            ├─ 1.988 Swap.propose  bart_playground\moves.py:33
            │  ├─ 1.014 Swap.try_propose  bart_playground\moves.py:146
            │  │  ├─ 0.861 Tree.swap_split  bart_playground\params.py:398
            │  │  │  └─ 0.843 Tree.change_split  bart_playground\params.py:391
            │  │  │     └─ 0.831 Tree.update_n  bart_playground\params.py:405
            │  │  │        ├─ 0.417 Tree.update_n  bart

In [8]:
models = {"bart" : bart, 
          "rf" : rf, 
          "lr" : lr,
          "btz" : btz,
          "bart_default" : bart_default}
results = {}
for model_name, model in models.items():
    if model_name == "btz":
        results[model_name] = mean_squared_error(y_test, btpred)
    else:
        results[model_name] = mean_squared_error(y_test, model.predict(X_test))
results

{'bart': 0.02112037030574617,
 'rf': 0.022139023845392215,
 'lr': 0.048045521328019404,
 'btz': 0.022409798535525886,
 'bart_default': 0.02292141162635982}