In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
from hidmed import *
from tqdm import tqdm
import pickle
import matplotlib.pyplot as plt
import matplotlib
from matplotlib.lines import Line2D


import warnings
from sklearn.exceptions import ConvergenceWarning
warnings.filterwarnings("ignore", category=ConvergenceWarning)

## Evaluate estimators

In [8]:
seed = 2

# setups = ["a", "b", "c"]
setups = ["a",]
dims = [5,]

sample_sizes = (np.array([750, 1500, 3000])).astype(int)

filename = f"assets/results_test.pkl"

estimators = [
#     ProximalInverseProbWeighting,
#     ProximalOutcomeRegression,
    ProximalMultiplyRobust,
]
num_evals = 10
folds = 4

try:
    results = pickle.load(open(filename, "rb"))
except:
    results = {}

results = {}

results.keys()

dict_keys([])

In [9]:
import warnings
from sklearn.exceptions import ConvergenceWarning
warnings.filterwarnings("ignore", category=ConvergenceWarning)

for dim in dims:
    if dim == 5:
        xdim, zdim, wdim, mdim, udim = 5, 2, 2, 2, 1
    else:
        xdim, zdim, wdim, mdim, udim = 1, 1, 1, 1, 1

    for setup in setups:
        for n in sample_sizes:
            # set up model
            datagen = LinearHidMedDGP(
                xdim, zdim, wdim, mdim, udim, setup=setup, seed=seed
            )
            true_psi = datagen.true_psi()

            # hyperparameter tuning
            tuner = ProximalMultiplyRobust("c", folds=folds, num_runs=100)
            dataset = datagen.sample_dataset(n, seed=seed+1+np.random.choice(num_evals))
            tuner.fit(dataset, seed=np.random.randint(2**32))
            tuned_param_dict = tuner.param_dict

            # estimation
            for estimator in estimators:
                print(f">>> {n} samples, {folds}-fold, setup {setup}, dim {dim}, est: {estimator.__name__}")
                if (dim, setup, n, estimator.__name__) in results.keys():
                    continue

                print(
                    f"Running {estimator.__name__} for {dim}-dimensional case, n={n}, setup={setup}"
                )

                # evaluation
                res = {
                    "predictors": [],
                    "prop_scores": [],
                    "true_psi": np.zeros(num_evals),
                    "estimates": [],
                    "estimate": np.zeros(num_evals),
                    "bias": np.zeros(num_evals),
                    "mse": np.zeros(num_evals),
                    "anb": np.zeros(num_evals),
                    "covered": np.zeros(num_evals),
                    "ci_width": np.zeros(num_evals),
                    "bootstrap_ci": np.zeros((num_evals, 2)),
                    "bootstrap_covered": np.zeros(num_evals),
                }
                for idx, i in enumerate(tqdm(range(seed + 1, seed + num_evals + 1))):
                    dataset_i = datagen.sample_dataset(n, seed=i)
                    predictor = estimator(
                        setup, folds=folds, param_dict=tuned_param_dict, verbose=False,
                    )
                    predictor.fit(dataset_i, seed=i)
                    estimate = predictor.evaluate(reduce=False)

                    res["predictors"].append(predictor)
                    res["prop_scores"].append([
                        pred.treatment.predict_proba(dataset_i.x)[:, 1]
                        for pred in predictor.estimators 
                        if hasattr(pred, "treatment") and pred.treatment is not None
                    ])
                    res["estimates"].append(estimate)
                    res["true_psi"][idx] = true_psi
                    res["estimate"][idx] = np.mean(estimate.flatten())
                    res["bias"][idx] = np.mean(estimate.flatten()) - true_psi
                    res["mse"][idx] = calculate_mse(estimate.flatten(), true_psi)
                    res["anb"][idx] = absolute_normalized_bias(estimate.flatten(), true_psi)
                    res["covered"][idx] = is_covered(estimate.flatten(), true_psi)
                    res["ci_width"][idx] = confidence_interval(estimate.flatten())
                    
                    # bootstrap CI for PMR only
                    if estimator.__name__ == "ProximalMultiplyRobust":
                        psi_means = []
                        for run in range(100):
                            predictor = estimator(
                                setup, folds=folds, param_dict=tuned_param_dict, verbose=False
                            )
                            predictor.fit(dataset_i, seed = num_evals+1+run)
                            psi_means.append(predictor.evaluate(reduce=True))

                        bootstrap_ci = [np.quantile(psi_means, 0.05), np.quantile(psi_means, 0.95)]
                        inside = bootstrap_ci[0] <= true_psi and true_psi <= bootstrap_ci[1]
                        res["bootstrap_ci"][idx, 0] = bootstrap_ci[0]
                        res["bootstrap_ci"][idx, 1] = bootstrap_ci[1]
                        res["bootstrap_covered"][idx] = inside

                results[dim, setup, n, estimator.__name__] = res
                print("bias", np.mean(res["bias"]))
                print("mse", np.mean(res["mse"]))
                print("est", np.mean(res["estimate"]))
                print("true", true_psi)
                print("ci_width", np.mean(res["ci_width"]))
                print("coverage", np.mean(res["covered"]))

                if estimator.__name__ == "ProximalMultiplyRobust":
                    print("bootstrap coverage", np.mean(res["bootstrap_covered"]))

                pickle.dump(results, open(filename, "wb"))

