In [1]:
import numpy as np
import sys
from pyinstrument import Profiler
from sklearn.linear_model import LinearRegression
from sklearn.tree import DecisionTreeRegressor
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
import os
sys.path.append(os.path.dirname(os.getcwd()))
from bart_playground import *

import bartz

In [2]:
proposal_probs = {"grow" : 0.5,
                  "prune" : 0.5}
generator = DataGenerator(n_samples=160, n_features=2, noise=0.1, random_seed=42)
X, y = generator.generate(scenario="piecewise_flat")
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
np.set_printoptions(suppress=True)
print(y_train[:12])

[ 0.50327821  0.60672224  0.26898966  0.55211673  0.50693811  0.66162097
 -0.64127659  0.65112284  0.03487759  0.23276531  0.44055996  0.38216964]


In [3]:
profiler = Profiler()
profiler.start()
bart = ChangeNumTreeBART(ndpost=200, nskip=100, n_trees=200, proposal_probs=proposal_probs)
bart.fit(X_train, y_train)
profiler.stop()
profiler.print()

Iterations:   1%|▏         | 4/300 [00:00<00:27, 10.78it/s]

<bart_playground.moves.Combine object at 0x000001706F1E3D90>
<bart_playground.moves.Combine object at 0x000001706F364A90>


Iterations:   2%|▏         | 6/300 [00:00<00:29, 10.00it/s]

<bart_playground.moves.Combine object at 0x000001706F4A8090>
<bart_playground.moves.Combine object at 0x000001706F5D7E10>


Iterations:   3%|▎         | 10/300 [00:01<00:30,  9.42it/s]

<bart_playground.moves.Combine object at 0x000001706F61BD90>
<bart_playground.moves.Combine object at 0x000001706F8F82D0>


Iterations:   4%|▍         | 13/300 [00:01<00:30,  9.30it/s]

<bart_playground.moves.Combine object at 0x000001706F9B8090>
<bart_playground.moves.Combine object at 0x000001706FAE0510>


Iterations:   5%|▌         | 16/300 [00:01<00:30,  9.33it/s]

<bart_playground.moves.Combine object at 0x000001706FCD9950>
<bart_playground.moves.Combine object at 0x000001706FD602D0>


Iterations:   6%|▌         | 18/300 [00:01<00:29,  9.46it/s]

<bart_playground.moves.Combine object at 0x000001706FDB3350>


Iterations:   7%|▋         | 21/300 [00:02<00:29,  9.44it/s]

<bart_playground.moves.Combine object at 0x0000017070F6BD90>


Iterations:   7%|▋         | 22/300 [00:02<00:29,  9.37it/s]

<bart_playground.moves.Combine object at 0x000001706F4F4A90>


Iterations:   9%|▊         | 26/300 [00:02<00:28,  9.69it/s]

<bart_playground.moves.Combine object at 0x00000170712C3D90>
<bart_playground.moves.Combine object at 0x00000170714C3D90>


Iterations:   9%|▉         | 28/300 [00:02<00:28,  9.45it/s]

<bart_playground.moves.Combine object at 0x00000170715D0510>


Iterations:  11%|█▏        | 34/300 [00:03<00:27,  9.59it/s]

<bart_playground.moves.Combine object at 0x00000170718C02D0>
<bart_playground.moves.Combine object at 0x0000017072B150D0>


Iterations:  12%|█▏        | 35/300 [00:03<00:27,  9.57it/s]

<bart_playground.moves.Combine object at 0x0000017072B582D0>
<bart_playground.moves.Combine object at 0x00000170719D4A90>


Iterations:  13%|█▎        | 39/300 [00:04<00:27,  9.43it/s]

<bart_playground.moves.Combine object at 0x00000170713E82D0>
<bart_playground.moves.Combine object at 0x0000017072F53D90>


Iterations:  14%|█▎        | 41/300 [00:04<00:27,  9.37it/s]

<bart_playground.moves.Combine object at 0x0000017072F50510>


Iterations:  14%|█▍        | 43/300 [00:04<00:27,  9.22it/s]

