In [1]:
import numpy as np
import sys
from pyinstrument import Profiler
from sklearn.linear_model import LinearRegression
from sklearn.tree import DecisionTreeRegressor
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
import os
sys.path.append(os.path.dirname(os.getcwd()))
from bart_playground import *

import bartz

In [2]:
proposal_probs = {"grow" : 0.5,
                  "prune" : 0.5}
generator = DataGenerator(n_samples=160, n_features=2, noise=0.1, random_seed=42)
X, y = generator.generate(scenario="piecewise_flat")
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
np.set_printoptions(suppress=True)
print(y_train[:12])

[ 0.50327821  0.60672224  0.26898966  0.55211673  0.50693811  0.66162097
 -0.64127659  0.65112284  0.03487759  0.23276531  0.44055996  0.38216964]


In [3]:
profiler = Profiler()
profiler.start()
bart = ChangeNumTreeBART(ndpost=200, nskip=100, n_trees=200, proposal_probs=proposal_probs)
bart.fit(X_train, y_train)
profiler.stop()
profiler.print()

Iterations:   1%|          | 2/300 [00:06<12:27,  2.51s/it]

<bart_playground.moves.Combine object at 0x000001EED187EA10>
[10865.47754013]


Iterations:   1%|          | 3/300 [00:06<07:13,  1.46s/it]

<bart_playground.moves.Combine object at 0x000001EED187EA10>
[16.88564266]


Iterations:   2%|▏         | 5/300 [00:06<03:31,  1.39it/s]

<bart_playground.moves.Combine object at 0x000001EED1CE6A10>
[58.50680051]


Iterations:   2%|▏         | 6/300 [00:06<02:43,  1.80it/s]

<bart_playground.moves.Combine object at 0x000001EED1D6EA10>
[43.45790246]


Iterations:   4%|▍         | 13/300 [00:08<01:23,  3.45it/s]

<bart_playground.moves.Combine object at 0x000001EED2972610>
[11.4922321]


Iterations:   5%|▍         | 14/300 [00:09<01:22,  3.48it/s]

<bart_playground.moves.Combine object at 0x000001EED2B5EA10>
[9.55764894]


Iterations:   5%|▌         | 15/300 [00:09<01:20,  3.53it/s]

<bart_playground.moves.Combine object at 0x000001EED2D0E890>
[517.24899254]


Iterations:   5%|▌         | 16/300 [00:09<01:20,  3.55it/s]

<bart_playground.moves.Combine object at 0x000001EED2D0EA10>
[71.07600625]


Iterations:   6%|▌         | 17/300 [00:09<01:19,  3.54it/s]

<bart_playground.moves.Combine object at 0x000001EED301BA90>
[14.94165332]


Iterations:   7%|▋         | 20/300 [00:10<01:17,  3.59it/s]

<bart_playground.moves.Combine object at 0x000001EED34D6A10>
[45.72082099]


Iterations:   7%|▋         | 21/300 [00:11<01:17,  3.59it/s]

<bart_playground.moves.Combine object at 0x000001EED36B6A10>
[1284.18429031]


Iterations:   8%|▊         | 23/300 [00:11<01:17,  3.55it/s]

<bart_playground.moves.Combine object at 0x000001EED3AA6A10>
[214.72045037]


Iterations:   8%|▊         | 24/300 [00:11<01:17,  3.58it/s]

<bart_playground.moves.Combine object at 0x000001EED3B8EA10>
[144.61972604]


Iterations:   9%|▉         | 27/300 [00:12<01:15,  3.60it/s]

<bart_playground.moves.Combine object at 0x000001EED400EA10>
[46585.1361601]


Iterations:   9%|▉         | 28/300 [00:13<01:15,  3.59it/s]

<bart_playground.moves.Combine object at 0x000001EED3EC3B50>
[167.33973902]


Iterations:  10%|▉         | 29/300 [00:13<01:16,  3.57it/s]

