## FLAML for hp optimisation and model selection
We use FLAML twice, first to find the best component model for each estimator, and then to optimise the estimators themselves and choose the best estimator. Here we show how it's done

In [1]:
%load_ext autoreload
%autoreload 2
import os, sys
import warnings
warnings.filterwarnings('ignore') # suppress sklearn deprecation warnings for now.. 

root_path = root_path = os.path.realpath('../..')
sys.path.append(os.path.join(root_path, "auto-causality"))

In [2]:
from auto_causality import AutoCausality
from auto_causality.datasets import synth_ihdp, preprocess_dataset

### Model fitting & scoring
Here we fit a (selection of) model(s) to the data and score them with the ERUPT metric on held-out data

In [3]:
# import dataset
data_df = synth_ihdp()
data_df, features_X, features_W, targets, treatment = preprocess_dataset(data_df)

# choose which estimators to fit
estimator_list = ["LinearDML","metalearners"]

# init autocausality object with chosen metric to optimise
ac = AutoCausality(
    time_budget=1000, estimator_list=estimator_list, metric="erupt", verbose=3,components_verbose=2,components_time_budget=2,use_ray=False
)

# run autocausality
myresults = ac.fit(data_df, treatment, targets[0], features_W, features_X)

# return best estimator
print(f"Best estimator: {ac.best_estimator}")
# config of best estimator:
print(f"best config: {ac.best_config}")
# best score:
print(f"best score: {ac.best_score}")


[flaml.tune.tune: 03-09 13:52:04] {447} INFO - trial 1 config: {'fit_cate_intercept': 0, 'mc_iters': 2}


config: {'fit_cate_intercept': 0, 'mc_iters': 2}


[flaml.tune.tune: 03-09 13:52:17] {108} INFO - result: {'erupt': 6.617034599232005, 'qini': 0.0686353473244199, 'auc': 0.5488984954335461, 'ate': 3.954803577752026, 'r_score': 0.08134845510334765, 'training_iteration': 0, 'config': {'fit_cate_intercept': 0, 'mc_iters': 2}, 'config/fit_cate_intercept': 0, 'config/mc_iters': 2, 'experiment_tag': 'exp', 'time_total_s': 13.417981624603271}
[flaml.tune.tune: 03-09 13:52:17] {447} INFO - trial 2 config: {'fit_cate_intercept': 0, 'mc_iters': 1}


config: {'fit_cate_intercept': 0, 'mc_iters': 1}


[flaml.tune.tune: 03-09 13:52:26] {108} INFO - result: {'erupt': 6.5969876361117565, 'qini': 0.048013551040238314, 'auc': 0.5522268805524556, 'ate': 4.051511723272021, 'r_score': 0.08264122567360821, 'training_iteration': 0, 'config': {'fit_cate_intercept': 0, 'mc_iters': 1}, 'config/fit_cate_intercept': 0, 'config/mc_iters': 1, 'experiment_tag': 'exp', 'time_total_s': 8.787203311920166}
[flaml.tune.tune: 03-09 13:52:26] {447} INFO - trial 3 config: {'fit_cate_intercept': 0, 'mc_iters': 3}


config: {'fit_cate_intercept': 0, 'mc_iters': 3}


[flaml.tune.tune: 03-09 13:52:38] {108} INFO - result: {'erupt': 6.617034599232005, 'qini': 0.0541811502496222, 'auc': 0.5481371718047101, 'ate': 3.897052467081625, 'r_score': 0.07709173345439224, 'training_iteration': 0, 'config': {'fit_cate_intercept': 0, 'mc_iters': 3}, 'config/fit_cate_intercept': 0, 'config/mc_iters': 3, 'experiment_tag': 'exp', 'time_total_s': 12.61718463897705}
[flaml.tune.tune: 03-09 13:52:38] {447} INFO - trial 4 config: {'fit_cate_intercept': 0, 'mc_iters': 7}


config: {'fit_cate_intercept': 0, 'mc_iters': 7}