<bart_playground.moves.Combine object at 0x0000017073207710>
<bart_playground.moves.Combine object at 0x000001707336BD90>


Iterations:  16%|█▌        | 47/300 [00:04<00:25,  9.90it/s]

<bart_playground.moves.Combine object at 0x00000170733F3D90>
<bart_playground.moves.Combine object at 0x0000017073290090>
<bart_playground.moves.Combine object at 0x0000017074653D90>


Iterations:  20%|██        | 61/300 [00:06<00:22, 10.46it/s]

<bart_playground.moves.Combine object at 0x0000017074DF02D0>


Iterations:  22%|██▏       | 65/300 [00:06<00:22, 10.48it/s]

<bart_playground.moves.Combine object at 0x000001707609BD90>
<bart_playground.moves.Combine object at 0x0000017075FB82D0>


Iterations:  23%|██▎       | 69/300 [00:07<00:22, 10.40it/s]

<bart_playground.moves.Combine object at 0x00000170764982D0>
<bart_playground.moves.Combine object at 0x000001707653D150>


Iterations:  24%|██▍       | 73/300 [00:07<00:21, 10.48it/s]

<bart_playground.moves.Combine object at 0x0000017076573D90>
<bart_playground.moves.Combine object at 0x00000170765B0510>


Iterations:  26%|██▌       | 77/300 [00:07<00:21, 10.60it/s]

<bart_playground.moves.Combine object at 0x0000017077A482D0>


Iterations:  26%|██▋       | 79/300 [00:08<00:20, 10.60it/s]

<bart_playground.moves.Combine object at 0x0000017077BE8950>


Iterations:  27%|██▋       | 81/300 [00:08<00:20, 10.70it/s]

<bart_playground.moves.Combine object at 0x0000017077BEBD90>
<bart_playground.moves.Combine object at 0x0000017077FA0A50>


Iterations:  29%|██▉       | 87/300 [00:08<00:20, 10.39it/s]

<bart_playground.moves.Combine object at 0x00000170781BBD90>
<bart_playground.moves.Combine object at 0x00000170783602D0>


Iterations:  30%|███       | 91/300 [00:09<00:19, 10.87it/s]

<bart_playground.moves.Combine object at 0x00000170794482D0>
<bart_playground.moves.Combine object at 0x00000170794E02D0>
<bart_playground.moves.Combine object at 0x0000017077A48090>


Iterations:  31%|███       | 93/300 [00:09<00:18, 11.17it/s]

<bart_playground.moves.Combine object at 0x00000170795BBD90>


Iterations:  32%|███▏      | 97/300 [00:09<00:18, 11.10it/s]

<bart_playground.moves.Combine object at 0x00000170797FBD90>


Iterations:  34%|███▎      | 101/300 [00:10<00:17, 11.35it/s]

<bart_playground.moves.Combine object at 0x00000170799F3D90>
<bart_playground.moves.Combine object at 0x0000017079BF3D90>


Iterations:  34%|███▍      | 103/300 [00:10<00:17, 11.56it/s]

<bart_playground.moves.Combine object at 0x000001707AC90510>
<bart_playground.moves.Combine object at 0x000001707ADFBC50>


Iterations:  36%|███▌      | 107/300 [00:10<00:17, 10.97it/s]

<bart_playground.moves.Combine object at 0x000001706F147B90>
<bart_playground.moves.Combine object at 0x000001707AEC02D0>


Iterations:  36%|███▋      | 109/300 [00:10<00:16, 11.26it/s]

<bart_playground.moves.Combine object at 0x000001707AEC02D0>
<bart_playground.moves.Combine object at 0x000001707B043D90>


Iterations:  38%|███▊      | 113/300 [00:11<00:16, 11.12it/s]

<bart_playground.moves.Combine object at 0x000001707B2582D0>
<bart_playground.moves.Combine object at 0x000001707B2ECA90>


Iterations:  41%|████      | 123/300 [00:11<00:15, 11.24it/s]

<bart_playground.moves.Combine object at 0x000001707C8B5150>
<bart_playground.moves.Combine object at 0x000001707C7EBD90>