<bart_playground.moves.Combine object at 0x000001EED44ABCD0>
[1.63637142]


Iterations:  10%|█         | 30/300 [00:13<01:16,  3.51it/s]

<bart_playground.moves.Combine object at 0x000001EED451EA10>
[44.0241946]


Iterations:  11%|█         | 32/300 [00:14<01:15,  3.55it/s]

<bart_playground.moves.Combine object at 0x000001EED474ABD0>
[26.97671712]


Iterations:  11%|█▏        | 34/300 [00:14<01:15,  3.51it/s]

<bart_playground.moves.Combine object at 0x000001EED4A5EA10>
[15.64997471]


Iterations:  12%|█▏        | 35/300 [00:15<01:15,  3.50it/s]

<bart_playground.moves.Combine object at 0x000001EED4FE8190>
[14.01859239]


Iterations:  13%|█▎        | 39/300 [00:16<01:15,  3.48it/s]

<bart_playground.moves.Combine object at 0x000001EED669EA10>
[75.55639787]


Iterations:  13%|█▎        | 40/300 [00:16<01:15,  3.47it/s]

<bart_playground.moves.Combine object at 0x000001EED669EA10>
[16.92739755]


Iterations:  15%|█▍        | 44/300 [00:17<01:13,  3.50it/s]

<bart_playground.moves.Combine object at 0x000001EED69F3210>
[7.12457748]


Iterations:  15%|█▌        | 46/300 [00:18<01:12,  3.49it/s]

<bart_playground.moves.Combine object at 0x000001EED73150D0>
[151.34800423]


Iterations:  17%|█▋        | 52/300 [00:19<01:11,  3.48it/s]

<bart_playground.moves.Combine object at 0x000001EED7DBEA10>
[744.47804932]


Iterations:  19%|█▉        | 57/300 [00:21<01:12,  3.37it/s]

<bart_playground.moves.Combine object at 0x000001EED85AEA10>
[123.70990432]


Iterations:  20%|█▉        | 59/300 [00:21<01:12,  3.34it/s]

<bart_playground.moves.Combine object at 0x000001EED8883CD0>
[49.76148443]


Iterations:  21%|██        | 62/300 [00:22<01:09,  3.40it/s]

<bart_playground.moves.Combine object at 0x000001EED9FA1E10>
[177.99328573]


Iterations:  21%|██▏       | 64/300 [00:23<01:09,  3.41it/s]

<bart_playground.moves.Combine object at 0x000001EEDA1EBCD0>
[60.87019065]


Iterations:  22%|██▏       | 65/300 [00:23<01:09,  3.37it/s]

<bart_playground.moves.Combine object at 0x000001EEDA50EA10>
[13.38366095]


Iterations:  22%|██▏       | 67/300 [00:24<01:09,  3.37it/s]

<bart_playground.moves.Combine object at 0x000001EEDA85EA10>
[167.41490581]


Iterations:  23%|██▎       | 69/300 [00:24<01:08,  3.37it/s]

<bart_playground.moves.Combine object at 0x000001EEDAC22BD0>
[7.45109045]


Iterations:  24%|██▍       | 72/300 [00:25<01:06,  3.43it/s]

<bart_playground.moves.Combine object at 0x000001EEDB18EA10>
[531.94690895]


Iterations:  26%|██▌       | 77/300 [00:27<01:07,  3.30it/s]

<bart_playground.moves.Combine object at 0x000001EEDB99EA10>
[17.42057802]


Iterations:  26%|██▌       | 78/300 [00:27<01:07,  3.30it/s]

<bart_playground.moves.Combine object at 0x000001EEDB99EA10>
[109.68665106]


Iterations:  26%|██▋       | 79/300 [00:27<01:05,  3.37it/s]

<bart_playground.moves.Combine object at 0x000001EEDBE76A10>
[109.57176247]


Iterations:  27%|██▋       | 80/300 [00:28<01:04,  3.43it/s]