==== Fitting fold 1 (281 fitting, 281 valid.)
Treatment prob params: {'C': 0.1, 'degree': 1} log_loss:  0.696
Bridge q params: {'lambda1': 0.05014961228546211, 'lambda2': 1.4926287082308678, 'gamma1': 0.2480080574000914, 'gamma2': 0.005000000000000002}, score: 0.23030550226204255
Bridge h params: {'lambda1': 0.030884549011656662, 'lambda2': 0.5661078260112563, 'gamma1': 0.02016063531328705, 'gamma2': 0.005000000000000002}, score: 0.05216722063980911
eta params: {'alpha': 0.3273697063589532, 'gamma': 0.02628309571044793}, r2: 0.8392207921095405
==== Fitting fold 2 (281 fitting, 281 valid.)
Treatment prob params: {'C': 0.1, 'degree': 1} log_loss:  0.701
Bridge q params: {'lambda1': 0.08143177390846636, 'lambda2': 0.030884549011656662, 'gamma1': 0.015254478735307919, 'gamma2': 0.005000000000000002}, score: 0.2320121646768003
Bridge h params: {'lambda1': 0.030884549011656662, 'lambda2': 0.9192327197497739, 'gamma1': 0.015254478735307919, 'gamma2': 0.005000000000000002}, score: 0.0468943373

100%|██████████| 10/10 [17:29<00:00, 104.97s/it]


bias 0.1532932292969001
mse 0.09575973843010446
est -0.772847287567426
true -0.9261405168643262
ci_width 0.3704219669446401
coverage 0.5
bootstrap coverage 0.5
==== Fitting fold 1 (563 fitting, 562 valid.)
Treatment prob params: {'C': 0.164, 'degree': 1} log_loss:  0.683
Bridge q params: {'lambda1': 0.24672146895704905, 'lambda2': 0.03548964694594352, 'gamma1': 0.005000000000000002, 'gamma2': 0.005000000000000002}, score: 0.1436770350650347
Bridge h params: {'lambda1': 0.05762726319284316, 'lambda2': 0.03548964694594352, 'gamma1': 0.011542251415688211, 'gamma2': 0.005000000000000002}, score: 0.03419060704436755
eta params: {'alpha': 0.010024213354123476, 'gamma': 0.0024699382190216733}, r2: 0.7021493376222561
==== Fitting fold 2 (563 fitting, 562 valid.)
Treatment prob params: {'C': 0.1, 'degree': 1} log_loss:  0.689
Bridge q params: {'lambda1': 0.4006205823509287, 'lambda2': 0.05762726319284316, 'gamma1': 0.005000000000000002, 'gamma2': 0.005000000000000002}, score: 0.1866441376336992

100%|██████████| 10/10 [31:37<00:00, 189.78s/it]


