In [17]:
%reload_ext autoreload
%autoreload 2
import torch
import dagrad.flex as flex
from dagrad.utils import utils
import numpy as np

utils.set_random_seed(1)
torch.manual_seed(1)

<torch._C.Generator at 0x1282b22d0>

In [3]:
def postprocess(B, graph_thres=0.3):
    """Post-process estimated solution:
        (1) Thresholding.
        (2) Remove the edges with smallest absolute weight until a DAG
            is obtained.

    Args:
        B (numpy.ndarray): [d, d] weighted matrix.
        graph_thres (float): Threshold for weighted matrix. Default: 0.3.

    Returns:
        numpy.ndarray: [d, d] weighted matrix of DAG.
    """
    B = np.copy(B)
    B[np.abs(B) <= graph_thres] = 0    # Thresholding
    B, _ = utils.threshold_till_dag(B)

    return B

## Struct Learn via Flex

`flex.struct_learn` is a function that requires the following 5 inputs:
- `dataset`: A data matrix with `n x d` shape, where `n` is the number of samples and `d` is the number of variables/features.

- `model`: The SEM module. We provide standard SEM implementations such as---LinearModel, LogisticModel, MLP

- `constrained_solver`: An instance of a ConstrainedSolver class. We provide implementations such as `PathFollowing` and `AugmentedLagrangian`

- `unconstrained_solver`: An instance of an UnconstrainedSolver class. We provide implementation for `GradientBasedSolver`

- `loss_fn`: Instance of a Loss class. All available losses can be found at flex/loss.py

- `dag_fn`: Instance of a DagFn class. All available DAG functions can be found at flex/dags.py

## Generate Data

In [5]:
# Generate linear data
n, d, s0, graph_type, sem_type = 1000, 20, 20, "ER", "gauss"
linear_B_true = utils.simulate_dag(d, s0, graph_type)
linear_dataset = utils.simulate_linear_sem(linear_B_true, n, sem_type)

# Generate non-linear data
n, d, s0, graph_type, sem_type = 1000, 20, 20, "ER", "mlp"
nonlinear_B_true = utils.simulate_dag(d, s0, graph_type)
nonlinear_dataset = utils.simulate_nonlinear_sem(nonlinear_B_true, n, sem_type)

# Using flex to implement NOTEARS

## linear NOTEARS

In [12]:
# Linear model
model = flex.LinearModel(d)

# Use AML to solve the constrained problem
cons_solver = flex.AugmentedLagrangian(
    num_iter=10,
    num_steps=[3e4,6e4],
    l1_coeff=0.03,
)

# Use Adam to solve the unconstrained problem
uncons_solver = flex.GradientBasedSolver(
    optimizer=torch.optim.Adam(model.parameters(), lr=3e-4),
)

# Use MSE loss
loss_fn = flex.MSELoss()

# Use Trace of matrix exponential as DAG function
dag_fn = flex.Exp()

# Learn the DAG
W_est = flex.struct_learn(
    dataset=linear_dataset,
    model=model,
    constrained_solver=cons_solver,
    unconstrained_solver=uncons_solver,
    loss_fn=loss_fn,
    dag_fn=dag_fn,
    w_threshold=0.3,
)

acc = utils.count_accuracy(linear_B_true, W_est != 0)
print("Results: ", acc)

100%|██████████| 10/10 [00:37<00:00,  3.79s/it]

Total Time: 37.8957781791687
Results:  {'fdr': 0.0, 'tpr': 1.0, 'fpr': 0.0, 'shd': 0, 'nnz': 20}





## nonlinear NOTEARS

In [61]:
def grandag_aug_lagrangian(n, d, s0, num_layers=0):
    graph_type, sem_type = "ER", "mlp"
    nonlinear_B_true = utils.simulate_dag(d, s0, graph_type)
    nonlinear_dataset = utils.simulate_nonlinear_sem(nonlinear_B_true, n, sem_type)

    model = flex.MLP(dims=[d, 2, d], num_layers=num_layers, hid_dim=10, activation="sigmoid", bias=True)

    # Use AML to solve the constrained problem
    cons_solver = flex.AugmentedLagrangian(
        num_iter=10,
        num_steps=[4e4,6e4],
        # l1_coeff=0,
        # weight_decay=0,
        rho_init=1e-3,
    )

    # Use Adam to solve the unconstrained problem
    uncons_solver = flex.GradientBasedSolver(
        optimizer=torch.optim.RMSprop(model.parameters(), lr=1e-3),
    )

    # Use MSE loss
    loss_fn = flex.MSELoss()

    # Use Trace of matrix exponential as DAG function
    dag_fn = flex.Grandag_h()

    print(f'b true is {nonlinear_B_true}')

    # Learn the DAG
    W_est = flex.struct_learn(
        dataset=nonlinear_dataset,
        model=model,
        constrained_solver=cons_solver,
        unconstrained_solver=uncons_solver,
        loss_fn=loss_fn,
        dag_fn=dag_fn,
        # w_threshold=0.3,
        w_threshold=0,
    )

    W_est = postprocess(W_est)
    print(f'W est is {W_est}')
    acc = utils.count_accuracy(nonlinear_B_true, W_est != 0)
    print("Results: ", acc)
    return acc

