In [1]:
import numpy as np
import sys
from pyinstrument import Profiler
from sklearn.linear_model import LinearRegression
from sklearn.tree import DecisionTreeRegressor
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
import os
sys.path.append(os.path.dirname(os.getcwd()))
from bart_playground import *

import bartz

In [2]:
proposal_probs = {"grow" : 0.4,
                  "prune" : 0.4,
                  "change" : 0.1,
                  "swap" : 0.1}
generator = DataGenerator(n_samples=160, n_features=2, noise=0.1, random_seed=42)
X, y = generator.generate(scenario="piecewise_flat")
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
np.set_printoptions(suppress=True)
print(y_train[:12])

[ 0.50327821  0.60672224  0.26898966  0.55211673  0.50693811  0.66162097
 -0.64127659  0.65112284  0.03487759  0.23276531  0.44055996  0.38216964]


In [3]:
profiler = Profiler()
profiler.start()
bart = ChangeNumTreeBART(ndpost=200, nskip=100, n_trees=200, proposal_probs=proposal_probs)
bart.fit(X_train, y_train)
profiler.stop()
profiler.print()

Iterations:   1%|          | 2/300 [00:00<00:19, 15.03it/s]

<bart_playground.moves.Combine object at 0x0000025A8EF0B9D0>
<bart_playground.moves.Combine object at 0x0000025A8EF08510>


Iterations:   2%|▏         | 6/300 [00:00<00:30,  9.57it/s]

<bart_playground.moves.Combine object at 0x0000025A930CCED0>


Iterations:   3%|▎         | 8/300 [00:00<00:33,  8.80it/s]

<bart_playground.moves.Combine object at 0x0000025A930CCED0>
<bart_playground.moves.Combine object at 0x0000025A9351CED0>


Iterations:   3%|▎         | 10/300 [00:01<00:35,  8.07it/s]

<bart_playground.moves.Combine object at 0x0000025A93608510>


Iterations:   4%|▍         | 12/300 [00:01<00:36,  7.98it/s]

<bart_playground.moves.Combine object at 0x0000025A936F39D0>
<bart_playground.moves.Combine object at 0x0000025A938FCED0>


Iterations:   5%|▌         | 15/300 [00:01<00:37,  7.67it/s]

<bart_playground.moves.Combine object at 0x0000025A93A5CED0>
<bart_playground.moves.Combine object at 0x0000025A93A1CED0>


Iterations:   6%|▌         | 18/300 [00:02<00:40,  6.91it/s]

<bart_playground.moves.Combine object at 0x0000025A93DCCED0>


Iterations:   7%|▋         | 20/300 [00:02<00:38,  7.27it/s]

<bart_playground.moves.Combine object at 0x0000025A9501CED0>
<bart_playground.moves.Combine object at 0x0000025A95164ED0>


Iterations:   7%|▋         | 22/300 [00:02<00:40,  6.95it/s]

<bart_playground.moves.Combine object at 0x0000025A9525F710>


Iterations:   9%|▉         | 27/300 [00:03<00:34,  7.87it/s]

<bart_playground.moves.Combine object at 0x0000025A956F39D0>


Iterations:  10%|█         | 30/300 [00:03<00:34,  7.85it/s]

<bart_playground.moves.Combine object at 0x0000025A959B4ED0>


Iterations:  12%|█▏        | 36/300 [00:04<00:33,  7.79it/s]

<bart_playground.moves.Combine object at 0x0000025A96CEB9D0>


Iterations:  13%|█▎        | 38/300 [00:04<00:35,  7.46it/s]

<bart_playground.moves.Combine object at 0x0000025A970DCED0>


Iterations:  13%|█▎        | 40/300 [00:05<00:33,  7.74it/s]

<bart_playground.moves.Combine object at 0x0000025A97270510>


Iterations:  15%|█▌        | 46/300 [00:05<00:31,  8.17it/s]

<bart_playground.moves.Combine object at 0x0000025A977B39D0>


Iterations:  16%|█▌        | 48/300 [00:06<00:32,  7.77it/s]

<bart_playground.moves.Combine object at 0x0000025A97A10510>


Iterations:  17%|█▋        | 51/300 [00:06<00:29,  8.31it/s]

<bart_playground.moves.Combine object at 0x0000025A98D539D0>
<bart_playground.moves.Combine object at 0x0000025A98DEB9D0>


Iterations:  18%|█▊        | 53/300 [00:06<00:28,  8.62it/s]

<bart_playground.moves.Combine object at 0x0000025A98ECB9D0>


Iterations:  19%|█▉        | 57/300 [00:07<00:27,  8.83it/s]

