In [1]:
import numpy as np
import sys

from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error

from bart_playground.bcf.bcf import BCF

In [2]:

def generate_bcf_data(n=1000, p=10, noise_level=0.5, random_state=42):
    rng = np.random.default_rng(random_state)
    
    X = rng.normal(0, 1, (n, p))
    mu = 2 * np.sin(np.pi * X[:,0]) + X[:,1]**2
    tau = 0.5 * X[:,2] + 0.5 * (X[:,3] > 0)
    z = rng.binomial(1, 0.5, n)
    y = mu + z * tau + rng.normal(0, noise_level, n)
    
    return X, y, z, mu, tau

# Generate data
X, y, z, true_mu, true_tau = generate_bcf_data(n=400)
X_train, X_test, y_train, y_test, z_train, z_test, mu_train, mu_test, tau_train, tau_test = \
    train_test_split(X, y, z, true_mu, true_tau, test_size=0.2, random_state=42)


In [3]:
bcf = BCF(
    n_mu_trees=200,       # Number of prognostic effect trees
    n_tau_trees=50,       # Number of treatment effect trees
    mu_alpha=0.95,        # Tree depth prior for mu
    mu_beta=2.0,          # Tree depth prior for mu
    tau_alpha=0.5,        # Simpler trees for treatment effects
    tau_beta=3.0,         # Penalize complex tau trees
    tau_k=0.5,            # Regularization for treatment effects
    ndpost=100,          # Posterior samples
    nskip=100,            # Burn-in iterations
    random_state=42
)

bcf.fit(X_train, y_train, z_train)

Running iteration 0


AttributeError: 'Tree' object has no attribute 'data'

In [4]:
bcf_result = bcf.predict_components(X_test, z_test)

In [5]:
print(true_mu)

[ 2.71684358  1.34472509 -0.63365769  1.02608723  1.73888124  1.9754376
  1.79107368  2.80794892  2.4197341   1.74576228 -0.16727784  2.01422626
  0.1796448   0.01390835  2.90731807  0.11659677  0.30721345 -1.56917465
 -1.5917818  -1.49401607  3.72622603  2.28159178 -1.23463075  2.33575867
 -0.74566678 -1.71445183 -1.28974473 -1.43459943  5.0396901  -1.39969616
  0.84155053 -0.8650911   1.2997825  -1.59230643 -1.49262385  2.54808842
 -0.77508353  1.83791286  2.37985367 -1.65157205 -1.03087048  7.44263279
  1.27203707  0.47939415 -1.01221795 -0.67920555  2.36234222  1.94114296
  3.34380179 -0.39514    -1.01849692  2.05183846  0.60072226 -0.92865176
  2.34918754  2.18070178 -0.32870426  1.97812309 -0.58331061  1.3483261
  2.33120779  0.27050157 -1.9643076   4.61931388  2.48856719 -0.88979993
  1.64750676  1.32543523  1.47943408  1.7288232  -0.634925    0.40661282
  0.67726992  1.97657709  1.3997243   0.16149443 -1.00902911  0.34309148
 -0.31393897  0.13184707  0.23879523 -0.85217321  0.0

In [None]:
print(bcf_result[0])
# print(bcf_result[1])

[[-0.27231121 -0.15802242 -0.11514596 ... -0.15385393 -0.06573387
  -0.13114827]
 [-0.17577614 -0.15531792 -0.09444619 ... -0.23747505 -0.15955988
  -0.17352617]
 [-0.16653047 -0.23012336 -0.25865487 ... -0.20590028 -0.27744345
  -0.26902721]
 ...
 [-0.27319769 -0.36103053 -0.35124076 ... -0.2103615  -0.1782664
  -0.12743677]
 [-0.02765818 -0.09676122 -0.19219316 ... -0.09315378 -0.07574765
  -0.04070369]
 [-0.07167381 -0.05160261 -0.12410957 ... -0.25437118 -0.15726972
  -0.21938487]]
[[-0.09582354 -0.06562958 -0.05420542 ... -0.09160685 -0.09223696
  -0.09060468]
 [-0.25889359 -0.22669395 -0.24389164 ... -0.26551635 -0.24873014
  -0.27618366]
 [-0.30419328 -0.31638303 -0.32902712 ... -0.35064426 -0.30758179
  -0.34972707]
 ...
 [-0.123807   -0.13599675 -0.13079251 ... -0.21193768 -0.13826744
  -0.13611861]
 [-0.14821247 -0.16040223 -0.15397864 ... -0.13576621 -0.09270374
  -0.08735229]
 [-0.32061547 -0.33280522 -0.32138106 ... -0.27450819 -0.23144572
  -0.22609427]]


In [7]:
print(mean_squared_error(np.mean(bcf_result[2], axis=1), y_test))

4.45698205774403


In [6]:
print(mean_squared_error(np.zeros_like(y_test), y_test))

4.3616648463025856
