In [1]:
import numpy as np
import sys
from pyinstrument import Profiler
from sklearn.linear_model import LinearRegression
from sklearn.tree import DecisionTreeRegressor
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
import os
sys.path.append(os.path.dirname(os.getcwd()))
from bart_playground import *

import bartz

In [2]:
proposal_probs = {"grow" : 0.5,
                  "prune" : 0.5}
generator = DataGenerator(n_samples=160, n_features=2, noise=0.1, random_seed=42)
X, y = generator.generate(scenario="piecewise_flat")
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
np.set_printoptions(suppress=True)
print(y_train[:12])

[ 0.50327821  0.60672224  0.26898966  0.55211673  0.50693811  0.66162097
 -0.64127659  0.65112284  0.03487759  0.23276531  0.44055996  0.38216964]


In [3]:
profiler = Profiler()
profiler.start()
bart = ChangeNumTreeBART(ndpost=200, nskip=100, n_trees=200, proposal_probs=proposal_probs)
bart.fit(X_train, y_train)
profiler.stop()
profiler.print()

Iterations:   1%|▏         | 4/300 [00:01<02:22,  2.08it/s]

<bart_playground.moves.Combine object at 0x000002023BF51DE0>
[112.41811441]


Iterations:   2%|▏         | 5/300 [00:02<02:30,  1.96it/s]

<bart_playground.moves.Combine object at 0x000002023BEEE470>
[44.831226]


Iterations:   2%|▏         | 6/300 [00:02<02:37,  1.87it/s]

<bart_playground.moves.Combine object at 0x00000202535E2D40>
[275.72162132]


Iterations:   2%|▏         | 7/300 [00:03<02:36,  1.87it/s]

<bart_playground.moves.Combine object at 0x000002023BF520B0>
[554.8807424]


Iterations:   3%|▎         | 9/300 [00:04<02:38,  1.84it/s]

<bart_playground.moves.Combine object at 0x0000020253F560E0>
[2.99094199]


Iterations:   4%|▎         | 11/300 [00:05<02:47,  1.72it/s]

<bart_playground.moves.Combine object at 0x0000020253F55A80>
[1.03177664]


Iterations:   4%|▍         | 12/300 [00:06<02:48,  1.71it/s]

<bart_playground.moves.Combine object at 0x00000202545ECC70>
[153.23218457]


Iterations:   4%|▍         | 13/300 [00:07<02:47,  1.71it/s]

<bart_playground.moves.Combine object at 0x00000202545EEA10>
[65.93976785]


Iterations:   5%|▌         | 15/300 [00:08<02:46,  1.71it/s]

<bart_playground.moves.Combine object at 0x0000020254B67370>
[11.23139931]


Iterations:   5%|▌         | 16/300 [00:08<02:44,  1.73it/s]

<bart_playground.moves.Combine object at 0x000002025503BCA0>
[134.81528915]


Iterations:   6%|▌         | 17/300 [00:09<02:47,  1.69it/s]

<bart_playground.moves.Combine object at 0x00000202557006D0>
[8.08327701]


Iterations:   6%|▋         | 19/300 [00:10<02:44,  1.71it/s]

<bart_playground.moves.Combine object at 0x000002025503B910>
[235.68809831]


Iterations:   7%|▋         | 21/300 [00:11<02:38,  1.76it/s]

<bart_playground.moves.Combine object at 0x000002025684D6C0>
[225.51564808]


Iterations:   8%|▊         | 25/300 [00:13<02:36,  1.76it/s]

<bart_playground.moves.Combine object at 0x0000020257E62AA0>
[154.86304643]


Iterations:   9%|▊         | 26/300 [00:14<02:34,  1.77it/s]

<bart_playground.moves.Combine object at 0x0000020257E62AA0>
[37.4115415]


Iterations:   9%|▉         | 27/300 [00:15<02:39,  1.72it/s]

<bart_playground.moves.Combine object at 0x0000020257E62AA0>
[18.97049357]


Iterations:  11%|█         | 32/300 [00:18<02:37,  1.70it/s]

<bart_playground.moves.Combine object at 0x00000202590AEF20>
[104.89002927]