<bart_playground.moves.Combine object at 0x0000025A99204ED0>


Iterations:  20%|██        | 60/300 [00:07<00:29,  8.24it/s]

<bart_playground.moves.Combine object at 0x0000025A98B54ED0>


Iterations:  22%|██▏       | 65/300 [00:08<00:31,  7.38it/s]

<bart_playground.moves.Combine object at 0x0000025A9A8E0510>
<bart_playground.moves.Combine object at 0x0000025A9AACCED0>


Iterations:  22%|██▏       | 67/300 [00:08<00:32,  7.20it/s]

<bart_playground.moves.Combine object at 0x0000025A9ABB39D0>


Iterations:  23%|██▎       | 69/300 [00:08<00:29,  7.74it/s]

<bart_playground.moves.Combine object at 0x0000025A9ADAF250>


Iterations:  24%|██▍       | 72/300 [00:09<00:28,  8.12it/s]

<bart_playground.moves.Combine object at 0x0000025A9AF80510>
<bart_playground.moves.Combine object at 0x0000025A9B174ED0>


Iterations:  25%|██▍       | 74/300 [00:09<00:26,  8.60it/s]

<bart_playground.moves.Combine object at 0x0000025A9C1D8510>


Iterations:  25%|██▌       | 75/300 [00:09<00:25,  8.92it/s]

<bart_playground.moves.Combine object at 0x0000025A9C35CED0>
<bart_playground.moves.Combine object at 0x0000025A9C4239D0>


Iterations:  26%|██▌       | 77/300 [00:09<00:24,  9.25it/s]

<bart_playground.moves.Combine object at 0x0000025A9C4FB9D0>


Iterations:  27%|██▋       | 82/300 [00:10<00:22,  9.67it/s]

<bart_playground.moves.Combine object at 0x0000025A9C8139D0>
<bart_playground.moves.Combine object at 0x0000025A9C774ED0>


Iterations:  29%|██▉       | 87/300 [00:10<00:23,  9.06it/s]

<bart_playground.moves.Combine object at 0x0000025A9CB388D0>


Iterations:  30%|██▉       | 89/300 [00:10<00:24,  8.75it/s]

<bart_playground.moves.Combine object at 0x0000025A9CC839D0>


Iterations:  30%|███       | 91/300 [00:11<00:24,  8.67it/s]

<bart_playground.moves.Combine object at 0x0000025A9DF1CED0>
<bart_playground.moves.Combine object at 0x0000025A9E057710>


Iterations:  31%|███       | 93/300 [00:11<00:24,  8.51it/s]

<bart_playground.moves.Combine object at 0x0000025A9C5ACED0>


Iterations:  32%|███▏      | 95/300 [00:11<00:22,  9.20it/s]

<bart_playground.moves.Combine object at 0x0000025A9E18CED0>


Iterations:  32%|███▏      | 97/300 [00:11<00:25,  7.91it/s]

<bart_playground.moves.Combine object at 0x0000025A9E350510>
<bart_playground.moves.Combine object at 0x0000025A9E4F4ED0>


Iterations:  33%|███▎      | 100/300 [00:12<00:23,  8.51it/s]

<bart_playground.moves.Combine object at 0x0000025A9E649C10>
<bart_playground.moves.Combine object at 0x0000025A9E73C850>


Iterations:  34%|███▍      | 102/300 [00:12<00:22,  8.76it/s]

<bart_playground.moves.Combine object at 0x0000025A9F7AB9D0>
<bart_playground.moves.Combine object at 0x0000025A9E73CED0>


Iterations:  35%|███▌      | 105/300 [00:12<00:20,  9.64it/s]

<bart_playground.moves.Combine object at 0x0000025A9FA839D0>


Iterations:  36%|███▋      | 109/300 [00:13<00:20,  9.28it/s]

<bart_playground.moves.Combine object at 0x0000025A9FCFB9D0>


Iterations:  38%|███▊      | 114/300 [00:13<00:19,  9.75it/s]

<bart_playground.moves.Combine object at 0x0000025AA009B9D0>


Iterations:  39%|███▊      | 116/300 [00:13<00:18,  9.76it/s]

<bart_playground.moves.Combine object at 0x0000025AA1318510>


Iterations:  40%|███▉      | 119/300 [00:14<00:18,  9.96it/s]

<bart_playground.moves.Combine object at 0x0000025AA1404ED0>


Iterations:  40%|████      | 120/300 [00:14<00:18,  9.65it/s]

<bart_playground.moves.Combine object at 0x0000025AA15BCED0>
<bart_playground.moves.Combine object at 0x0000025AA1708510>