In [63]:
grandag_aug_lagrangian(1000, 5, 5, num_layers=2)

b true is [[0. 0. 1. 1. 0.]
 [0. 0. 1. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0.]
 [0. 1. 0. 0. 0.]]


100%|██████████| 10/10 [35:37<00:00, 213.77s/it]

Total Time: 2137.7368030548096
W est is [[  0.          46.96578181  23.21168747   0.         124.45029529]
 [  0.           0.          17.7703808    0.           0.        ]
 [  0.           0.           0.           0.           0.        ]
 [  0.           0.          15.69918247   0.           0.        ]
 [  0.           6.58772124  15.50309389   0.           0.        ]]
Results:  {'fdr': 0.42857142857142855, 'tpr': 0.8, 'fpr': 0.6, 'shd': 4, 'nnz': 7}





{'fdr': 0.42857142857142855, 'tpr': 0.8, 'fpr': 0.6, 'shd': 4, 'nnz': 7}

In [65]:
grandag_aug_lagrangian(1000, 5, 5, num_layers=2)

b true is [[0. 0. 1. 0. 0.]
 [1. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 1. 1. 0.]]


100%|██████████| 10/10 [1:23:57<00:00, 503.73s/it]  

Total Time: 5037.3168268203735
at the end, model.adj() is tensor([[0.0000e+00, 0.0000e+00, 1.3680e+01, 0.0000e+00, 0.0000e+00],
        [2.0044e-01, 0.0000e+00, 4.2398e+01, 0.0000e+00, 0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [7.1160e+01, 2.7062e+00, 3.4308e+01, 0.0000e+00, 3.1298e+02],
        [1.0862e+02, 1.4081e+00, 1.6743e+01, 0.0000e+00, 0.0000e+00]],
       grad_fn=<TBackward0>) and model.adjacency is tensor([[0., 0., 1., 0., 0.],
        [1., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 1.],
        [1., 1., 1., 0., 0.]])
W est is [[  0.           0.          13.68010699   0.           0.        ]
 [  0.           0.          42.3983111    0.           0.        ]
 [  0.           0.           0.           0.           0.        ]
 [ 71.16035462   2.70618804  34.30839006   0.         312.98285533]
 [108.62411618   1.40805567  16.74264408   0.           0.        ]]
Results:  {'fdr': 0.7777777777777778, 




{'fdr': 0.7777777777777778, 'tpr': 0.4, 'fpr': 1.4, 'shd': 8, 'nnz': 9}

In [66]:
grandag_aug_lagrangian(1000, 5, 5, num_layers=0)

b true is [[0. 1. 0. 1. 0.]
 [0. 0. 1. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 1. 1. 0. 0.]
 [0. 0. 0. 0. 0.]]


100%|██████████| 10/10 [03:46<00:00, 22.62s/it]

Total Time: 226.22626566886902
at the end, model.adj() is tensor([[0.0000, 0.0000, 0.0000, 1.3291, 0.0578],
        [0.0000, 0.0000, 0.4339, 0.0000, 0.0231],
        [0.1237, 0.0000, 0.0000, 0.0000, 0.0150],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0096],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000]], grad_fn=<TBackward0>) and model.adjacency is tensor([[0., 0., 0., 1., 1.],
        [0., 0., 1., 0., 1.],
        [1., 0., 0., 0., 1.],
        [0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0.]])
W est is [[0.         0.         0.         1.32909476 0.        ]
 [0.         0.         0.43391254 0.         0.        ]
 [0.         0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.        ]]
Results:  {'fdr': 0.0, 'tpr': 0.4, 'fpr': 0.0, 'shd': 3, 'nnz': 2}





{'fdr': 0.0, 'tpr': 0.4, 'fpr': 0.0, 'shd': 3, 'nnz': 2}

