In [2]:
%reload_ext autoreload
%autoreload 2
import torch
import dagrad.flex as flex
from dagrad.utils import utils
import numpy as np
from dagrad.flex.prune import cam_pruning

utils.set_random_seed(1)
torch.manual_seed(1)

No GPU automatically detected. Setting SETTINGS.GPU to 0, and SETTINGS.NJOBS to cpu_count.
IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html


<torch._C.Generator at 0x120ce5ad0>

In [3]:
def postprocess(B, graph_thres=0.3):
    """Post-process estimated solution:
        (1) Thresholding.
        (2) Remove the edges with smallest absolute weight until a DAG
            is obtained.

    Args:
        B (numpy.ndarray): [d, d] weighted matrix.
        graph_thres (float): Threshold for weighted matrix. Default: 0.3.

    Returns:
        numpy.ndarray: [d, d] weighted matrix of DAG.
    """
    B = np.copy(B)
    B[np.abs(B) <= graph_thres] = 0    # Thresholding
    B, _ = utils.threshold_till_dag(B)

    return B

## Struct Learn via Flex

`flex.struct_learn` is a function that requires the following 5 inputs:
- `dataset`: A data matrix with `n x d` shape, where `n` is the number of samples and `d` is the number of variables/features.

- `model`: The SEM module. We provide standard SEM implementations such as---LinearModel, LogisticModel, MLP

- `constrained_solver`: An instance of a ConstrainedSolver class. We provide implementations such as `PathFollowing` and `AugmentedLagrangian`

- `unconstrained_solver`: An instance of an UnconstrainedSolver class. We provide implementation for `GradientBasedSolver`

- `loss_fn`: Instance of a Loss class. All available losses can be found at flex/loss.py

- `dag_fn`: Instance of a DagFn class. All available DAG functions can be found at flex/dags.py

## Generate Data

In [3]:
# Generate linear data
n, d, s0, graph_type, sem_type = 1000, 20, 20, "ER", "gauss"
linear_B_true = utils.simulate_dag(d, s0, graph_type)
linear_dataset = utils.simulate_linear_sem(linear_B_true, n, sem_type)

# Generate non-linear data
n, d, s0, graph_type, sem_type = 1000, 5, 5, "ER", "mlp"
nonlinear_B_true = utils.simulate_dag(d, s0, graph_type)
nonlinear_dataset = utils.simulate_nonlinear_sem(nonlinear_B_true, n, sem_type)

# Using flex to implement NOTEARS

## linear NOTEARS

In [12]:
# Linear model
model = flex.LinearModel(d)

# Use AML to solve the constrained problem
cons_solver = flex.AugmentedLagrangian(
    num_iter=10,
    num_steps=[3e4,6e4],
    l1_coeff=0.03,
)

# Use Adam to solve the unconstrained problem
uncons_solver = flex.GradientBasedSolver(
    optimizer=torch.optim.Adam(model.parameters(), lr=3e-4),
)

# Use MSE loss
loss_fn = flex.MSELoss()

# Use Trace of matrix exponential as DAG function
dag_fn = flex.Exp()

# Learn the DAG
W_est = flex.struct_learn(
    dataset=linear_dataset,
    model=model,
    constrained_solver=cons_solver,
    unconstrained_solver=uncons_solver,
    loss_fn=loss_fn,
    dag_fn=dag_fn,
    w_threshold=0.3,
)

acc = utils.count_accuracy(linear_B_true, W_est != 0)
print("Results: ", acc)

100%|██████████| 10/10 [00:37<00:00,  3.79s/it]

Total Time: 37.8957781791687
Results:  {'fdr': 0.0, 'tpr': 1.0, 'fpr': 0.0, 'shd': 0, 'nnz': 20}





## nonlinear NOTEARS

