In [1]:
import numpy as np
import sys
from pyinstrument import Profiler
from sklearn.linear_model import LinearRegression
from sklearn.tree import DecisionTreeRegressor
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
import os
sys.path.append(os.path.dirname(os.getcwd()))
from bart_playground import *

import bartz

In [2]:
proposal_probs = {"grow" : 0.4,
                  "prune" : 0.4,
                  "change" : 0.1,
                  "swap" : 0.1}
generator = DataGenerator(n_samples=160, n_features=2, noise=0.1, random_seed=42)
X, y = generator.generate(scenario="piecewise_flat")
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
np.set_printoptions(suppress=True)
print(y_train[:12])

[ 0.50327821  0.60672224  0.26898966  0.55211673  0.50693811  0.66162097
 -0.64127659  0.65112284  0.03487759  0.23276531  0.44055996  0.38216964]


In [3]:
profiler = Profiler()
profiler.start()
bart = ChangeNumTreeBART(ndpost=200, nskip=100, n_trees=200, proposal_probs=proposal_probs)
bart.fit(X_train, y_train)
profiler.stop()
profiler.print()

Iterations:   1%|          | 2/300 [00:00<00:45,  6.53it/s]

<bart_playground.moves.Combine object at 0x0000015453829E70>


Iterations:   1%|▏         | 4/300 [00:00<00:49,  6.03it/s]

<bart_playground.moves.Combine object at 0x000001545382A2F0>


Iterations:   2%|▏         | 5/300 [00:00<00:54,  5.38it/s]

<bart_playground.moves.Combine object at 0x000001545382BF10>


Iterations:   2%|▏         | 7/300 [00:01<01:06,  4.41it/s]

<bart_playground.moves.Combine object at 0x000001546AC0EB30>


Iterations:   3%|▎         | 8/300 [00:01<01:18,  3.70it/s]

<bart_playground.moves.Combine object at 0x000001546AC0EB30>


Iterations:   3%|▎         | 9/300 [00:02<01:20,  3.60it/s]

<bart_playground.moves.Combine object at 0x000001546AC0EB30>


Iterations:   4%|▎         | 11/300 [00:02<01:15,  3.83it/s]

<bart_playground.moves.Combine object at 0x000001546B324AF0>


Iterations:   4%|▍         | 12/300 [00:02<01:11,  4.00it/s]

<bart_playground.moves.Combine object at 0x000001546B6DCBB0>


Iterations:   5%|▍         | 14/300 [00:03<01:10,  4.07it/s]

<bart_playground.moves.Combine object at 0x000001546B324AF0>


Iterations:   5%|▌         | 15/300 [00:03<01:19,  3.59it/s]

<bart_playground.moves.Combine object at 0x000001546B93CF10>


Iterations:   6%|▌         | 17/300 [00:04<01:18,  3.61it/s]

<bart_playground.moves.Combine object at 0x000001546B93C700>


Iterations:   6%|▋         | 19/300 [00:04<01:14,  3.75it/s]

<bart_playground.moves.Combine object at 0x000001546BB89240>


Iterations:   7%|▋         | 20/300 [00:04<01:13,  3.83it/s]

<bart_playground.moves.Combine object at 0x000001546BE8E890>


Iterations:   7%|▋         | 21/300 [00:05<01:16,  3.65it/s]

<bart_playground.moves.Combine object at 0x000001546C23F160>


Iterations:   9%|▊         | 26/300 [00:06<01:01,  4.44it/s]

<bart_playground.moves.Combine object at 0x000001546DA86680>


Iterations:  10%|▉         | 29/300 [00:07<01:03,  4.26it/s]

<bart_playground.moves.Combine object at 0x000001546E096CE0>


Iterations:  12%|█▏        | 35/300 [00:08<01:10,  3.76it/s]

<bart_playground.moves.Combine object at 0x000001546E535960>


Iterations:  12%|█▏        | 37/300 [00:09<01:18,  3.35it/s]

<bart_playground.moves.Combine object at 0x000001546EB9E6B0>


Iterations:  13%|█▎        | 39/300 [00:09<01:14,  3.52it/s]

<bart_playground.moves.Combine object at 0x000001546EB9E290>


Iterations:  15%|█▌        | 45/300 [00:11<01:22,  3.08it/s]

<bart_playground.moves.Combine object at 0x000001546F85DD80>


Iterations:  16%|█▌        | 47/300 [00:12<01:12,  3.49it/s]

<bart_playground.moves.Combine object at 0x000001546FB99810>


Iterations:  17%|█▋        | 50/300 [00:13<01:08,  3.64it/s]