Iterations:  42%|████▏     | 125/300 [00:12<00:15, 11.14it/s]

<bart_playground.moves.Combine object at 0x000001707CA30E50>


Iterations:  42%|████▏     | 127/300 [00:12<00:15, 11.41it/s]

<bart_playground.moves.Combine object at 0x000001707CC3F710>
<bart_playground.moves.Combine object at 0x000001707CDC82D0>


Iterations:  44%|████▍     | 133/300 [00:12<00:13, 11.97it/s]

<bart_playground.moves.Combine object at 0x000001707CEE82D0>
<bart_playground.moves.Combine object at 0x000001707E063D90>


Iterations:  46%|████▌     | 137/300 [00:13<00:13, 12.31it/s]

<bart_playground.moves.Combine object at 0x000001707E1DBB90>
<bart_playground.moves.Combine object at 0x000001707E335490>


Iterations:  46%|████▋     | 139/300 [00:13<00:12, 12.47it/s]

<bart_playground.moves.Combine object at 0x000001707E370510>


Iterations:  49%|████▉     | 147/300 [00:13<00:11, 12.80it/s]

<bart_playground.moves.Combine object at 0x000001707E6CBD90>
<bart_playground.moves.Combine object at 0x000001707E488510>


Iterations:  50%|████▉     | 149/300 [00:14<00:11, 12.70it/s]

<bart_playground.moves.Combine object at 0x000001707F91BD90>


Iterations:  51%|█████     | 153/300 [00:14<00:11, 12.75it/s]

<bart_playground.moves.Combine object at 0x000001707FBD82D0>
<bart_playground.moves.Combine object at 0x000001707F8D02D0>


Iterations:  52%|█████▏    | 157/300 [00:14<00:10, 13.12it/s]

<bart_playground.moves.Combine object at 0x000001707FD7FCD0>


Iterations:  54%|█████▎    | 161/300 [00:15<00:10, 12.80it/s]

<bart_playground.moves.Combine object at 0x000001707FF4CA90>


Iterations:  55%|█████▌    | 165/300 [00:15<00:10, 13.02it/s]

<bart_playground.moves.Combine object at 0x000001700121D150>


Iterations:  56%|█████▌    | 167/300 [00:15<00:10, 13.08it/s]

<bart_playground.moves.Combine object at 0x000001700132BD90>
<bart_playground.moves.Combine object at 0x00000170014A5610>
<bart_playground.moves.Combine object at 0x0000017001531850>


Iterations:  57%|█████▋    | 171/300 [00:15<00:09, 13.04it/s]

<bart_playground.moves.Combine object at 0x00000170016402D0>
<bart_playground.moves.Combine object at 0x000001700176B9D0>


Iterations:  60%|█████▉    | 179/300 [00:16<00:09, 13.39it/s]

<bart_playground.moves.Combine object at 0x00000170029BBD90>
<bart_playground.moves.Combine object at 0x0000017002AE02D0>


Iterations:  62%|██████▏   | 185/300 [00:16<00:08, 13.84it/s]

<bart_playground.moves.Combine object at 0x0000017002D77710>
<bart_playground.moves.Combine object at 0x0000017002E0BD90>


Iterations:  64%|██████▎   | 191/300 [00:17<00:07, 14.38it/s]

<bart_playground.moves.Combine object at 0x00000170030382D0>


Iterations:  65%|██████▌   | 195/300 [00:17<00:07, 14.60it/s]

<bart_playground.moves.Combine object at 0x00000170031982D0>
<bart_playground.moves.Combine object at 0x000001700303BD90>
<bart_playground.moves.Combine object at 0x0000017004313D90>


Iterations:  66%|██████▌   | 197/300 [00:17<00:07, 14.43it/s]

<bart_playground.moves.Combine object at 0x00000170043982D0>
<bart_playground.moves.Combine object at 0x00000170044282D0>
<bart_playground.moves.Combine object at 0x00000170044BF710>


Iterations:  68%|██████▊   | 203/300 [00:18<00:06, 15.03it/s]

