In [62]:
# Import all the necessary modules
%load_ext autoreload
%autoreload 2
from dagrad import dagrad # dagrad is the main class for learning the structure of a DAG
from dagrad import generate_linear_data, generate_nonlinear_data, count_accuracy, threshold_till_dag
from dagrad.hfunction.h_functions import SCCPowerIteration
from dagrad.hfunction.h_functions import h_fn
import torch
import numpy as np
import matplotlib.pyplot as plt 

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [7]:
def postprocess(B, graph_thres=0.3):
    """Post-process estimated solution:
        (1) Thresholding.
        (2) Remove the edges with smallest absolute weight until a DAG
            is obtained.

    Args:
        B (numpy.ndarray): [d, d] weighted matrix.
        graph_thres (float): Threshold for weighted matrix. Default: 0.3.

    Returns:
        numpy.ndarray: [d, d] weighted matrix of DAG.
    """
    B = np.copy(B)
    # B[np.abs(B) <= graph_thres] = 0    # Thresholding
    B, _ = threshold_till_dag(B)

    return B

### Linear SEM - EV method

In [82]:
def sdcd_ev(n, d, s0, graph_type, noise_type, error_var, seed=None):
    X, W_true, B_true = generate_linear_data(n,d,s0,graph_type,noise_type,error_var,seed)
    X = torch.from_numpy(X).float()
    model = 'linear' # Define the model
    W_sdcd = dagrad(
        X,
        model = model,
        method = 'dagma',
        compute_lib='torch',
        h_fn='user_h',
        general_options={
            'user_params': {
                'is_prescreen': False,
                'power_grad': SCCPowerIteration(
                    torch.zeros(d, d, dtype = torch.double, requires_grad = True, device = 'cpu'),
                    d,
                )
            }
        },
        # method_options={
        #     'mu_factor': 0.9,
        # }
    ) # Learn the structure of the DAG using SDCD
    W_sdcd = postprocess(W_sdcd)
    print(f"Linear Model")
    print(f"data size: {n}, graph type: {graph_type}, sem type: {noise_type}")
    acc_sdcd = count_accuracy(B_true, W_sdcd != 0) # Measure the accuracy of the learned structure using SDCD
    print('Accuracy of SDCD:', acc_sdcd)

    return acc_sdcd, W_sdcd

In [50]:
sdcd_ev(1000, 5, 5, 'ER', 'gauss', 'eq')

Linear Model
data size: 1000, graph type: ER, sem type: gauss
Accuracy of SDCD: {'fdr': 0.0, 'tpr': 1.0, 'fpr': 0.0, 'shd': 0, 'nnz': 5}


{'fdr': 0.0, 'tpr': 1.0, 'fpr': 0.0, 'shd': 0, 'nnz': 5}

In [98]:
_, W_pred = sdcd_ev(1000, 10, 10, 'ER', 'gauss', 'eq')#, seed=2)
print(f'W_pred: {W_pred}')
W_pred[0, 0] = 0.5
W_pred[1, 1] = 0.5
d = 10
sdcd_h = h_fn.user_h(torch.from_numpy(W_pred).double(), user_params={
                'is_prescreen': False,
                'power_grad': SCCPowerIteration(
                    torch.zeros(d, d, dtype = torch.double, requires_grad = True, device = 'cpu'),
                    d,
                )
            }
)
print(f'sdcd h: {sdcd_h}')
logdet_h = h_fn.h_logdet_sq(torch.from_numpy(W_pred).double())
print(f'logdet h: {logdet_h}')
logdet_abs_h = h_fn.h_logdet_abs(torch.from_numpy(W_pred).double())
print(f'logdet_abs h: {logdet_abs_h}')
h_exp_sq = h_fn.h_exp_sq(torch.from_numpy(W_pred).double())
print(f'h_exp_sq: {h_exp_sq}')

Linear Model
data size: 1000, graph type: ER, sem type: gauss


Python(76940) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