In [5]:
def grandag_aug_lagrangian(n, d, s0, num_layers=2, noise_type="gauss"):
    graph_type, sem_type = "ER", "mlp"
    nonlinear_B_true = utils.simulate_dag(d, s0, graph_type)
    nonlinear_dataset = utils.simulate_nonlinear_sem(nonlinear_B_true, n, sem_type, noise_type=noise_type)
    # print(f'b_true is {nonlinear_B_true}')

    train_samples = int(nonlinear_dataset.shape[0] * 0.8)
    test_samples = nonlinear_dataset.shape[0] - train_samples
    train_dataset = nonlinear_dataset[:train_samples, :]
    test_dataset = nonlinear_dataset[train_samples:, :]

    # Nonlinear model
    model = flex.MLP(dims=[d, 2, d], num_layers=num_layers, hid_dim=10, activation="sigmoid", bias=True)

    # Use AML to solve the constrained problem
    cons_solver = flex.AugmentedLagrangian(
        num_iter=100000,
        num_steps=[1,1],
        # l1_coeff=0.01,
        # weight_decay=0.01,
        rho_init=1e-3,
    )

    # Use Adam to solve the unconstrained problem
    uncons_solver = flex.GradientBasedSolver(
        optimizer=torch.optim.RMSprop(model.parameters(), lr=1e-3),
    )

    # Use MSE loss
    loss_fn = flex.MSELoss()

    # Use Trace of matrix exponential as DAG function
    dag_fn = flex.Grandag_h()

    # Learn the DAG
    W_est = flex.struct_learn(
        dataset=train_dataset,
        model=model,
        constrained_solver=cons_solver,
        unconstrained_solver=uncons_solver,
        loss_fn=loss_fn,
        dag_fn=dag_fn,
        w_threshold=0.0,
    )

    W_est = postprocess(W_est)

    acc = utils.count_accuracy(nonlinear_B_true, W_est != 0)
    print(f'Results before CAM pruning: {acc}')

    to_keep = (torch.from_numpy(W_est) > 0).type(torch.Tensor)
    B_est = model.adjacency * to_keep
    # print(f'B_est is {B_est}')

    opt = {
        'cam_pruning_cutoff': np.logspace(-6, 0, 10),
        'exp_path': 'cam_pruning',
    }
    try:
        cam_pruning_cutoff = [float(i) for i in opt['cam_pruning_cutoff']]
    except:
        cam_pruning_cutoff = [float(opt['cam_pruning_cutoff'])]
    for cutoff in cam_pruning_cutoff:
        B_est = cam_pruning(B_est, train_dataset, test_dataset, opt, cutoff=cutoff)
        # print(f'now B_est is {B_est}')
    acc = utils.count_accuracy(nonlinear_B_true, B_est.detach().cpu().numpy() != 0)
    print("Results: ", acc)
    return acc

In [72]:
shds = []
for i in range(10):
    acc = grandag_aug_lagrangian(1000, 5, 5, num_layers=2)
    shds.append(acc['shd'])
print(f'mean shd: {np.mean(shds)}')

Using the first value from num_steps for the first 99998 iterations
100%|██████████| 100000/100000 [00:45<00:00, 2207.26it/s]


Total Time: 45.3078351020813
Results before CAM pruning: {'fdr': 0.42857142857142855, 'tpr': 0.8, 'fpr': 0.6, 'shd': 3, 'nnz': 7}


Using the first value from num_steps for the first 99998 iterations


Results:  {'fdr': 0.3333333333333333, 'tpr': 0.8, 'fpr': 0.4, 'shd': 2, 'nnz': 6}


100%|██████████| 100000/100000 [00:28<00:00, 3476.31it/s] 


Total Time: 28.776642084121704
Results before CAM pruning: {'fdr': 0.5, 'tpr': 0.8, 'fpr': 0.8, 'shd': 5, 'nnz': 8}


Using the first value from num_steps for the first 99998 iterations


Results:  {'fdr': 0.0, 'tpr': 0.6, 'fpr': 0.0, 'shd': 2, 'nnz': 3}


100%|██████████| 100000/100000 [00:24<00:00, 4104.80it/s]


Total Time: 24.369171857833862
Results before CAM pruning: {'fdr': 0.5, 'tpr': 0.8, 'fpr': 0.8, 'shd': 4, 'nnz': 8}


