In [1]:
import numpy as np
import sys

from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error

from bart_playground.bcf.bcf import BCF
from bart_playground.params import Tree
from bart_playground import *

In [2]:
proposal_probs = {"grow" : 0.5,
                  "prune" : 0.5}
n_samples = 640
generator = DataGenerator(n_samples=n_samples, n_features=2, noise=0.1, random_seed=42)
X, y = generator.generate(scenario="piecewise_flat")
z_rng = np.random.default_rng(0)
z1 = z_rng.binomial(1, 0.5, n_samples).astype(bool)
z2 = ((1 - z1) * z_rng.binomial(1, 0.5, n_samples)).astype(bool)
z = np.column_stack((z1, z2))
y = y + z[:, 0] * 0.5 - z[:, 1] * 0.5
# z = z1.reshape(-1, 1)
# y = y + z[:, 0] * 0.5 - 0.5

X_train, X_test, y_train, y_test, z_train, z_test = train_test_split(X, y, z, random_state=42)
np.set_printoptions(suppress=True)
print(y_train[:10])

[ 0.56517481 -0.04621271 -0.27779943  0.53019707  0.98857901  0.44933973
  0.77858126  0.43837069 -0.01191688  0.875094  ]


In [3]:
print(X_train[0:5, :])

[[0.73489316 0.20240459]
 [0.74882078 0.80138943]
 [0.58106114 0.3468698 ]
 [0.97069802 0.89312112]
 [0.23855282 0.84940884]]


In [4]:
bcf = BCF(
    n_treat_arms=z.shape[1],  # Number of treatment arms
    n_mu_trees=100,       # Number of prognostic effect trees
    n_tau_trees=[50, 50],       # Number of treatment effect trees
    ndpost=100,          # Posterior samples
    nskip=100,            # Burn-in iterations
    random_state=42
)

bcf.fit(X_train, y_train, z_train)

Iterations:   0%|          | 0/200 [00:00<?, ?it/s]

Iterations: 100%|██████████| 200/200 [00:58<00:00,  3.40it/s]


In [15]:
tree_sp : Tree = bcf.sampler.trace[-1].mu_trees[70]

print(tree_sp)
print(tree_sp.vars)
print(tree_sp.leaf_vals)
# print(bcf.sampler.trace[-1].evaluate(z = np.zeros_like(y_train, dtype = bool), X = X_train)[0:10])

np.testing.assert_allclose(bcf.sampler.trace[-1].evaluate(z_train), bcf.sampler.trace[-1].evaluate(z_train, X_train))

X_1 <= 0.802 (split, n = 480)
	Val: 0.040 (leaf, n = 374)
	Val: -0.011 (leaf, n = 106)
[ 1 -1 -1 -2 -2 -2 -2 -2 -2 -2 -2 -2 -2 -2 -2 -2]
[        nan  0.0403628  -0.01113588         nan         nan         nan
         nan         nan         nan         nan         nan         nan
         nan         nan         nan         nan]


In [16]:
bcf_result = bcf.predict_components(X_test, z_test)

In [17]:
print(bcf_result[0][0:10])
print(bcf_result[1][0:10])

[ 0.18081937  0.16190407 -0.04493251  0.15664016  0.07212702  0.15650407
  0.16830937  0.13487491  0.15732622  0.03827908]
[[ 0.21249684 -0.14761714]
 [ 0.20908512 -0.18370751]
 [ 0.22524417 -0.18702716]
 [ 0.22611821 -0.18696958]
 [ 0.22692862 -0.1954979 ]
 [ 0.20886132 -0.18517991]
 [ 0.21264844 -0.17227509]
 [ 0.22662866 -0.18646481]
 [ 0.21099638 -0.1674805 ]
 [ 0.23044729 -0.1885336 ]]


In [8]:
print(bcf_result[1][0:10].shape)

(10, 2)


#### MLearner class

In [9]:
import re
from sklearn import clone

def control_indices(z):
    result = np.zeros(z.shape[0], dtype=bool)
    for arm in range(z.shape[1]):
        result = result | z[:, arm]
    return ~result

