In [1]:
from dagma_nl import dagmats

In [7]:
def dagma(lambda1, lambda2, lr, thresh):
    from generate_data import SyntheticDataset
    n, d, p = 1000, 5, 3
    dag_obj = SyntheticDataset(n, d, p, B_scale=1.0, graph_type='ER', degree=2, A_scale=1.0, noise_type='EV', mlp=False)

    A_true = dag_obj.A
    X = dag_obj.X
    Y = dag_obj.Y

    eq_model = dagmats.DagmaTS(n=n, p=p, d=d)
    model = dagmats.DagmaLinear(eq_model, verbose=False)
    W_est = model.fit(X, Y, lambda1=lambda1, lambda2=lambda2, lr=lr, w_threshold=thresh)
    return W_est, A_true

In [8]:
def dagma_grid(lambda1_list, lambda2_list, lr_list, thresh_list):
    out_list = []
    for thresh in thresh_list:
        for lambda1 in lambda1_list:
            for lambda2 in lambda2_list:
                for lr in lr_list:
                    out_list.append({
                        'thresh': thresh,
                        'lambda1': lambda1,
                        'lambda2': lambda2,
                        'lr': lr
                    })
    return out_list


In [9]:
grid = dagma_grid(
    lambda1_list=[0.02, 0.05, 0.1, 0.2],
    lambda2_list=[0.005, 0.01, 0.03],
    lr_list=[0.001, 0.005, 0.01],
    thresh_list=[0.01, 0.1, 0.2, 0.3]
)

In [10]:
import grid_search

In [11]:
grid_search.perform_grid_search(grid=grid, model_func=dagma, output_dir='dagmats01.jsonl', thresh=0.01)

  0%|          | 0/230000.0 [00:00<?, ?it/s]

{'fdr': 0.0, 'tpr': 0.3620689655172414, 'fpr': 0.0, 'shd': 37, 'pred_size': 21, 'thresh': 0.01, 'lambda1': 0.005, 'lambda2': 0.005, 'lr': 0.005}


  0%|          | 0/230000.0 [00:00<?, ?it/s]

{'fdr': 0.05555555555555555, 'tpr': 0.25757575757575757, 'fpr': 0.008064516129032258, 'shd': 49, 'pred_size': 18, 'thresh': 0.01, 'lambda1': 0.005, 'lambda2': 0.005, 'lr': 0.01}


  0%|          | 0/230000.0 [00:00<?, ?it/s]

{'fdr': 0.16, 'tpr': 0.2916666666666667, 'fpr': 0.03389830508474576, 'shd': 51, 'pred_size': 25, 'thresh': 0.01, 'lambda1': 0.005, 'lambda2': 0.005, 'lr': 0.02}


  0%|          | 0/230000.0 [00:00<?, ?it/s]

{'fdr': 0.125, 'tpr': 0.23728813559322035, 'fpr': 0.015267175572519083, 'shd': 45, 'pred_size': 16, 'thresh': 0.01, 'lambda1': 0.005, 'lambda2': 0.005, 'lr': 0.05}


  0%|          | 0/230000.0 [00:00<?, ?it/s]

{'fdr': 0.0, 'tpr': 0.30158730158730157, 'fpr': 0.0, 'shd': 44, 'pred_size': 19, 'thresh': 0.01, 'lambda1': 0.005, 'lambda2': 0.01, 'lr': 0.005}


  0%|          | 0/230000.0 [00:00<?, ?it/s]

{'fdr': 0.058823529411764705, 'tpr': 0.2909090909090909, 'fpr': 0.007407407407407408, 'shd': 39, 'pred_size': 17, 'thresh': 0.01, 'lambda1': 0.005, 'lambda2': 0.01, 'lr': 0.01}


  0%|          | 0/230000.0 [00:00<?, ?it/s]

{'fdr': 0.05555555555555555, 'tpr': 0.2833333333333333, 'fpr': 0.007692307692307693, 'shd': 43, 'pred_size': 18, 'thresh': 0.01, 'lambda1': 0.005, 'lambda2': 0.01, 'lr': 0.02}


  0%|          | 0/230000.0 [00:00<?, ?it/s]

{'fdr': 0.1, 'tpr': 0.3, 'fpr': 0.015384615384615385, 'shd': 42, 'pred_size': 20, 'thresh': 0.01, 'lambda1': 0.005, 'lambda2': 0.01, 'lr': 0.05}


  0%|          | 0/230000.0 [00:00<?, ?it/s]