bias 0.1250581238687218
mse 0.030499102958667933
est -0.8010823929956044
true -0.9261405168643262
ci_width 0.24012775302415693
coverage 0.3
bootstrap coverage 0.8
==== Fitting fold 1 (1125 fitting, 1125 valid.)
Treatment prob params: {'C': 0.1, 'degree': 2} log_loss:  0.691
Bridge q params: {'lambda1': 0.04075965548029613, 'lambda2': 0.10746898225406258, 'gamma1': 0.14198815201729706, 'gamma2': 0.005000000000000002}, score: 0.14450019749760956
Bridge h params: {'lambda1': 0.04075965548029613, 'lambda2': 1.9698857148449014, 'gamma1': 0.008733406762343065, 'gamma2': 0.005000000000000002}, score: 0.013600108718028322
eta params: {'alpha': 0.012760364519243493, 'gamma': 0.0016434135223429905}, r2: 0.7314292721418261
==== Fitting fold 2 (1125 fitting, 1125 valid.)
Treatment prob params: {'C': 590205999.693, 'degree': 1} log_loss:  0.685
Bridge h params: {'lambda1': 0.04075965548029613, 'lambda2': 3.1986546026252993, 'gamma1': 0.015254478735307919, 'gamma2': 0.005000000000000002}, score: 0.0

100%|██████████| 10/10 [2:42:25<00:00, 974.57s/it] 


bias 0.056640389046447394
mse 0.013449286984179656
est -0.8695001278178788
true -0.9261405168643262
ci_width 0.2107623770059258
coverage 0.7
bootstrap coverage 0.9


## Visualize results

In [None]:
results = pickle.load(open(filename, "rb"))

In [None]:
dim = 1
if dim == 1:
    sample_sizes = (np.array([400, 800, 1600])).astype(int)
else:
    sample_sizes = (np.array([800, 1600, 3200])).astype(int)

for setup in setups:
    idx = 1
    positions = []
    for col in range(len(sample_sizes)):
        for i in range(len(estimators)):
            positions.append(idx)
            idx += 1
        idx += 2

    labels = [
        sample_sizes[col] if i == 2 else " "
        for col in range(len(sample_sizes))
        for i in range(3)
    ]

    fig, ax = plt.subplots(figsize=(3 * 2, 4))
    plotdata = []
    for n in sample_sizes:
        for estimator in [est.__name__ for est in estimators]:
            plotdata.append(np.array(results[dim, setup, n, estimator]["bias"]))

    bplot = ax.boxplot(
        np.vstack(plotdata).T,
        positions=positions,
        labels=labels,
        patch_artist=True,
        widths=1,
        showmeans=True,
    )

    base_colors = ["cornflowerblue", "mediumaquamarine", "khaki"]
    colors = [c for _ in range(len(sample_sizes)) for c in base_colors]
    for patch, color in zip(bplot["boxes"], colors):
        patch.set_facecolor(color)

    custom_lines = [Line2D([0], [0], color=c, lw=4) for c in base_colors]
    ax.legend(custom_lines, ["POR", "PIPW", "PMR"])

    ax.axhline(y=0.0, color="k", linestyle="--", zorder=3)
    # plt.ylim(-2.5, 2.5)
    plt.xticks(positions[1 :: len(estimators)], sample_sizes)
    plt.xlabel("Sample Size")
    plt.ylabel("Bias")
    plt.tight_layout()
    # plt.savefig("assets/prox_hidmed_1d_sample_sizes_setup_%s.png" % mode, dpi=500)
    plt.show()

In [None]:
for setup in setups:
    for dim in dims:
        if dim == 1:
            sample_sizes = (np.array([400, 800, 1600])).astype(int)
        else:
            sample_sizes = (np.array([800, 1600, 3200])).astype(int)
    
        for n in sample_sizes:
            # Generate the key for the current combination
            key = (dim, setup, n, "ProximalMultiplyRobust")

            # Check if the key exists in the dictionary
            if key in results:
                # Retrieve the sub-dictionary for the current key
                data = results[key]

                # Print averages for each metric
                print(f"Summary for {key}:")
                for metric in ["estimate", "bias", "mse", "anb", "bootstrap_covered", "bootstrap_ci"]:
                    average = np.mean(data[metric])
                    print(f"  Average {metric}: {average:.4f}")
                print()