<bart_playground.moves.Combine object at 0x000001700450BD90>
<bart_playground.moves.Combine object at 0x00000170046702D0>


Iterations:  70%|██████▉   | 209/300 [00:18<00:06, 14.87it/s]

<bart_playground.moves.Combine object at 0x00000170048B9390>
<bart_playground.moves.Combine object at 0x0000017004990510>
<bart_playground.moves.Combine object at 0x0000017004993D90>


Iterations:  71%|███████   | 213/300 [00:18<00:05, 15.37it/s]

<bart_playground.moves.Combine object at 0x0000017005A48510>
<bart_playground.moves.Combine object at 0x0000017005BA0510>


Iterations:  73%|███████▎  | 219/300 [00:19<00:04, 16.23it/s]

<bart_playground.moves.Combine object at 0x0000017005CF82D0>
<bart_playground.moves.Combine object at 0x0000017005CF82D0>


Iterations:  75%|███████▌  | 225/300 [00:19<00:04, 16.83it/s]

<bart_playground.moves.Combine object at 0x0000017005E68510>
<bart_playground.moves.Combine object at 0x00000170061F82D0>


Iterations:  76%|███████▋  | 229/300 [00:19<00:04, 15.74it/s]

<bart_playground.moves.Combine object at 0x00000170072BBB10>
<bart_playground.moves.Combine object at 0x000001700734DBD0>


Iterations:  78%|███████▊  | 233/300 [00:19<00:04, 16.67it/s]

<bart_playground.moves.Combine object at 0x00000170073E02D0>
<bart_playground.moves.Combine object at 0x000001700754ECD0>


Iterations:  79%|███████▉  | 237/300 [00:20<00:03, 17.09it/s]

<bart_playground.moves.Combine object at 0x00000170075902D0>
<bart_playground.moves.Combine object at 0x00000170076702D0>
<bart_playground.moves.Combine object at 0x000001700778E750>


Iterations:  81%|████████  | 243/300 [00:20<00:03, 17.87it/s]

<bart_playground.moves.Combine object at 0x00000170078A02D0>
<bart_playground.moves.Combine object at 0x00000170078A02D0>
<bart_playground.moves.Combine object at 0x00000170079B8510>
<bart_playground.moves.Combine object at 0x0000017007A06A90>


Iterations:  83%|████████▎ | 250/300 [00:20<00:02, 19.22it/s]

<bart_playground.moves.Combine object at 0x0000017008AB6190>
<bart_playground.moves.Combine object at 0x0000017008B341D0>


Iterations:  84%|████████▍ | 253/300 [00:20<00:02, 19.88it/s]

<bart_playground.moves.Combine object at 0x0000017008C002D0>
<bart_playground.moves.Combine object at 0x0000017008C9D690>
<bart_playground.moves.Combine object at 0x0000017008D68510>


Iterations:  87%|████████▋ | 260/300 [00:21<00:02, 19.82it/s]

<bart_playground.moves.Combine object at 0x0000017008EE0510>
<bart_playground.moves.Combine object at 0x0000017008F6D150>
<bart_playground.moves.Combine object at 0x00000170090402D0>


Iterations:  88%|████████▊ | 264/300 [00:21<00:01, 19.33it/s]

<bart_playground.moves.Combine object at 0x00000170090D59D0>
<bart_playground.moves.Combine object at 0x000001700A0F6F50>
<bart_playground.moves.Combine object at 0x000001700A178C90>


Iterations:  91%|█████████▏| 274/300 [00:22<00:01, 20.48it/s]

<bart_playground.moves.Combine object at 0x000001700A3E8290>
<bart_playground.moves.Combine object at 0x000001700A425290>


Iterations:  92%|█████████▏| 277/300 [00:22<00:01, 21.18it/s]

<bart_playground.moves.Combine object at 0x000001700A5582D0>
<bart_playground.moves.Combine object at 0x000001700A6C82D0>


Iterations:  94%|█████████▍| 283/300 [00:22<00:00, 21.23it/s]