<bart_playground.moves.Combine object at 0x000001EEDC00EA10>
[78.68904439]


Iterations:  27%|██▋       | 82/300 [00:28<01:02,  3.48it/s]

<bart_playground.moves.Combine object at 0x000001EEDC37EA10>
[174.52830188]


Iterations:  28%|██▊       | 84/300 [00:29<01:03,  3.38it/s]

<bart_playground.moves.Combine object at 0x000001EEDC62EA10>
[52.19569677]


Iterations:  29%|██▊       | 86/300 [00:29<01:02,  3.43it/s]

<bart_playground.moves.Combine object at 0x000001EEDCAAEA10>
[80.94604817]


Iterations:  29%|██▉       | 88/300 [00:30<01:02,  3.39it/s]

<bart_playground.moves.Combine object at 0x000001EEDDD63CD0>
[125.63006139]


Iterations:  30%|██▉       | 89/300 [00:30<01:01,  3.42it/s]

<bart_playground.moves.Combine object at 0x000001EEDE020DD0>
[47.27425197]


Iterations:  31%|███▏      | 94/300 [00:32<00:55,  3.68it/s]

<bart_playground.moves.Combine object at 0x000001EEDE7DFA50>
[296.09604307]


Iterations:  32%|███▏      | 95/300 [00:32<00:55,  3.68it/s]

<bart_playground.moves.Combine object at 0x000001EEDE95EA10>
[237.84813417]


Iterations:  32%|███▏      | 96/300 [00:32<00:54,  3.71it/s]

<bart_playground.moves.Combine object at 0x000001EEDEB66A10>
[42.2350407]


Iterations:  32%|███▏      | 97/300 [00:32<00:54,  3.71it/s]

<bart_playground.moves.Combine object at 0x000001EEDECFDCD0>
[95.52380448]


Iterations:  33%|███▎      | 98/300 [00:33<01:02,  3.25it/s]

<bart_playground.moves.Combine object at 0x000001EEDEE86A10>
[399.57505941]


Iterations:  34%|███▍      | 102/300 [00:34<00:53,  3.69it/s]

<bart_playground.moves.Combine object at 0x000001EEDF386A10>
[657.29007355]


Iterations:  35%|███▍      | 104/300 [00:34<00:52,  3.75it/s]

<bart_playground.moves.Combine object at 0x000001EEDF84EA10>
[26.72237978]


Iterations:  37%|███▋      | 110/300 [00:36<00:52,  3.60it/s]

<bart_playground.moves.Combine object at 0x000001EEE0216A10>
[63.07196731]


Iterations:  37%|███▋      | 111/300 [00:36<00:53,  3.53it/s]

<bart_playground.moves.Combine object at 0x000001EEE02D3CD0>
[7303.24994313]


Iterations:  40%|███▉      | 119/300 [00:39<00:49,  3.64it/s]

<bart_playground.moves.Combine object at 0x000001EEE2126A10>
[19.6807908]


Iterations:  40%|████      | 121/300 [00:39<00:49,  3.63it/s]

<bart_playground.moves.Combine object at 0x000001EEE224EA10>
[223.67506203]


Iterations:  41%|████      | 122/300 [00:39<00:49,  3.62it/s]

<bart_playground.moves.Combine object at 0x000001EEE24CEA10>
[43.49358323]


Iterations:  41%|████      | 123/300 [00:40<00:48,  3.62it/s]

<bart_playground.moves.Combine object at 0x000001EEE262EA10>
[247.25172021]


Iterations:  43%|████▎     | 128/300 [00:41<00:47,  3.63it/s]

<bart_playground.moves.Combine object at 0x000001EEE2F56A10>
[102.67700115]


Iterations:  45%|████▌     | 136/300 [00:43<00:45,  3.57it/s]

<bart_playground.moves.Combine object at 0x000001EEE3D9EA10>
[33.57801144]