[flaml.tune.tune: 03-09 13:53:08] {108} INFO - result: {'erupt': 6.617034599232005, 'qini': 0.05910850316775548, 'auc': 0.5506785754544984, 'ate': 3.9722660554472675, 'r_score': 0.08291888156692384, 'training_iteration': 0, 'config': {'fit_cate_intercept': 0, 'mc_iters': 7}, 'config/fit_cate_intercept': 0, 'config/mc_iters': 7, 'experiment_tag': 'exp', 'time_total_s': 29.093570709228516}
[flaml.tune.tune: 03-09 13:53:08] {447} INFO - trial 5 config: {'fit_cate_intercept': 0, 'mc_iters': 6}


config: {'fit_cate_intercept': 0, 'mc_iters': 6}


[flaml.tune.tune: 03-09 13:53:33] {108} INFO - result: {'erupt': 6.608235304460594, 'qini': 0.04882562209804825, 'auc': 0.551706058584556, 'ate': 4.0112334591736944, 'r_score': 0.07960891259324032, 'training_iteration': 0, 'config': {'fit_cate_intercept': 0, 'mc_iters': 6}, 'config/fit_cate_intercept': 0, 'config/mc_iters': 6, 'experiment_tag': 'exp', 'time_total_s': 25.257190465927124}
[flaml.tune.tune: 03-09 13:53:33] {447} INFO - trial 6 config: {'fit_cate_intercept': 0, 'mc_iters': 8}


config: {'fit_cate_intercept': 0, 'mc_iters': 8}


[flaml.tune.tune: 03-09 13:54:06] {108} INFO - result: {'erupt': 6.608235304460594, 'qini': 0.05226812355031053, 'auc': 0.5502799015359904, 'ate': 3.9802352821778126, 'r_score': 0.08108307668330783, 'training_iteration': 0, 'config': {'fit_cate_intercept': 0, 'mc_iters': 8}, 'config/fit_cate_intercept': 0, 'config/mc_iters': 8, 'experiment_tag': 'exp', 'time_total_s': 33.26601481437683}
[flaml.tune.tune: 03-09 13:54:06] {447} INFO - trial 7 config: {'fit_cate_intercept': 1, 'mc_iters': 0}


config: {'fit_cate_intercept': 1, 'mc_iters': 0}


[flaml.tune.tune: 03-09 13:54:11] {108} INFO - result: {'erupt': 6.589599394724014, 'qini': 0.04896708671093154, 'auc': 0.5552458003188403, 'ate': 4.0426156965985784, 'r_score': 0.0762351122434487, 'training_iteration': 0, 'config': {'fit_cate_intercept': 1, 'mc_iters': 0}, 'config/fit_cate_intercept': 1, 'config/mc_iters': 0, 'experiment_tag': 'exp', 'time_total_s': 4.714990615844727}
[flaml.tune.tune: 03-09 13:54:11] {447} INFO - trial 8 config: {'fit_cate_intercept': 1, 'mc_iters': 2}


config: {'fit_cate_intercept': 1, 'mc_iters': 2}


[flaml.tune.tune: 03-09 13:54:19] {108} INFO - result: {'erupt': 6.5904054301125745, 'qini': 0.031631698218000395, 'auc': 0.5541527768768967, 'ate': 4.0245496762138755, 'r_score': 0.08763079964034703, 'training_iteration': 0, 'config': {'fit_cate_intercept': 1, 'mc_iters': 2}, 'config/fit_cate_intercept': 1, 'config/mc_iters': 2, 'experiment_tag': 'exp', 'time_total_s': 8.625952959060669}
[flaml.tune.tune: 03-09 13:54:19] {447} INFO - trial 9 config: {'fit_cate_intercept': 1, 'mc_iters': 4}


config: {'fit_cate_intercept': 1, 'mc_iters': 4}


[flaml.tune.tune: 03-09 13:54:36] {108} INFO - result: {'erupt': 6.598257958889015, 'qini': 0.0541511296545197, 'auc': 0.5541054655064134, 'ate': 4.017107066830389, 'r_score': 0.08815861999429098, 'training_iteration': 0, 'config': {'fit_cate_intercept': 1, 'mc_iters': 4}, 'config/fit_cate_intercept': 1, 'config/mc_iters': 4, 'experiment_tag': 'exp', 'time_total_s': 16.75154733657837}
[flaml.tune.tune: 03-09 13:54:36] {447} INFO - trial 10 config: {'fit_cate_intercept': 1, 'mc_iters': 6}


