In [1]:
import numpy as np
import sys

from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error

from bart_playground.bcf.bcf import BCF
from bart_playground.params import Tree
from bart_playground import *

In [2]:
proposal_probs = {"grow" : 0.5,
                  "prune" : 0.5}
n_samples = 640
generator = DataGenerator(n_samples=n_samples, n_features=2, noise=0.1, random_seed=42)
X, y = generator.generate(scenario="piecewise_flat")
z_rng = np.random.default_rng(0)
z1 = z_rng.binomial(1, 0.5, n_samples).astype(bool)
z2 = ((1 - z1) * z_rng.binomial(1, 0.5, n_samples)).astype(bool)
z = np.column_stack((z1, z2))
y = y + z[:, 0] * 0.5 - z[:, 1] * 0.5
# z = z1.reshape(-1, 1)
# y = y + z[:, 0] * 0.5 - 0.5

X_train, X_test, y_train, y_test, z_train, z_test = train_test_split(X, y, z, random_state=42)
np.set_printoptions(suppress=True)
print(y_train[:10])

[ 0.56517481 -0.04621271 -0.27779943  0.53019707  0.98857901  0.44933973
  0.77858126  0.43837069 -0.01191688  0.875094  ]


In [3]:
print(X_train[0:5, :])

[[0.73489316 0.20240459]
 [0.74882078 0.80138943]
 [0.58106114 0.3468698 ]
 [0.97069802 0.89312112]
 [0.23855282 0.84940884]]


In [4]:
bcf = BCF(
    n_treat_arms=z.shape[1],  # Number of treatment arms
    n_mu_trees=100,       # Number of prognostic effect trees
    n_tau_trees=[50, 50],       # Number of treatment effect trees
    ndpost=100,          # Posterior samples
    nskip=100,            # Burn-in iterations
    random_state=42
)


In [5]:
%prun -s cumtime -D profile.prof -q bcf.fit(X_train, y_train, z_train)

Iterations:   0%|          | 0/200 [00:00<?, ?it/s]

Iterations: 100%|██████████| 200/200 [00:48<00:00,  4.15it/s]

 
*** Profile stats marshalled to file 'profile.prof'.





In [6]:
!gprof2dot -f pstats profile.prof -o profile.dot
!dot -Tpng profile.dot -o profile.png

'gprof2dot' �����ڲ����ⲿ���Ҳ���ǿ����еĳ���
���������ļ���
Error: dot: can't open profile.dot: No such file or directory


In [7]:
tree_sp : Tree = bcf.sampler.trace[-1].mu_trees[70]

print(tree_sp)
print(tree_sp.vars)
print(tree_sp.leaf_vals)
# print(bcf.sampler.trace[-1].evaluate(z = np.zeros_like(y_train, dtype = bool), X = X_train)[0:10])

np.testing.assert_allclose(bcf.sampler.trace[-1].evaluate(z_train), bcf.sampler.trace[-1].evaluate(z_train, X_train))

X_0 <= 0.634 (split, n = 480)
	Val: 0.025 (leaf, n = 302)
	X_0 <= 0.645 (split, n = 178)
		X_1 <= 0.513 (split, n = 5)
			Val: -0.015 (leaf, n = 4)
			Val: -0.001 (leaf, n = 1)
		Val: 0.030 (leaf, n = 173)
[ 0 -1  0 -2 -2  1 -1 -2 -2 -2 -2 -1 -1 -2 -2 -2]
[        nan  0.02528837         nan         nan         nan         nan
  0.03047521         nan         nan         nan         nan -0.01490972
 -0.00079991         nan         nan         nan]


In [8]:
bcf_result = bcf.predict_components(X_test, z_test)

In [9]:
print(bcf_result[0][0:10])
print(bcf_result[1][0:10])

[ 0.15709025  0.15305458 -0.01036046  0.16424118  0.07738599  0.23181437
  0.17920782  0.13620924  0.15456179  0.00673656]
[[ 0.20657058 -0.17048685]
 [ 0.20462666 -0.18841114]
 [ 0.23645313 -0.19565541]
 [ 0.19918737 -0.18590392]
 [ 0.25908708 -0.16420505]
 [ 0.20548417 -0.19205775]
 [ 0.2091397  -0.18844518]
 [ 0.25779398 -0.16288561]
 [ 0.20489799 -0.17245193]
 [ 0.26108411 -0.170051  ]]