Iterations:  46%|████▋     | 139/300 [00:44<00:44,  3.63it/s]

<bart_playground.moves.Combine object at 0x000001EEE3E66A10>
[45.30396496]


Iterations:  47%|████▋     | 140/300 [00:44<00:44,  3.62it/s]

<bart_playground.moves.Break object at 0x000001EEE42DE810>
[0.70755389]


Iterations:  48%|████▊     | 144/300 [00:45<00:41,  3.72it/s]

<bart_playground.moves.Combine object at 0x000001EEE5A16A10>
[158.2179745]


Iterations:  49%|████▊     | 146/300 [00:46<00:41,  3.67it/s]

<bart_playground.moves.Combine object at 0x000001EEE5CCBCD0>
[110.97541026]


Iterations:  49%|████▉     | 148/300 [00:47<00:40,  3.71it/s]

<bart_playground.moves.Combine object at 0x000001EEE5F0EA10>
[8.01514011]


Iterations:  50%|████▉     | 149/300 [00:47<00:40,  3.71it/s]

<bart_playground.moves.Combine object at 0x000001EEE62BABD0>
[4472.0882425]


Iterations:  50%|█████     | 150/300 [00:47<00:40,  3.71it/s]

<bart_playground.moves.Combine object at 0x000001EEE63DABD0>
[74.07604015]


Iterations:  50%|█████     | 151/300 [00:47<00:40,  3.71it/s]

<bart_playground.moves.Combine object at 0x000001EEE65C6A10>
[91.97150904]


Iterations:  52%|█████▏    | 155/300 [00:48<00:37,  3.87it/s]

<bart_playground.moves.Combine object at 0x000001EEE6C71910>
[23.84760622]


Iterations:  52%|█████▏    | 156/300 [00:49<00:37,  3.86it/s]

<bart_playground.moves.Combine object at 0x000001EEE6E029D0>
[312.58693318]


Iterations:  52%|█████▏    | 157/300 [00:49<00:36,  3.88it/s]

<bart_playground.moves.Combine object at 0x000001EEE6F56A10>
[2990.64905305]


Iterations:  53%|█████▎    | 159/300 [00:49<00:35,  3.96it/s]

<bart_playground.moves.Combine object at 0x000001EEE71FABD0>
[250.54821383]


Iterations:  54%|█████▎    | 161/300 [00:50<00:34,  4.01it/s]

<bart_playground.moves.Combine object at 0x000001EEE756EA10>
[14.22819008]


Iterations:  54%|█████▍    | 163/300 [00:50<00:33,  4.10it/s]

<bart_playground.moves.Combine object at 0x000001EEE76EEA10>
[9551.96262048]


Iterations:  55%|█████▌    | 166/300 [00:51<00:32,  4.16it/s]

<bart_playground.moves.Combine object at 0x000001EEE7C4F9D0>
[9.53147077]


Iterations:  56%|█████▌    | 167/300 [00:51<00:31,  4.16it/s]

<bart_playground.moves.Combine object at 0x000001EEE7B5EA10>
[6.77411964]


Iterations:  56%|█████▌    | 168/300 [00:52<00:32,  4.07it/s]

<bart_playground.moves.Combine object at 0x000001EEE7FBABD0>
[81.60361687]


Iterations:  57%|█████▋    | 170/300 [00:52<00:32,  4.03it/s]

<bart_playground.moves.Combine object at 0x000001EEE8123CD0>
[347.67145855]


Iterations:  57%|█████▋    | 172/300 [00:53<00:30,  4.19it/s]

<bart_playground.moves.Combine object at 0x000001EEE9492BD0>
[201.49297288]


Iterations:  58%|█████▊    | 173/300 [00:53<00:29,  4.24it/s]

<bart_playground.moves.Combine object at 0x000001EEE9677A50>
[70.49747627]


Iterations:  58%|█████▊    | 175/300 [00:53<00:28,  4.35it/s]

