In [1]:
import numpy as np
import sys

from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error

from bart_playground.bcf.bcf import BCF

In [2]:

def generate_bcf_data(n=1000, p=4, noise_level=0.5, random_state=42):
    rng = np.random.default_rng(random_state)
    
    X = rng.normal(0, 1, (n, p))
    mu = 2 * np.sin(np.pi * X[:,0]) + X[:,1]**2
    tau = 0.5 * X[:,2] + 0.5 * (X[:,3] > 0)
    z = rng.binomial(1, 0.5, n)
    y = mu + z * tau + rng.normal(0, noise_level, n)
    
    return X, y, z, mu, tau

# Generate data
X, y, z, true_mu, true_tau = generate_bcf_data(n=400)
X_train, X_test, y_train, y_test, z_train, z_test, mu_train, mu_test, tau_train, tau_test = \
    train_test_split(X, y, z, true_mu, true_tau, test_size=0.2, random_state=42)


In [3]:
print(X_train[0:10, :])

[[ 0.0660307   1.12724121  0.46750934 -0.85929246]
 [-0.91945229  0.49716074  0.14242574  0.69048535]
 [-0.3191066   0.21723119 -0.20208325 -0.57789411]
 [-0.05928265 -0.72928694 -0.41447307  0.63391038]
 [ 1.5791855   0.49455666  0.97366351  1.24196   ]
 [ 1.72869764 -0.98685708 -0.24527785  0.77733758]
 [ 1.79437005  1.31480787 -0.10973418  0.35272016]
 [-0.65056362 -0.67421246 -0.71233706 -0.87950963]
 [-0.13108665 -1.83090583  0.92829699 -0.60500071]
 [-0.15252253  0.38339386  0.99982425 -1.05853608]]


In [4]:
bcf = BCF(
    n_mu_trees=100,       # Number of prognostic effect trees
    n_tau_trees=50,       # Number of treatment effect trees
    mu_alpha=0.95,        # Tree depth prior for mu
    mu_beta=2.0,          # Tree depth prior for mu
    tau_alpha=0.5,        # Simpler trees for treatment effects
    tau_beta=3.0,         # Penalize complex tau trees
    tau_k=0.5,            # Regularization for treatment effects
    ndpost=100,          # Posterior samples
    nskip=100,            # Burn-in iterations
    random_state=42
)

bcf.fit(X_train, y_train, z_train)

Iterations: 100%|██████████| 200/200 [00:07<00:00, 28.52it/s]


In [5]:
bcf_result = bcf.predict_mean(X_test, z_test)

In [6]:
print(mu_test[0:10])
print(tau_test[0:10])

[ 0.32830421  1.97569722  1.87179898 -0.58645334  1.57840294 -1.62047604
 -1.04906521  2.81244886 -1.79417073  0.07297647]
[ 0.91869334 -0.37490667 -0.73540315  0.87684551  0.77705825  1.07765602
 -0.24252774  0.80212718 -0.03598128  1.71020751]


In [7]:
print(bcf_result[0][0:10])
print(bcf_result[1][0:10])

[-47.93188535 -51.12461387 -51.12461387 -50.58730143 -51.12461387
 -51.12461387 -51.12461387 -41.80783354 -51.12461387 -48.80099429]
[-38.4908773  -38.4908773  -38.4908773  -38.4908773  -38.03732651
 -38.4908773  -38.4908773  -34.42720899 -38.4908773  -38.4908773 ]


In [9]:
print(mean_squared_error(bcf_result[2], y_test))

5314.907124786562


In [10]:
bcf_result[1]

array([-38.4908773 , -38.4908773 , -38.4908773 , -38.4908773 ,
       -38.03732651, -38.4908773 , -38.4908773 , -34.42720899,
       -38.4908773 , -38.4908773 , -38.4908773 , -38.03362695,
       -38.03362695, -38.4908773 , -36.65214323, -38.4908773 ,
       -38.4908773 , -38.4908773 , -34.25908121, -38.4908773 ,
       -34.25653145, -38.4908773 , -38.4908773 , -38.4908773 ,
       -38.4908773 , -37.52426602, -38.4908773 , -38.03732651,
       -38.4908773 , -38.4908773 , -38.4908773 , -38.4908773 ,
       -38.4908773 , -38.4908773 , -38.4908773 , -38.03732651,
       -34.2594069 , -38.4908773 , -34.42720899, -38.4908773 ,
       -38.4908773 , -38.4908773 , -38.11292352, -38.4908773 ,
       -38.4908773 , -38.4908773 , -37.52328615, -37.37287038,
       -38.4908773 , -34.43141571, -36.65214323, -38.4908773 ,
       -38.4908773 , -36.67879573, -38.11292352, -38.4908773 ,
       -38.4908773 , -38.4908773 , -38.4908773 , -38.4908773 ,
       -38.4908773 , -38.4908773 , -38.4908773 , -38.49

In [11]:
print(mean_squared_error(np.zeros_like(y_test), y_test))

5.142700247470763