Iterations:  42%|████▏     | 126/300 [00:14<00:16, 10.43it/s]

<bart_playground.moves.Combine object at 0x0000025AA19E4ED0>
<bart_playground.moves.Combine object at 0x0000025AA1B0AED0>


Iterations:  44%|████▎     | 131/300 [00:15<00:16, 10.36it/s]

<bart_playground.moves.Combine object at 0x0000025AA1CD4ED0>
<bart_playground.moves.Combine object at 0x0000025AA2E6CE10>
<bart_playground.moves.Combine object at 0x0000025AA2D8CED0>


Iterations:  46%|████▌     | 137/300 [00:15<00:14, 11.10it/s]

<bart_playground.moves.Combine object at 0x0000025AA320B9D0>
<bart_playground.moves.Combine object at 0x0000025AA3337250>


Iterations:  47%|████▋     | 141/300 [00:16<00:13, 11.74it/s]

<bart_playground.moves.Combine object at 0x0000025AA33839D0>


Iterations:  48%|████▊     | 145/300 [00:16<00:13, 11.60it/s]

<bart_playground.moves.Combine object at 0x0000025AA34E39D0>
<bart_playground.moves.Combine object at 0x0000025AA4721DD0>


Iterations:  50%|████▉     | 149/300 [00:17<00:13, 11.13it/s]

<bart_playground.moves.Combine object at 0x0000025AA4723110>
<bart_playground.moves.Combine object at 0x0000025AA47B39D0>


Iterations:  50%|█████     | 151/300 [00:17<00:13, 11.19it/s]

<bart_playground.moves.Combine object at 0x0000025AA4944ED0>


Iterations:  52%|█████▏    | 155/300 [00:17<00:12, 11.76it/s]

<bart_playground.moves.Combine object at 0x0000025AA4D74ED0>


Iterations:  53%|█████▎    | 159/300 [00:17<00:12, 11.10it/s]

<bart_playground.moves.Combine object at 0x0000025AA4F30510>


Iterations:  54%|█████▎    | 161/300 [00:18<00:12, 11.19it/s]

<bart_playground.moves.Combine object at 0x0000025AA503B9D0>


Iterations:  55%|█████▌    | 165/300 [00:18<00:12, 10.81it/s]

<bart_playground.moves.Combine object at 0x0000025AA4E04ED0>


Iterations:  56%|█████▌    | 167/300 [00:18<00:11, 11.36it/s]

<bart_playground.moves.Combine object at 0x0000025AA4A6CED0>
<bart_playground.moves.Combine object at 0x0000025AA6567090>


Iterations:  56%|█████▋    | 169/300 [00:18<00:11, 11.54it/s]

<bart_playground.moves.Combine object at 0x0000025AA65FB9D0>
<bart_playground.moves.Combine object at 0x0000025AA6694ED0>


Iterations:  58%|█████▊    | 173/300 [00:19<00:11, 11.43it/s]

<bart_playground.moves.Combine object at 0x0000025AA68539D0>
<bart_playground.moves.Combine object at 0x0000025AA68A39D0>


Iterations:  59%|█████▉    | 177/300 [00:19<00:10, 12.20it/s]

<bart_playground.moves.Combine object at 0x0000025AA7AD8F90>
<bart_playground.moves.Combine object at 0x0000025AA68EAED0>
<bart_playground.moves.Combine object at 0x0000025AA7BAB9D0>


Iterations:  60%|█████▉    | 179/300 [00:19<00:09, 12.14it/s]

<bart_playground.moves.Combine object at 0x0000025AA7C0B9D0>


Iterations:  61%|██████    | 183/300 [00:19<00:09, 12.46it/s]

<bart_playground.moves.Combine object at 0x0000025AA7D94ED0>


Iterations:  63%|██████▎   | 189/300 [00:20<00:09, 11.90it/s]

<bart_playground.moves.Combine object at 0x0000025AA81D0510>


Iterations:  64%|██████▍   | 193/300 [00:20<00:09, 11.79it/s]

<bart_playground.moves.Combine object at 0x0000025AA82B8510>
<bart_playground.moves.Combine object at 0x0000025AA9534590>


Iterations:  66%|██████▌   | 197/300 [00:21<00:07, 12.96it/s]

<bart_playground.moves.Combine object at 0x0000025AA957CED0>
<bart_playground.moves.Combine object at 0x0000025AA9749CD0>


Iterations:  66%|██████▋   | 199/300 [00:21<00:07, 13.08it/s]

<bart_playground.moves.Combine object at 0x0000025AA970F890>
<bart_playground.moves.Combine object at 0x0000025AA98757D0>
<bart_playground.moves.Combine object at 0x0000025AA98B0510>


