In [2]:
import os
import numpy as np

from models import GolemModel
from trainers import GolemTrainer
from data_loader import SyntheticDataset
from data_loader import SCM_data

from utils import MEC


# For logging of tensorflow messages
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'


def golem(X, lambda_1, lambda_2, equal_variances=True,
          num_iter=1e+5, learning_rate=1e-3, seed=1,
          checkpoint_iter=None, output_dir=None, B_init=None):
    """Solve the unconstrained optimization problem of GOLEM, which involves
        GolemModel and GolemTrainer.

    Args:
        X (numpy.ndarray): [n, d] data matrix.
        lambda_1 (float): Coefficient of L1 penalty.
        lambda_2 (float): Coefficient of DAG penalty.
        equal_variances (bool): Whether to assume equal noise variances
            for likelibood objective. Default: True.
        num_iter (int): Number of iterations for training.
        learning_rate (float): Learning rate of Adam optimizer. Default: 1e-3.
        seed (int): Random seed. Default: 1.
        checkpoint_iter (int): Number of iterations between each checkpoint.
            Set to None to disable. Default: None.
        output_dir (str): Output directory to save training outputs.
        B_init (numpy.ndarray or None): [d, d] weighted matrix for initialization.
            Set to None to disable. Default: None.

    Returns:
        numpy.ndarray: [d, d] estimated weighted matrix.

    Hyperparameters:
        (1) GOLEM-NV: equal_variances=False, lambda_1=2e-3, lambda_2=5.0.
        (2) GOLEM-EV: equal_variances=True, lambda_1=2e-2, lambda_2=5.0.
    """
    # Center the data
    X = X - X.mean(axis=0, keepdims=True)

    # Set up model
    n, d = X.shape
    model = GolemModel(n, d, lambda_1, lambda_2, equal_variances, seed, B_init)

    # Training
    trainer = GolemTrainer(learning_rate)
    B_est = trainer.train(model, X, num_iter, checkpoint_iter, output_dir)

    return B_est    # Not thresholded yet

def weight_to_adjacency(W, threshold=0.05):
    """
    Convert a weight matrix to an adjacency matrix.
    
    Parameters:
        W (np.ndarray): Weight matrix (square matrix).
        threshold (float): Values with absolute weight <= threshold are treated as 0.
    
    Returns:
        np.ndarray: Binary adjacency matrix of the same shape.
    """
    if not isinstance(W, np.ndarray):
        raise TypeError("Input W must be a numpy array.")
    if W.shape[0] != W.shape[1]:
        raise ValueError("Input W must be a square matrix.")
    
    G = (np.abs(W) > threshold).astype(int)
    return G


In [3]:
import logging

from data_loader import SyntheticDataset
from data_loader.synthetic_dataset import dataset_based_on_B
from utils.train import postprocess
from utils.utils import count_accuracy, set_seed

# Setup for logging
# Required for printing histories if checkpointing is activated
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s %(levelname)s - %(name)s - %(message)s'
)

# Reproducibility
set_seed(1)

# Load dataset
n, d = 1000, 4
graph_type, degree = 'ER', 0.5    # ER2 graph
B_scale = 1.0
noise_type = 'gaussian_ev'
#dataset = SyntheticDataset(n, d, graph_type, degree,
#                           noise_type, B_scale, seed=1)

'''
times = 20
for i in range(1, 6):
    true_count = [0] * 6
    for seed in range(times):
        X, Y, Z, G_true, CPDAG = SCM_data.generate_scm_data(i,10000, seed = seed)
        data = np.array([X, Y, Z]).T
        #print(data.T@ data / 10000)
        W_est = golem(data, lambda_1=2e-2, lambda_2=5.0,
                equal_variances=True, num_iter=1e+4)
        G_est = weight_to_adjacency(W_est, 0.05)
        if MEC.is_in_markov_equiv_class(G_true, G_est): true_count[i-1] += 1
    print(f"SCM {i} : {true_count[i-1]/times}")
'''

for i in range(1, 6):
    X, Y, Z, G_true, CPDAG = SCM_data.generate_scm_data(i,10000, seed = 1)
    data = np.array([X, Y, Z]).T
    W_est = golem(data, lambda_1=2e-2, lambda_2=5.0,
                equal_variances=True, num_iter=1e+4)
    G_est = weight_to_adjacency(W_est, 0.05)
    print("pattern",i)
    print("G_true = \n",G_true)
    print("G_est = \n",G_est)





