In [None]:
import sys
sys.path.insert(0, "../pycre/")

from cre import *
from parsers import get_parser
from dataset import dataset_generator

import numpy as np
import pandas as pd

import seaborn as sns
import matplotlib.pyplot as plt

from tqdm import tqdm

# Simulations

In [None]:
# settings
effect_sizes = np.arange(0, 4, 0.2)
N = 2000
t_ss = 0.5
n_seeds = 10
methods = ["tlearner", "slearner", "xlearner","causalforest", "drlearner", "aipw"]

# ground truth
rules = ["(X['x1']>0.5) & (X['x2']<=0.5)", "(X['x5']>0.5) & (X['x6']<=0.5)"]

In [None]:
result = pd.DataFrame(columns = ['estimator','effect_size','seed','recall', 'precision', 'f1_score'])
for method in tqdm(methods):
    for effect_size in effect_sizes:
        for seed in range(n_seeds):
            X, y, z, ite = dataset_generator(N = N, 
                                             P = 10, 
                                             binary_cov = True,
                                             binary_out = False, 
                                             effect_size = effect_size,
                                             confounding = "no",
                                             M = 2,
                                             seed = seed)

            model = CRE(verbose = False,
                        method = method,
                        t_ss = t_ss,
                        seed = seed,)
            model.fit(X, y, z)
            rules_pred = model.rules
            TP = len(set(rules_pred).intersection(set(rules)))
            FP = len(set(rules_pred).difference(set(rules)))
            FN = len(set(rules).difference(set(rules_pred)))
            recall = TP/(TP+FN)
            if (TP+FP)>0:
                precision = TP/(TP+FP)
            else: 
                precision = 0
            F1 = 2*TP/(2*TP+FP+FN)
            result = result.append({'estimator': method,
                                    'effect_size': effect_size,
                                    'seed': seed, 
                                    'recall': recall, 
                                    'precision': precision, 
                                    'f1_score': F1}, 
                                    ignore_index=True)
result

# Visulization

In [None]:
fig, axes = plt.subplots(3, 1, figsize=(10, 15), sharey=True)
sns.set_theme(style="whitegrid")
sns.lineplot(ax=axes[0], x="effect_size", y="recall", hue="estimator", data=result)
axes[0].set_ylabel("Recall")
axes[0].set_xlabel("Effect size")
sns.lineplot(ax=axes[1], x="effect_size", y="precision", hue="estimator", data=result)
axes[1].set_ylabel("Precision")
axes[1].set_xlabel("Effect size")
sns.lineplot(ax=axes[2], x="effect_size", y="f1_score", hue="estimator", data=result)
axes[2].set_ylabel("F1 score")
axes[2].set_xlabel("Effect size")
plt.show()