Iterations:  67%|██████▋   | 201/300 [00:21<00:07, 12.71it/s]

<bart_playground.moves.Combine object at 0x0000025AA98B0510>
<bart_playground.moves.Combine object at 0x0000025AA99B7250>


Iterations:  69%|██████▉   | 207/300 [00:21<00:07, 13.02it/s]

<bart_playground.moves.Combine object at 0x0000025AA9B80210>


Iterations:  70%|██████▉   | 209/300 [00:22<00:07, 12.68it/s]

<bart_playground.moves.Combine object at 0x0000025AA6464ED0>
<bart_playground.moves.Combine object at 0x0000025AAAEB8510>


Iterations:  71%|███████   | 213/300 [00:22<00:06, 12.93it/s]

<bart_playground.moves.Combine object at 0x0000025AAB00B950>


Iterations:  72%|███████▏  | 217/300 [00:22<00:06, 12.50it/s]

<bart_playground.moves.Combine object at 0x0000025AAB1ECED0>
<bart_playground.moves.Combine object at 0x0000025AAB2D0510>
<bart_playground.moves.Combine object at 0x0000025AAB332ED0>


Iterations:  74%|███████▎  | 221/300 [00:22<00:06, 12.44it/s]

<bart_playground.moves.Combine object at 0x0000025AAB3D39D0>
<bart_playground.moves.Combine object at 0x0000025AAB51F250>


Iterations:  76%|███████▋  | 229/300 [00:23<00:05, 14.08it/s]

<bart_playground.moves.Combine object at 0x0000025AAB69CED0>
<bart_playground.moves.Combine object at 0x0000025AAC8B7410>
<bart_playground.moves.Combine object at 0x0000025AA9B30510>


Iterations:  78%|███████▊  | 233/300 [00:23<00:04, 14.36it/s]

<bart_playground.moves.Combine object at 0x0000025AA9B30510>
<bart_playground.moves.Combine object at 0x0000025AACA80510>
<bart_playground.moves.Combine object at 0x0000025AACA80510>


Iterations:  78%|███████▊  | 235/300 [00:23<00:04, 14.48it/s]

<bart_playground.moves.Combine object at 0x0000025AACBBA310>
<bart_playground.moves.Combine object at 0x0000025AAADCF510>
<bart_playground.moves.Combine object at 0x0000025AACC5B9D0>


Iterations:  80%|███████▉  | 239/300 [00:24<00:04, 14.70it/s]

<bart_playground.moves.Combine object at 0x0000025AACCE8510>


Iterations:  80%|████████  | 241/300 [00:24<00:03, 14.95it/s]

<bart_playground.moves.Combine object at 0x0000025AACD3B9D0>
<bart_playground.moves.Combine object at 0x0000025AAC6CB9D0>


Iterations:  82%|████████▏ | 247/300 [00:24<00:03, 15.93it/s]

<bart_playground.moves.Combine object at 0x0000025AAE084ED0>
<bart_playground.moves.Combine object at 0x0000025AAE110510>
<bart_playground.moves.Combine object at 0x0000025AAE1677D0>
<bart_playground.moves.Combine object at 0x0000025AAE207E10>


Iterations:  84%|████████▍ | 253/300 [00:25<00:03, 15.56it/s]

<bart_playground.moves.Combine object at 0x0000025AAE33C2D0>
<bart_playground.moves.Combine object at 0x0000025AAE33CED0>


Iterations:  86%|████████▌ | 257/300 [00:25<00:02, 16.29it/s]

<bart_playground.moves.Combine object at 0x0000025AAE4CCED0>
<bart_playground.moves.Combine object at 0x0000025AAE5639D0>


Iterations:  87%|████████▋ | 262/300 [00:25<00:02, 18.29it/s]

<bart_playground.moves.Combine object at 0x0000025AAE6997D0>
<bart_playground.moves.Combine object at 0x0000025AAE6DA110>
<bart_playground.moves.Combine object at 0x0000025AAE724ED0>


Iterations:  88%|████████▊ | 265/300 [00:25<00:01, 19.87it/s]

<bart_playground.moves.Combine object at 0x0000025AAE806050>
<bart_playground.moves.Combine object at 0x0000025AAF863290>
<bart_playground.moves.Combine object at 0x0000025AAF8AB490>
<bart_playground.moves.Combine object at 0x0000025AAF903E90>


Iterations:  91%|█████████ | 272/300 [00:26<00:01, 18.26it/s]

<bart_playground.moves.Combine object at 0x0000025AAFB297D0>
<bart_playground.moves.Combine object at 0x0000025AAFBD0510>