{'fdr': 0.0, 'tpr': 0.34545454545454546, 'fpr': 0.0, 'shd': 36, 'pred_size': 19, 'thresh': 0.01, 'lambda1': 0.005, 'lambda2': 0.03, 'lr': 0.005}


  0%|          | 0/230000.0 [00:00<?, ?it/s]

{'fdr': 0.0, 'tpr': 0.2711864406779661, 'fpr': 0.0, 'shd': 43, 'pred_size': 16, 'thresh': 0.01, 'lambda1': 0.005, 'lambda2': 0.03, 'lr': 0.01}


  0%|          | 0/230000.0 [00:00<?, ?it/s]

{'fdr': 0.11764705882352941, 'tpr': 0.234375, 'fpr': 0.015873015873015872, 'shd': 49, 'pred_size': 17, 'thresh': 0.01, 'lambda1': 0.005, 'lambda2': 0.03, 'lr': 0.02}


  0%|          | 0/230000.0 [00:00<?, ?it/s]

{'fdr': 0.35294117647058826, 'tpr': 0.15714285714285714, 'fpr': 0.05, 'shd': 59, 'pred_size': 17, 'thresh': 0.01, 'lambda1': 0.005, 'lambda2': 0.03, 'lr': 0.05}


  0%|          | 0/230000.0 [00:00<?, ?it/s]

{'fdr': 0.0, 'tpr': 0.34782608695652173, 'fpr': 0.0, 'shd': 45, 'pred_size': 24, 'thresh': 0.01, 'lambda1': 0.005, 'lambda2': 0.05, 'lr': 0.005}


  0%|          | 0/230000.0 [00:00<?, ?it/s]

{'fdr': 0.125, 'tpr': 0.22950819672131148, 'fpr': 0.015503875968992248, 'shd': 47, 'pred_size': 16, 'thresh': 0.01, 'lambda1': 0.005, 'lambda2': 0.05, 'lr': 0.01}


  0%|          | 0/230000.0 [00:00<?, ?it/s]

{'fdr': 0.0, 'tpr': 0.296875, 'fpr': 0.0, 'shd': 45, 'pred_size': 19, 'thresh': 0.01, 'lambda1': 0.005, 'lambda2': 0.05, 'lr': 0.02}


  0%|          | 0/230000.0 [00:00<?, ?it/s]

{'fdr': 0.35294117647058826, 'tpr': 0.15942028985507245, 'fpr': 0.049586776859504134, 'shd': 58, 'pred_size': 17, 'thresh': 0.01, 'lambda1': 0.005, 'lambda2': 0.05, 'lr': 0.05}


  0%|          | 0/230000.0 [00:00<?, ?it/s]

{'fdr': 0.0, 'tpr': 0.26785714285714285, 'fpr': 0.0, 'shd': 41, 'pred_size': 15, 'thresh': 0.01, 'lambda1': 0.01, 'lambda2': 0.005, 'lr': 0.005}


  0%|          | 0/230000.0 [00:00<?, ?it/s]

{'fdr': 0.0, 'tpr': 0.30357142857142855, 'fpr': 0.0, 'shd': 39, 'pred_size': 17, 'thresh': 0.01, 'lambda1': 0.01, 'lambda2': 0.005, 'lr': 0.01}


  0%|          | 0/230000.0 [00:00<?, ?it/s]

{'fdr': 0.05, 'tpr': 0.3392857142857143, 'fpr': 0.007462686567164179, 'shd': 37, 'pred_size': 20, 'thresh': 0.01, 'lambda1': 0.01, 'lambda2': 0.005, 'lr': 0.02}


  0%|          | 0/230000.0 [00:00<?, ?it/s]

{'fdr': 0.19047619047619047, 'tpr': 0.27419354838709675, 'fpr': 0.03125, 'shd': 45, 'pred_size': 21, 'thresh': 0.01, 'lambda1': 0.01, 'lambda2': 0.005, 'lr': 0.05}


  0%|          | 0/230000.0 [00:00<?, ?it/s]

{'fdr': 0.0, 'tpr': 0.3469387755102041, 'fpr': 0.0, 'shd': 32, 'pred_size': 17, 'thresh': 0.01, 'lambda1': 0.01, 'lambda2': 0.01, 'lr': 0.005}


  0%|          | 0/230000.0 [00:00<?, ?it/s]

{'fdr': 0.04, 'tpr': 0.3870967741935484, 'fpr': 0.0078125, 'shd': 38, 'pred_size': 25, 'thresh': 0.01, 'lambda1': 0.01, 'lambda2': 0.01, 'lr': 0.01}


  0%|          | 0/230000.0 [00:00<?, ?it/s]