class MLearner:
    def __init__(self, n_treated_arms, model_treated, model_control):
        self.model_treated_list = [None] * n_treated_arms
        for i in range(n_treated_arms):
            self.model_treated_list[i] = clone(model_treated)
        self.model_control = clone(model_control)

    def fit(self, X, y, z):
        X_treated_list = [X[z[:, arm]] for arm in range(z.shape[1])]
        y_treated_list = [y[z[:, arm]] for arm in range(z.shape[1])]
        X_control = X[control_indices(z)]
        y_control = y[control_indices(z)]

        # Fit the models.
        for i in range(len(self.model_treated_list)):
            self.model_treated_list[i].fit(X_treated_list[i], y_treated_list[i])
        self.model_control.fit(X_control, y_control)
        # return self

    def predict(self, X, z):
        # Predict outcome: if z is True, use model_treated; else, use model_control.
        preds = np.empty(len(X))
        preds[control_indices(z)] = self.model_control.predict(X[control_indices(z)])
        for arm in range(z.shape[1]):
            preds[z[:, arm]] = self.model_treated_list[arm].predict(X[z[:, arm]])
        return preds

In [10]:
from sklearn.linear_model import LinearRegression
from sklearn.tree import DecisionTreeRegressor
from sklearn.ensemble import RandomForestRegressor

rf = RandomForestRegressor(random_state=42)
dt = DecisionTreeRegressor(random_state=42)
lr = LinearRegression()

#### Comparison

In [11]:
nta = z.shape[1]  # Number of treatment arms
tlearner_rf = MLearner(n_treated_arms=nta, model_treated=rf, model_control=rf)
tlearner_rf.fit(X_train, y_train, z_train)
tlearner_lr = MLearner(n_treated_arms=nta, model_treated=lr, model_control=lr)
tlearner_lr.fit(X_train, y_train, z_train)
tlearner_dt = MLearner(n_treated_arms=nta, model_treated=dt, model_control=dt)
tlearner_dt.fit(X_train, y_train, z_train)

In [12]:
models = {"bcf" : bcf, 
          "rf" : tlearner_rf, 
          "lr" : tlearner_lr,
          "dt" : tlearner_dt}
results = {}
for model_name, model in models.items():
    results[model_name] = mean_squared_error(y_test, model.predict(X_test, z_test))
results

{'bcf': 0.023985779987415565,
 'rf': 0.021441280609918477,
 'lr': 0.05466893346335514,
 'dt': 0.03223509459291045}

In [13]:
print(mean_squared_error(bcf_result[2], y_test))

0.023985779987415565


In [14]:
print(bcf_result[1][0:10])
resultq = bcf_result[0][0:10] + np.sum(z_test[0:10] * bcf_result[1][0:10], axis=1)
print(bcf.preprocessor.transform_y(resultq))

print(bcf_result[2][0:10])
print(y_test[0:10])
# print(bcf.preprocessor.backtransform_y(y_test[0:10]))

# print(bcf_result[0][0] + z_test[0, 0] * bcf_result[1][0, 0] + z_test[0, 1] * bcf_result[1][0, 1])
# print(bcf.preprocessor.transform_y(bcf_result[0][0] + z_test[0, 0] * bcf_result[1][0, 0] + z_test[0, 1] * bcf_result[1][0, 1]))

[[ 0.21249684 -0.14761714]
 [ 0.20908512 -0.18370751]
 [ 0.22524417 -0.18702716]
 [ 0.22611821 -0.18696958]
 [ 0.22692862 -0.1954979 ]
 [ 0.20886132 -0.18517991]
 [ 0.21264844 -0.17227509]
 [ 0.22662866 -0.18646481]
 [ 0.21099638 -0.1674805 ]
 [ 0.23044729 -0.1885336 ]]
[-0.00284772 -0.0248307  -0.1088195   0.04648417  0.10340043  0.12990108
  0.13613258 -0.03673484  0.04675836  0.09127936]
[ 0.12340614 -0.01422866 -0.54008018  0.43227173  0.78862283  0.95454264
  0.99355787 -0.08876015  0.4339884   0.71273312]
[ 0.02738038 -0.00431336 -0.61648281  0.40993078  0.72292736  1.0178741
  1.11825307  0.08045703  0.53730523  0.63120253]
