In [1]:
import numpy as np
import sys

from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error

from bart_playground.bcf.bcf import BCF

In [2]:

def generate_bcf_data(n=1000, p=10, noise_level=0.5, random_state=42):
    rng = np.random.default_rng(random_state)
    
    X = rng.normal(0, 1, (n, p))
    mu = 2 * np.sin(np.pi * X[:,0]) + X[:,1]**2
    tau = 0.5 * X[:,2] + 0.5 * (X[:,3] > 0)
    z = rng.binomial(1, 0.5, n)
    y = mu + z * tau + rng.normal(0, noise_level, n)
    
    return X, y, z, mu, tau

# Generate data
X, y, z, true_mu, true_tau = generate_bcf_data(n=400)
X_train, X_test, y_train, y_test, z_train, z_test, mu_train, mu_test, tau_train, tau_test = \
    train_test_split(X, y, z, true_mu, true_tau, test_size=0.2, random_state=42)


In [3]:
bcf = BCF(
    n_mu_trees=100,       # Number of prognostic effect trees
    n_tau_trees=50,       # Number of treatment effect trees
    mu_alpha=0.95,        # Tree depth prior for mu
    mu_beta=2.0,          # Tree depth prior for mu
    tau_alpha=0.5,        # Simpler trees for treatment effects
    tau_beta=3.0,         # Penalize complex tau trees
    tau_k=0.5,            # Regularization for treatment effects
    ndpost=100,          # Posterior samples
    nskip=100,            # Burn-in iterations
    random_state=42
)

bcf.fit(X_train, y_train, z_train)

Running iteration 0
Running iteration 10
Running iteration 20
Running iteration 30
Running iteration 40
Running iteration 50
Running iteration 60
Running iteration 70
Running iteration 80
Running iteration 90
Running iteration 100
Running iteration 110
Running iteration 120
Running iteration 130
Running iteration 140
Running iteration 150
Running iteration 160
Running iteration 170
Running iteration 180
Running iteration 190


In [4]:
bcf_result = bcf.predict_mean(X_test, z_test)

In [5]:
print(mu_test[0:10])
print(tau_test[0:10])

[ 1.54669433 -0.51575534 -1.59230643  4.38294783  0.23879614 -0.58645334
 -1.18404848  2.01948447  0.54231799  0.01365932]
[ 0.56593691  0.43713884 -0.34946521  0.14309025  1.08265377  0.87684551
  0.25262578  0.74538335 -0.17369142  0.397889  ]


In [6]:
print(bcf_result[0][0:10])
print(bcf_result[1][0:10])

[-49.25558521 -47.04657084 -49.25558521 -38.015537   -48.20189178
 -48.73100482 -49.25558521 -47.05882873 -48.88826311 -49.25558521]
[-27.30272924 -25.16638152 -27.82990031 -20.10066101 -26.76907796
 -27.41681809 -28.07060733 -25.62590694 -29.14730617 -28.73297877]


In [7]:
print(mean_squared_error(np.mean(bcf_result[2], axis=1), y_test))

15.123774624866945


In [18]:
bcf_result[1]

array([-3.18901233, -3.14922538, -3.10004103, -3.03921503, -3.18166351,
       -3.21154326, -3.18183622, -3.10676569, -3.22645011, -3.18908948,
       -3.08522202, -3.12971113, -3.20543036, -3.05988483, -3.33791747,
       -3.06623325, -3.17222586, -3.15414745, -3.18698642, -3.10589177,
       -3.10208394, -3.13981763, -3.02303773, -3.18476844, -3.12856429,
       -3.03770024, -3.0759509 , -3.18367701, -3.19679083, -3.09898327,
       -3.12225554, -3.13591076, -3.12395639, -3.19621599, -3.05529295,
       -3.13747383, -3.15343735, -3.20090348, -3.06413098, -2.8969067 ,
       -3.16508478, -3.11918572, -3.13392106, -3.22502351, -3.06903249,
       -3.17258849, -3.13575054, -3.17147294, -3.18418621, -3.22253932,
       -3.01108473, -3.2211039 , -3.19965803, -3.10628653, -3.06204315,
       -3.20020574, -2.99440505, -2.97488509, -3.14646601, -3.19319967,
       -3.13994192, -3.12947754, -3.13734884, -3.22547481, -3.07884877,
       -3.18377584, -3.26136836, -3.16224746, -3.12001615, -3.24

In [8]:
print(mean_squared_error(np.zeros_like(y_test), y_test))

4.3616648463025856