{'fdr': 0.09523809523809523, 'tpr': 0.3114754098360656, 'fpr': 0.015503875968992248, 'shd': 42, 'pred_size': 21, 'thresh': 0.01, 'lambda1': 0.01, 'lambda2': 0.01, 'lr': 0.02}


  0%|          | 0/230000.0 [00:00<?, ?it/s]

{'fdr': 0.0, 'tpr': 0.32075471698113206, 'fpr': 0.0, 'shd': 36, 'pred_size': 17, 'thresh': 0.01, 'lambda1': 0.01, 'lambda2': 0.01, 'lr': 0.05}


  0%|          | 0/230000.0 [00:00<?, ?it/s]

{'fdr': 0.0, 'tpr': 0.3333333333333333, 'fpr': 0.0, 'shd': 36, 'pred_size': 18, 'thresh': 0.01, 'lambda1': 0.01, 'lambda2': 0.03, 'lr': 0.005}


  0%|          | 0/230000.0 [00:00<?, ?it/s]

{'fdr': 0.0, 'tpr': 0.2830188679245283, 'fpr': 0.0, 'shd': 38, 'pred_size': 15, 'thresh': 0.01, 'lambda1': 0.01, 'lambda2': 0.03, 'lr': 0.01}


  0%|          | 0/230000.0 [00:00<?, ?it/s]

{'fdr': 0.0, 'tpr': 0.31666666666666665, 'fpr': 0.0, 'shd': 41, 'pred_size': 19, 'thresh': 0.01, 'lambda1': 0.01, 'lambda2': 0.03, 'lr': 0.02}


  0%|          | 0/230000.0 [00:00<?, ?it/s]

{'fdr': 0.2777777777777778, 'tpr': 0.20967741935483872, 'fpr': 0.0390625, 'shd': 49, 'pred_size': 18, 'thresh': 0.01, 'lambda1': 0.01, 'lambda2': 0.03, 'lr': 0.05}


  0%|          | 0/230000.0 [00:00<?, ?it/s]

{'fdr': 0.0, 'tpr': 0.3333333333333333, 'fpr': 0.0, 'shd': 42, 'pred_size': 21, 'thresh': 0.01, 'lambda1': 0.01, 'lambda2': 0.05, 'lr': 0.005}


  0%|          | 0/230000.0 [00:00<?, ?it/s]

{'fdr': 0.11764705882352941, 'tpr': 0.23809523809523808, 'fpr': 0.015748031496062992, 'shd': 48, 'pred_size': 17, 'thresh': 0.01, 'lambda1': 0.01, 'lambda2': 0.05, 'lr': 0.01}


  0%|          | 0/230000.0 [00:00<?, ?it/s]

{'fdr': 0.058823529411764705, 'tpr': 0.3076923076923077, 'fpr': 0.007246376811594203, 'shd': 36, 'pred_size': 17, 'thresh': 0.01, 'lambda1': 0.01, 'lambda2': 0.05, 'lr': 0.02}


  0%|          | 0/230000.0 [00:00<?, ?it/s]

{'fdr': 0.13333333333333333, 'tpr': 0.24528301886792453, 'fpr': 0.014598540145985401, 'shd': 40, 'pred_size': 15, 'thresh': 0.01, 'lambda1': 0.01, 'lambda2': 0.05, 'lr': 0.05}


  0%|          | 0/230000.0 [00:00<?, ?it/s]

{'fdr': 0.0, 'tpr': 0.5142857142857142, 'fpr': 0.0, 'shd': 17, 'pred_size': 18, 'thresh': 0.01, 'lambda1': 0.02, 'lambda2': 0.005, 'lr': 0.005}


  0%|          | 0/230000.0 [00:00<?, ?it/s]

{'fdr': 0.0, 'tpr': 0.45, 'fpr': 0.0, 'shd': 22, 'pred_size': 18, 'thresh': 0.01, 'lambda1': 0.02, 'lambda2': 0.005, 'lr': 0.01}


  0%|          | 0/230000.0 [00:00<?, ?it/s]

{'fdr': 0.0, 'tpr': 0.4, 'fpr': 0.0, 'shd': 27, 'pred_size': 18, 'thresh': 0.01, 'lambda1': 0.02, 'lambda2': 0.005, 'lr': 0.02}


  0%|          | 0/230000.0 [00:00<?, ?it/s]

{'fdr': 0.1111111111111111, 'tpr': 0.32653061224489793, 'fpr': 0.014184397163120567, 'shd': 33, 'pred_size': 18, 'thresh': 0.01, 'lambda1': 0.02, 'lambda2': 0.005, 'lr': 0.05}


  0%|          | 0/230000.0 [00:00<?, ?it/s]

