In [1]:
import numpy as np
import sys
from pyinstrument import Profiler
from sklearn.linear_model import LinearRegression
from sklearn.tree import DecisionTreeRegressor
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
from bart_playground import *
# import bartz

In [2]:
proposal_probs = {"grow" : 0.5,
                  "prune" : 0.5}
generator = DataGenerator(n_samples=160, n_features=2, noise=0.1, random_seed=42)
X, y = generator.generate(scenario="piecewise_flat")
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
np.set_printoptions(suppress=True)
print(y_train[:12])

[ 0.50327821  0.60672224  0.26898966  0.55211673  0.50693811  0.66162097
 -0.64127659  0.65112284  0.03487759  0.23276531  0.44055996  0.38216964]


In [3]:
# profiler = Profiler()
# profiler.start()
bart = DefaultBART(ndpost=200, nskip=100, n_trees=100, proposal_probs=proposal_probs)
bart.fit(X_train, y_train)
# profiler.stop()
# profiler.print()

Running iteration 0
Running iteration 10
Running iteration 20
Running iteration 30
Running iteration 40
Running iteration 50
Running iteration 60
Running iteration 70
Running iteration 80
Running iteration 90
Running iteration 100
Running iteration 110
Running iteration 120
Running iteration 130
Running iteration 140
Running iteration 150
Running iteration 160
Running iteration 170
Running iteration 180
Running iteration 190
Running iteration 200
Running iteration 210
Running iteration 220
Running iteration 230
Running iteration 240
Running iteration 250
Running iteration 260
Running iteration 270
Running iteration 280
Running iteration 290


In [4]:
arrays = [tree.vars for tree in bart.trace[-1].trees]
counts = np.array([np.count_nonzero(arr >= 0) for arr in arrays])
print(counts)
deep_trees = np.array([count >= 3 for count in counts])
print(np.where(deep_trees))

[1 2 1 0 0 1 2 1 1 2 1 2 2 1 1 1 1 2 2 1 0 1 1 1 1 1 1 1 2 1 1 1 0 1 1 1 1
 1 1 1 1 3 1 2 1 1 1 2 1 1 2 2 1 1 1 1 0 1 1 0 1 2 1 1 1 3 1 2 1 1 1 1 3 1
 1 1 1 2 1 1 1 1 1 1 1 2 2 1 2 1 1 3 2 4 0 1 2 1 2 1]
(array([41, 65, 72, 91, 93], dtype=int64),)


In [5]:
print(bart.sampler.trace[-1].global_params)

{'eps_sigma2': 0.007642923877694697}


In [6]:
from bart_playground import visualize_tree
tree_sp : Tree = bart.sampler.trace[1].trees[0]

print(tree_sp)
print(tree_sp.vars)
print(tree_sp.leaf_vals)
# print(tree_sp.node_indicators)
# visualize_tree(tree_sp, tree_sp)

X_0 <= 0.830 (split, n = 120)
	Val: -0.046 (leaf, n = 102)
	Val: -0.032 (leaf, n = 18)
[ 0 -1 -1 -2 -2 -2 -2 -2]
[        nan -0.04601371 -0.0316645          nan         nan         nan
         nan         nan]


In [7]:
y_train_le = y_train[X_train[:, 0] <= 0.830]
y_train_gt = y_train[X_train[:, 0] > 0.830]

print("y_train where X_train[:,0] <= 0.830 mean:", y_train_le.mean())
print("y_train where X_train[:,0] > 0.830 mean:", y_train_gt.mean())

y_train where X_train[:,0] <= 0.830 mean: 0.28478345154810286
y_train where X_train[:,0] > 0.830 mean: 0.5282667945838203


In [8]:
rf = RandomForestRegressor()
lr = LinearRegression()
rf.fit(X_train, y_train)
lr.fit(X_train, y_train)

# btz = bartz.BART.gbart(np.transpose(X_train), y_train, ntree=100, ndpost=200, nskip=100)
# btpred_all = btz.predict(np.transpose(X_test))
# btpred = np.mean(np.array(btpred_all), axis=0)

In [9]:
# print(btz.lamda)
# print(btz._show_tree(1, 0))# , print_all=True))

# btz.first_sigma

In [10]:

models = {"bart" : bart, 
          "rf" : rf, 
          "lr" : lr}#,
          #"btz" : btz}
results = {}
for model_name, model in models.items():
    if model_name == "btz":
        results[model_name] = mean_squared_error(y_test, btpred)
    else:
        results[model_name] = mean_squared_error(y_test, model.predict(X_test))
results

{'bart': 0.02348080777490483,
 'rf': 0.02169203808209099,
 'lr': 0.048045521328019404}

In [11]:
print(bart.sampler.trace[-1].evaluate(X_train)[:12])
print(y_train[:12])

[ 0.34395396  0.36112538  0.1337295   0.20110625  0.31703823  0.41806062
 -0.41340886  0.30941164 -0.11689252  0.01780752  0.19348933  0.33764771]
[ 0.50327821  0.60672224  0.26898966  0.55211673  0.50693811  0.66162097
 -0.64127659  0.65112284  0.03487759  0.23276531  0.44055996  0.38216964]


In [12]:
mean_squared_error(y_test, np.ones_like(y_test) * y_test.mean())

0.10534048469161521

In [13]:
if all([(bart.sampler.trace[-1].trees[i].evaluate() == bart.sampler.trace[-1].trees[i].evaluate(X_train)).all()
            for i in range(100)]):
    print("True")
else:
    print("False")

True