<bart_playground.moves.Combine object at 0x0000015470032500>


Iterations:  17%|█▋        | 51/300 [00:13<01:03,  3.94it/s]

<bart_playground.moves.Combine object at 0x00000154703DDE10>


Iterations:  17%|█▋        | 52/300 [00:13<01:01,  4.01it/s]

<bart_playground.moves.Combine object at 0x00000154703DE050>


Iterations:  19%|█▊        | 56/300 [00:14<00:53,  4.54it/s]

<bart_playground.moves.Combine object at 0x0000015471CD9480>


Iterations:  20%|█▉        | 59/300 [00:15<00:58,  4.12it/s]

<bart_playground.moves.Combine object at 0x0000015472281060>


Iterations:  21%|██▏       | 64/300 [00:17<01:19,  2.96it/s]

<bart_playground.moves.Combine object at 0x0000015472767880>


Iterations:  22%|██▏       | 65/300 [00:17<01:21,  2.88it/s]

<bart_playground.moves.Combine object at 0x0000015472AD9870>


Iterations:  22%|██▏       | 66/300 [00:17<01:21,  2.88it/s]

<bart_playground.moves.Combine object at 0x0000015472ADBB50>


Iterations:  23%|██▎       | 69/300 [00:18<01:06,  3.50it/s]

<bart_playground.moves.Combine object at 0x0000015472ADA200>


Iterations:  24%|██▍       | 72/300 [00:19<01:01,  3.73it/s]

<bart_playground.moves.Combine object at 0x0000015474E6A8F0>
<bart_playground.moves.Combine object at 0x0000015472F3FF40>


Iterations:  25%|██▍       | 74/300 [00:19<00:53,  4.26it/s]

<bart_playground.moves.Combine object at 0x00000154537E2830>


Iterations:  25%|██▌       | 76/300 [00:20<00:44,  5.03it/s]

<bart_playground.moves.Combine object at 0x0000015444042B90>
<bart_playground.moves.Combine object at 0x0000015474E6BBE0>


Iterations:  26%|██▌       | 78/300 [00:20<00:43,  5.05it/s]

<bart_playground.moves.Combine object at 0x0000015475142A40>


Iterations:  27%|██▋       | 82/300 [00:21<00:40,  5.44it/s]

<bart_playground.moves.Combine object at 0x00000154753E3FD0>
<bart_playground.moves.Combine object at 0x00000154753E2680>


Iterations:  29%|██▊       | 86/300 [00:22<00:45,  4.71it/s]

<bart_playground.moves.Combine object at 0x0000015475B5BAC0>


Iterations:  29%|██▉       | 88/300 [00:22<00:46,  4.55it/s]

<bart_playground.moves.Combine object at 0x0000015475D6F940>


Iterations:  30%|███       | 91/300 [00:23<00:44,  4.68it/s]

<bart_playground.moves.Combine object at 0x0000015475FBB430>


Iterations:  31%|███       | 92/300 [00:23<00:42,  4.87it/s]

<bart_playground.moves.Combine object at 0x0000015475FBAEC0>
<bart_playground.moves.Combine object at 0x0000015476297280>


Iterations:  32%|███▏      | 95/300 [00:23<00:40,  5.11it/s]

<bart_playground.moves.Combine object at 0x00000154765B0E50>


Iterations:  32%|███▏      | 97/300 [00:24<00:38,  5.23it/s]

<bart_playground.moves.Combine object at 0x0000015476899480>
<bart_playground.moves.Combine object at 0x000001547689A830>


Iterations:  33%|███▎      | 98/300 [00:24<00:38,  5.18it/s]

<bart_playground.moves.Combine object at 0x000001547689A830>


Iterations:  33%|███▎      | 100/300 [00:24<00:39,  5.07it/s]

<bart_playground.moves.Combine object at 0x0000015476BBC880>


Iterations:  34%|███▎      | 101/300 [00:25<00:39,  5.09it/s]

<bart_playground.moves.Combine object at 0x0000015477E26440>


Iterations:  34%|███▍      | 103/300 [00:25<00:38,  5.08it/s]

<bart_playground.moves.Combine object at 0x000001547808E6B0>


Iterations:  35%|███▌      | 105/300 [00:25<00:36,  5.38it/s]

<bart_playground.moves.Combine object at 0x0000015478336E60>


Iterations:  36%|███▌      | 108/300 [00:26<00:38,  5.05it/s]

<bart_playground.moves.Combine object at 0x00000154788EA7A0>


Iterations:  38%|███▊      | 113/300 [00:27<00:35,  5.32it/s]

<bart_playground.moves.Combine object at 0x00000154788EAD70>


