In [1]:
from tqdm import tqdm 
import numpy as np

from sempler.generators import dag_avg_deg
from sempler import LGANM
from numpy.random import laplace

import ges
from our_utils import *

import matplotlib.pyplot as plt 

from noisy_causal_discovery import *

In [2]:
# setting experimental parameters
trials = 100
avg_deg = 2
err_lvl = 0.05

# setting parameters to noisy GES
max_iter = 5
delta = 0.5

# setting dimension and sample increments
d_inc = 5
n_inc = 50
no_graphs = 10
p_add = 0.01


#initializing results matrix
results_rand = np.zeros((4, 5))

In [None]:
eps = 0.02

for d_iter in range(results_rand.shape[0]):
    for n_iter in range(results_rand.shape[1]):

        if results_rand[d_iter, n_iter] != 0:
            continue

        d, n = d_inc * (d_iter+1), n_inc * (n_iter+1)

        err_lvl_adj = alpha_tilde(err_lvl, eps, n)

        for trial in tqdm(range(trials)):
            
            G = dag_avg_deg(d, avg_deg, w_min=2, w_max=4)
            p_remove = np.clip(2.5/np.sum(G != 0), 0, 1)
            data = LGANM(G, (0,0), (1,1)).sample(n=n)
            
            G_options = generate_graphs(G, no_graphs, p_remove, p_add)
            
            
            sensitivity = get_sensitivity(data, G_options, delta=delta)
            noise_lvl = noise_scale(eps, sensitivity)
            noisy_scores = [score(data, G_options[j], delta=delta) + laplace(scale=noise_lvl) for j in range(no_graphs)]
            stable_G_est = G_options[np.argmin(noisy_scores)]
            
            results_rand[d_iter, n_iter] += (1-is_valid(data, stable_G_est, err_lvl_adj, G, true_effect=False))
                                            # To get true effect coverage replace the above with 
                                            # is_valid(data, cpdag_estimate, err_lvl, G, true_effect=True)

        results_rand[d_iter, n_iter] /= trials
        print("eps, d, n=", eps, d, n, ", metrics=", results_rand[d_iter, n_iter])

100%|██████████| 100/100 [02:01<00:00,  1.21s/it]


eps, d, n= 0.02 5 50 , metrics= 0.03


100%|██████████| 100/100 [02:06<00:00,  1.27s/it]


eps, d, n= 0.02 5 100 , metrics= 0.05


100%|██████████| 100/100 [02:31<00:00,  1.51s/it]


eps, d, n= 0.02 5 150 , metrics= 0.06


  p_remove = np.clip(2.5/np.sum(G != 0), 0, 1)
100%|██████████| 100/100 [02:13<00:00,  1.33s/it]


eps, d, n= 0.02 5 200 , metrics= 0.05


100%|██████████| 100/100 [02:05<00:00,  1.26s/it]


eps, d, n= 0.02 5 250 , metrics= 0.05


100%|██████████| 100/100 [04:05<00:00,  2.46s/it]


eps, d, n= 0.02 10 50 , metrics= 0.07


100%|██████████| 100/100 [04:01<00:00,  2.41s/it]


eps, d, n= 0.02 10 100 , metrics= 0.05


100%|██████████| 100/100 [04:00<00:00,  2.40s/it]


eps, d, n= 0.02 10 150 , metrics= 0.03


100%|██████████| 100/100 [03:27<00:00,  2.08s/it]


eps, d, n= 0.02 10 200 , metrics= 0.01


100%|██████████| 100/100 [03:10<00:00,  1.91s/it]


eps, d, n= 0.02 10 250 , metrics= 0.03


100%|██████████| 100/100 [04:59<00:00,  2.99s/it]


eps, d, n= 0.02 15 50 , metrics= 0.07


100%|██████████| 100/100 [04:45<00:00,  2.85s/it]


eps, d, n= 0.02 15 100 , metrics= 0.06


100%|██████████| 100/100 [04:53<00:00,  2.94s/it]


eps, d, n= 0.02 15 150 , metrics= 0.02


100%|██████████| 100/100 [04:31<00:00,  2.72s/it]


eps, d, n= 0.02 15 200 , metrics= 0.01


100%|██████████| 100/100 [04:37<00:00,  2.77s/it]


eps, d, n= 0.02 15 250 , metrics= 0.06


100%|██████████| 100/100 [06:18<00:00,  3.78s/it]


eps, d, n= 0.02 20 50 , metrics= 0.08


  5%|▌         | 5/100 [00:18<06:25,  4.05s/it]

[Note:] To get the true effect coverage, simply add the flag true_effect=True in the is_valid function, as per the comment.

In [None]:
d_list = (np.arange(4)+1)*d_inc
n_list = (np.arange(5)+1)*n_inc

In [None]:
plt.figure(figsize=(7, 5))
colors=['darkorchid', 'royalblue', 'firebrick', 'darkorange', 'forestgreen']
for d_iter in range(4):
    d = d_inc * (d_iter+1)
    plt.plot(n_list, results_rand[d_iter], color=colors[d_iter], label="d = " + str(d), linewidth=2)
plt.plot(n_list, err_lvl*np.ones(5), '--', color="black", linewidth=2)
plt.xlabel("sample size", fontsize=20)
plt.xticks(fontsize=14)
plt.ylabel("error probability", fontsize=20)
plt.yticks(fontsize=14)
plt.legend(fontsize=16)
plt.title("Noisy Select ($\epsilon = 0.02$)", fontsize=20)
plt.ylim((0.0, 0.5))
plt.tight_layout();