Using the first value from num_steps for the first 99998 iterations


Results:  {'fdr': 0.0, 'tpr': 0.4, 'fpr': 0.0, 'shd': 3, 'nnz': 2}


100%|██████████| 100000/100000 [00:49<00:00, 2008.40it/s]


Total Time: 49.796138048172
Results before CAM pruning: {'fdr': 0.5555555555555556, 'tpr': 0.8, 'fpr': 1.0, 'shd': 6, 'nnz': 9}


Using the first value from num_steps for the first 99998 iterations


Results:  {'fdr': 0.0, 'tpr': 0.4, 'fpr': 0.0, 'shd': 3, 'nnz': 2}


100%|██████████| 100000/100000 [00:46<00:00, 2164.50it/s]


Total Time: 46.204646825790405
Results before CAM pruning: {'fdr': 0.5, 'tpr': 0.8, 'fpr': 0.8, 'shd': 4, 'nnz': 8}


Using the first value from num_steps for the first 99998 iterations


Results:  {'fdr': 0.25, 'tpr': 0.6, 'fpr': 0.2, 'shd': 2, 'nnz': 4}


100%|██████████| 100000/100000 [00:51<00:00, 1950.18it/s]


Total Time: 51.28387141227722
Results before CAM pruning: {'fdr': 0.5555555555555556, 'tpr': 0.8, 'fpr': 1.0, 'shd': 5, 'nnz': 9}


Using the first value from num_steps for the first 99998 iterations


Results:  {'fdr': 0.42857142857142855, 'tpr': 0.8, 'fpr': 0.6, 'shd': 3, 'nnz': 7}


100%|██████████| 100000/100000 [01:02<00:00, 1603.85it/s]


Total Time: 62.35616064071655
Results before CAM pruning: {'fdr': 0.2, 'tpr': 0.8, 'fpr': 0.2, 'shd': 2, 'nnz': 5}


Using the first value from num_steps for the first 99998 iterations


Results:  {'fdr': 0.0, 'tpr': 0.8, 'fpr': 0.0, 'shd': 1, 'nnz': 4}


100%|██████████| 100000/100000 [00:35<00:00, 2847.02it/s]


Total Time: 35.13168430328369
Results before CAM pruning: {'fdr': 0.4444444444444444, 'tpr': 1.0, 'fpr': 0.8, 'shd': 4, 'nnz': 9}


Using the first value from num_steps for the first 99998 iterations


Results:  {'fdr': 0.16666666666666666, 'tpr': 1.0, 'fpr': 0.2, 'shd': 1, 'nnz': 6}


100%|██████████| 100000/100000 [00:18<00:00, 5477.45it/s]


Total Time: 18.262706756591797
Results before CAM pruning: {'fdr': 0.6, 'tpr': 0.8, 'fpr': 1.2, 'shd': 6, 'nnz': 10}


Using the first value from num_steps for the first 99998 iterations


Results:  {'fdr': 0.0, 'tpr': 0.4, 'fpr': 0.0, 'shd': 3, 'nnz': 2}


100%|██████████| 100000/100000 [00:32<00:00, 3080.43it/s]


Total Time: 32.470055103302
Results before CAM pruning: {'fdr': 0.5, 'tpr': 1.0, 'fpr': 1.0, 'shd': 5, 'nnz': 10}
Results:  {'fdr': 0.0, 'tpr': 1.0, 'fpr': 0.0, 'shd': 0, 'nnz': 5}
mean shd: 2.0


In [7]:
shds = []
for i in range(10):
    acc = grandag_aug_lagrangian(1000, 5, 10, num_layers=2)
    shds.append(acc['shd'])
print(f'mean shd: {np.mean(shds)}')

Using the first value from num_steps for the first 99998 iterations


Total Time: 27.739933967590332
Results before CAM pruning: {'fdr': 0.0, 'tpr': 1.0, 'fpr': 0.0, 'shd': 0, 'sid': 0.0, 'nnz': 10}
Results:  {'fdr': 0.0, 'tpr': 1.0, 'fpr': 0.0, 'shd': 0, 'sid': 0.0, 'nnz': 10}