Iterations:  39%|███▊      | 116/300 [00:28<00:33,  5.48it/s]

<bart_playground.moves.Combine object at 0x000001547911ACE0>


Iterations:  39%|███▉      | 118/300 [00:28<00:31,  5.77it/s]

<bart_playground.moves.Combine object at 0x0000015479416140>


Iterations:  40%|████      | 120/300 [00:28<00:32,  5.49it/s]

<bart_playground.moves.Combine object at 0x00000154794166B0>


Iterations:  40%|████      | 121/300 [00:28<00:31,  5.70it/s]

<bart_playground.moves.Combine object at 0x000001547967BDC0>


Iterations:  42%|████▏     | 126/300 [00:29<00:28,  6.20it/s]

<bart_playground.moves.Combine object at 0x000001547AB76B60>
<bart_playground.moves.Combine object at 0x000001547AB77CD0>


Iterations:  43%|████▎     | 130/300 [00:30<00:28,  5.99it/s]

<bart_playground.moves.Combine object at 0x000001547B0C1930>
<bart_playground.moves.Combine object at 0x000001547B0C2DA0>


Iterations:  44%|████▍     | 132/300 [00:30<00:27,  6.08it/s]

<bart_playground.moves.Combine object at 0x000001547B0C1930>


Iterations:  46%|████▌     | 137/300 [00:31<00:25,  6.29it/s]

<bart_playground.moves.Combine object at 0x000001547B845420>


Iterations:  46%|████▋     | 139/300 [00:31<00:24,  6.59it/s]

<bart_playground.moves.Combine object at 0x000001547B8464A0>
<bart_playground.moves.Combine object at 0x000001547B8464A0>


Iterations:  48%|████▊     | 144/300 [00:32<00:23,  6.58it/s]

<bart_playground.moves.Combine object at 0x000001547B8464A0>
<bart_playground.moves.Combine object at 0x000001547BD32B60>


Iterations:  49%|████▉     | 148/300 [00:33<00:25,  6.04it/s]

<bart_playground.moves.Combine object at 0x000001547BFDF460>
<bart_playground.moves.Combine object at 0x000001547C1F93F0>


Iterations:  50%|█████     | 150/300 [00:33<00:23,  6.30it/s]

<bart_playground.moves.Combine object at 0x000001547C4D5EA0>


Iterations:  52%|█████▏    | 155/300 [00:34<00:20,  7.10it/s]

<bart_playground.moves.Combine object at 0x000001547D706B90>


Iterations:  53%|█████▎    | 158/300 [00:34<00:20,  6.89it/s]

<bart_playground.moves.Combine object at 0x000001547DC14CA0>


Iterations:  53%|█████▎    | 160/300 [00:35<00:23,  6.02it/s]

<bart_playground.moves.Combine object at 0x000001547DE95300>


Iterations:  55%|█████▍    | 164/300 [00:35<00:21,  6.34it/s]

<bart_playground.moves.Combine object at 0x000001547DE97640>


Iterations:  55%|█████▌    | 166/300 [00:36<00:20,  6.42it/s]

<bart_playground.moves.Combine object at 0x000001547E39D1B0>
<bart_playground.moves.Combine object at 0x000001547E39D1B0>


Iterations:  56%|█████▋    | 169/300 [00:36<00:19,  6.57it/s]

<bart_playground.moves.Combine object at 0x000001547E6B7790>
<bart_playground.moves.Combine object at 0x000001547E6B77F0>


Iterations:  57%|█████▋    | 172/300 [00:37<00:19,  6.50it/s]

<bart_playground.moves.Combine object at 0x000001547E8FA440>
<bart_playground.moves.Combine object at 0x000001547E8FAB00>


Iterations:  59%|█████▊    | 176/300 [00:37<00:16,  7.71it/s]

<bart_playground.moves.Combine object at 0x000001547EB966B0>
<bart_playground.moves.Combine object at 0x000001547EE86020>


Iterations:  59%|█████▉    | 178/300 [00:37<00:17,  7.06it/s]

<bart_playground.moves.Combine object at 0x000001547EB966B0>
<bart_playground.moves.Combine object at 0x000001547EB966B0>


Iterations:  61%|██████    | 183/300 [00:38<00:16,  7.27it/s]

<bart_playground.moves.Combine object at 0x0000015400891930>


Iterations:  63%|██████▎   | 188/300 [00:39<00:16,  6.83it/s]

<bart_playground.moves.Combine object at 0x0000015400893C70>


Iterations:  64%|██████▍   | 192/300 [00:39<00:16,  6.74it/s]

<bart_playground.moves.Combine object at 0x0000015400B750F0>
<bart_playground.moves.Combine object at 0x0000015400B76920>


