In [1]:
import numpy as np
import sys
from pyinstrument import Profiler
from sklearn.linear_model import LinearRegression
from sklearn.tree import DecisionTreeRegressor
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
from bart_playground import *
import bartz

In [2]:
proposal_probs = {"grow" : 0.5,
                  "prune" : 0.5}
generator = DataGenerator(n_samples=160, n_features=2, noise=0.01, random_seed=42)
X, y = generator.generate(scenario="piecewise_flat")
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
np.set_printoptions(suppress=True)
print(y_train)

[ 0.50032782  0.51067222  0.47689897  0.50521167  0.50069381  0.5161621
 -0.51412766  0.51511228  0.00348776  0.02327653  0.494056    0.48821696
  0.5084182   0.50459386  0.50827988  0.4878344  -0.01601278  0.49668512
  0.50928297  0.51475949  0.01743935 -0.48679338  0.4770871   0.49394999
  0.49810756  0.49794838  0.50989584  0.49606259  0.51327686  0.50582655
 -0.4827923   0.51353586  0.00856794 -0.00570944  0.50701954  0.4932886
  0.48418406  0.50092127 -0.51013717  0.00637534  0.49641983  0.50524188
  0.50590906  0.00196776 -0.01117384  0.47927017 -0.49638078  0.50232878
  0.50313908  0.00844385  0.51309551  0.50304367  0.48617818  0.50803585
  0.48169094  0.00011426  0.49706808  0.5107624   0.49166631  0.50383394
  0.50438993  0.49154339  0.49868913  0.50270343  0.0041389   0.50276274
  0.49205864  0.00072034  0.51068472 -0.00296571  0.50846584  0.51604254
  0.48959456  0.49930862  0.49295327  0.50054354  0.01951013  0.52076981
  0.50229211  0.00066546 -0.00837398 -0.00190651  0.4

In [4]:
# profiler = Profiler()
# profiler.start()
bart = DefaultBART(ndpost=200, nskip=100, n_trees=100, proposal_probs=proposal_probs)
bart.fit(X_train, y_train)
# profiler.stop()
# profiler.print()

Running iteration 0
Running iteration 10
Running iteration 20
Running iteration 30
Running iteration 40
Running iteration 50
Running iteration 60
Running iteration 70
Running iteration 80
Running iteration 90
Running iteration 100
Running iteration 110
Running iteration 120
Running iteration 130
Running iteration 140
Running iteration 150
Running iteration 160
Running iteration 170
Running iteration 180
Running iteration 190
Running iteration 200
Running iteration 210
Running iteration 220
Running iteration 230
Running iteration 240
Running iteration 250
Running iteration 260
Running iteration 270
Running iteration 280
Running iteration 290


In [5]:
arrays = [tree.vars for tree in bart.trace[-1].trees]
counts = np.array([np.count_nonzero(arr >= 0) for arr in arrays])
print(counts)
deep_trees = np.array([count >= 3 for count in counts])
print(np.where(deep_trees))

[1 0 1 1 2 1 1 1 2 1 2 2 1 1 2 1 1 1 1 0 2 1 0 1 1 1 1 1 2 1 1 1 2 2 0 1 3
 1 0 0 1 1 1 1 1 1 2 1 1 2 2 1 1 3 1 1 1 1 0 1 0 1 2 2 1 1 2 2 3 1 1 1 1 1
 1 5 1 1 2 1 1 1 1 1 2 2 1 1 0 2 1 1 1 1 2 1 1 2 1 1]
(array([36, 53, 68, 75], dtype=int64),)


In [8]:
print(bart.sampler.trace[-1].global_params)

{'eps_sigma2': 0.004743495735960561}


In [16]:
from bart_playground import visualize_tree
tree_sp : Tree = bart.sampler.trace[1].trees[0]

print(tree_sp)
print(tree_sp.vars)
print(tree_sp.leaf_vals)
# print(tree_sp.node_indicators)
# visualize_tree(tree_sp, tree_sp)

X_0 <= 0.830 (split, n = 120)
	Val: -0.047 (leaf, n = 102)
	Val: -0.032 (leaf, n = 18)
[ 0 -1 -1 -2 -2 -2 -2 -2]
[        nan -0.04656084 -0.03189183         nan         nan         nan
         nan         nan]


In [19]:
y_train_le = y_train[X_train[:, 0] <= 0.830]
y_train_gt = y_train[X_train[:, 0] > 0.830]

print("y_train where X_train[:,0] <= 0.830 mean:", y_train_le.mean())
print("y_train where X_train[:,0] > 0.830 mean:", y_train_gt.mean())

y_train where X_train[:,0] <= 0.830 mean: 0.28436069809598674
y_train where X_train[:,0] > 0.830 mean: 0.5028266794583821


In [10]:
rf = RandomForestRegressor()
lr = LinearRegression()
rf.fit(X_train, y_train)
lr.fit(X_train, y_train)

btz = bartz.BART.gbart(np.transpose(X_train), y_train, ntree=100, ndpost=200, nskip=100)
btpred_all = btz.predict(np.transpose(X_test))
btpred = np.mean(np.array(btpred_all), axis=0)

Iteration 100/300 P_grow=0.55 P_prune=0.45 A_grow=0.24 A_prune=0.24 (burnin)
Iteration 200/300 P_grow=0.58 P_prune=0.42 A_grow=0.40 A_prune=0.40
Iteration 300/300 P_grow=0.59 P_prune=0.41 A_grow=0.36 A_prune=0.39


In [11]:
print(btz.lamda)
# print(btz.first_sigma)
print(btz.sigest)

0.0070522516
0.19027376


In [12]:

models = {"bart" : bart, 
          "rf" : rf, 
          "lr" : lr,
          "btz" : btz}
results = {}
for model_name, model in models.items():
    if model_name == "btz":
        results[model_name] = mean_squared_error(y_test, btpred)
    else:
        results[model_name] = mean_squared_error(y_test, model.predict(X_test))
results

{'bart': 0.07488821265630978,
 'rf': 0.007564635053349439,
 'lr': 0.038809233220870286,
 'btz': 0.010248932694085864}

In [13]:
print(bart.sampler.trace[-1].evaluate(X_test))
print(y_test)
for i in range(100):
    mse_i = mean_squared_error(y_test, bart.trace[i].evaluate(X_test))
    # print(mse_i)

[ 0.34329356 -0.38988196 -0.14587971 -0.0361436   0.24744113 -0.37452441
 -0.04885274  0.05947299  0.49775278  0.13365391 -0.21828501  0.15856022
  0.41096897  0.13146075  0.19905843  0.24747273  0.23031417  0.49662627
  0.34780141 -0.04929071  0.50737513 -0.35642943  0.17145617  0.38470598
  0.29165901  0.13285134 -0.41427073  0.00938116 -0.15963852 -0.28591256
 -0.20823378 -0.04912247  0.21525107 -0.30172976 -0.10862269  0.18891709
 -0.17961635 -0.55700762  0.10709073  0.49662627]
[ 0.50749434 -0.48910783  0.00383377 -0.01058536  0.50820528 -0.49687991
 -0.01296472  0.49852471  0.4800694   0.49868862  0.49933047  0.49528224
  0.50750869  0.50554117  0.4930107   0.51155312  0.494661    0.50160191
  0.49882458  0.01820646  0.48748353 -0.0157204   0.5078235   0.49302576
  0.50608761  0.49974137 -0.00265839  0.49874991  0.0042789  -0.00678264
 -0.00342786  0.50439637  0.50202306 -0.01480537  0.51177412  0.50474697
  0.50813764 -0.49631643  0.47996478  0.50530065]


In [16]:
mean_squared_error(y_test, np.ones_like(y_test) * y_test.mean())

0.09757576385098224

In [18]:
if all([(bart.sampler.trace[-1].trees[i].evaluate() == bart.sampler.trace[-1].trees[i].evaluate(X_train)).all()
        for i in range(100)]):
    print("True")
else:
    print("False")

False