Iterations:  11%|█         | 33/300 [00:18<02:38,  1.69it/s]

<bart_playground.moves.Combine object at 0x00000202590AEF20>
[9.80673439]


Iterations:  11%|█▏        | 34/300 [00:19<02:39,  1.67it/s]

<bart_playground.moves.Combine object at 0x0000020259CE4100>
[65.30335774]


Iterations:  12%|█▏        | 35/300 [00:19<02:38,  1.67it/s]

<bart_playground.moves.Combine object at 0x0000020259CE4100>
[42.31078566]


Iterations:  13%|█▎        | 38/300 [00:21<02:38,  1.65it/s]

<bart_playground.moves.Combine object at 0x000002025A3B4250>
[188.90589435]


Iterations:  13%|█▎        | 39/300 [00:22<02:36,  1.67it/s]

<bart_playground.moves.Combine object at 0x000002025AAA34F0>
[667.25531151]


Iterations:  13%|█▎        | 40/300 [00:22<02:37,  1.65it/s]

<bart_playground.moves.Combine object at 0x000002025AFA9180>
[90.225664]


Iterations:  14%|█▍        | 42/300 [00:24<02:42,  1.59it/s]

<bart_playground.moves.Combine object at 0x000002025B661210>
[215.63550085]


Iterations:  15%|█▍        | 44/300 [00:25<02:32,  1.68it/s]

<bart_playground.moves.Combine object at 0x000002025BC99210>
[180.23417755]


Iterations:  15%|█▌        | 46/300 [00:26<02:40,  1.59it/s]

<bart_playground.moves.Combine object at 0x000002025C32DB10>
[19.45854125]


Iterations:  16%|█▌        | 47/300 [00:27<02:33,  1.64it/s]

<bart_playground.moves.Combine object at 0x000002025C32D990>
[29.85850081]


Iterations:  16%|█▌        | 48/300 [00:27<02:34,  1.63it/s]

<bart_playground.moves.Combine object at 0x000002025C32FC10>
[241.11964282]


Iterations:  20%|██        | 60/300 [00:34<02:16,  1.75it/s]

<bart_playground.moves.Combine object at 0x00000202608F6C80>
[3.23635636]


Iterations:  21%|██        | 63/300 [00:36<02:18,  1.71it/s]

<bart_playground.moves.Combine object at 0x0000020260FFEB00>
[413.12434526]


Iterations:  21%|██▏       | 64/300 [00:37<02:18,  1.71it/s]

<bart_playground.moves.Combine object at 0x0000020260FFCA60>
[10.23849461]


Iterations:  22%|██▏       | 67/300 [00:38<02:14,  1.73it/s]

<bart_playground.moves.Combine object at 0x0000020262366260>
[40.71933059]


Iterations:  23%|██▎       | 68/300 [00:39<02:14,  1.73it/s]

<bart_playground.moves.Combine object at 0x0000020262367460>
[343.72402743]


Iterations:  24%|██▍       | 72/300 [00:41<02:15,  1.69it/s]

<bart_playground.moves.Combine object at 0x00000202630C0E50>
[5896.81213258]


Iterations:  24%|██▍       | 73/300 [00:42<02:13,  1.69it/s]

<bart_playground.moves.Combine object at 0x0000020263752FB0>
[12.72228217]


Iterations:  25%|██▌       | 75/300 [00:43<02:13,  1.69it/s]

<bart_playground.moves.Combine object at 0x0000020264E79360>
[4.29981348]


Iterations:  26%|██▌       | 77/300 [00:44<02:09,  1.72it/s]

<bart_playground.moves.Combine object at 0x0000020264E78310>
[9.52567354]


Iterations:  27%|██▋       | 80/300 [00:46<02:06,  1.74it/s]

<bart_playground.moves.Combine object at 0x0000020265C6EA70>
[16.71314335]


Iterations:  27%|██▋       | 82/300 [00:47<02:03,  1.76it/s]

<bart_playground.moves.Combine object at 0x00000202663801F0>
[958.88269127]


Iterations:  29%|██▊       | 86/300 [00:49<02:02,  1.74it/s]

