### Test BCF mu when tau_trees are set to 0 (equivalent to BART)

In [12]:
import numpy as np

from sklearn.linear_model import LinearRegression
from sklearn.tree import DecisionTreeRegressor
from sklearn.ensemble import RandomForestRegressor
import bartz

from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error

from bart_playground import *
from bart_playground.bcf.bcf import BCF
from bart_playground.params import Tree

In [13]:
proposal_probs = {"grow" : 0.5,
                  "prune" : 0.5}
generator = DataGenerator(n_samples=160, n_features=2, noise=0.1, random_seed=42)
X, y = generator.generate(scenario="piecewise_flat")
X_train, X_test, y_train, y_test, z_train, z_test = train_test_split(X, y, (1 - np.zeros_like(y)).astype(bool), random_state=42)
np.set_printoptions(suppress=True)
print(y_train[:12])
print(X_train[0:5, :])

[ 0.50327821  0.60672224  0.26898966  0.55211673  0.50693811  0.66162097
 -0.64127659  0.65112284  0.03487759  0.23276531  0.44055996  0.38216964]
[[0.78692438 0.66485086]
 [0.3179911  0.50452624]
 [0.55485247 0.37092228]
 [0.12275793 0.83111267]
 [0.92484187 0.09911314]]


In [14]:
# bart = DefaultBART(ndpost=100, nskip=100, n_trees=100, proposal_probs=proposal_probs)
# bart.fit(X_train, y_train)

bcf = BCF(
    n_mu_trees=100,       # Number of prognostic effect trees
    n_tau_trees=50,       # Number of treatment effect trees
    ndpost=100,          # Posterior samples
    nskip=100,            # Burn-in iterations
    random_state=42
)
bcf.fit(X_train, y_train, z_train)

Iterations: 100%|██████████| 200/200 [00:06<00:00, 30.88it/s]


In [15]:
arrays = [tree.vars for tree in bcf.sampler.trace[-1].mu_trees]
counts = np.array([np.count_nonzero(arr >= 0) for arr in arrays])
print(counts)
deep_trees = np.array([count >= 3 for count in counts])
print(np.where(deep_trees))

[1 1 1 1 1 1 1 4 1 2 2 2 1 1 0 1 1 1 1 2 2 1 3 1 1 1 1 1 1 1 1 1 1 1 2 1 1
 2 1 1 2 1 1 1 1 1 1 3 2 1 3 2 1 1 1 1 2 1 1 1 2 1 1 0 1 2 1 1 1 1 1 1 3 1
 1 2 1 2 1 1 2 1 0 1 1 0 2 1 0 1 1 2 1 1 1 1 0 1 1 1]
(array([ 7, 22, 47, 50, 72], dtype=int64),)


In [16]:
print(bcf.sampler.trace[-1].global_params)

{'eps_sigma2': 0.005733414508936352}


In [17]:
from bart_playground import visualize_tree
tree_sp : Tree = bcf.sampler.trace[-1].mu_trees[72]

print(tree_sp)
print(tree_sp.vars)
print(tree_sp.leaf_vals)
# print(tree_sp.node_indicators)
# visualize_tree(tree_sp, tree_sp)

X_1 <= 0.116 (split, n = 120)
	X_0 <= 0.093 (split, n = 15)
		Val: 0.003 (leaf, n = 1)
		Val: 0.008 (leaf, n = 14)
	X_0 <= 0.287 (split, n = 105)
		Val: 0.005 (leaf, n = 34)
		Val: 0.035 (leaf, n = 71)
[ 1  0  0 -1 -1 -1 -1 -2 -2 -2 -2 -2 -2 -2 -2 -2]
[       nan        nan        nan 0.00320229 0.00783413 0.00538357
 0.03476661        nan        nan        nan        nan        nan
        nan        nan        nan        nan]


In [18]:
rf = RandomForestRegressor(random_state=42)
lr = LinearRegression()
rf.fit(X_train, y_train)
lr.fit(X_train, y_train)

btz = bartz.BART.gbart(np.transpose(X_train), y_train, ntree=100, ndpost=200, nskip=100)
btpred_all = btz.predict(np.transpose(X_test))
btpred = np.mean(np.array(btpred_all), axis=0)

Iteration 100/300 P_grow=0.55 P_prune=0.45 A_grow=0.36 A_prune=0.36 (burnin)
Iteration 200/300 P_grow=0.57 P_prune=0.43 A_grow=0.35 A_prune=0.37
Iteration 300/300 P_grow=0.57 P_prune=0.43 A_grow=0.39 A_prune=0.40


In [19]:
models = {"bcf" : bcf, 
          "rf" : rf, 
          "lr" : lr,
          "btz" : btz}
results = {}
for model_name, model in models.items():
    if model_name == "btz":
        results[model_name] = mean_squared_error(y_test, btpred)
    elif model_name == "bcf":
        results[model_name] = mean_squared_error(y_test, bcf.predict_mean(X_test, z_test)[2])
    else:
        results[model_name] = mean_squared_error(y_test, model.predict(X_test))
results

{'bcf': 0.023709398713852352,
 'rf': 0.022139023845392215,
 'lr': 0.048045521328019404,
 'btz': 0.02328283761397566}

In [20]:
print(bcf.sampler.trace[-1].mu_view.evaluate(X_train)[:12])
print(bcf.preprocessor.transform_y(y_train)[:12])

[ 0.36899463  0.33800563  0.27153915  0.31520681  0.34463921  0.38089753
 -0.41651139  0.39616038 -0.11335028 -0.00896246  0.22936939  0.34190874]
[ 0.34923863  0.42552948  0.17644883  0.38525745  0.35193784  0.46601776
 -0.49488025  0.4582753   0.0037892   0.14973307  0.30298339  0.25992003]


In [21]:
mean_squared_error(y_test, np.ones_like(y_test) * y_test.mean())

0.10534048469161521

In [22]:
if all([(bcf.sampler.trace[-1].mu_trees[i].evaluate() == bcf.sampler.trace[-1].mu_trees[i].evaluate(X_train)).all()
            for i in range(100)]):
    print("True")
else:
    print("False")

True