Iterations:  65%|██████▌   | 195/300 [00:40<00:15,  6.88it/s]

<bart_playground.moves.Combine object at 0x0000015401126920>
<bart_playground.moves.Break object at 0x0000015401125900>


Iterations:  66%|██████▌   | 198/300 [00:40<00:15,  6.68it/s]

<bart_playground.moves.Combine object at 0x00000154011251E0>
<bart_playground.moves.Combine object at 0x0000015401126D10>


Iterations:  67%|██████▋   | 200/300 [00:41<00:16,  6.14it/s]

<bart_playground.moves.Combine object at 0x00000154013FBD60>


Iterations:  68%|██████▊   | 205/300 [00:41<00:13,  7.24it/s]

<bart_playground.moves.Combine object at 0x000001540170D660>


Iterations:  69%|██████▉   | 208/300 [00:42<00:12,  7.18it/s]

<bart_playground.moves.Combine object at 0x000001540170D660>


Iterations:  70%|███████   | 210/300 [00:42<00:11,  7.90it/s]

<bart_playground.moves.Combine object at 0x0000015402C77070>


Iterations:  71%|███████   | 212/300 [00:42<00:12,  7.10it/s]

<bart_playground.moves.Combine object at 0x0000015402F25930>


Iterations:  72%|███████▏  | 216/300 [00:43<00:10,  7.65it/s]

<bart_playground.moves.Combine object at 0x00000154031CF9D0>
<bart_playground.moves.Combine object at 0x00000154031CF9D0>


Iterations:  73%|███████▎  | 218/300 [00:43<00:11,  7.28it/s]

<bart_playground.moves.Combine object at 0x000001540343F3D0>
<bart_playground.moves.Combine object at 0x000001540343FB50>


Iterations:  73%|███████▎  | 220/300 [00:43<00:10,  7.35it/s]

<bart_playground.moves.Combine object at 0x000001540343E200>


Iterations:  76%|███████▌  | 228/300 [00:45<00:11,  6.39it/s]

<bart_playground.moves.Combine object at 0x0000015403A395D0>


Iterations:  77%|███████▋  | 230/300 [00:45<00:10,  6.82it/s]

<bart_playground.moves.Combine object at 0x0000015403CE3430>
<bart_playground.moves.Combine object at 0x0000015403FA2C50>


Iterations:  78%|███████▊  | 234/300 [00:45<00:08,  7.89it/s]

<bart_playground.moves.Combine object at 0x00000154042E66E0>
<bart_playground.moves.Combine object at 0x0000015403FA2DA0>


Iterations:  79%|███████▉  | 237/300 [00:46<00:07,  8.04it/s]

<bart_playground.moves.Combine object at 0x00000154042E6E30>


Iterations:  80%|███████▉  | 239/300 [00:46<00:07,  7.85it/s]

<bart_playground.moves.Combine object at 0x000001540554B910>
<bart_playground.moves.Combine object at 0x00000154057D58D0>


Iterations:  81%|████████  | 242/300 [00:46<00:06,  8.67it/s]

<bart_playground.moves.Combine object at 0x00000154057D5930>
<bart_playground.moves.Combine object at 0x000001540554A6E0>


Iterations:  82%|████████▏ | 247/300 [00:47<00:06,  8.74it/s]

<bart_playground.moves.Combine object at 0x0000015405AB35E0>
<bart_playground.moves.Combine object at 0x0000015405D942E0>


Iterations:  84%|████████▍ | 252/300 [00:47<00:04,  9.74it/s]

<bart_playground.moves.Combine object at 0x0000015406052E00>
<bart_playground.moves.Combine object at 0x0000015406052E00>


Iterations:  85%|████████▌ | 256/300 [00:48<00:04, 10.70it/s]

<bart_playground.moves.Combine object at 0x0000015406050CD0>
<bart_playground.moves.Combine object at 0x00000154062ADE70>
<bart_playground.moves.Combine object at 0x00000154062AFD90>


Iterations:  87%|████████▋ | 260/300 [00:48<00:03, 10.79it/s]

<bart_playground.moves.Combine object at 0x00000154062ADE70>


Iterations:  87%|████████▋ | 262/300 [00:48<00:03, 10.43it/s]

<bart_playground.moves.Combine object at 0x00000154065330D0>


Iterations:  89%|████████▉ | 268/300 [00:49<00:03,  9.82it/s]

<bart_playground.moves.Combine object at 0x00000154079BE0B0>


Iterations:  91%|█████████ | 272/300 [00:49<00:02,  9.52it/s]