config: {'fit_cate_intercept': 1, 'mc_iters': 6}


[flaml.tune.tune: 03-09 13:55:02] {108} INFO - result: {'erupt': 6.593243371660474, 'qini': 0.04133045774729062, 'auc': 0.5533178189276199, 'ate': 4.00644935475858, 'r_score': 0.09058375101472749, 'training_iteration': 0, 'config': {'fit_cate_intercept': 1, 'mc_iters': 6}, 'config/fit_cate_intercept': 1, 'config/mc_iters': 6, 'experiment_tag': 'exp', 'time_total_s': 25.504411697387695}
[flaml.tune.tune: 03-09 13:55:02] {447} INFO - trial 1 config: {'fit_cate_intercept': 0, 'mc_iters': 7, 'n_alphas': 283, 'n_alphas_cov': 32, 'tol': 2e-07, 'max_iter': 100}


... Estimator: backdoor.econml.dml.LinearDML
 erupt (train): 6.617035
 qini (train): 0.068635
 auc (train): 0.548898
 ate (train): 3.954804
 r_score (train): 0.081348
config: {'fit_cate_intercept': 0, 'mc_iters': 7, 'n_alphas': 283, 'n_alphas_cov': 32, 'tol': 2e-07, 'max_iter': 100}


[flaml.tune.tune: 03-09 13:55:32] {108} INFO - result: {'erupt': 6.625710662003542, 'qini': 0.025886481936572118, 'auc': 0.5503406463088423, 'ate': 4.969852574324653, 'r_score': -0.004860680819799068, 'training_iteration': 0, 'config': {'fit_cate_intercept': 0, 'mc_iters': 7, 'n_alphas': 283, 'n_alphas_cov': 32, 'tol': 2e-07, 'max_iter': 100}, 'config/fit_cate_intercept': 0, 'config/mc_iters': 7, 'config/n_alphas': 283, 'config/n_alphas_cov': 32, 'config/tol': 2e-07, 'config/max_iter': 100, 'experiment_tag': 'exp', 'time_total_s': 30.273974418640137}
[flaml.tune.tune: 03-09 13:55:32] {447} INFO - trial 2 config: {'fit_cate_intercept': 0, 'mc_iters': 7, 'n_alphas': 576, 'n_alphas_cov': 18, 'tol': 7e-07, 'max_iter': 100}


config: {'fit_cate_intercept': 0, 'mc_iters': 7, 'n_alphas': 576, 'n_alphas_cov': 18, 'tol': 7e-07, 'max_iter': 100}


[flaml.tune.tune: 03-09 13:56:02] {108} INFO - result: {'erupt': 6.625710662003542, 'qini': 0.045834772475797564, 'auc': 0.5488157147140639, 'ate': 5.2261504694273135, 'r_score': -0.06067992057934246, 'training_iteration': 0, 'config': {'fit_cate_intercept': 0, 'mc_iters': 7, 'n_alphas': 576, 'n_alphas_cov': 18, 'tol': 7e-07, 'max_iter': 100}, 'config/fit_cate_intercept': 0, 'config/mc_iters': 7, 'config/n_alphas': 576, 'config/n_alphas_cov': 18, 'config/tol': 7e-07, 'config/max_iter': 100, 'experiment_tag': 'exp', 'time_total_s': 30.328886032104492}
[flaml.tune.tune: 03-09 13:56:02] {447} INFO - trial 3 config: {'fit_cate_intercept': 0, 'mc_iters': 7, 'n_alphas': 139, 'n_alphas_cov': 57, 'tol': 1e-07, 'max_iter': 300}


config: {'fit_cate_intercept': 0, 'mc_iters': 7, 'n_alphas': 139, 'n_alphas_cov': 57, 'tol': 1e-07, 'max_iter': 300}
