In [27]:
%reload_ext autoreload
%autoreload 2
import torch
import dagrad.flex as flex
from dagrad.utils import utils

utils.set_random_seed(1)
torch.manual_seed(1)

<torch._C.Generator at 0x10b7adb70>

## Struct Learn via Flex

`flex.struct_learn` is a function that requires the following 5 inputs:
- `dataset`: A data matrix with `n x d` shape, where `n` is the number of samples and `d` is the number of variables/features.

- `model`: The SEM module. We provide standard SEM implementations such as---LinearModel, LogisticModel, MLP

- `constrained_solver`: An instance of a ConstrainedSolver class. We provide implementations such as `PathFollowing` and `AugmentedLagrangian`

- `unconstrained_solver`: An instance of an UnconstrainedSolver class. We provide implementation for `GradientBasedSolver`

- `loss_fn`: Instance of a Loss class. All available losses can be found at flex/loss.py

- `dag_fn`: Instance of a DagFn class. All available DAG functions can be found at flex/dags.py

## Generate Data

In [3]:
# Generate linear data
n, d, s0, graph_type, sem_type = 1000, 20, 20, "ER", "gauss"
linear_B_true = utils.simulate_dag(d, s0, graph_type)
linear_dataset = utils.simulate_linear_sem(linear_B_true, n, sem_type)

# Generate non-linear data
n, d, s0, graph_type, sem_type = 1000, 20, 20, "ER", "mlp"
nonlinear_B_true = utils.simulate_dag(d, s0, graph_type)
nonlinear_dataset = utils.simulate_nonlinear_sem(nonlinear_B_true, n, sem_type)

# Using flex to implement NOTEARS

## linear NOTEARS

In [12]:
# Linear model
model = flex.LinearModel(d)

# Use AML to solve the constrained problem
cons_solver = flex.AugmentedLagrangian(
    num_iter=10,
    num_steps=[3e4,6e4],
    l1_coeff=0.03,
)

# Use Adam to solve the unconstrained problem
uncons_solver = flex.GradientBasedSolver(
    optimizer=torch.optim.Adam(model.parameters(), lr=3e-4),
)

# Use MSE loss
loss_fn = flex.MSELoss()

# Use Trace of matrix exponential as DAG function
dag_fn = flex.Exp()

# Learn the DAG
W_est = flex.struct_learn(
    dataset=linear_dataset,
    model=model,
    constrained_solver=cons_solver,
    unconstrained_solver=uncons_solver,
    loss_fn=loss_fn,
    dag_fn=dag_fn,
    w_threshold=0.3,
)

acc = utils.count_accuracy(linear_B_true, W_est != 0)
print("Results: ", acc)

100%|██████████| 10/10 [00:37<00:00,  3.79s/it]

Total Time: 37.8957781791687
Results:  {'fdr': 0.0, 'tpr': 1.0, 'fpr': 0.0, 'shd': 0, 'nnz': 20}





## nonlinear NOTEARS

In [28]:
# Nonlinear model
model = flex.MLP(dims=[d, 10, 1], activation="sigmoid", bias=True)

# Use AML to solve the constrained problem
cons_solver = flex.AugmentedLagrangian(
    num_iter=10,
    num_steps=[4e4,6e4],
    l1_coeff=0.01,
    weight_decay=0.01,
)

# Use Adam to solve the unconstrained problem
uncons_solver = flex.GradientBasedSolver(
    optimizer=torch.optim.Adam(model.parameters(), lr=3e-4),
)

# Use MSE loss
loss_fn = flex.MSELoss()

# Use Trace of matrix exponential as DAG function
dag_fn = flex.Exp()

# Learn the DAG
W_est = flex.struct_learn(
    dataset=nonlinear_dataset,
    model=model,
    constrained_solver=cons_solver,
    unconstrained_solver=uncons_solver,
    loss_fn=loss_fn,
    dag_fn=dag_fn,
    w_threshold=0.3,
)

acc = utils.count_accuracy(nonlinear_B_true, W_est != 0)
print("Results: ", acc)

100%|██████████| 10/10 [03:29<00:00, 20.97s/it]

Total Time: 209.70779299736023
Results:  {'fdr': 0.0, 'tpr': 0.0, 'fpr': 0.0, 'shd': 20, 'nnz': 0}





# Using flex to implement DAGMA

## Linear DAGMA

In [14]:
# Linear model
model = flex.LinearModel(d)

# Use path following to solve the constrained problem
cons_solver = flex.PathFollowing(
    num_iter=5,
    mu_init=1.0,
    mu_scale=0.1,
    logdet_coeff=[1.0, .9, .8, .7, .6],
    num_steps=[3e4, 6e4],
    l1_coeff=0.03,
)

# use Adam to solve the unconstrained problem
uncons_solver = flex.GradientBasedSolver(
    optimizer=torch.optim.Adam(model.parameters(), lr=3e-4, betas=(0.99, 0.999))
)