Accuracy of SDCD: {'fdr': 0.4444444444444444, 'tpr': 0.5, 'fpr': 0.11428571428571428, 'shd': 8, 'sid': 15.0, 'nnz': 9}
W_pred: [[ 0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.        ]
 [ 0.39283375 -1.77910505 -1.30054712 -1.34660752  0.          0.43409608
   0.          1.80742869  1.17600714  0.        ]
 [ 0.          0.          0.          0.          0.          0.
   0.          0.          0.          0.        ]
 [ 0.60381409  0.          0.          0.          0.          0.
   0.          0.          0.          1.7855964 ]
 [ 0.          0.          0.          0.     

In [37]:
sdcd_ev(1000, 100, 50, 'ER', 'gauss', 'eq', seed=2)

Linear Model
data size: 1000, graph type: ER, sem type: gauss
Accuracy of SDCD: {'fdr': 0.0, 'tpr': 0.66, 'fpr': 0.0, 'shd': 17, 'nnz': 33}


{'fdr': 0.0, 'tpr': 0.66, 'fpr': 0.0, 'shd': 17, 'nnz': 33}

In [6]:
golem_ev(1000, 50, 100, 'SF', 'gauss', 'eq', seed=2)

Linear Model
data size: 1000, graph type: SF, nodes: 50, edges: 100, sem type: gauss


Python(55874) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


Accuracy of Golem: {'fdr': 0.09523809523809523, 'tpr': 0.979381443298969, 'fpr': 0.008865248226950355, 'shd': 10, 'sid': 19.0, 'nnz': 105}


{'fdr': 0.09523809523809523,
 'tpr': 0.979381443298969,
 'fpr': 0.008865248226950355,
 'shd': 10,
 'sid': 19.0,
 'nnz': 105}

In [7]:
# ER1 graph with 100 nodes, as in https://arxiv.org/pdf/2006.10201 5.1
n, d, s0, graph_type, noise_type = 1000, 100, 50, 'ER', 'gauss' # Define the parameters of the data
X, W_true, B_true = generate_linear_data(n,d,s0,graph_type,noise_type, error_var='eq',seed  =2) # Generate the data
X = torch.from_numpy(X).float()
model = 'linear' # Define the model
W_dagma = dagrad(
    X,
    model = model,
    method = 'dagma',
    compute_lib='torch',
) # Learn the structure of the DAG using Dagma
print(f"Linear Model")
print(f"data size: {n}, graph type: {graph_type}, sem type: {noise_type}")

acc_dagma = count_accuracy(B_true, W_dagma != 0) # Measure the accuracy of the learned structure using Dagma
print('Accuracy of Dagma:', acc_dagma)


Linear Model
data size: 1000, graph type: ER, sem type: gauss
Accuracy of Dagma: {'fdr': 0.0, 'tpr': 1.0, 'fpr': 0.0, 'shd': 0, 'sid': 0.0, 'nnz': 50}


In [8]:
golem_ev(1000, 100, 200, 'ER', 'gauss', 'eq', seed=2)

Linear Model
data size: 1000, graph type: ER, nodes: 100, edges: 200, sem type: gauss
Accuracy of Golem: {'fdr': 0.009950248756218905, 'tpr': 0.995, 'fpr': 0.0004210526315789474, 'shd': 2, 'sid': 25.0, 'nnz': 201}


{'fdr': 0.009950248756218905,
 'tpr': 0.995,
 'fpr': 0.0004210526315789474,
 'shd': 2,
 'sid': 25.0,
 'nnz': 201}

In [9]:
# ER4 graph with 100 nodes, as in https://arxiv.org/pdf/2006.10201 5.1
n, d, s0, graph_type, noise_type = 1000, 100, 200, 'ER', 'gauss' # Define the parameters of the data
X, W_true, B_true = generate_linear_data(n,d,s0,graph_type,noise_type,error_var='eq',seed  =2) # Generate the data
X = torch.from_numpy(X).float()
model = 'linear' # Define the model
W_dagma = dagrad(
    X,
    model = model,
    method = 'dagma',
    compute_lib='torch',
) # Learn the structure of the DAG using Dagma
print(f"Linear Model")
print(f"data size: {n}, graph type: {graph_type}, sem type: {noise_type}")

acc_dagma = count_accuracy(B_true, W_dagma != 0) # Measure the accuracy of the learned structure using Dagma
print('Accuracy of Dagma:', acc_dagma)


Linear Model
data size: 1000, graph type: ER, sem type: gauss
Accuracy of Dagma: {'fdr': 0.0, 'tpr': 0.985, 'fpr': 0.0, 'shd': 3, 'sid': 134.0, 'nnz': 197}


### h_func unit tests

In [102]:
W_preds = [
    # DAG 1
    torch.tensor([
    [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
    [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
    [ 0., -1.92773958,  0., -1.58384229,  0.,  0.,  0.,  0.,  0.,  2.19178483],
    [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
    [-1.40118736,  0.,  0.,  0.,  0., -1.16847115,  0.,  0.,  0.,  0.],
    [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
    [ 0.,  0.,  0.48648503,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
    [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
    [ 0.,  0.,  1.67483952,  0.,  0.,  0.,  0., -0.36516031,  0.,  1.49810611],
    [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.]
], dtype=torch.double),
    # DAG 2
    torch.tensor([
    [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
    [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
    [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
    [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
    [ 0.39283375, -1.77910505, -1.30054712, -1.34660752,  0.,  0.43409608,  0.,  1.80742869,  1.17600714,  0.],
    [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
    [ 0.60381409,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.7855964],
    [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
    [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
    [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.]
], dtype=torch.double),
    # Non-DAGs
    torch.tensor([
    [ 1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
    [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
    [ 0., -1.92773958,  0., -1.58384229,  0.,  0.,  0.,  0.,  0.,  2.19178483],
    [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
    [-1.40118736,  0.,  0.,  0.,  0., -1.16847115,  0.,  0.,  0.,  0.],
    [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
    [ 0.,  0.,  0.48648503,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
    [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
    [ 0.,  0.,  1.67483952,  0.,  0.,  0.,  0., -0.36516031,  0.,  1.49810611],
    [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.]
], dtype=torch.double),
    torch.tensor([
    [ 0.5,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
    [ 0.,  -0.5,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
    [ 0.,  0.,  0,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
    [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
    [ 0.39283375, -1.77910505, -1.30054712, -1.34660752,  0.,  0.43409608,  0.,  1.80742869,  1.17600714,  0.],
    [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
    [ 0.60381409,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.7855964],
    [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
    [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
    [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.]
], dtype=torch.double)
]

# Print results
for i, W in enumerate(W_preds):
    print(f"W_pred {i+1}:")
    W_pred = W_preds[i]
    d = 10
    sdcd_h = h_fn.user_h(W_pred, user_params={
                    'is_prescreen': False,
                    'power_grad': SCCPowerIteration(
                        torch.zeros(d, d, dtype = torch.double, requires_grad = True, device = 'cpu'),
                        d,
                    )
                }
    )
    print(f'sdcd h: {sdcd_h}')
    logdet_h = h_fn.h_logdet_sq(W_pred)
    print(f'logdet h: {logdet_h}')
    logdet_abs_h = h_fn.h_logdet_abs(W_pred)
    print(f'logdet_abs h: {logdet_abs_h}')
    h_exp_sq = h_fn.h_exp_sq(W_pred)
    print(f'h_exp_sq: {h_exp_sq}')

W_pred 1:
sdcd h: 0.0
logdet h: 0.0
logdet_abs h: 0.0
h_exp_sq: 0.0
W_pred 2:
sdcd h: 0.0
logdet h: -2.220446049250313e-16
logdet_abs h: 2.7755575615628914e-16
h_exp_sq: 0.0
W_pred 3:
sdcd h: 100.0
logdet h: inf
logdet_abs h: inf
h_exp_sq: 1.7182818284590446
W_pred 4:
sdcd h: 50.0
logdet h: 0.5753641449035616
logdet_abs h: 1.3862943611198908
h_exp_sq: 0.5680508333754819
