In [1]:
import functools
import os
import threading

import seaborn as sns

from causal_models.copd_data_scms import CopdDataSCM
from copd_data_main import execute_strategy_iteration
from utils import parallel_utils, plot_utils

In [2]:
REPO_DIR = "/path/to/repo/directory"
LOG_DIR = os.path.join(REPO_DIR, "logs", "copd")
LOG_PATH = os.path.join(LOG_DIR, "progress.txt")
RESULTS_DIR = os.path.join(REPO_DIR, "results")
PLOTS_DIR = os.path.join(RESULTS_DIR, "plots")

In [3]:
DATA_DIR = os.path.join(REPO_DIR, "datasets", "yang_and_ding")
DATA_VAL_FILEPATH = os.path.join(DATA_DIR, "validation_ns3.csv")
DATA_MAIN_FILEPATH = os.path.join(DATA_DIR, "main_ns3.csv")

In [4]:
true_scm_val = CopdDataSCM(data_filepath=DATA_VAL_FILEPATH)
true_scm_main = CopdDataSCM(data_filepath=DATA_MAIN_FILEPATH)

In [5]:
def run_in_background(results_dict):
    parallel_utils.get_timeseries_for(
        true_scm_val,
        [
            "oracle",
            "single_source",
            "etc_0.1",
            "etg_0.1",
            "etc_0.2",
            "etg_0.2",
        ],
        horizons=[4000, 6000, 8000, 10000, 12000, 14000, 16000],
        iterations=2000,
        results_dict=results_dict,
        execute_fn=functools.partial(
            execute_strategy_iteration,
            true_scm_val=true_scm_val,
            true_scm_main=true_scm_main,
            optimal_kappa=0.11,
            cost_per_source=[4.0, 1.0],
        ),
        log_path=LOG_PATH,
    )

In [1]:
results_dict = {}
thread = threading.Thread(target=run_in_background, args=(results_dict,))
thread.start()

In [2]:
thread.is_alive()

In [15]:
palette = sns.color_palette("colorblind")
name_to_linestyle_color = {
    "etc_0.1": ["solid", palette[0], "o"],
    "etc_0.2": ["dashed", palette[-3], ">"],
    "etg_0.1": ["dotted", palette[2], "^"],
    "etg_0.2": ["loosely dashed", palette[3], "D"],
    "single_source": ["dashdotted", palette[1], "v"],
    "oracle": ["long dash with offset", palette[-2], "*"],
}

name_to_label = {"single_source": "fixed"}

In [3]:
plot_utils.plot_mse_curve(
    results_dict=results_dict,
    name_to_linestyle_color=name_to_linestyle_color,
    name_to_label=name_to_label,
)