<bart_playground.moves.Combine object at 0x000001EEE99ACF10>
[8.92957148]


Iterations:  60%|█████▉    | 179/300 [00:54<00:27,  4.37it/s]

<bart_playground.moves.Combine object at 0x000001EEE9F06A10>
[84.87794112]


Iterations:  60%|██████    | 180/300 [00:54<00:27,  4.37it/s]

<bart_playground.moves.Combine object at 0x000001EEEA076A10>
[0.32625116]


Iterations:  61%|██████    | 182/300 [00:55<00:26,  4.41it/s]

<bart_playground.moves.Combine object at 0x000001EEEA1FEA10>
[55.02696594]


Iterations:  62%|██████▏   | 185/300 [00:55<00:25,  4.50it/s]

<bart_playground.moves.Combine object at 0x000001EF0010AA50>
[35.31870332]


Iterations:  62%|██████▏   | 186/300 [00:56<00:24,  4.61it/s]

<bart_playground.moves.Combine object at 0x000001EF002C99D0>
[84.01614747]


Iterations:  63%|██████▎   | 190/300 [00:57<00:23,  4.68it/s]

<bart_playground.moves.Combine object at 0x000001EF0077ABD0>
[4257.38894036]


Iterations:  64%|██████▍   | 192/300 [00:57<00:22,  4.74it/s]

<bart_playground.moves.Combine object at 0x000001EF0077ABD0>
[1658.39678599]


Iterations:  65%|██████▌   | 195/300 [00:58<00:22,  4.69it/s]

<bart_playground.moves.Combine object at 0x000001EF00DF2BD0>
[31.07152182]


Iterations:  65%|██████▌   | 196/300 [00:58<00:21,  4.73it/s]

<bart_playground.moves.Combine object at 0x000001EF00E0EA10>
[21.96267889]


Iterations:  67%|██████▋   | 200/300 [00:59<00:20,  4.93it/s]

<bart_playground.moves.Combine object at 0x000001EF02326A10>
[616.04838557]


Iterations:  67%|██████▋   | 201/300 [00:59<00:19,  4.98it/s]

<bart_playground.moves.Combine object at 0x000001EF0235EA10>
[31.52520437]


Iterations:  69%|██████▊   | 206/300 [01:00<00:18,  5.08it/s]

<bart_playground.moves.Combine object at 0x000001EF02A81010>
[37.34994506]


Iterations:  69%|██████▉   | 208/300 [01:00<00:17,  5.19it/s]

<bart_playground.moves.Combine object at 0x000001EF02C6EA10>
[0.85773658]
<bart_playground.moves.Combine object at 0x000001EF02DCEA10>
[9.07073752]


Iterations:  70%|███████   | 210/300 [01:01<00:17,  5.19it/s]

<bart_playground.moves.Combine object at 0x000001EF02E43CD0>
[0.45217049]


Iterations:  72%|███████▏  | 215/300 [01:01<00:16,  5.20it/s]

<bart_playground.moves.Combine object at 0x000001EF0344EA10>
[1881.94636263]


Iterations:  73%|███████▎  | 218/300 [01:02<00:15,  5.13it/s]

<bart_playground.moves.Combine object at 0x000001EF03716A10>
[2770.56272204]


Iterations:  73%|███████▎  | 219/300 [01:02<00:15,  5.19it/s]

<bart_playground.moves.Combine object at 0x000001EF0398EA10>
[80.00255392]


Iterations:  73%|███████▎  | 220/300 [01:02<00:15,  5.24it/s]

<bart_playground.moves.Combine object at 0x000001EF03A56A10>
[76.02177152]


Iterations:  75%|███████▍  | 224/300 [01:03<00:14,  5.39it/s]

<bart_playground.moves.Combine object at 0x000001EF03F2ABD0>
[1095.54490831]


Iterations:  76%|███████▌  | 228/300 [01:04<00:12,  5.54it/s]