2025-08-04 21:43:18,762 INFO - trainers.golem_trainer - Started training for 10000.0 iterations.


pattern 1
G_true = 
 [[0 0 0]
 [0 0 0]
 [0 0 0]]
G_est = 
 [[0 0 0]
 [0 0 0]
 [0 0 0]]


2025-08-04 21:43:29,482 INFO - trainers.golem_trainer - Started training for 10000.0 iterations.


pattern 2
G_true = 
 [[0 1 0]
 [0 0 0]
 [0 0 0]]
G_est = 
 [[0 1 0]
 [0 0 0]
 [0 0 0]]


2025-08-04 21:43:40,430 INFO - trainers.golem_trainer - Started training for 10000.0 iterations.


pattern 3
G_true = 
 [[0 1 0]
 [0 0 0]
 [0 1 0]]
G_est = 
 [[0 1 0]
 [0 0 0]
 [0 1 0]]


2025-08-04 21:43:51,250 INFO - trainers.golem_trainer - Started training for 10000.0 iterations.


pattern 4
G_true = 
 [[0 1 0]
 [0 0 1]
 [0 0 0]]
G_est = 
 [[0 1 0]
 [0 0 1]
 [0 0 0]]


2025-08-04 21:44:02,306 INFO - trainers.golem_trainer - Started training for 10000.0 iterations.


pattern 5
G_true = 
 [[0 1 1]
 [0 0 1]
 [0 0 0]]
G_est = 
 [[0 1 1]
 [0 0 1]
 [0 0 0]]


In [4]:
times = 20
for i in range(1, 6):
    true_count = [0] * 6
    for seed in range(times):
        X, Y, Z, G_true, CPDAG = SCM_data.generate_scm_data(i,10000, seed = seed)
        data = np.array([X, Y, Z]).T
        #print(data.T@ data / 10000)
        W_est = golem(data, lambda_1=2e-2, lambda_2=5.0,
                equal_variances=False, num_iter=1e+4)
        G_est = weight_to_adjacency(W_est, 0.05)
        if MEC.is_in_markov_equiv_class(G_true, G_est): true_count[i-1] += 1
    print(f"SCM {i} : {true_count[i-1]/times}")

2025-08-04 21:46:07,680 INFO - trainers.golem_trainer - Started training for 10000.0 iterations.
2025-08-04 21:46:18,282 INFO - trainers.golem_trainer - Started training for 10000.0 iterations.
2025-08-04 21:46:29,203 INFO - trainers.golem_trainer - Started training for 10000.0 iterations.
2025-08-04 21:46:39,717 INFO - trainers.golem_trainer - Started training for 10000.0 iterations.
2025-08-04 21:46:50,493 INFO - trainers.golem_trainer - Started training for 10000.0 iterations.
2025-08-04 21:47:01,736 INFO - trainers.golem_trainer - Started training for 10000.0 iterations.
2025-08-04 21:47:12,495 INFO - trainers.golem_trainer - Started training for 10000.0 iterations.
2025-08-04 21:47:23,337 INFO - trainers.golem_trainer - Started training for 10000.0 iterations.
2025-08-04 21:47:34,305 INFO - trainers.golem_trainer - Started training for 10000.0 iterations.
2025-08-04 21:47:45,034 INFO - trainers.golem_trainer - Started training for 10000.0 iterations.
2025-08-04 21:47:56,769 INFO -

SCM 1 : 1.0


2025-08-04 21:49:45,285 INFO - trainers.golem_trainer - Started training for 10000.0 iterations.
2025-08-04 21:49:56,043 INFO - trainers.golem_trainer - Started training for 10000.0 iterations.
2025-08-04 21:50:06,844 INFO - trainers.golem_trainer - Started training for 10000.0 iterations.
2025-08-04 21:50:17,441 INFO - trainers.golem_trainer - Started training for 10000.0 iterations.
2025-08-04 21:50:28,593 INFO - trainers.golem_trainer - Started training for 10000.0 iterations.
2025-08-04 21:50:39,044 INFO - trainers.golem_trainer - Started training for 10000.0 iterations.
2025-08-04 21:50:49,622 INFO - trainers.golem_trainer - Started training for 10000.0 iterations.
2025-08-04 21:51:00,292 INFO - trainers.golem_trainer - Started training for 10000.0 iterations.
2025-08-04 21:51:11,204 INFO - trainers.golem_trainer - Started training for 10000.0 iterations.
2025-08-04 21:51:22,342 INFO - trainers.golem_trainer - Started training for 10000.0 iterations.
2025-08-04 21:51:35,973 INFO -