<bart_playground.moves.Combine object at 0x000002026774CA90>
[202.00470531]


Iterations:  29%|██▉       | 88/300 [00:50<01:59,  1.78it/s]

<bart_playground.moves.Combine object at 0x000002026774CA30>
[114.05095295]


Iterations:  30%|██▉       | 89/300 [00:51<01:56,  1.81it/s]

<bart_playground.moves.Combine object at 0x000002026774E290>
[60.87631394]


Iterations:  30%|███       | 90/300 [00:51<01:56,  1.81it/s]

<bart_playground.moves.Combine object at 0x000002026774DE10>
[35.17123729]


Iterations:  30%|███       | 91/300 [00:52<01:55,  1.81it/s]

<bart_playground.moves.Combine object at 0x0000020268457640>
[199.99460712]


Iterations:  31%|███       | 92/300 [00:53<01:53,  1.84it/s]

<bart_playground.moves.Combine object at 0x0000020268457FD0>
[85.57677653]


Iterations:  32%|███▏      | 95/300 [00:54<01:58,  1.73it/s]

<bart_playground.moves.Combine object at 0x00000202691255D0>
[31.69540786]


Iterations:  33%|███▎      | 99/300 [00:56<01:47,  1.88it/s]

<bart_playground.moves.Combine object at 0x000002026A84AF50>
[39.25670305]


Iterations:  34%|███▎      | 101/300 [00:57<01:45,  1.88it/s]

<bart_playground.moves.Combine object at 0x0000020269126770>
[747.38206746]


Iterations:  34%|███▍      | 103/300 [00:59<01:47,  1.83it/s]

<bart_playground.moves.Combine object at 0x0000020269126770>
[82.013487]


Iterations:  35%|███▍      | 104/300 [00:59<01:47,  1.83it/s]

<bart_playground.moves.Combine object at 0x0000020269126770>
[100.3979734]


Iterations:  35%|███▌      | 105/300 [01:00<01:46,  1.84it/s]

<bart_playground.moves.Combine object at 0x000002026B7125C0>
[50.76205803]


Iterations:  36%|███▌      | 107/300 [01:01<01:39,  1.94it/s]

<bart_playground.moves.Combine object at 0x000002026BD1F9A0>
[89.06509168]


Iterations:  36%|███▌      | 108/300 [01:01<01:39,  1.92it/s]

<bart_playground.moves.Combine object at 0x000002026C39F850>
[8.58041087]


Iterations:  36%|███▋      | 109/300 [01:02<01:37,  1.96it/s]

<bart_playground.moves.Combine object at 0x000002026C39F850>
[77.76776198]


Iterations:  37%|███▋      | 111/300 [01:03<01:40,  1.88it/s]

<bart_playground.moves.Combine object at 0x000002026D259ED0>
[6.88737363]


Iterations:  38%|███▊      | 113/300 [01:04<01:41,  1.84it/s]

<bart_playground.moves.Combine object at 0x000002026D899B10>
[68.81867452]


Iterations:  40%|████      | 121/300 [01:08<01:25,  2.08it/s]

<bart_playground.moves.Combine object at 0x000002026ED8B3A0>
[27.9254107]


Iterations:  41%|████      | 122/300 [01:08<01:26,  2.05it/s]

<bart_playground.moves.Combine object at 0x00000202703E8640>
[111.5669366]


Iterations:  41%|████      | 123/300 [01:09<01:26,  2.04it/s]

<bart_playground.moves.Combine object at 0x000002026ED8B9D0>
[10.03342706]


Iterations:  42%|████▏     | 126/300 [01:10<01:25,  2.03it/s]

<bart_playground.moves.Combine object at 0x0000020270B437F0>
[51.15706333]


Iterations:  43%|████▎     | 128/300 [01:11<01:26,  2.00it/s]

<bart_playground.moves.Combine object at 0x0000020270B43D30>
[6555.46129109]


Iterations:  44%|████▍     | 132/300 [01:13<01:20,  2.08it/s]

<bart_playground.moves.Combine object at 0x0000020270B43D30>
[2413.46990495]


Iterations:  44%|████▍     | 133/300 [01:14<01:20,  2.08it/s]

