In [1]:
from tqdm import tqdm 
import numpy as np

from sempler.generators import dag_avg_deg
from sempler import LGANM

import ges
from scores import HuberScore
from our_utils import *

import matplotlib.pyplot as plt 

from noisy_causal_discovery import noisy_fit

In [2]:
# setting experimental parameters
trials = 100
avg_deg = 1
err_lvl = 0.05

# setting parameters to noisy GES
max_iter = 5
delta = 0.5

# setting dimension and sample increments
d_inc = 5
n_inc = 100

#initializing results matrix
results_rand = np.zeros((4, 5))

In [None]:
for d_iter in range(results_rand.shape[0]):
    for n_iter in range(results_rand.shape[1]):

        if results_rand[d_iter, n_iter] != 0:
            continue

        d, n = d_inc * (d_iter+1), n_inc * (n_iter+1)

        for trial in tqdm(range(trials)):
            np.random.seed(trial)
            G = dag_avg_deg(d, avg_deg, w_min=2, w_max=4)
            data = LGANM(G, (0,0), (1,1)).sample(n=n)

            cpdag_estimate, _ = noisy_fit(HuberScore(data)) # change to HuberScore(data, delta=0.5) 
                                                            # to see behavior of classical GES with
                                                            # the Huber score used in the paper.

            results_rand[d_iter, n_iter] += (1-is_valid(data, cpdag_estimate, err_lvl, G))
                                            # To get true effect coverage replace the above with 
                                            # is_valid(data, cpdag_estimate, err_lvl, G, true_effect=True)
                
        results_rand[d_iter, n_iter] /= trials
        print("d, n=", d, n, ", metrics=", results_rand[d_iter, n_iter])

100%|██████████| 100/100 [00:38<00:00,  2.57it/s]


d, n= 5 100 , metrics= 0.08


100%|██████████| 100/100 [00:40<00:00,  2.45it/s]


d, n= 5 200 , metrics= 0.06


100%|██████████| 100/100 [00:39<00:00,  2.52it/s]


d, n= 5 300 , metrics= 0.05


100%|██████████| 100/100 [00:38<00:00,  2.57it/s]


d, n= 5 400 , metrics= 0.05


100%|██████████| 100/100 [00:39<00:00,  2.54it/s]


d, n= 5 500 , metrics= 0.04


100%|██████████| 100/100 [29:49<00:00, 17.89s/it]  


d, n= 10 100 , metrics= 0.09


100%|██████████| 100/100 [40:22<00:00, 24.23s/it]  


d, n= 10 200 , metrics= 0.11


100%|██████████| 100/100 [03:38<00:00,  2.18s/it]


d, n= 10 300 , metrics= 0.03


100%|██████████| 100/100 [03:48<00:00,  2.28s/it]


d, n= 10 400 , metrics= 0.12


100%|██████████| 100/100 [03:49<00:00,  2.30s/it]


d, n= 10 500 , metrics= 0.08


100%|██████████| 100/100 [15:23<00:00,  9.24s/it]


d, n= 15 100 , metrics= 0.09


100%|██████████| 100/100 [15:48<00:00,  9.48s/it]


d, n= 15 200 , metrics= 0.07


100%|██████████| 100/100 [15:50<00:00,  9.50s/it]


d, n= 15 300 , metrics= 0.1


100%|██████████| 100/100 [23:06<00:00, 13.87s/it]  


d, n= 15 400 , metrics= 0.08


100%|██████████| 100/100 [1:44:30<00:00, 62.70s/it]   


d, n= 15 500 , metrics= 0.1


100%|██████████| 100/100 [53:48<00:00, 32.28s/it] 


d, n= 20 100 , metrics= 0.16


100%|██████████| 100/100 [11:27:03<00:00, 412.23s/it]   


d, n= 20 200 , metrics= 0.14


100%|██████████| 100/100 [1:07:50<00:00, 40.70s/it]


d, n= 20 300 , metrics= 0.08


 22%|██▏       | 22/100 [11:52<33:28, 25.74s/it]  

[NOTE]: To get the plots for classical GES with the Huber Score simply add delta=0.5 in the HuberScore specification above, as per the comment. (by default delta is set to infinity which is equivalent to the BIC score). To get the true effect coverage, simply add the flag true_effect=True in the is_valid function, as per the comment.

In [None]:
d_list = (np.arange(4)+1)*d_inc
n_list = (np.arange(5)+1)*n_inc

In [None]:
plt.figure(figsize=(7, 5))
colors=['darkorchid', 'royalblue', 'firebrick', 'darkorange', 'forestgreen']
for d_iter in range(4):
    d = d_inc * (d_iter+1)
    plt.plot(n_list, results_rand[d_iter], color=colors[d_iter], label="d = " + str(d), linewidth=2)
plt.plot(n_list, err_lvl*np.ones(5), '--', color="black", linewidth=2)
plt.xlabel("sample size", fontsize=20)
plt.xticks(fontsize=14)
plt.ylabel("error probability", fontsize=20)
plt.yticks(fontsize=14)
plt.legend(fontsize=16)
plt.title("Classical GES", fontsize=20)
plt.ylim((0.0, 0.5))
plt.tight_layout();