Using the first value from num_steps for the first 99998 iterations


Total Time: 28.82768940925598
Results before CAM pruning: {'fdr': 0.1111111111111111, 'tpr': 0.8, 'fpr': 1.0, 'shd': 2, 'sid': 11.0, 'nnz': 9}
Results:  {'fdr': 0.25, 'tpr': 0.3, 'fpr': 1.0, 'shd': 7, 'sid': 17.0, 'nnz': 4}


Using the first value from num_steps for the first 99998 iterations


Total Time: 13.399073123931885
Results before CAM pruning: {'fdr': 0.1, 'tpr': 0.9, 'fpr': 1.0, 'shd': 1, 'sid': 2.0, 'nnz': 10}
Results:  {'fdr': 0.0, 'tpr': 0.1, 'fpr': 0.0, 'shd': 9, 'sid': 12.0, 'nnz': 1}


Using the first value from num_steps for the first 99998 iterations


Total Time: 37.92148494720459
Results before CAM pruning: {'fdr': 0.0, 'tpr': 1.0, 'fpr': 0.0, 'shd': 0, 'sid': 0.0, 'nnz': 10}
Results:  {'fdr': 0.0, 'tpr': 1.0, 'fpr': 0.0, 'shd': 0, 'sid': 0.0, 'nnz': 10}


Using the first value from num_steps for the first 99998 iterations


Total Time: 20.857305765151978
Results before CAM pruning: {'fdr': 0.2222222222222222, 'tpr': 0.7, 'fpr': 2.0, 'shd': 3, 'sid': 9.0, 'nnz': 9}
Results:  {'fdr': 0.4, 'tpr': 0.3, 'fpr': 2.0, 'shd': 7, 'sid': 13.0, 'nnz': 5}


Using the first value from num_steps for the first 99998 iterations


Total Time: 35.0725371837616
Results before CAM pruning: {'fdr': 0.2857142857142857, 'tpr': 0.5, 'fpr': 2.0, 'shd': 5, 'sid': 11.0, 'nnz': 7}
Results:  {'fdr': 0.3333333333333333, 'tpr': 0.4, 'fpr': 2.0, 'shd': 6, 'sid': 12.0, 'nnz': 6}


Using the first value from num_steps for the first 99998 iterations


Total Time: 44.350411891937256
Results before CAM pruning: {'fdr': 0.0, 'tpr': 0.8, 'fpr': 0.0, 'shd': 2, 'sid': 5.0, 'nnz': 8}
Results:  {'fdr': 0.0, 'tpr': 0.5, 'fpr': 0.0, 'shd': 5, 'sid': 8.0, 'nnz': 5}


Using the first value from num_steps for the first 99998 iterations


Total Time: 37.63465762138367
Results before CAM pruning: {'fdr': 0.1111111111111111, 'tpr': 0.8, 'fpr': 1.0, 'shd': 2, 'sid': 5.0, 'nnz': 9}
Results:  {'fdr': 0.1111111111111111, 'tpr': 0.8, 'fpr': 1.0, 'shd': 2, 'sid': 5.0, 'nnz': 9}


Using the first value from num_steps for the first 99998 iterations


Total Time: 33.43969488143921
Results before CAM pruning: {'fdr': 0.5, 'tpr': 0.4, 'fpr': 4.0, 'shd': 6, 'sid': 16.0, 'nnz': 8}
Results:  {'fdr': 0.4, 'tpr': 0.3, 'fpr': 2.0, 'shd': 7, 'sid': 17.0, 'nnz': 5}


Using the first value from num_steps for the first 99998 iterations


Total Time: 36.471829891204834
Results before CAM pruning: {'fdr': 0.1, 'tpr': 0.9, 'fpr': 1.0, 'shd': 1, 'sid': 6.0, 'nnz': 10}
Results:  {'fdr': 0.1, 'tpr': 0.9, 'fpr': 1.0, 'shd': 1, 'sid': 6.0, 'nnz': 10}
mean shd: 4.4