SCM 2 : 1.0


2025-08-04 21:53:38,976 INFO - trainers.golem_trainer - Started training for 10000.0 iterations.
2025-08-04 21:53:50,325 INFO - trainers.golem_trainer - Started training for 10000.0 iterations.
2025-08-04 21:54:01,685 INFO - trainers.golem_trainer - Started training for 10000.0 iterations.
2025-08-04 21:54:12,528 INFO - trainers.golem_trainer - Started training for 10000.0 iterations.
2025-08-04 21:54:23,336 INFO - trainers.golem_trainer - Started training for 10000.0 iterations.
2025-08-04 21:54:34,797 INFO - trainers.golem_trainer - Started training for 10000.0 iterations.
2025-08-04 21:54:45,744 INFO - trainers.golem_trainer - Started training for 10000.0 iterations.
2025-08-04 21:54:56,777 INFO - trainers.golem_trainer - Started training for 10000.0 iterations.
2025-08-04 21:55:08,137 INFO - trainers.golem_trainer - Started training for 10000.0 iterations.
2025-08-04 21:55:22,345 INFO - trainers.golem_trainer - Started training for 10000.0 iterations.
2025-08-04 21:55:36,420 INFO -

SCM 3 : 0.0


2025-08-04 21:57:46,849 INFO - trainers.golem_trainer - Started training for 10000.0 iterations.
2025-08-04 21:57:58,711 INFO - trainers.golem_trainer - Started training for 10000.0 iterations.
2025-08-04 21:58:10,199 INFO - trainers.golem_trainer - Started training for 10000.0 iterations.
2025-08-04 21:58:22,610 INFO - trainers.golem_trainer - Started training for 10000.0 iterations.
2025-08-04 21:58:33,852 INFO - trainers.golem_trainer - Started training for 10000.0 iterations.
2025-08-04 21:58:44,919 INFO - trainers.golem_trainer - Started training for 10000.0 iterations.
2025-08-04 21:58:57,038 INFO - trainers.golem_trainer - Started training for 10000.0 iterations.
2025-08-04 21:59:10,232 INFO - trainers.golem_trainer - Started training for 10000.0 iterations.
2025-08-04 21:59:26,842 INFO - trainers.golem_trainer - Started training for 10000.0 iterations.
2025-08-04 21:59:43,150 INFO - trainers.golem_trainer - Started training for 10000.0 iterations.
2025-08-04 21:59:57,707 INFO -

SCM 4 : 0.0


2025-08-04 22:02:27,387 INFO - trainers.golem_trainer - Started training for 10000.0 iterations.
2025-08-04 22:02:42,914 INFO - trainers.golem_trainer - Started training for 10000.0 iterations.
2025-08-04 22:02:57,929 INFO - trainers.golem_trainer - Started training for 10000.0 iterations.
2025-08-04 22:03:13,070 INFO - trainers.golem_trainer - Started training for 10000.0 iterations.
2025-08-04 22:03:27,840 INFO - trainers.golem_trainer - Started training for 10000.0 iterations.
2025-08-04 22:03:42,884 INFO - trainers.golem_trainer - Started training for 10000.0 iterations.
2025-08-04 22:03:58,447 INFO - trainers.golem_trainer - Started training for 10000.0 iterations.
2025-08-04 22:04:13,983 INFO - trainers.golem_trainer - Started training for 10000.0 iterations.
2025-08-04 22:04:27,877 INFO - trainers.golem_trainer - Started training for 10000.0 iterations.
2025-08-04 22:04:40,683 INFO - trainers.golem_trainer - Started training for 10000.0 iterations.
2025-08-04 22:50:13,727 INFO -

SCM 5 : 1.0
