In [2]:
import torch
import torch.nn as nn
from torch.optim import Adam, LBFGS
from nodag_gumbel_softmax import train_gumbel_sgd
from SCM_data import generate_scm_data
import numpy as np
from numpy.linalg import LinAlgError, inv
from scipy.linalg import sqrtm
import MEC
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

In [3]:
def nodag_findbest(R_hat, lam=0.5, delta=1e-6, max_steps=5000, tau_start=0.2, tau_end=0.2, times=100):
    best_loss = np.inf
    best_seed = 0
    for t in range(times):
        seed = t
        np.random.seed(seed) 
        B_init = np.random.randn(*R_hat.shape)
        B_final,G_final, info = train_gumbel_sgd(
            Rhat_np = R_hat,
            lam = lam,
            delta = delta,
            max_steps = max_steps,
            tau_start = tau_start,
            tau_end = tau_end,
            B_init = B_init
            )
        if info["final_loss"] < best_loss:
            best_loss = info["final_loss"]
            best_likelihood = info["final_likelihood"]
            best_penalty = info["final_penalty"]
            best_seed = seed
            best_G = G_final
            best_B = B_final
    return best_G, best_B, best_loss, best_likelihood, best_penalty, best_seed

In [4]:
for i in range(1, 6):
    X, Y, Z, G_true, CPDAG = generate_scm_data(i, 10000)
    A_true = (np.eye(3) - G_true)
    data = np.array([X, Y, Z]).T
    R_hat = np.cov(data.T)
    B_init = np.random.randn(*R_hat.shape)
    likelihood_true = - 2 * np.log(np.linalg.det(A_true)) + np.trace(A_true.T @ R_hat @ A_true)
    best_G, best_B, best_loss, best_likelihood, best_penalty, best_seed = nodag_findbest(R_hat = R_hat, times = 100)
    print("SCM: ",i)
    print("likelihood_true = ", likelihood_true)
    print("G_true = \n", G_true)
    print("G_est = \n", best_G)
    # print("Is in MEC: ", MEC.is_in_markov_equiv_class(G_true, best_B))
    print("Final Loss = ", best_loss)
    print("Final penalty = ", best_penalty)
    print("Final likelihood = ", best_likelihood)
    print("seed = ", best_seed)
    print("")
    
    

SCM:  1
likelihood_true =  3.037318063988293
G_true = 
 [[0 0 0]
 [0 0 0]
 [0 0 0]]
G_est = 
 [[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
Final Loss =  3.0370697380252674
Final penalty =  1.955358835544745e-12
Final likelihood =  3.037069738023312
seed =  87

SCM:  2
likelihood_true =  3.0179642443567762
G_true = 
 [[0 1 0]
 [0 0 0]
 [0 0 0]]
G_est = 
 [[0. 0. 0.]
 [1. 0. 0.]
 [0. 0. 0.]]
Final Loss =  3.5174811335654548
Final penalty =  0.49999999996025857
Final likelihood =  3.0174811336051963
seed =  87

SCM:  3
likelihood_true =  3.017964244356776
G_true = 
 [[0 1 0]
 [0 0 0]
 [0 1 0]]
G_est = 
 [[0. 1. 0.]
 [0. 0. 0.]
 [0. 1. 0.]]
Final Loss =  4.017474717799517
Final penalty =  0.9999999999597986
Final likelihood =  3.0174747178397183
seed =  74

SCM:  4
likelihood_true =  3.017964244356773
G_true = 
 [[0 1 0]
 [0 0 1]
 [0 0 0]]
G_est = 
 [[0. 0. 0.]
 [1. 0. 1.]
 [0. 0. 0.]]
Final Loss =  4.017433666743141
Final penalty =  1.0000000000007825
Final likelihood =  3.0174336667423587
seed =