In [8]:
shds = []
for i in range(10):
    acc = grandag_aug_lagrangian(1000, 10, 20, num_layers=2, noise_type="gauss")
    shds.append(acc['shd'])
print(f'mean shd: {np.mean(shds)}')

Using the first value from num_steps for the first 99998 iterations


Total Time: 225.47241616249084
Results before CAM pruning: {'fdr': 0.525, 'tpr': 0.95, 'fpr': 0.84, 'shd': 22, 'sid': 2.0, 'nnz': 40}
Results:  {'fdr': 0.10526315789473684, 'tpr': 0.85, 'fpr': 0.08, 'shd': 5, 'sid': 16.0, 'nnz': 19}


Using the first value from num_steps for the first 99998 iterations


Total Time: 200.37697792053223
Results before CAM pruning: {'fdr': 0.5384615384615384, 'tpr': 0.9, 'fpr': 0.84, 'shd': 22, 'sid': 18.0, 'nnz': 39}
Results:  {'fdr': 0.2, 'tpr': 0.8, 'fpr': 0.16, 'shd': 7, 'sid': 24.0, 'nnz': 20}


Using the first value from num_steps for the first 99998 iterations


Total Time: 201.64833092689514
Results before CAM pruning: {'fdr': 0.56, 'tpr': 0.55, 'fpr': 0.56, 'shd': 21, 'sid': 36.0, 'nnz': 25}
Results:  {'fdr': 0.25, 'tpr': 0.45, 'fpr': 0.12, 'shd': 12, 'sid': 44.0, 'nnz': 12}


Using the first value from num_steps for the first 99998 iterations


Total Time: 174.58065724372864
Results before CAM pruning: {'fdr': 0.5128205128205128, 'tpr': 0.95, 'fpr': 0.8, 'shd': 20, 'sid': 7.0, 'nnz': 39}
Results:  {'fdr': 0.14285714285714285, 'tpr': 0.6, 'fpr': 0.08, 'shd': 9, 'sid': 34.0, 'nnz': 14}


Using the first value from num_steps for the first 99998 iterations


Total Time: 1711.46369099617
Results before CAM pruning: {'fdr': 0.5142857142857142, 'tpr': 0.85, 'fpr': 0.72, 'shd': 19, 'sid': 17.0, 'nnz': 35}
Results:  {'fdr': 0.29411764705882354, 'tpr': 0.6, 'fpr': 0.2, 'shd': 11, 'sid': 36.0, 'nnz': 17}


Using the first value from num_steps for the first 99998 iterations


Total Time: 206.94866394996643
Results before CAM pruning: {'fdr': 0.5777777777777777, 'tpr': 0.95, 'fpr': 1.04, 'shd': 26, 'sid': 3.0, 'nnz': 45}
Results:  {'fdr': 0.19047619047619047, 'tpr': 0.85, 'fpr': 0.16, 'shd': 6, 'sid': 15.0, 'nnz': 21}


Using the first value from num_steps for the first 99998 iterations


Total Time: 1332.8905787467957
Results before CAM pruning: {'fdr': 0.5714285714285714, 'tpr': 0.9, 'fpr': 0.96, 'shd': 24, 'sid': 21.0, 'nnz': 42}
Results:  {'fdr': 0.1, 'tpr': 0.45, 'fpr': 0.04, 'shd': 12, 'sid': 53.0, 'nnz': 10}


Using the first value from num_steps for the first 99998 iterations


Total Time: 370.53898429870605
Results before CAM pruning: {'fdr': 0.45714285714285713, 'tpr': 0.95, 'fpr': 0.64, 'shd': 17, 'sid': 1.0, 'nnz': 35}
Results:  {'fdr': 0.06666666666666667, 'tpr': 0.7, 'fpr': 0.04, 'shd': 7, 'sid': 27.0, 'nnz': 15}


Using the first value from num_steps for the first 99998 iterations