In [10]:
print(bcf_result[1][0:10].shape)

(10, 2)


#### MLearner class

In [11]:
from sklearn import clone

def control_indices(z):
    result = np.zeros(z.shape[0], dtype=bool)
    for arm in range(z.shape[1]):
        result = result | z[:, arm]
    return ~result

class MLearner:
    def __init__(self, n_treated_arms, model_treated, model_control):
        self.model_treated_list = [None] * n_treated_arms
        for i in range(n_treated_arms):
            self.model_treated_list[i] = clone(model_treated)
        self.model_control = clone(model_control)

    def fit(self, X, y, z):
        X_treated_list = [X[z[:, arm]] for arm in range(z.shape[1])]
        y_treated_list = [y[z[:, arm]] for arm in range(z.shape[1])]
        X_control = X[control_indices(z)]
        y_control = y[control_indices(z)]

        # Fit the models.
        for i in range(len(self.model_treated_list)):
            self.model_treated_list[i].fit(X_treated_list[i], y_treated_list[i])
        self.model_control.fit(X_control, y_control)
        # return self

    def predict(self, X, z):
        # Predict outcome: if z is True, use model_treated; else, use model_control.
        preds = np.empty(len(X))
        preds[control_indices(z)] = self.model_control.predict(X[control_indices(z)])
        for arm in range(z.shape[1]):
            preds[z[:, arm]] = self.model_treated_list[arm].predict(X[z[:, arm]])
        return preds

In [12]:
from sklearn.linear_model import LinearRegression
from sklearn.tree import DecisionTreeRegressor
from sklearn.ensemble import RandomForestRegressor

rf = RandomForestRegressor(random_state=42)
dt = DecisionTreeRegressor(random_state=42)
lr = LinearRegression()

#### Comparison

In [13]:
nta = z.shape[1]  # Number of treatment arms
tlearner_rf = MLearner(n_treated_arms=nta, model_treated=rf, model_control=rf)
tlearner_rf.fit(X_train, y_train, z_train)
tlearner_lr = MLearner(n_treated_arms=nta, model_treated=lr, model_control=lr)
tlearner_lr.fit(X_train, y_train, z_train)
tlearner_dt = MLearner(n_treated_arms=nta, model_treated=dt, model_control=dt)
tlearner_dt.fit(X_train, y_train, z_train)

In [14]:
models = {"bcf" : bcf, 
          "rf" : tlearner_rf, 
          "lr" : tlearner_lr,
          "dt" : tlearner_dt}
results = {}
for model_name, model in models.items():
    results[model_name] = mean_squared_error(y_test, model.predict(X_test, z_test))
results

{'bcf': 0.02411122732234205,
 'rf': 0.021441280609918477,
 'lr': 0.05466893346335514,
 'dt': 0.03223509459291045}

In [15]:
print(mean_squared_error(bcf_result[2], y_test))

0.02411122732234205


In [16]:
print(bcf_result[1][0:10])
resultq = bcf_result[0][0:10] + np.sum(z_test[0:10] * bcf_result[1][0:10], axis=1)
print(bcf.preprocessor.transform_y(resultq))

print(bcf_result[2][0:10])
print(y_test[0:10])

[[ 0.20657058 -0.17048685]
 [ 0.20462666 -0.18841114]
 [ 0.23645313 -0.19565541]
 [ 0.19918737 -0.18590392]
 [ 0.25908708 -0.16420505]
 [ 0.20548417 -0.19205775]
 [ 0.2091397  -0.18844518]
 [ 0.25779398 -0.16288561]
 [ 0.20489799 -0.17245193]
 [ 0.26108411 -0.170051  ]]
[-0.02147091 -0.03024719 -0.09845107  0.04952192  0.11835428  0.15864912
  0.13908587 -0.02677816  0.04565355  0.0909174 ]
[ 0.00680687 -0.0481412  -0.47516378  0.45129097  0.88224845  1.1345333
  1.01204837 -0.02642168  0.42707125  0.71046688]
[ 0.02738038 -0.00431336 -0.61648281  0.40993078  0.72292736  1.0178741
  1.11825307  0.08045703  0.53730523  0.63120253]