In [46]:
grandag_aug_lagrangian(1000, 5, 5, num_layers=0)

b true is [[0. 0. 0. 1. 0.]
 [1. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 1. 1. 0.]]


100%|██████████| 10/10 [04:32<00:00, 27.27s/it]

Total Time: 272.707524061203
W est is [[0.         0.         0.         0.         0.        ]
 [0.         0.         0.         0.         1.00312622]
 [0.         0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.        ]]
Results:  {'fdr': 0.0, 'tpr': 0.2, 'fpr': 0.0, 'shd': 4, 'nnz': 1}





{'fdr': 0.0, 'tpr': 0.2, 'fpr': 0.0, 'shd': 4, 'nnz': 1}

In [48]:
# lr = 9e-4
grandag_aug_lagrangian(1000, 5, 5, num_layers=0)

  self.vwarn(f"Using the first value from num_steps for the first {missing_steps} iterations")


b true is [[0. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0.]
 [0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 0.]
 [1. 0. 1. 1. 0.]]


100%|██████████| 10/10 [02:31<00:00, 15.16s/it]

Total Time: 151.58235120773315
W est is [[0.         0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.        ]
 [1.32461408 0.         0.         0.         0.        ]]
Results:  {'fdr': 0.0, 'tpr': 0.2, 'fpr': 0.0, 'shd': 4, 'nnz': 1}





{'fdr': 0.0, 'tpr': 0.2, 'fpr': 0.0, 'shd': 4, 'nnz': 1}

In [52]:
# lr = 8e-4
grandag_aug_lagrangian(1000, 5, 5, num_layers=0)

  self.vwarn(f"Using the first value from num_steps for the first {missing_steps} iterations")


b true is [[0. 1. 1. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 1. 1. 0. 0.]
 [0. 0. 1. 0. 0.]]


100%|██████████| 10/10 [03:09<00:00, 18.90s/it]

Total Time: 189.0419738292694
W est is [[0.         0.49815513 0.81447565 0.         0.        ]
 [0.         0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.        ]
 [0.         0.48558764 1.22250752 0.         0.        ]
 [0.         0.         1.25534521 0.         0.        ]]
Results:  {'fdr': 0.0, 'tpr': 1.0, 'fpr': 0.0, 'shd': 0, 'nnz': 5}





{'fdr': 0.0, 'tpr': 1.0, 'fpr': 0.0, 'shd': 0, 'nnz': 5}

In [53]:
# lr = 8e-4
grandag_aug_lagrangian(1000, 5, 10, num_layers=0)

b true is [[0. 1. 1. 1. 0.]
 [0. 0. 0. 0. 0.]
 [0. 1. 0. 1. 0.]
 [0. 1. 0. 0. 0.]
 [1. 1. 1. 1. 0.]]


100%|██████████| 10/10 [05:24<00:00, 32.48s/it]

Total Time: 324.8534641265869
W est is [[0.         0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.        ]
 [0.         0.         0.45826668 0.         0.        ]
 [1.09484874 0.         0.         0.         0.        ]]
Results:  {'fdr': 0.5, 'tpr': 0.1, 'fpr': 1.0, 'shd': 9, 'nnz': 2}





{'fdr': 0.5, 'tpr': 0.1, 'fpr': 1.0, 'shd': 9, 'nnz': 2}

In [54]:
# lr = 8e-4
grandag_aug_lagrangian(1000, 10, 10, num_layers=0)

b true is [[0. 0. 1. 1. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 1. 0. 0. 0. 1. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 1. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]
 [1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]]


100%|██████████| 10/10 [11:16<00:00, 67.68s/it]

