In [2]:
import functools
import os
import threading

import seaborn as sns

from causal_models.iv_with_covariates_scms import UniformIVCovariatesSCM
from logistic_iv_LATE_main import execute_strategy_iteration
from utils import parallel_utils, plot_utils

In [2]:
REPO_DIR = "/path/to/repo/directory"
LOG_DIR = os.path.join(REPO_DIR, "logs", "iv_late")
LOG_PATH = os.path.join(LOG_DIR, "progress.txt")
RESULTS_DIR = os.path.join(REPO_DIR, "results")
PLOTS_DIR = os.path.join(RESULTS_DIR, "plots")

In [None]:
true_scm = UniformIVCovariatesSCM(
    wz=0.1,
    wx=0.1,
    wy=0.1,
    ux=0.1,
    uy=1,
    zx=3,
    xy=0.5,
    var_nw=0.1,
    var_nu=0.1,
    var_nx=0,
    var_ny=0.8,
    bias_x=-1.5,
)

In [4]:
def run_in_background(results_dict):
    parallel_utils.get_timeseries_for(
        true_scm,
        [
            "oracle_with_true_nu",
            "fixed_equal",
            "etc_0.1",
            "etg_0.1",
            "etc_0.2",
            "etg_0.2",
        ],
        horizons=[2000, 4000, 6000, 8000],
        iterations=16000,
        results_dict=results_dict,
        execute_fn=functools.partial(execute_strategy_iteration, oracle_kappa=0.65),
        log_path=LOG_PATH,
    )

In [10]:
results_dict = {}
thread = threading.Thread(target=run_in_background, args=(results_dict,))
thread.start()

In [3]:
thread.is_alive()

In [8]:
palette = sns.color_palette("colorblind")
name_to_linestyle_color = {
    "etc_0.1": ["solid", palette[0], "o"],
    "etc_0.2": ["dashed", palette[-3], ">"],
    "etg_0.1": ["dotted", palette[2], "^"],
    "etg_0.2": ["loosely dashed", palette[3], "D"],
    "fixed_equal": ["dashdotted", palette[1], "v"],
}

name_to_label = {"fixed_equal": "fixed"}

In [4]:
plot_utils.plot_regret_curve(
    results_dict=results_dict,
    name_to_linestyle_color=name_to_linestyle_color,
    name_to_label=name_to_label
)