<bart_playground.moves.Combine object at 0x00000154079BFBB0>
<bart_playground.moves.Combine object at 0x00000154079BFBB0>


Iterations:  92%|█████████▏| 276/300 [00:50<00:02,  9.96it/s]

<bart_playground.moves.Combine object at 0x0000015407F914B0>
<bart_playground.moves.Combine object at 0x0000015407F93A60>


Iterations:  93%|█████████▎| 278/300 [00:50<00:02,  9.98it/s]

<bart_playground.moves.Combine object at 0x0000015407F93820>
<bart_playground.moves.Combine object at 0x0000015407F93A60>


Iterations:  93%|█████████▎| 280/300 [00:50<00:01, 10.56it/s]

<bart_playground.moves.Combine object at 0x0000015407F93A60>
<bart_playground.moves.Combine object at 0x0000015408212EC0>


Iterations:  94%|█████████▍| 282/300 [00:50<00:01, 10.44it/s]

<bart_playground.moves.Combine object at 0x0000015408213580>
<bart_playground.moves.Combine object at 0x0000015408213F70>
<bart_playground.moves.Combine object at 0x0000015408213D00>


Iterations:  95%|█████████▌| 286/300 [00:51<00:01, 10.98it/s]

<bart_playground.moves.Combine object at 0x00000154084E6A40>
<bart_playground.moves.Combine object at 0x00000154087A36A0>


Iterations:  97%|█████████▋| 292/300 [00:51<00:00, 10.87it/s]

<bart_playground.moves.Combine object at 0x00000154087A22C0>
<bart_playground.moves.Combine object at 0x00000154087A2A40>


Iterations:  98%|█████████▊| 294/300 [00:51<00:00, 11.79it/s]

<bart_playground.moves.Break object at 0x00000154087A2A40>


Iterations:  99%|█████████▉| 298/300 [00:52<00:00, 12.29it/s]

<bart_playground.moves.Combine object at 0x0000015408A26BC0>
<bart_playground.moves.Combine object at 0x0000015408D25210>
<bart_playground.moves.Combine object at 0x0000015408D27C70>


Iterations: 100%|██████████| 300/300 [00:52<00:00,  5.73it/s]



  _     ._   __/__   _ _  _  _ _/_   Recorded: 18:10:33  Samples:  51436
 /_//_/// /_\ / //_// / //_'/ //     Duration: 52.419    CPU time: 51.172
/   _/                      v5.0.1

Profile at C:\Windows\Temp\ipykernel_25160\4120808940.py:2

52.418 <module>  C:\Windows\Temp\ipykernel_25160\4120808940.py:1
└─ 52.418 ChangeNumTreeBART.fit  bart_playground\bart.py:22
   └─ 52.415 NTreeSampler.run  bart_playground\samplers.py:69
      └─ 52.075 NTreeSampler.one_iter  bart_playground\samplers.py:226
         ├─ 20.728 Swap.propose  bart_playground\moves.py:34
         │  ├─ 7.669 Swap.try_propose  bart_playground\moves.py:146
         │  │  ├─ 4.469 Tree.swap_split  bart_playground\params.py:296
         │  │  │  └─ 4.336 Tree.change_split  bart_playground\params.py:289
         │  │  │     └─ 4.239 Tree.update_n  bart_playground\params.py:384
         │  │  │        ├─ 2.176 Tree.update_n  bart_playground\params.py:384
         │  │  │        │  ├─ 0.951 sum  <__array_function__ internal

In [4]:
rf = RandomForestRegressor()
lr = LinearRegression()
rf.fit(X_train, y_train)
lr.fit(X_train, y_train)

btz = bartz.BART.gbart(np.transpose(X_train), y_train, ntree=100, ndpost=200, nskip=100)
btpred_all = btz.predict(np.transpose(X_test))
btpred = np.mean(np.array(btpred_all), axis=0)

Iteration 100/300 P_grow=0.55 P_prune=0.45 A_grow=0.36 A_prune=0.36 (burnin)
Iteration 200/300 P_grow=0.57 P_prune=0.43 A_grow=0.35 A_prune=0.37
Iteration 300/300 P_grow=0.57 P_prune=0.43 A_grow=0.39 A_prune=0.40


In [5]:
bart.predict(X_test)

TypeError: list indices must be integers or slices, not list

In [None]:

models = {"bart" : bart, 
          "rf" : rf, 
          "lr" : lr,
          "btz" : btz}
results = {}
for model_name, model in models.items():
    if model_name == "btz":
        results[model_name] = mean_squared_error(y_test, btpred)
    else:
        results[model_name] = mean_squared_error(y_test, model.predict(X_test))
results