In [1]:
import functools
import os
import threading

import seaborn as sns

from causal_models.observational_two_covariates_scms import (
    UniformObservationalDataTwoCovariatesSCM,
)
from observational_two_covariates_main import execute_strategy_iteration
from utils import parallel_utils, plot_utils

In [2]:
REPO_DIR = "/path/to/repo/directory"
LOG_DIR = os.path.join(REPO_DIR, "logs", "obs_two_covar")
LOG_PATH = os.path.join(LOG_DIR, "progress.txt")
RESULTS_DIR = os.path.join(REPO_DIR, "results")
PLOTS_DIR = os.path.join(RESULTS_DIR, "plots")

In [12]:
true_scm = UniformObservationalDataTwoCovariatesSCM(
    beta=1, var_nu=1.0, var_nw=1.0, var_ny=1.0, ux=0.3, wx=1, uy=0.2, wy=1
)

In [46]:
def run_in_background(results_dict):
    parallel_utils.get_timeseries_for(
        true_scm,
        [
            "oracle_with_true_nu",
            "fixed_single_source",
            "etc_0.1",
            "etg_0.1",
            "etc_0.2",
            "etg_0.2",
        ],
        horizons=[2000, 4000, 6000, 8000, 10000, 12000],
        iterations=4000,
        results_dict=results_dict,
        execute_fn=functools.partial(
            execute_strategy_iteration, optimal_kappa=0.4, cost_per_source=[2.0, 1.0]
        ),
        log_path=LOG_PATH,
    )

In [13]:
results_dict = {}
thread = threading.Thread(target=run_in_background, args=(results_dict,))
thread.start()

In [2]:
thread.is_alive()

In [9]:
palette = sns.color_palette("colorblind")
name_to_linestyle_color = {
    "etc_0.1": ["solid", palette[0], "o"],
    "etc_0.2": ["dashed", palette[-3], ">"],
    "etg_0.1": ["dotted", palette[2], "^"],
    "etg_0.2": ["loosely dashed", palette[3], "D"],
    "fixed_single_source": ["dashdotted", palette[1], "v"],
}

name_to_label = {"fixed_single_source": "fixed"}

In [1]:
plot_utils.plot_regret_curve(
    results_dict=results_dict,
    name_to_linestyle_color=name_to_linestyle_color,
    name_to_label=name_to_label,
)