<bart_playground.moves.Combine object at 0x00000202717E6C80>
[4.64675848]


Iterations:  45%|████▌     | 135/300 [01:15<01:18,  2.10it/s]

<bart_playground.moves.Combine object at 0x00000202717E5420>
[192.54238112]


Iterations:  46%|████▌     | 137/300 [01:16<01:15,  2.17it/s]

<bart_playground.moves.Combine object at 0x00000202717E5420>
[5.63672344]


Iterations:  46%|████▌     | 138/300 [01:16<01:13,  2.19it/s]

<bart_playground.moves.Combine object at 0x0000020272CE79A0>
[749.32621212]


Iterations:  48%|████▊     | 145/300 [01:19<01:11,  2.17it/s]

<bart_playground.moves.Combine object at 0x0000020273A04940>
[178.74967222]


Iterations:  49%|████▉     | 147/300 [01:20<01:09,  2.21it/s]

<bart_playground.moves.Combine object at 0x000002027401B0A0>
[8.50223276]


Iterations:  49%|████▉     | 148/300 [01:21<01:08,  2.23it/s]

<bart_playground.moves.Combine object at 0x000002027570A0E0>
[54.8604021]


Iterations:  51%|█████     | 152/300 [01:22<01:04,  2.31it/s]

<bart_playground.moves.Combine object at 0x00000202764563E0>
[255.14002596]


Iterations:  51%|█████▏    | 154/300 [01:23<01:03,  2.31it/s]

<bart_playground.moves.Combine object at 0x0000020276457C40>
[3723.56780008]


Iterations:  52%|█████▏    | 155/300 [01:24<01:03,  2.29it/s]

<bart_playground.moves.Combine object at 0x0000020276A749A0>
[55.50265641]


Iterations:  53%|█████▎    | 159/300 [01:25<01:00,  2.31it/s]

<bart_playground.moves.Combine object at 0x00000202777152A0>
[14.82095618]


Iterations:  54%|█████▍    | 163/300 [01:27<00:57,  2.39it/s]

<bart_playground.moves.Combine object at 0x0000020277717FA0>
[24.42680147]


Iterations:  55%|█████▌    | 166/300 [01:28<00:56,  2.36it/s]

<bart_playground.moves.Combine object at 0x000002027853E770>
[16.19700585]


Iterations:  56%|█████▌    | 167/300 [01:29<00:56,  2.35it/s]

<bart_playground.moves.Combine object at 0x000002027853E2F0>
[6839.74277997]


Iterations:  56%|█████▌    | 168/300 [01:29<00:56,  2.34it/s]

<bart_playground.moves.Combine object at 0x000002027853DD20>
[3.06127597]


Iterations:  57%|█████▋    | 170/300 [01:30<00:53,  2.41it/s]

<bart_playground.moves.Combine object at 0x0000020278ADE800>
[40.32691125]


Iterations:  57%|█████▋    | 172/300 [01:31<00:51,  2.47it/s]

<bart_playground.moves.Combine object at 0x0000020279277460>
[320.1274391]


Iterations:  59%|█████▉    | 177/300 [01:33<00:50,  2.44it/s]

<bart_playground.moves.Combine object at 0x000002027B573EE0>
[128.69542476]


Iterations:  60%|█████▉    | 179/300 [01:34<00:48,  2.48it/s]

<bart_playground.moves.Combine object at 0x000002027B572FB0>
[156.06258001]


Iterations:  61%|██████▏   | 184/300 [01:36<00:45,  2.55it/s]

<bart_playground.moves.Combine object at 0x000002027C2A8F70>
[56.39614318]


Iterations:  62%|██████▏   | 186/300 [01:36<00:43,  2.60it/s]

<bart_playground.moves.Combine object at 0x000002027C8B9DB0>
[101.32253561]


Iterations:  63%|██████▎   | 189/300 [01:38<00:44,  2.49it/s]

<bart_playground.moves.Combine object at 0x000002027C8BB760>
[76.49914794]


Iterations:  64%|██████▍   | 193/300 [01:39<00:41,  2.55it/s]

