In [1]:
import os
import threading

import seaborn as sns

from causal_models.jtpa_scms import JTPADataBootstrapSCM
from jtpa_iv_late_main import execute_strategy_iteration
from utils import parallel_utils, plot_utils

In [2]:
REPO_DIR = "/path/to/repo/directory"
LOG_DIR = os.path.join(REPO_DIR, "logs", "jtpa")
LOG_PATH = os.path.join(LOG_DIR, "progress.txt")
RESULTS_DIR = os.path.join(REPO_DIR, "results")
PLOTS_DIR = os.path.join(RESULTS_DIR, "plots")

In [3]:
DATA_DIR = os.path.join(REPO_DIR, "datasets")
DATA_FILEPATH = os.path.join(DATA_DIR, "jtpa_processed.pkl")

In [4]:
true_scm = JTPADataBootstrapSCM(data_filepath=DATA_FILEPATH)

In [5]:
def run_in_background(results_dict):
    parallel_utils.get_timeseries_for(
        true_scm,
        [
            "complete_data_cross_fit_mlp",
            "oracle_mlp",
            "fixed_equal_mlp",
            "etc_0.1_mlp",
            "etg_0.1_mlp",
            "etc_0.2_mlp",
            "etg_0.2_mlp",
        ],
        horizons=[2000, 4000, 6000, 8000],
        iterations=2000,
        results_dict=results_dict,
        execute_fn=execute_strategy_iteration,
        log_path=LOG_PATH,
    )

In [8]:
results_dict = {}
thread = threading.Thread(target=run_in_background, args=(results_dict,))
thread.start()

In [9]:
thread.is_alive()

In [6]:
palette = sns.color_palette("colorblind")
name_to_linestyle_color = {
    "etc_0.1_mlp": ["solid", palette[0], "o"],
    "etc_0.2_mlp": ["dashed", palette[-3], ">"],
    "etg_0.1_mlp": ["dotted", palette[2], "^"],
    "etg_0.2_mlp": ["loosely dashed", palette[3], "D"],
    "fixed_equal_mlp": ["dashdotted", palette[1], "v"],
    "oracle_mlp": ["long dash with offset", palette[-2], "*"],
}

name_to_label = {"fixed_equal": "fixed"}

In [12]:
plot_utils.plot_mse_curve(
    results_dict=results_dict,
    name_to_linestyle_color=name_to_linestyle_color,
    name_to_label=name_to_label,
)