Total Time: 200.92305779457092
Results before CAM pruning: {'fdr': 0.55, 'tpr': 0.9, 'fpr': 0.88, 'shd': 22, 'sid': 14.0, 'nnz': 40}
Results:  {'fdr': 0.17647058823529413, 'tpr': 0.7, 'fpr': 0.12, 'shd': 7, 'sid': 31.0, 'nnz': 17}


Using the first value from num_steps for the first 99998 iterations


Total Time: 253.19810009002686
Results before CAM pruning: {'fdr': 0.5641025641025641, 'tpr': 0.85, 'fpr': 0.88, 'shd': 22, 'sid': 23.0, 'nnz': 39}
Results:  {'fdr': 0.4, 'tpr': 0.45, 'fpr': 0.24, 'shd': 14, 'sid': 43.0, 'nnz': 15}
mean shd: 9.0


In [9]:
shds = []
for i in range(10):
    acc = grandag_aug_lagrangian(1000, 20, 40, num_layers=2, noise_type="gauss")
    shds.append(acc['shd'])
print(f'mean shd: {np.mean(shds)}')

Using the first value from num_steps for the first 99998 iterations


Total Time: 404.33473110198975
Results before CAM pruning: {'fdr': 0.782608695652174, 'tpr': 0.875, 'fpr': 0.84, 'shd': 128, 'sid': 65.0, 'nnz': 161}
Results:  {'fdr': 0.20588235294117646, 'tpr': 0.675, 'fpr': 0.04666666666666667, 'shd': 17, 'sid': 131.0, 'nnz': 34}


Using the first value from num_steps for the first 99998 iterations


Total Time: 296.53825187683105
Results before CAM pruning: {'fdr': 0.7625, 'tpr': 0.95, 'fpr': 0.8133333333333334, 'shd': 123, 'sid': 22.0, 'nnz': 160}
Results:  {'fdr': 0.06896551724137931, 'tpr': 0.675, 'fpr': 0.013333333333333334, 'shd': 14, 'sid': 109.0, 'nnz': 29}


Using the first value from num_steps for the first 99998 iterations


Total Time: 338.94224095344543
Results before CAM pruning: {'fdr': 0.8, 'tpr': 0.775, 'fpr': 0.8266666666666667, 'shd': 128, 'sid': 72.0, 'nnz': 155}
Results:  {'fdr': 0.3, 'tpr': 0.525, 'fpr': 0.06, 'shd': 23, 'sid': 163.0, 'nnz': 30}


Using the first value from num_steps for the first 99998 iterations


Total Time: 351.58502411842346
Results before CAM pruning: {'fdr': 0.7978142076502732, 'tpr': 0.925, 'fpr': 0.9733333333333334, 'shd': 146, 'sid': 22.0, 'nnz': 183}
Results:  {'fdr': 0.09375, 'tpr': 0.725, 'fpr': 0.02, 'shd': 13, 'sid': 92.0, 'nnz': 32}


Using the first value from num_steps for the first 99998 iterations


Total Time: 314.57723116874695
Results before CAM pruning: {'fdr': 0.7677419354838709, 'tpr': 0.9, 'fpr': 0.7933333333333333, 'shd': 120, 'sid': 34.0, 'nnz': 155}
Results:  {'fdr': 0.06896551724137931, 'tpr': 0.675, 'fpr': 0.013333333333333334, 'shd': 14, 'sid': 121.0, 'nnz': 29}


Using the first value from num_steps for the first 99998 iterations


Total Time: 352.3067240715027
Results before CAM pruning: {'fdr': 0.7727272727272727, 'tpr': 0.875, 'fpr': 0.7933333333333333, 'shd': 121, 'sid': 50.0, 'nnz': 154}
Results:  {'fdr': 0.16129032258064516, 'tpr': 0.65, 'fpr': 0.03333333333333333, 'shd': 16, 'sid': 127.0, 'nnz': 31}


Using the first value from num_steps for the first 99998 iterations