<bart_playground.moves.Combine object at 0x000002027D589990>
[50.01881429]


Iterations:  65%|██████▍   | 194/300 [01:40<00:41,  2.55it/s]

<bart_playground.moves.Combine object at 0x000002027D58BD90>
[26.15497396]


Iterations:  65%|██████▌   | 195/300 [01:40<00:42,  2.45it/s]

<bart_playground.moves.Combine object at 0x000002027ED51F00>
[41.77145767]


Iterations:  65%|██████▌   | 196/300 [01:40<00:41,  2.48it/s]

<bart_playground.moves.Combine object at 0x000002027ED514B0>
[46.78339353]


Iterations:  66%|██████▌   | 197/300 [01:41<00:40,  2.54it/s]

<bart_playground.moves.Combine object at 0x000002027ED51480>
[178.03764257]


Iterations:  66%|██████▌   | 198/300 [01:41<00:39,  2.58it/s]

<bart_playground.moves.Combine object at 0x000002027F3D4100>
[3468.25429834]


Iterations:  67%|██████▋   | 201/300 [01:42<00:37,  2.61it/s]

<bart_playground.moves.Combine object at 0x000002027F3D6BC0>
[79.66902685]


Iterations:  67%|██████▋   | 202/300 [01:43<00:39,  2.47it/s]

<bart_playground.moves.Combine object at 0x000002027F3D7070>
[71.26478025]


Iterations:  69%|██████▊   | 206/300 [01:44<00:39,  2.37it/s]

<bart_playground.moves.Combine object at 0x000002027BCA4EE0>
[95.1646303]


Iterations:  69%|██████▉   | 207/300 [01:45<00:37,  2.51it/s]

<bart_playground.moves.Combine object at 0x000002027FA4EB60>
[36.42419903]


Iterations:  69%|██████▉   | 208/300 [01:45<00:35,  2.62it/s]

<bart_playground.moves.Combine object at 0x000002023BE9ED10>
[8.28401547]


Iterations:  70%|███████   | 211/300 [01:46<00:30,  2.94it/s]

<bart_playground.moves.Combine object at 0x0000020200917BE0>
[66.37017779]


Iterations:  71%|███████▏  | 214/300 [01:47<00:28,  3.06it/s]

<bart_playground.moves.Combine object at 0x00000202015AD0F0>
[3251.84364762]


Iterations:  72%|███████▏  | 216/300 [01:48<00:27,  3.11it/s]

<bart_playground.moves.Combine object at 0x0000020200F9AAA0>
[42.93828724]


Iterations:  72%|███████▏  | 217/300 [01:48<00:26,  3.12it/s]

<bart_playground.moves.Combine object at 0x00000202015AD0F0>
[0.87295595]


Iterations:  75%|███████▍  | 224/300 [01:50<00:23,  3.25it/s]

<bart_playground.moves.Combine object at 0x0000020203836140>
[64.66774896]


Iterations:  75%|███████▌  | 226/300 [01:51<00:22,  3.29it/s]

<bart_playground.moves.Combine object at 0x0000020203DCC820>
[17.44767455]


Iterations:  76%|███████▌  | 227/300 [01:51<00:23,  3.14it/s]

<bart_playground.moves.Combine object at 0x0000020203DCFEB0>
[120.96696703]


Iterations:  76%|███████▌  | 228/300 [01:51<00:22,  3.23it/s]

<bart_playground.moves.Combine object at 0x0000020203DCF2B0>
[2.10145021]


Iterations:  77%|███████▋  | 231/300 [01:52<00:19,  3.49it/s]

<bart_playground.moves.Combine object at 0x0000020204456EF0>
[139.04308893]


Iterations:  78%|███████▊  | 233/300 [01:53<00:19,  3.51it/s]

<bart_playground.moves.Combine object at 0x0000020204457E50>
[78.60148997]


Iterations:  78%|███████▊  | 235/300 [01:53<00:18,  3.53it/s]

<bart_playground.moves.Combine object at 0x0000020204A2E140>
[23.94988138]


Iterations:  79%|███████▊  | 236/300 [01:54<00:18,  3.50it/s]