<bart_playground.moves.Combine object at 0x000001EF04366A10>
[91.10634149]


Iterations:  77%|███████▋  | 230/300 [01:04<00:12,  5.55it/s]

<bart_playground.moves.Combine object at 0x000001EF04366A10>
[39.66519377]


Iterations:  77%|███████▋  | 232/300 [01:05<00:11,  5.77it/s]

<bart_playground.moves.Combine object at 0x000001EF040B2BD0>
[9.45193916]


Iterations:  78%|███████▊  | 234/300 [01:05<00:11,  5.79it/s]

<bart_playground.moves.Combine object at 0x000001EF05A0ABD0>
[20.57731001]


Iterations:  79%|███████▊  | 236/300 [01:05<00:10,  5.90it/s]

<bart_playground.moves.Combine object at 0x000001EF05C39450>
[36.09870609]


Iterations:  79%|███████▉  | 238/300 [01:06<00:10,  5.91it/s]

<bart_playground.moves.Combine object at 0x000001EF05E1F650>
[1.65810246]


Iterations:  80%|████████  | 240/300 [01:06<00:10,  5.94it/s]

<bart_playground.moves.Combine object at 0x000001EF05E8EA10>
[298.67606257]


Iterations:  81%|████████  | 243/300 [01:06<00:09,  6.21it/s]

<bart_playground.moves.Combine object at 0x000001EF06322BD0>
[174.03909792]
<bart_playground.moves.Combine object at 0x000001EF05AF3CD0>
[3.77339063]


Iterations:  82%|████████▏ | 245/300 [01:07<00:08,  6.29it/s]

<bart_playground.moves.Combine object at 0x000001EF064F68D0>
[17.22081961]


Iterations:  84%|████████▎ | 251/300 [01:08<00:07,  6.57it/s]

<bart_playground.moves.Combine object at 0x000001EF06A26A10>
[57.2713909]
<bart_playground.moves.Combine object at 0x000001EF06A26A10>
[46.89240403]


Iterations:  84%|████████▍ | 253/300 [01:08<00:07,  6.67it/s]

<bart_playground.moves.Combine object at 0x000001EF06BB6A10>
[26.17594007]


Iterations:  87%|████████▋ | 262/300 [01:09<00:05,  6.91it/s]

<bart_playground.moves.Combine object at 0x000001EF072DEA10>
[1968.19889869]
<bart_playground.moves.Combine object at 0x000001EF074DBCD0>
[4.71780699]


Iterations:  88%|████████▊ | 265/300 [01:10<00:05,  6.83it/s]

<bart_playground.moves.Combine object at 0x000001EF07717ED0>
[60.39074411]


Iterations:  89%|████████▉ | 267/300 [01:10<00:04,  7.09it/s]

<bart_playground.moves.Combine object at 0x000001EF0789BCD0>
[4237.12916704]


Iterations:  90%|█████████ | 271/300 [01:10<00:03,  7.48it/s]

<bart_playground.moves.Combine object at 0x000001EF08B76A10>
[4.81601352]


Iterations:  92%|█████████▏| 276/300 [01:11<00:03,  7.70it/s]

<bart_playground.moves.Combine object at 0x000001EF08F9EA10>
[58.77956328]
<bart_playground.moves.Combine object at 0x000001EF0907EA10>
[66.45171703]


Iterations:  93%|█████████▎| 278/300 [01:11<00:02,  7.68it/s]

<bart_playground.moves.Combine object at 0x000001EF091D6A10>
[21.08126491]
<bart_playground.moves.Combine object at 0x000001EF076B03D0>
[34.7975915]


Iterations:  93%|█████████▎| 280/300 [01:12<00:02,  7.86it/s]

<bart_playground.moves.Combine object at 0x000001EF09322BD0>
[5.36829687]


Iterations:  95%|█████████▌| 286/300 [01:12<00:01,  8.18it/s]