Total Time: 348.37236499786377
Results before CAM pruning: {'fdr': 0.7662337662337663, 'tpr': 0.9, 'fpr': 0.7866666666666666, 'shd': 120, 'sid': 38.0, 'nnz': 154}
Results:  {'fdr': 0.10714285714285714, 'tpr': 0.625, 'fpr': 0.02, 'shd': 17, 'sid': 94.0, 'nnz': 28}


Using the first value from num_steps for the first 99998 iterations


Total Time: 288.392498254776
Results before CAM pruning: {'fdr': 0.8206896551724138, 'tpr': 0.65, 'fpr': 0.7933333333333333, 'shd': 128, 'sid': 140.0, 'nnz': 145}
Results:  {'fdr': 0.3181818181818182, 'tpr': 0.375, 'fpr': 0.04666666666666667, 'shd': 32, 'sid': 155.0, 'nnz': 22}


Using the first value from num_steps for the first 99998 iterations


Total Time: 346.60823917388916
Results before CAM pruning: {'fdr': 0.8108108108108109, 'tpr': 0.525, 'fpr': 0.6, 'shd': 102, 'sid': 192.0, 'nnz': 111}
Results:  {'fdr': 0.3076923076923077, 'tpr': 0.45, 'fpr': 0.05333333333333334, 'shd': 28, 'sid': 141.0, 'nnz': 26}


Using the first value from num_steps for the first 99998 iterations


Total Time: 338.1101818084717
Results before CAM pruning: {'fdr': 0.8133333333333334, 'tpr': 0.7, 'fpr': 0.8133333333333334, 'shd': 127, 'sid': 129.0, 'nnz': 150}
Results:  {'fdr': 0.2222222222222222, 'tpr': 0.525, 'fpr': 0.04, 'shd': 22, 'sid': 153.0, 'nnz': 27}
mean shd: 19.6


In [77]:
grandag_aug_lagrangian(1000, 5, 5, num_layers=2, noise_type='exp')

Using the first value from num_steps for the first 99998 iterations
100%|██████████| 100000/100000 [11:11<00:00, 148.89it/s] 


Total Time: 671.6974756717682
Results before CAM pruning: {'fdr': 0.8, 'tpr': 0.4, 'fpr': 1.6, 'shd': 8, 'sid': 15.0, 'nnz': 10}
Results:  {'fdr': 0.6666666666666666, 'tpr': 0.4, 'fpr': 0.8, 'shd': 4, 'sid': 16.0, 'nnz': 6}


{'fdr': 0.6666666666666666,
 'tpr': 0.4,
 'fpr': 0.8,
 'shd': 4,
 'sid': 16.0,
 'nnz': 6}

In [78]:
grandag_aug_lagrangian(1000, 5, 5, num_layers=2, noise_type='gumbel')

Using the first value from num_steps for the first 99998 iterations


Total Time: 976.2653551101685
Results before CAM pruning: {'fdr': 0.5714285714285714, 'tpr': 0.6, 'fpr': 0.8, 'shd': 4, 'sid': 7.0, 'nnz': 7}
Results:  {'fdr': 0.5, 'tpr': 0.4, 'fpr': 0.4, 'shd': 3, 'sid': 12.0, 'nnz': 4}


{'fdr': 0.5, 'tpr': 0.4, 'fpr': 0.4, 'shd': 3, 'sid': 12.0, 'nnz': 4}

In [76]:
grandag_aug_lagrangian(1000, 10, 10, num_layers=2)

100%|██████████| 100000/100000 [04:59<00:00, 334.20it/s]


Total Time: 299.23108100891113
Results before CAM pruning: {'fdr': 0.76, 'tpr': 0.6, 'fpr': 0.5428571428571428, 'shd': 23, 'sid': 13.0, 'nnz': 25}
Results:  {'fdr': 0.4444444444444444, 'tpr': 0.5, 'fpr': 0.11428571428571428, 'shd': 9, 'sid': 20.0, 'nnz': 9}


{'fdr': 0.4444444444444444,
 'tpr': 0.5,
 'fpr': 0.11428571428571428,
 'shd': 9,
 'sid': 20.0,
 'nnz': 9}