# Use MSE loss
loss_fn = flex.MSELoss()

# Use LogDet as DAG function
dag_fn = flex.LogDet()

W_est = flex.struct_learn(
    dataset=linear_dataset,
    model=model,
    constrained_solver=cons_solver,
    unconstrained_solver=uncons_solver,
    loss_fn=loss_fn,
    dag_fn=dag_fn,
    w_threshold=0.3,
)

acc = utils.count_accuracy(linear_B_true, W_est != 0)
print("Results: ", acc)

  self.vwarn(
100%|██████████| 5/5 [00:08<00:00,  1.77s/it]

Total Time: 8.867684841156006
Results:  {'fdr': 0.0, 'tpr': 1.0, 'fpr': 0.0, 'shd': 0, 'nnz': 20}





## Nonlinear DAGMA

In [24]:
# nonlinear model
model = flex.MLP(dims=[d, 10, 1], activation="sigmoid", bias=True)

# Use path following to solve the constrained problem
cons_solver = flex.PathFollowing(
    num_iter=4,
    mu_init=0.1,
    mu_scale=0.1,
    logdet_coeff=1.0,
    num_steps=[5e4, 8e4],
    weight_decay=0.02,
    l1_coeff=0.005,
)

# use Adam to solve the unconstrained problem
uncons_solver = flex.GradientBasedSolver(
    optimizer=torch.optim.Adam(model.parameters(), lr=2e-4, betas=(0.99, 0.999))
)

# Use NLL loss
loss_fn = flex.NLLLoss()

# Use LogDet as DAG function
dag_fn = flex.LogDet()

# Learn the DAG
W_est = flex.struct_learn(
    dataset=nonlinear_dataset,
    model=model,
    constrained_solver=cons_solver,
    unconstrained_solver=uncons_solver,
    loss_fn=loss_fn,
    dag_fn=dag_fn,
    w_threshold=0.3,
)

acc = utils.count_accuracy(nonlinear_B_true, W_est != 0)
print("Results: ", acc)

  self.vwarn(
  0%|          | 0/4 [00:00<?, ?it/s]

total loss: 1.4852686697050566
total loss: 1.4841601892731302
total loss: 1.4834838066836706
total loss: 1.482928097093023
total loss: 1.482375262516888
total loss: 1.4817856942234555
total loss: 1.4811618199503622
total loss: 1.4805227948717286
total loss: 1.4798847157404067
total loss: 1.4792609563114199
total loss: 1.478651813302038
total loss: 1.4780549576059157
total loss: 1.477468605470451
total loss: 1.4768890552518772
total loss: 1.4763153483318576
total loss: 1.4757467053333848
total loss: 1.4751827564338196
total loss: 1.4746212177736737
total loss: 1.4740629073990479
total loss: 1.4735063117024405
total loss: 1.4729519224867202
total loss: 1.4724001077484477
total loss: 1.4718511641816108
total loss: 1.4713039774329395
total loss: 1.4707593546204538
total loss: 1.4702186845141165
total loss: 1.4696819093424902
total loss: 1.4691494374945675
total loss: 1.468621887006432
total loss: 1.4681000237161
total loss: 1.4675831006353497
total loss: 1.4670710710370916
total loss: 1.46

 25%|██▌       | 1/4 [02:04<06:12, 124.28s/it]

total loss: 1.3283758479828394
total loss: 1.32837654391242
total loss: 1.328377439846091
total loss: 1.3283783707378158
total loss: 1.328379137944846
total loss: 1.3283797002070017
total loss: 1.328380141654114
total loss: 1.328380414477312
total loss: 1.3283804922371996
total loss: 1.3283803463906065
total loss: 1.3283800545648614
total loss: 1.3283795581008353
total loss: 1.3283789205010035
total loss: 1.3283782791292136
total loss: 1.328377507712292
total loss: 1.328376694238888
total loss: 1.328376002217927
total loss: 1.328375435942531
total loss: 1.328374967760475
total loss: 1.328374700562015
total loss: 1.3283746170954005
total loss: 1.328374672971972
total loss: 1.3283748714308865
total loss: 1.3283752562918243
total loss: 1.328375706545504
total loss: 1.3283762809806823
total loss: 1.3283768963016265
total loss: 1.328377511353171
total loss: 1.328378118983105
total loss: 1.3283787018700257
total loss: 1.328379249249545
total loss: 1.3283797482382884
total loss: 1.32838018887

 25%|██▌       | 1/4 [03:23<10:09, 203.32s/it]

total loss: 1.3283778810749365
total loss: 1.3283785191873876
total loss: 1.3283790245406415
total loss: 1.328379338703759
total loss: 1.3283794875701747
total loss: 1.3283794923976278





KeyboardInterrupt: 