<bart_playground.moves.Combine object at 0x000001EF0977EA10>
[28.58946789]
<bart_playground.moves.Combine object at 0x000001EF0977EA10>
[86.75362432]


Iterations:  99%|█████████▉| 297/300 [01:14<00:00,  8.37it/s]

<bart_playground.moves.Combine object at 0x000001EF09FF6A10>
[293.62481446]
<bart_playground.moves.Combine object at 0x000001EF0A0BFF90>
[36.12678664]


Iterations: 100%|██████████| 300/300 [01:14<00:00,  4.02it/s]



  _     ._   __/__   _ _  _  _ _/_   Recorded: 02:58:28  Samples:  69637
 /_//_/// /_\ / //_// / //_'/ //     Duration: 74.583    CPU time: 18.531
/   _/                      v5.0.1

Profile at C:\Users\ztykk\AppData\Local\Temp\ipykernel_31284\4120808940.py:2

74.581 MainThread  <thread>:31024
└─ 73.740 _run_module_as_main  <frozen runpy>:173
      [19 frames hidden]  <frozen runpy>, ipykernel_launcher, t...
         73.740 ZMQInteractiveShell.run_ast_nodes  IPython\core\interactiveshell.py:3420
         └─ 73.740 <module>  ..\..\..\Temp\ipykernel_31284\4120808940.py:1
            └─ 73.740 ChangeNumTreeBART.fit  bart_playground\bart.py:25
               └─ 73.739 NTreeSampler.run  bart_playground\samplers.py:74
                  └─ 73.428 NTreeSampler.one_iter  bart_playground\samplers.py:288
                     ├─ 64.035 NTreeSampler.log_mh_ratio  bart_playground\samplers.py:280
                     │  ├─ 58.386 LeafValPrior.leaf_vals_log_prior_ratio  bart_playground\priors.py:335


In [4]:
bart.trace[-1].n_trees

77

In [5]:
rf = RandomForestRegressor()
lr = LinearRegression()
rf.fit(X_train, y_train)
lr.fit(X_train, y_train)

btz = bartz.BART.gbart(np.transpose(X_train), y_train, ntree=100, ndpost=200, nskip=100)
btpred_all = btz.predict(np.transpose(X_test))
btpred = np.mean(np.array(btpred_all), axis=0)

Iteration 100/300 P_grow=0.55 P_prune=0.45 A_grow=0.36 A_prune=0.36 (burnin)
Iteration 200/300 P_grow=0.57 P_prune=0.43 A_grow=0.35 A_prune=0.37
Iteration 300/300 P_grow=0.57 P_prune=0.43 A_grow=0.39 A_prune=0.40


In [6]:
bart.predict(X_test)

array([ 0.10701153, -0.37168918, -0.09408964, -0.01623137,  0.23304708,
       -0.37697258, -0.07291095,  0.29328867,  0.12260477,  0.19869506,
        0.22055105,  0.31708713,  0.14132524,  0.17105998, -0.00280411,
        0.16150211,  0.12924694,  0.13259508,  0.16554859,  0.12300439,
        0.12516578,  0.11175026,  0.1513158 ,  0.13625087,  0.2378987 ,
        0.20829832, -0.29625306,  0.23833624, -0.19933885,  0.13483923,
        0.09734963,  0.20335612,  0.1677779 ,  0.05705412,  0.32767282,
        0.18301567,  0.25916432, -0.40542409,  0.09705688,  0.11091615])

In [7]:

models = {"bart" : bart, 
          "rf" : rf, 
          "lr" : lr,
          "btz" : btz}
results = {}
for model_name, model in models.items():
    if model_name == "btz":
        results[model_name] = mean_squared_error(y_test, btpred)
    else:
        results[model_name] = mean_squared_error(y_test, model.predict(X_test))
results

{'bart': 0.08483996019293796,
 'rf': 0.02347109382463492,
 'lr': 0.048045521328019404,
 'btz': 0.02328283761397566}