In [37]:
grandag_aug_lagrangian(1000, 20, 20, num_layers=2)

  self.vwarn(f"Using the first value from num_steps for the first {missing_steps} iterations")


b_true is [[0. 0. 0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 1. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0. 1. 0. 0. 0. 0. 1. 0. 0. 1. 0. 0. 1. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0

100%|██████████| 100000/100000 [07:06<00:00, 234.64it/s] 

h new is 0.0 and h tol is 1e-08 so ending at 13021th iteration
Total Time: 426.18568086624146
Results:  {'fdr': 0.8689655172413793, 'tpr': 0.95, 'fpr': 0.7411764705882353, 'shd': 126, 'nnz': 145}





# Using flex to implement DAGMA

## Linear DAGMA

In [14]:
# Linear model
model = flex.LinearModel(d)

# Use path following to solve the constrained problem
cons_solver = flex.PathFollowing(
    num_iter=5,
    mu_init=1.0,
    mu_scale=0.1,
    logdet_coeff=[1.0, .9, .8, .7, .6],
    num_steps=[3e4, 6e4],
    l1_coeff=0.03,
)

# use Adam to solve the unconstrained problem
uncons_solver = flex.GradientBasedSolver(
    optimizer=torch.optim.Adam(model.parameters(), lr=3e-4, betas=(0.99, 0.999))
)

# Use MSE loss
loss_fn = flex.MSELoss()

# Use LogDet as DAG function
dag_fn = flex.LogDet()

W_est = flex.struct_learn(
    dataset=linear_dataset,
    model=model,
    constrained_solver=cons_solver,
    unconstrained_solver=uncons_solver,
    loss_fn=loss_fn,
    dag_fn=dag_fn,
    w_threshold=0.3,
)

acc = utils.count_accuracy(linear_B_true, W_est != 0)
print("Results: ", acc)

  self.vwarn(
100%|██████████| 5/5 [00:08<00:00,  1.77s/it]

Total Time: 8.867684841156006
Results:  {'fdr': 0.0, 'tpr': 1.0, 'fpr': 0.0, 'shd': 0, 'nnz': 20}





## Nonlinear DAGMA

In [48]:
n, d, s0, graph_type, sem_type = 1000, 5, 5, "ER", "mlp"
nonlinear_B_true = utils.simulate_dag(d, s0, graph_type)
nonlinear_dataset = utils.simulate_nonlinear_sem(nonlinear_B_true, n, sem_type)

# nonlinear model
model = flex.MLP(dims=[d, 2, 5], num_layers=2, hid_dim=10, activation="sigmoid", bias=True)

# Use path following to solve the constrained problem
cons_solver = flex.PathFollowing(
    num_iter=4,
    mu_init=0.1,
    mu_scale=0.1,
    logdet_coeff=1.0,
    num_steps=[5e4, 8e4],
    weight_decay=0.02,
    l1_coeff=0.005,
)

# use Adam to solve the unconstrained problem
uncons_solver = flex.GradientBasedSolver(
    # optimizer=torch.optim.Adam(model.parameters(), lr=2e-4, betas=(0.99, 0.999))
    optimizer=torch.optim.SGD(model.parameters(), lr=2e-4)
)

# Use NLL loss
loss_fn = flex.NLLLoss()

# Use LogDet as DAG function
dag_fn = flex.LogDet()

# Learn the DAG
W_est = flex.struct_learn(
    dataset=nonlinear_dataset,
    model=model,
    constrained_solver=cons_solver,
    unconstrained_solver=uncons_solver,
    loss_fn=loss_fn,
    dag_fn=dag_fn,
    w_threshold=0.3,
)

W_est = postprocess(W_est)
acc = utils.count_accuracy(nonlinear_B_true, W_est != 0)
print("Results: ", acc)

  self.vwarn(
100%|██████████| 4/4 [09:40<00:00, 145.12s/it]

Total Time: 580.489175081253
Results:  {'fdr': 0.0, 'tpr': 0.0, 'fpr': 0.0, 'shd': 5, 'nnz': 0}