<bart_playground.moves.Combine object at 0x000001700A6C82D0>
<bart_playground.moves.Combine object at 0x000001700A81FF50>
<bart_playground.moves.Combine object at 0x000001700B8CD5D0>


Iterations:  96%|█████████▋| 289/300 [00:22<00:00, 22.00it/s]

<bart_playground.moves.Combine object at 0x000001700B95AB10>
<bart_playground.moves.Combine object at 0x000001700B9E05D0>


Iterations:  98%|█████████▊| 295/300 [00:22<00:00, 22.90it/s]

<bart_playground.moves.Combine object at 0x000001700BA7BD90>


Iterations: 100%|██████████| 300/300 [00:23<00:00, 12.94it/s]



  _     ._   __/__   _ _  _  _ _/_   Recorded: 23:01:18  Samples:  22926
 /_//_/// /_\ / //_// / //_'/ //     Duration: 23.185    CPU time: 4.922
/   _/                      v5.0.1

Profile at C:\Users\ztykk\AppData\Local\Temp\ipykernel_27452\4120808940.py:2

23.185 <module>  ..\..\..\Temp\ipykernel_27452\4120808940.py:1
└─ 23.185 ChangeNumTreeBART.fit  bart_playground\bart.py:22
   └─ 23.185 NTreeSampler.run  bart_playground\samplers.py:69
      └─ 23.053 NTreeSampler.one_iter  bart_playground\samplers.py:226
         ├─ 9.146 NTreeSampler.log_mh_ratio  bart_playground\samplers.py:219
         │  ├─ 6.876 BARTLikelihood.trees_log_marginal_lkhd_ratio  bart_playground\priors.py:299
         │  │  └─ 6.759 BARTLikelihood.trees_log_marginal_lkhd  bart_playground\priors.py:257
         │  │     ├─ 1.569 Parameters.leaf_basis  bart_playground\params.py:581
         │  │     │  ├─ 0.955 <listcomp>  bart_playground\params.py:591
         │  │     │  │  └─ 0.912 Tree.leaf_basis  bart_playgrou

In [4]:
bart.trace[-1].n_trees

67

In [5]:
rf = RandomForestRegressor()
lr = LinearRegression()
rf.fit(X_train, y_train)
lr.fit(X_train, y_train)

btz = bartz.BART.gbart(np.transpose(X_train), y_train, ntree=100, ndpost=200, nskip=100)
btpred_all = btz.predict(np.transpose(X_test))
btpred = np.mean(np.array(btpred_all), axis=0)

Iteration 100/300 P_grow=0.55 P_prune=0.45 A_grow=0.36 A_prune=0.36 (burnin)
Iteration 200/300 P_grow=0.57 P_prune=0.43 A_grow=0.35 A_prune=0.37
Iteration 300/300 P_grow=0.57 P_prune=0.43 A_grow=0.39 A_prune=0.40


In [6]:
bart.predict(X_test)

array([ 0.21482394, -0.26699378, -0.09046789,  0.03366177,  0.22883975,
       -0.26535576,  0.07187753,  0.24451332,  0.2711581 ,  0.19308552,
        0.23035877,  0.28834725,  0.25594342,  0.17606583,  0.17276658,
        0.25383368,  0.25687272,  0.28959783,  0.27028897,  0.07271488,
        0.29550951,  0.04023117,  0.22163899,  0.30395459,  0.37268176,
        0.26273772, -0.24558592,  0.22060259, -0.04311584,  0.07570632,
        0.01055021,  0.23385197,  0.29068019, -0.06674279,  0.19621228,
        0.18857238,  0.23534077, -0.34132456,  0.21302191,  0.26105411])

In [7]:

models = {"bart" : bart, 
          "rf" : rf, 
          "lr" : lr,
          "btz" : btz}
results = {}
for model_name, model in models.items():
    if model_name == "btz":
        results[model_name] = mean_squared_error(y_test, btpred)
    else:
        results[model_name] = mean_squared_error(y_test, model.predict(X_test))
results

{'bart': 0.058403174970126884,
 'rf': 0.02177684727829322,
 'lr': 0.048045521328019404,
 'btz': 0.02328283761397566}