<bart_playground.moves.Combine object at 0x0000020204A2E380>
[223.68636924]


Iterations:  79%|███████▉  | 238/300 [01:54<00:17,  3.60it/s]

<bart_playground.moves.Combine object at 0x0000020204FE2AA0>
[41.63160357]


Iterations:  80%|████████  | 241/300 [01:55<00:15,  3.69it/s]

<bart_playground.moves.Combine object at 0x00000202055E2020>
[69.86576842]


Iterations:  81%|████████  | 242/300 [01:55<00:16,  3.52it/s]

<bart_playground.moves.Combine object at 0x00000202055E3910>
[30.07442671]


Iterations:  81%|████████  | 243/300 [01:56<00:18,  3.15it/s]

<bart_playground.moves.Combine object at 0x00000202055E33A0>
[1.45120903]


Iterations:  81%|████████▏ | 244/300 [01:56<00:16,  3.31it/s]

<bart_playground.moves.Combine object at 0x0000020205B909A0>
[22.07125139]


Iterations:  82%|████████▏ | 246/300 [01:56<00:14,  3.65it/s]

<bart_playground.moves.Combine object at 0x0000020205B938E0>
[1270.42795572]


Iterations:  82%|████████▏ | 247/300 [01:57<00:14,  3.76it/s]

<bart_playground.moves.Combine object at 0x0000020205B92D10>
[67.40954897]


Iterations:  83%|████████▎ | 250/300 [01:57<00:12,  4.01it/s]

<bart_playground.moves.Combine object at 0x000002020619EFE0>
[3.37759765]


Iterations:  84%|████████▎ | 251/300 [01:58<00:12,  4.08it/s]

<bart_playground.moves.Combine object at 0x000002020777B610>
[13.64069304]


Iterations:  85%|████████▍ | 254/300 [01:58<00:10,  4.27it/s]

<bart_playground.moves.Combine object at 0x0000020207D26E90>
[3.42395227]


Iterations:  86%|████████▌ | 257/300 [01:59<00:10,  4.29it/s]

<bart_playground.moves.Combine object at 0x0000020207D25FC0>
[1994.82783647]


Iterations:  87%|████████▋ | 260/300 [02:00<00:09,  4.33it/s]

<bart_playground.moves.Combine object at 0x000002020828F760>
[130.71759928]


Iterations:  87%|████████▋ | 261/300 [02:00<00:09,  4.07it/s]

<bart_playground.moves.Combine object at 0x000002020828C460>
[75.54112584]


Iterations:  87%|████████▋ | 262/300 [02:00<00:09,  4.18it/s]

<bart_playground.moves.Combine object at 0x000002020828EE00>
[6.50076441]


Iterations:  88%|████████▊ | 263/300 [02:01<00:09,  3.78it/s]

<bart_playground.moves.Combine object at 0x000002020828E620>
[27.78463838]


Iterations:  88%|████████▊ | 264/300 [02:01<00:09,  3.98it/s]

<bart_playground.moves.Combine object at 0x000002020828EE00>
[936.90431892]


Iterations:  90%|█████████ | 271/300 [02:03<00:06,  4.24it/s]

<bart_playground.moves.Combine object at 0x0000020208E50E80>
[8.42989092]
<bart_playground.moves.Combine object at 0x0000020208E53010>
[113.17236375]


Iterations:  92%|█████████▏| 276/300 [02:04<00:05,  4.36it/s]

<bart_playground.moves.Combine object at 0x0000020209AB7580>
[14.93372206]


Iterations:  93%|█████████▎| 279/300 [02:04<00:04,  4.54it/s]

<bart_playground.moves.Combine object at 0x0000020209AB7280>
[33643.23219126]


Iterations:  94%|█████████▍| 282/300 [02:05<00:03,  4.91it/s]

<bart_playground.moves.Combine object at 0x0000020209AB7970>
[128.04108506]


Iterations:  95%|█████████▍| 284/300 [02:05<00:03,  5.08it/s]

<bart_playground.moves.Combine object at 0x000002020A11AB30>
[85.03981973]


Iterations:  95%|█████████▌| 286/300 [02:06<00:02,  4.91it/s]

