In [1]:
import numpy as np
from numpy import linalg as la
from IPython.display import display
from time import perf_counter
import os
from joblib import Parallel, delayed

import itertools

from src.model import Nonneg_dagma
import src.utils as utils

SEED = 10
N_CPUS = os.cpu_count() // 2
np.random.seed(SEED)


In [2]:
def cartesian_product(hyperparams):
    """
    Generate all combinations of hyperparameters.
    """
    param_names = list(hyperparams.keys())
    param_values = list(hyperparams.values())    
    param_combinations = [dict(zip(param_names, values)) for values in itertools.product(*param_values)]
    
    return param_combinations

args2str = lambda arguments: ''.join([f'{key[:3]}={val} ' for key, val in arguments.items()])


def run_grid_search_tuning(g, data_p, args_combs, model_const, norm_x, thr, verb=False):
    # Create data
    W_true, _, X = utils.simulate_sem(**data_p)
    norm_W_true = np.linalg.norm(W_true)
    W_true_bin = utils.to_bin(W_true, thr)
    X = X/np.linalg.norm(X, axis=1, keepdims=True) if norm_x else X
    
    fidelity = 1/data_p['n_samples']*la.norm(X - X @ W_true_bin, 'fro')**2

    print(f'Graph {g+1}: Fidelity: {fidelity:.3f}')

    shd, err, acyc, runtime = [np.zeros((len(args_combs)))  for _ in range(4)]
    for i, args in enumerate(args_combs):        
        model = model_const()
        t_init = perf_counter()
        W_est = model.fit(X, **args)
        t_end = perf_counter()
    
        W_est_bin = utils.to_bin(W_est, thr)
        shd[i], _, _ = utils.count_accuracy(W_true_bin, W_est_bin)
        err[i] = utils.compute_norm_sq_err(W_true, W_est, norm_W_true)
        runtime[i] = t_end - t_init
        acyc[i] = model.dagness(W_est)

        if verb:
            text = args2str(args)
            print(f'\t- {text}: shd {shd[i]}  -  err: {err[i]:.3f}  -  acyc: {acyc[i]:.5g}  -  time: {runtime[i]:.3f}')
    
    return shd, err, acyc, runtime

## Experiment parameters

In [3]:
model_const = Nonneg_dagma

verb = False
thr = .2
n_dags = 10
norm_x = False
N = 50
data_params = {
    'n_nodes': N,
    'n_samples': 500, # 1000,
    'graph_type': 'er',
    'edges': 2*N,
    'edge_type': 'positive',
    'w_range': (.5, 1),
    'var': 1/np.sqrt(N),
}

Hyperparams = {
    'stepsize': [5e-2],
    'alpha': [.05, .075, .1, .25, .5, .75, 1],
    's': [1],
    'lamb': [5e-4, 1e-3, 5e-3],
    'max_iters': [10000]
}

print('CPUs employed:', N_CPUS)

# Get combination of hyperparams for grid search
args_combs = cartesian_product(Hyperparams)    

t_init = perf_counter()
results = Parallel(n_jobs=N_CPUS)(delayed(run_grid_search_tuning)
                  (g, data_params, args_combs, model_const, norm_x, thr, verb) for g in range(n_dags))
t_end = perf_counter()

shd, err, acyc, runtime = zip(*results)
metrics = {'shd': shd, 'err': err, 'acyc': acyc, 'time': runtime}

CPUs employed: 32


Graph 4: Fidelity: 11.321
Graph 9: Fidelity: 9.601
Graph 5: Fidelity: 10.895
Graph 2: Fidelity: 13.548
Graph 10: Fidelity: 20.807
Graph 8: Fidelity: 12.880
Graph 3: Fidelity: 10.852
Graph 7: Fidelity: 10.687
Graph 6: Fidelity: 15.050
Graph 1: Fidelity: 11.993


In [4]:
leg = [args2str(args) for args in args_combs]
utils.display_results(leg, metrics, agg='mean')
utils.display_results(leg, metrics, agg='median')

Unnamed: 0,leg,shd,err,acyc,time
0,ste=0.05 alp=0.05 s=1 lam=0.0005 max=10000,7.6,0.072993,8.215874,1.748738
1,ste=0.05 alp=0.05 s=1 lam=0.001 max=10000,7.9,0.070044,8.399597,1.27682
2,ste=0.05 alp=0.05 s=1 lam=0.005 max=10000,19.0,0.107066,9.652709,2.339229
3,ste=0.05 alp=0.075 s=1 lam=0.0005 max=10000,11.6,0.174116,0.211919,3.164424
4,ste=0.05 alp=0.075 s=1 lam=0.001 max=10000,11.5,0.170331,0.253124,2.279736
5,ste=0.05 alp=0.075 s=1 lam=0.005 max=10000,11.3,0.159459,0.912529,1.74649
6,ste=0.05 alp=0.1 s=1 lam=0.0005 max=10000,0.6,0.017608,0.009066,1.588221
7,ste=0.05 alp=0.1 s=1 lam=0.001 max=10000,0.6,0.01626,0.013288,1.544562
8,ste=0.05 alp=0.1 s=1 lam=0.005 max=10000,0.7,0.014319,0.040879,1.441479
9,ste=0.05 alp=0.25 s=1 lam=0.0005 max=10000,6.2,0.051193,0.000192,1.346123


Unnamed: 0,leg,shd,err,acyc,time
0,ste=0.05 alp=0.05 s=1 lam=0.0005 max=10000,7.0,0.076256,8.232812,1.745772
1,ste=0.05 alp=0.05 s=1 lam=0.001 max=10000,8.0,0.07448,8.398825,1.302973
2,ste=0.05 alp=0.05 s=1 lam=0.005 max=10000,18.5,0.108124,9.619226,2.586661
3,ste=0.05 alp=0.075 s=1 lam=0.0005 max=10000,0.0,0.014326,0.178628,3.252227
4,ste=0.05 alp=0.075 s=1 lam=0.001 max=10000,0.0,0.012247,0.231667,2.156977
5,ste=0.05 alp=0.075 s=1 lam=0.005 max=10000,0.0,0.015252,0.938976,1.750269
6,ste=0.05 alp=0.1 s=1 lam=0.0005 max=10000,0.0,0.014708,0.003753,1.663795
7,ste=0.05 alp=0.1 s=1 lam=0.001 max=10000,0.0,0.013541,0.00519,1.569353
8,ste=0.05 alp=0.1 s=1 lam=0.005 max=10000,0.0,0.011231,0.028852,1.392687
9,ste=0.05 alp=0.25 s=1 lam=0.0005 max=10000,6.0,0.047495,0.000192,1.317268
