# Simple usage of different acquisition functions

In [None]:
from asbe.base import *
from asbe.models import *
from asbe.estimators import *
from econml.orf import DMLOrthoForest
from econml.dml import CausalForestDML
#from causalml.inference.nn import CEVAE
#from openbt.openbt import OPENBT
import pandas as pd
import numpy as np
from sklearn.ensemble import RandomForestRegressor
from sklearn.linear_model import LinearRegression, LogisticRegression
from sklearn.model_selection import train_test_split
from copy import deepcopy
import econml
from causalml.dataset import synthetic_data
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import ConstantKernel, RBF

In [None]:
np.random.seed(1005)
y, X, t, ite, _, e = synthetic_data(mode=1, n=1000, p=5, sigma=1.0)

In [None]:
y1 = np.where(t == 1, y, y + ite)
y0 = np.where(t == 0, y, y - ite)
X_train, X_test, t_train, t_test, y_train, y_test, ite_train, ite_test, y1_train, y1_test, y0_train, y0_test = train_test_split(
    X, t, y, ite, y1, y0,  test_size=0.9, random_state=1005)
ds = {"X_training": X_train,
     "y_training": y_train,
     "t_training": t_train,
     "X_pool": deepcopy(X_test), 
     "y_pool": deepcopy(y_test),
     "t_pool": deepcopy(t_test),
     "y1_pool": y1_test,
     "y0_pool":y0_test,
     "X_test": X_test,
     "y_test": y_test,
      "t_test": t_test,
      "ite_test": ite_test
     }

In [None]:
def test_acq(estimator, acq):
    asl = BaseActiveLearner(estimator = BaseITEEstimator(model = estimator,
                                         two_model=False,ps_model=None),
                        acquisition_function=acq,
                        assignment_function=MajorityAssignmentFunction(),
                        stopping_function = None,
                        dataset=ds)
    asl.fit()
    X_new, query_idx = asl.query(no_query=10)
    print(query_idx)
    asl.teach(query_idx)
    preds = asl.predict(asl.dataset["X_test"])
    print(asl.score())
    return True

# def test_acq_ob(acq):
#     asl = BaseActiveLearner(estimator = OPENBTITEEstimator(model = OPENBT(
#         model="bart",ntrees=200),
#                                          two_model=False,ps_model=None),
#                         acquisition_function=acq,
#                         assignment_function=MajorityAssignmentFunction(),
#                         stopping_function = None,
#                         dataset=ds)
#     asl.fit()
#     X_new, query_idx = asl.query(no_query=10)
#     print(query_idx)
#     asl.teach(query_idx)
#     preds = asl.predict(asl.dataset["X_test"])
#     print(asl.score())
#     return True

def test_acq_gp(acq):
    asl = BaseActiveLearner(estimator = GPEstimator(model = GaussianProcessRegressor(ConstantKernel()*RBF(np.ones(ds["X_training"].shape[1],))),
                                         two_model=True,
                                                    ps_model=None),
                        acquisition_function=acq,
                        assignment_function=MajorityAssignmentFunction(),
                        stopping_function = None,
                        dataset=ds)
    asl.fit()
    X_new, query_idx = asl.query(no_query=20)
    print(query_idx)
    print(asl.score())
    asl.teach(query_idx)
    asl.fit()
    #preds = asl.predict(asl.dataset["X_test"])
    print(asl.score())
    return True

In [None]:
test_acq(RandomForestRegressor(), RandomAcquisitionFunction())

[702 193 166 741  59 250  98  55 733 645]
0.8529185036438194


True

In [None]:
# test_acq(RandomForestRegressor(), UncertaintyAcquisitionFunction())

In [None]:
#test_acq_ob(UncertaintyAcquisitionFunction())

Overwriting k to agree with the model's default
Overwriting overallnu to agree with the model's default
Overwriting ntree to agree with the model's default
Overwriting ntreeh to agree with the model's default
Overwriting overallsd to agree with the model's default
Writing config file and data
/var/folders/44/gtm_t6x110jg6b13p4rbwkfh0000gn/T/openbtpy_vhh3z6f8
3+ x variables
Running model...
[ 24 597 834 510 860 710 671 432 795 791]
0.5621710540328109


True

In [None]:
#test_acq_ob(TypeSAcquistionFunction())

Overwriting k to agree with the model's default
Overwriting overallnu to agree with the model's default
Overwriting ntree to agree with the model's default
Overwriting ntreeh to agree with the model's default
Overwriting overallsd to agree with the model's default
Writing config file and data
/var/folders/44/gtm_t6x110jg6b13p4rbwkfh0000gn/T/openbtpy_abu0pjzg
3+ x variables
Running model...
[ 62 301 817 861 450 741 353 878 157  13]
0.5798006422653739


True

In [None]:
#test_acq_ob(EMCMAcquisitionFunction(no_query=10, B=10))

Overwriting k to agree with the model's default
Overwriting overallnu to agree with the model's default
Overwriting ntree to agree with the model's default
Overwriting ntreeh to agree with the model's default
Overwriting overallsd to agree with the model's default
Writing config file and data
/var/folders/44/gtm_t6x110jg6b13p4rbwkfh0000gn/T/openbtpy_jhj8mkn7
3+ x variables
Running model...
[782 834 100 821 145 545  62 164 168 671]
0.5720390720113978


True

In [None]:
#test_acq_gp(RandomAcquisitionFunction())

[304 297 703 649 724  81  45 249 807 432 258 458 556 363 846 736 284 121
 535 756]
(120, 5)
[[ 1.17921141e+00 -9.77120587e-01  5.46791330e-01 ... -4.67885567e-01
   2.10854854e-01  2.45374743e-02]
 [ 2.67492032e-01 -4.56857175e-01 -1.86445088e-01 ... -4.48442913e-01
   8.39346601e-01 -1.31788984e+00]
 [ 6.54254394e-01  2.45691765e-01  9.56956316e-01 ...  1.91443599e+00
   4.98526544e-01 -6.54651197e-01]
 ...
 [-1.03050086e+00  1.22630506e-01  4.93734689e-01 ... -3.37496729e-01
  -7.66367306e-01  4.00539635e-01]
 [ 1.73428815e+00 -3.14675994e-01  8.71824414e-01 ...  5.47451731e-01
  -4.18753576e-04  1.21294057e+00]
 [-2.70082400e-01  1.82355482e-01  7.00631387e-02 ...  1.10864941e-01
   1.29988691e+00 -1.99717705e-01]]
0.5337187351959323


True

In [None]:
test_acq_gp(UncertaintyAcquisitionFunction())

[687 421 116 185 205 482 486 440 635 243  76 476 869 612 506 108 530 611
 268 368]
0.5337187351959323
0.5751091655772831


True

In [None]:
test_acq_gp(TypeSAcquistionFunction())

[300 756 111  71 769 390 612 371 122 360 135 352 835 337 624 141 636 316
 529 142]
0.5337187351959323
0.6183597398651168


True

In [None]:
test_acq_gp(EMCMAcquisitionFunction())

[380 899 295 306 305 304 303 302 301 300 299 298 297 296 294 281 293 292
 291 290]
0.5337187351959323
0.5672603979810199


True