Total Time: 676.7651886940002
W est is [[0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.31340062 0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.         1.46143676]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.44987638 0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.        




{'fdr': 0.25, 'tpr': 0.3, 'fpr': 0.02857142857142857, 'shd': 8, 'nnz': 4}

In [56]:
# lr = 8e-4
grandag_aug_lagrangian(1000, 5, 5, num_layers=0)

b true is [[0. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0.]
 [1. 0. 0. 0. 1.]
 [1. 0. 0. 0. 0.]]


100%|██████████| 10/10 [08:03<00:00, 48.36s/it]

Total Time: 483.60016798973083
W est is [[0.         0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.        ]
 [0.         0.         0.         0.         2.31426757]
 [0.         0.         0.         0.         0.        ]]
Results:  {'fdr': 0.0, 'tpr': 0.2, 'fpr': 0.0, 'shd': 4, 'nnz': 1}





{'fdr': 0.0, 'tpr': 0.2, 'fpr': 0.0, 'shd': 4, 'nnz': 1}

In [50]:
# lr = 6e-4
grandag_aug_lagrangian(1000, 5, 5, num_layers=1)

  self.vwarn(f"Using the first value from num_steps for the first {missing_steps} iterations")


b true is [[0. 0. 0. 0. 0.]
 [1. 0. 0. 1. 0.]
 [0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 0.]
 [1. 1. 0. 0. 0.]]


100%|██████████| 10/10 [19:28<00:00, 116.83s/it]

Total Time: 1168.270653963089
W est is [[0.         1.01714285 4.01585949 3.61862998 0.87874441]
 [0.         0.         6.47749027 6.49723058 0.        ]
 [0.         0.         0.         1.80849105 0.        ]
 [0.         0.         0.         0.         0.        ]
 [0.         0.33163645 2.73502428 7.69451729 0.        ]]
Results:  {'fdr': 0.7, 'tpr': 0.6, 'fpr': 1.4, 'shd': 7, 'nnz': 10}





{'fdr': 0.7, 'tpr': 0.6, 'fpr': 1.4, 'shd': 7, 'nnz': 10}

# Using flex to implement DAGMA

## Linear DAGMA

In [14]:
# Linear model
model = flex.LinearModel(d)

# Use path following to solve the constrained problem
cons_solver = flex.PathFollowing(
    num_iter=5,
    mu_init=1.0,
    mu_scale=0.1,
    logdet_coeff=[1.0, .9, .8, .7, .6],
    num_steps=[3e4, 6e4],
    l1_coeff=0.03,
)

# use Adam to solve the unconstrained problem
uncons_solver = flex.GradientBasedSolver(
    optimizer=torch.optim.Adam(model.parameters(), lr=3e-4, betas=(0.99, 0.999))
)

# Use MSE loss
loss_fn = flex.MSELoss()

# Use LogDet as DAG function
dag_fn = flex.LogDet()

W_est = flex.struct_learn(
    dataset=linear_dataset,
    model=model,
    constrained_solver=cons_solver,
    unconstrained_solver=uncons_solver,
    loss_fn=loss_fn,
    dag_fn=dag_fn,
    w_threshold=0.3,
)

acc = utils.count_accuracy(linear_B_true, W_est != 0)
print("Results: ", acc)

  self.vwarn(
100%|██████████| 5/5 [00:08<00:00,  1.77s/it]

Total Time: 8.867684841156006
Results:  {'fdr': 0.0, 'tpr': 1.0, 'fpr': 0.0, 'shd': 0, 'nnz': 20}





## Nonlinear DAGMA

In [4]:
n, d, s0, graph_type, sem_type = 1000, 5, 5, "ER", "mlp"
nonlinear_B_true = utils.simulate_dag(d, s0, graph_type)
nonlinear_dataset = utils.simulate_nonlinear_sem(nonlinear_B_true, n, sem_type)

# nonlinear model
model = flex.MLP(dims=[d, 2, 5], num_layers=2, hid_dim=10, activation="sigmoid", bias=True)

# Use path following to solve the constrained problem
cons_solver = flex.PathFollowing(
    num_iter=4,
    mu_init=0.1,
    mu_scale=0.1,
    logdet_coeff=1.0,
    num_steps=[5e4, 8e4],
    weight_decay=0.02,
    l1_coeff=0.005,
)

# use Adam to solve the unconstrained problem
uncons_solver = flex.GradientBasedSolver(
    optimizer=torch.optim.Adam(model.parameters(), lr=2e-4, betas=(0.99, 0.999))
)

# Use NLL loss
loss_fn = flex.NLLLoss()

# Use LogDet as DAG function
dag_fn = flex.Grandag_h()

# Learn the DAG
W_est = flex.struct_learn(
    dataset=nonlinear_dataset,
    model=model,
    constrained_solver=cons_solver,
    unconstrained_solver=uncons_solver,
    loss_fn=loss_fn,
    dag_fn=dag_fn,
    w_threshold=0.3,
)

W_est = postprocess(W_est)
acc = utils.count_accuracy(nonlinear_B_true, W_est != 0)
print("Results: ", acc)

  self.vwarn(
100%|██████████| 4/4 [36:56<00:00, 554.15s/it]   

Total Time: 2216.6051728725433
Results:  {'fdr': 0.0, 'tpr': 0.0, 'fpr': 0.0, 'shd': 5, 'nnz': 0}