{'fdr': 0.0, 'tpr': 0.4318181818181818, 'fpr': 0.0, 'shd': 25, 'pred_size': 19, 'thresh': 0.01, 'lambda1': 0.02, 'lambda2': 0.01, 'lr': 0.005}


  0%|          | 0/230000.0 [00:00<?, ?it/s]

{'fdr': 0.1111111111111111, 'tpr': 0.38095238095238093, 'fpr': 0.013513513513513514, 'shd': 26, 'pred_size': 18, 'thresh': 0.01, 'lambda1': 0.02, 'lambda2': 0.01, 'lr': 0.01}


  0%|          | 0/230000.0 [00:00<?, ?it/s]

{'fdr': 0.041666666666666664, 'tpr': 0.5, 'fpr': 0.006944444444444444, 'shd': 23, 'pred_size': 24, 'thresh': 0.01, 'lambda1': 0.02, 'lambda2': 0.01, 'lr': 0.02}


  0%|          | 0/230000.0 [00:00<?, ?it/s]

{'fdr': 0.25, 'tpr': 0.25862068965517243, 'fpr': 0.03787878787878788, 'shd': 43, 'pred_size': 20, 'thresh': 0.01, 'lambda1': 0.02, 'lambda2': 0.01, 'lr': 0.05}


  0%|          | 0/230000.0 [00:00<?, ?it/s]

{'fdr': 0.0, 'tpr': 0.4, 'fpr': 0.0, 'shd': 24, 'pred_size': 16, 'thresh': 0.01, 'lambda1': 0.02, 'lambda2': 0.03, 'lr': 0.005}


  0%|          | 0/230000.0 [00:00<?, ?it/s]

{'fdr': 0.13333333333333333, 'tpr': 0.30952380952380953, 'fpr': 0.013513513513513514, 'shd': 29, 'pred_size': 15, 'thresh': 0.01, 'lambda1': 0.02, 'lambda2': 0.03, 'lr': 0.01}


  0%|          | 0/230000.0 [00:00<?, ?it/s]

{'fdr': 0.0625, 'tpr': 0.38461538461538464, 'fpr': 0.006622516556291391, 'shd': 24, 'pred_size': 16, 'thresh': 0.01, 'lambda1': 0.02, 'lambda2': 0.03, 'lr': 0.02}


  0%|          | 0/230000.0 [00:00<?, ?it/s]

{'fdr': 0.0, 'tpr': 0.46153846153846156, 'fpr': 0.0, 'shd': 21, 'pred_size': 18, 'thresh': 0.01, 'lambda1': 0.02, 'lambda2': 0.03, 'lr': 0.05}


  0%|          | 0/230000.0 [00:00<?, ?it/s]

{'fdr': 0.07142857142857142, 'tpr': 0.3023255813953488, 'fpr': 0.006802721088435374, 'shd': 30, 'pred_size': 14, 'thresh': 0.01, 'lambda1': 0.02, 'lambda2': 0.05, 'lr': 0.005}


  0%|          | 0/230000.0 [00:00<?, ?it/s]

{'fdr': 0.10526315789473684, 'tpr': 0.3695652173913043, 'fpr': 0.013888888888888888, 'shd': 29, 'pred_size': 19, 'thresh': 0.01, 'lambda1': 0.02, 'lambda2': 0.05, 'lr': 0.01}


  0%|          | 0/230000.0 [00:00<?, ?it/s]

{'fdr': 0.2, 'tpr': 0.2962962962962963, 'fpr': 0.029411764705882353, 'shd': 38, 'pred_size': 20, 'thresh': 0.01, 'lambda1': 0.02, 'lambda2': 0.05, 'lr': 0.02}


  0%|          | 0/230000.0 [00:00<?, ?it/s]

{'fdr': 0.2857142857142857, 'tpr': 0.24193548387096775, 'fpr': 0.046875, 'shd': 47, 'pred_size': 21, 'thresh': 0.01, 'lambda1': 0.02, 'lambda2': 0.05, 'lr': 0.05}


  0%|          | 0/230000.0 [00:00<?, ?it/s]

{'fdr': 0.0, 'tpr': 0.7619047619047619, 'fpr': 0.0, 'shd': 5, 'pred_size': 16, 'thresh': 0.01, 'lambda1': 0.05, 'lambda2': 0.005, 'lr': 0.005}


  0%|          | 0/230000.0 [00:00<?, ?it/s]

KeyboardInterrupt: 