Iterations:  93%|█████████▎| 278/300 [00:26<00:01, 20.62it/s]

<bart_playground.moves.Combine object at 0x0000025AAFCB0510>
<bart_playground.moves.Combine object at 0x0000025AAFD90510>
<bart_playground.moves.Combine object at 0x0000025AAFDD3250>


Iterations:  94%|█████████▎| 281/300 [00:26<00:00, 21.68it/s]

<bart_playground.moves.Combine object at 0x0000025AAFE68090>


Iterations:  96%|█████████▌| 287/300 [00:26<00:00, 19.49it/s]

<bart_playground.moves.Combine object at 0x0000025AAFEA8510>
<bart_playground.moves.Combine object at 0x0000025AAFF9E750>
<bart_playground.moves.Combine object at 0x0000025AAFFE7D50>
<bart_playground.moves.Combine object at 0x0000025AACA34ED0>


Iterations:  99%|█████████▊| 296/300 [00:27<00:00, 20.38it/s]

<bart_playground.moves.Combine object at 0x0000025AB1272810>
<bart_playground.moves.Combine object at 0x0000025AB1323510>


Iterations: 100%|██████████| 300/300 [00:27<00:00, 10.93it/s]



  _     ._   __/__   _ _  _  _ _/_   Recorded: 21:40:43  Samples:  27102
 /_//_/// /_\ / //_// / //_'/ //     Duration: 27.448    CPU time: 1.875
/   _/                      v5.0.1

Profile at C:\Users\ztykk\AppData\Local\Temp\ipykernel_29204\4120808940.py:2

27.447 <module>  ..\..\..\Temp\ipykernel_29204\4120808940.py:1
└─ 27.447 ChangeNumTreeBART.fit  bart_playground\bart.py:22
   └─ 27.446 NTreeSampler.run  bart_playground\samplers.py:69
      └─ 27.292 NTreeSampler.one_iter  bart_playground\samplers.py:226
         ├─ 10.113 Swap.propose  bart_playground\moves.py:34
         │  ├─ 3.332 Swap.try_propose  bart_playground\moves.py:146
         │  │  ├─ 1.942 Tree.swap_split  bart_playground\params.py:296
         │  │  │  └─ 1.892 Tree.change_split  bart_playground\params.py:289
         │  │  │     └─ 1.840 Tree.update_n  bart_playground\params.py:384
         │  │  │        ├─ 0.996 Tree.update_n  bart_playground\params.py:384
         │  │  │        │  ├─ 0.447 sum  <__array_func

In [11]:
bart.trace[-1].n_trees

57

In [4]:
rf = RandomForestRegressor()
lr = LinearRegression()
rf.fit(X_train, y_train)
lr.fit(X_train, y_train)

btz = bartz.BART.gbart(np.transpose(X_train), y_train, ntree=100, ndpost=200, nskip=100)
btpred_all = btz.predict(np.transpose(X_test))
btpred = np.mean(np.array(btpred_all), axis=0)

Iteration 100/300 P_grow=0.55 P_prune=0.45 A_grow=0.36 A_prune=0.36 (burnin)
Iteration 200/300 P_grow=0.57 P_prune=0.43 A_grow=0.35 A_prune=0.37
Iteration 300/300 P_grow=0.57 P_prune=0.43 A_grow=0.39 A_prune=0.40


In [5]:
bart.predict(X_test)

array([ 0.22001675, -0.230913  , -0.10374638,  0.04823599,  0.25229121,
       -0.2393027 ,  0.08098   ,  0.3200357 ,  0.35016712,  0.1823632 ,
        0.15326534,  0.34643374,  0.36212052,  0.17095316,  0.18518097,
        0.22246058,  0.25227022,  0.32611823,  0.31454847,  0.04541519,
        0.3164249 , -0.05924212,  0.22449371,  0.24489868,  0.32737593,
        0.18960932, -0.23711059,  0.16141928, -0.05281844, -0.04245347,
       -0.04566077,  0.16538296,  0.24238725, -0.11083148,  0.26970229,
        0.18398514,  0.14808005, -0.31253769,  0.23028383,  0.30723097])

In [6]:

models = {"bart" : bart, 
          "rf" : rf, 
          "lr" : lr,
          "btz" : btz}
results = {}
for model_name, model in models.items():
    if model_name == "btz":
        results[model_name] = mean_squared_error(y_test, btpred)
    else:
        results[model_name] = mean_squared_error(y_test, model.predict(X_test))
results

{'bart': 0.06035556187640084,
 'rf': 0.022391298855853174,
 'lr': 0.048045521328019404,
 'btz': 0.02328283761397566}