In [1]:
import os

import numpy as np
from tqdm import tqdm

from causal_models.jtpa_scms import JTPADataSCM
from jtpa_iv_late_main import execute_strategy_iteration

In [2]:
REPO_DIR = "/path/to/repo/directory"
DATA_DIR = os.path.join(REPO_DIR, "datasets")
DATA_FILEPATH = os.path.join(DATA_DIR, "jtpa_processed.pkl")

In [3]:
true_scm = JTPADataSCM(data_filepath=DATA_FILEPATH)

In [4]:
def compute_oracle_late_value(num_iters: int) -> tuple[float, float]:

    vals = []
    for i in tqdm(range(num_iters)):
        res = execute_strategy_iteration(
            true_scm=true_scm,
            strategy_name="complete_data_cross_fit_mlp",
            iteration_num=i,
            horizon=true_scm.get_max_size(),
        )
        vals.append(res.ate_hats[-1])
    
    return np.mean(vals), 1.96 * np.std(vals) / np.sqrt(num_iters)

In [5]:
oracle_late = compute_oracle_late_value(num_iters=2000)
print(oracle_late)

100%|██████████| 2000/2000 [17:20<00:00,  1.92it/s]

(0.13658784058673604, 0.000301806878367079)



