In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
from hidmed import *
import pickle
import matplotlib.pyplot as plt

In [3]:
def tune_parameters(xdim, zdim, wdim, mdim, udim, setup, n, folds=1, seed=0):
    """Tune the parameters of the estimators for a given setup."""
    assert setup in ["a", "b", "c"], "Invalid setup. Must be 'a', 'b', or 'c'."
    datagen = LinearHidMedDGP(xdim, zdim, wdim, mdim, udim, setup=setup, seed=seed)

    estimator = ProximalMultiplyRobust(
        generalized_model=(setup == "b" or setup == "c"),
        folds=folds,
        num_runs=200,
        n_jobs=1,
    )
    dataset = datagen.sample_dataset(n, seed=seed+1)
    estimator.fit(dataset)
    return estimator.params

# 1d-case
tuned_parameters_1d = {}
for setup in ["a", "b", "c"]:
    for n in np.arange(200, 6000, 500):
        print(f"Setup: {setup}, n: {n}")
        tuned_parameters_1d[setup, n] = tune_parameters(1, 1, 1, 1, 1, setup, n, folds=2)
        print("\n")

pickle.dump(tuned_parameters_1d, open("tuned_parameters_1d.pkl", "wb"))

# 5d-case
tuned_parameters_5d = {}
for setup in ["a", "b", "c"]:
    for n in np.arange(200, 6000, 500):
        print(f"Setup: {setup}, n: {n}")
        tuned_parameters_5d[setup, n] = tune_parameters(5, 5, 5, 5, 5, setup, n, folds=2)

pickle.dump(tuned_parameters_5d, open("tuned_parameters_5d.pkl", "wb"))

Setup: a, n: 200
==== Cross-fitting fold 1 (67/134 fit/eval)
Bridge h params: {'lambda1': 4e-06, 'lambda2': 0.03445, 'gamma': 0.0010217} score:  -0.0074084
Bridge q params: {'lambda1': 6.19e-05, 'lambda2': 0.0872144, 'gamma': 0.0011102} score:  -0.3440903
eta params: {'alpha': 0.006, 'gamma': 0.003} r2:  0.977
==== Estimate 1: 0.775510815601925
==== Cross-fitting fold 2 (67/134 fit/eval)
Bridge h params: {'lambda1': 4e-06, 'lambda2': 0.03445, 'gamma': 0.0010217} score:  nan
Bridge q params: {'lambda1': 6.19e-05, 'lambda2': 0.0872144, 'gamma': 0.0011102} score:  nan
eta params: {'alpha': 0.006, 'gamma': 0.003} r2:  nan
==== Estimate 2: 1.4015360035146316
==== Estimate: 1.0885234095582783


Setup: a, n: 700
==== Cross-fitting fold 1 (234/466 fit/eval)
Bridge h params: {'lambda1': 0.001629, 'lambda2': 3e-07, 'gamma': 0.0028176} score:  4.7531446
Bridge q params: {'lambda1': 5.32e-05, 'lambda2': 0.1, 'gamma': 0.001714} score:  -0.2460591
eta params: {'alpha': 0.0, 'gamma': 0.192} r2:  0.98