<bart_playground.moves.Combine object at 0x000002020A11AC50>
[1.44902503]


Iterations:  96%|█████████▌| 288/300 [02:06<00:02,  5.05it/s]

<bart_playground.moves.Combine object at 0x000002020A11AB30>
[52.59764148]


Iterations:  96%|█████████▋| 289/300 [02:06<00:02,  5.11it/s]

<bart_playground.moves.Combine object at 0x000002020B657F40>
[63.25693336]


Iterations:  98%|█████████▊| 293/300 [02:07<00:01,  5.46it/s]

<bart_playground.moves.Combine object at 0x000002020BB92830>
[5.92306468]


Iterations: 100%|██████████| 300/300 [02:08<00:00,  2.33it/s]



  _     ._   __/__   _ _  _  _ _/_   Recorded: 19:25:35  Samples:  119795
 /_//_/// /_\ / //_// / //_'/ //     Duration: 128.899   CPU time: 127.250
/   _/                      v5.0.1

Profile at C:\Windows\Temp\ipykernel_28564\4120808940.py:2

128.898 <module>  C:\Windows\Temp\ipykernel_28564\4120808940.py:1
└─ 128.898 ChangeNumTreeBART.fit  bart_playground\bart.py:22
   └─ 128.897 NTreeSampler.run  bart_playground\samplers.py:69
      └─ 128.541 NTreeSampler.one_iter  bart_playground\samplers.py:228
         ├─ 106.750 NTreeSampler.log_mh_ratio  bart_playground\samplers.py:220
         │  ├─ 88.111 LeafValPrior.leaf_vals_log_prior_ratio  bart_playground\priors.py:277
         │  │  ├─ 50.471 LeafValPrior.total_leaf_count  bart_playground\priors.py:265
         │  │  │  └─ 49.468 <genexpr>  bart_playground\priors.py:270
         │  │  │     ├─ 46.771 Tree.n_leaves  bart_playground\params.py:441
         │  │  │     │  ├─ 44.482 Tree.leaves  bart_playground\params.py:437
         │  │

In [4]:
bart.trace[-1].n_trees

67

In [5]:
rf = RandomForestRegressor()
lr = LinearRegression()
rf.fit(X_train, y_train)
lr.fit(X_train, y_train)

btz = bartz.BART.gbart(np.transpose(X_train), y_train, ntree=100, ndpost=200, nskip=100)
btpred_all = btz.predict(np.transpose(X_test))
btpred = np.mean(np.array(btpred_all), axis=0)

Iteration 100/300 P_grow=0.55 P_prune=0.45 A_grow=0.36 A_prune=0.36 (burnin)
Iteration 200/300 P_grow=0.57 P_prune=0.43 A_grow=0.35 A_prune=0.37
Iteration 300/300 P_grow=0.57 P_prune=0.43 A_grow=0.39 A_prune=0.40


In [6]:
bart.predict(X_test)

array([ 0.21482394, -0.26699378, -0.09046789,  0.03366177,  0.22883975,
       -0.26535576,  0.07187753,  0.24451332,  0.2711581 ,  0.19308552,
        0.23035877,  0.28834725,  0.25594342,  0.17606583,  0.17276658,
        0.25383368,  0.25687272,  0.28959783,  0.27028897,  0.07271488,
        0.29550951,  0.04023117,  0.22163899,  0.30395459,  0.37268176,
        0.26273772, -0.24558592,  0.22060259, -0.04311584,  0.07570632,
        0.01055021,  0.23385197,  0.29068019, -0.06674279,  0.19621228,
        0.18857238,  0.23534077, -0.34132456,  0.21302191,  0.26105411])

In [7]:

models = {"bart" : bart, 
          "rf" : rf, 
          "lr" : lr,
          "btz" : btz}
results = {}
for model_name, model in models.items():
    if model_name == "btz":
        results[model_name] = mean_squared_error(y_test, btpred)
    else:
        results[model_name] = mean_squared_error(y_test, model.predict(X_test))
results

{'bart': 0.05840317497012706,
 'rf': 0.022411429186007883,
 'lr': 0.048045521328019404,
 'btz': 0.02328283761397566}