In [1]:
import numpy as np
from numpy import linalg as la
from time import perf_counter
import os
from joblib import Parallel, delayed

import itertools

from src.model import Nonneg_dagma, MetMulDagma
import src.utils as utils

SEED = 10
N_CPUS = os.cpu_count() // 2
np.random.seed(SEED)


In [10]:
def cartesian_product(hyperparams):
    """
    Generate all combinations of hyperparameters.
    """
    param_names = list(hyperparams.keys())
    param_values = list(hyperparams.values())    
    param_combinations = [dict(zip(param_names, values)) for values in itertools.product(*param_values)]
    
    return param_combinations

args2str = lambda arguments: ''.join([f'{key[:3]}={val} ' for key, val in arguments.items()])

def print_best(key, metrics, args_combs, agg_funct='mean'):
    agg_metric = {key: getattr(np, agg_funct)(value, axis=0) for key, value in metrics.items()}
        
    idx = np.argmin(agg_metric[key])

    print(f'Combination with best {key} (agg: {agg_funct}):')
    print(args_combs[idx])
    print(f'shd: {agg_metric["shd"][idx]:.2f} | err: {agg_metric["err"][idx]:.4f} |' +
          f'acyc: {agg_metric["acyc"][idx]:.6f} | time: {agg_metric["time"][idx]:.2f}')

def run_grid_search_tuning(g, data_p, args_combs, model_const, norm_x, thr, verb=False):
    # Create data
    W_true, _, X = utils.simulate_sem(**data_p)
    norm_W_true = np.linalg.norm(W_true)
    W_true_bin = utils.to_bin(W_true, thr)
    X = X/np.linalg.norm(X, axis=1, keepdims=True) if norm_x else X
    
    fidelity = 1/data_p['n_samples']*la.norm(X - X @ W_true_bin, 'fro')**2

    print(f'Graph {g+1}: Fidelity: {fidelity:.3f}')

    shd, err, acyc, runtime = [np.zeros((len(args_combs)))  for _ in range(4)]
    for i, args in enumerate(args_combs):        
        model = model_const()
        t_init = perf_counter()
        W_est = model.fit(X, **args)
        t_end = perf_counter()
    
        W_est_bin = utils.to_bin(W_est, thr)
        shd[i], _, _ = utils.count_accuracy(W_true_bin, W_est_bin)
        err[i] = utils.compute_norm_sq_err(W_true, W_est, norm_W_true)
        runtime[i] = t_end - t_init
        acyc[i] = model.dagness(W_est)

        if verb:
            text = args2str(args)
            print(f'\t- {text}: shd {shd[i]}  -  err: {err[i]:.3f}  -  acyc: {acyc[i]:.5g}  -  time: {runtime[i]:.3f}')
    
    return shd, err, acyc, runtime

## Experiment parameters

In [3]:
model_const = MetMulDagma

verb = False
thr = .2
n_dags = 30
norm_x = False
N = 100
data_params = {
    'n_nodes': N,
    'n_samples': 1000, # 1000,
    'graph_type': 'er',
    'edges': 4*N,
    'edge_type': 'positive',
    'w_range': (.5, 1),
    'var': 1, # 1/np.sqrt(N),
}

Hyperparams = {
    'stepsize': [1e-5, 1e-4, 1e-3],
    'alpha_0': [.01],
    'rho_0': [.05],
    'beta': [2, 5],
    's': [1],
    'lamb': [1e-5, 1e-4, 1e-3],
    'iters_in': [15000, 30000],
    'iters_out': [7],
    'tol': [1e-6]
}

# Hyperparams = {
#     'stepsize': [5e-2],
#     'alpha': [.05, .075, .1, .25, .5, .75, 1],
#     's': [1],
#     'lamb': [5e-4, 1e-3, 5e-3],
#     'max_iters': [10000]
# }

print('CPUs employed:', N_CPUS)

# Get combination of hyperparams for grid search
args_combs = cartesian_product(Hyperparams)    

t_init = perf_counter()
results = Parallel(n_jobs=N_CPUS)(delayed(run_grid_search_tuning)
                  (g, data_params, args_combs, model_const, norm_x, thr, verb) for g in range(n_dags))
t_end = perf_counter()

shd, err, acyc, runtime = zip(*results)
metrics = {'shd': shd, 'err': err, 'acyc': acyc, 'time': runtime}

CPUs employed: 32


Graph 4: Fidelity: 18.371
Graph 8: Fidelity: 16.425
Graph 5: Fidelity: 15.080
Graph 2: Fidelity: 19.963
Graph 6: Fidelity: 15.837
Graph 11: Fidelity: 14.189
Graph 1: Fidelity: 17.659
Graph 16: Fidelity: 14.842
Graph 9: Fidelity: 15.035
Graph 13: Fidelity: 13.643
Graph 7: Fidelity: 15.445
Graph 26: Fidelity: 18.455
Graph 14: Fidelity: 15.825
Graph 12: Fidelity: 14.761
Graph 23: Fidelity: 17.671
Graph 28: Fidelity: 14.068
Graph 15: Fidelity: 14.283
Graph 29: Fidelity: 17.638
Graph 10: Fidelity: 18.679
Graph 22: Fidelity: 14.504
Graph 20: Fidelity: 14.604
Graph 3: Fidelity: 14.668
Graph 24: Fidelity: 15.304
Graph 19: Fidelity: 13.657
Graph 25: Fidelity: 13.732
Graph 27: Fidelity: 16.098
Graph 17: Fidelity: 13.925
Graph 30: Fidelity: 16.380
Graph 21: Fidelity: 16.767
Graph 18: Fidelity: 15.536


In [11]:
print_best('shd', metrics, args_combs)
print_best('err', metrics, args_combs)
print_best('shd', metrics, args_combs, agg_funct='median')
print_best('err', metrics, args_combs, agg_funct='median')


Combination with best shd (agg: mean):
{'stepsize': 0.0005, 'alpha_0': 0.01, 'beta': 0.1, 's': 1, 'lamb': 0.005, 'iters_in': 20000, 'tol': 1e-06}
shd: 9.57 | err: 0.0555 |acyc: 5.150962 | time: 226.20
Combination with best err (agg: mean):
{'stepsize': 0.0005, 'alpha_0': 0.01, 'beta': 0.1, 's': 1, 'lamb': 0.005, 'iters_in': 20000, 'tol': 1e-06}
shd: 9.57 | err: 0.0555 |acyc: 5.150962 | time: 226.20
Combination with best shd (agg: median):
{'stepsize': 0.05, 'alpha_0': 0.001, 'beta': 1, 's': 1, 'lamb': 0.005, 'iters_in': 20000, 'tol': 1e-06}
shd: 7.00 | err: 0.0422 |acyc: 0.005493 | time: 22.24
Combination with best err (agg: median):
{'stepsize': 0.05, 'alpha_0': 0.001, 'beta': 1, 's': 1, 'lamb': 0.005, 'iters_in': 20000, 'tol': 1e-06}
shd: 7.00 | err: 0.0422 |acyc: 0.005493 | time: 22.24


In [12]:
leg = [args2str(args) for args in args_combs]
utils.display_results(leg, metrics, agg='mean')
utils.display_results(leg, metrics, agg='median')

Unnamed: 0,leg,shd,err,acyc,time
0,ste=0.0005 alp=0.001 bet=0.1 s=1 lam=0.0005 it...,19.933333,0.086771,9.028635,322.065551
1,ste=0.0005 alp=0.001 bet=0.1 s=1 lam=0.001 ite...,17.066667,0.076391,9.118577,527.092364
2,ste=0.0005 alp=0.001 bet=0.1 s=1 lam=0.005 ite...,14.000000,0.070466,9.857354,474.626107
3,ste=0.0005 alp=0.001 bet=1 s=1 lam=0.0005 ite=...,27.000000,0.117654,0.000739,494.084724
4,ste=0.0005 alp=0.001 bet=1 s=1 lam=0.001 ite=2...,25.866667,0.110105,0.000680,387.690575
...,...,...,...,...,...
175,ste=0.05 alp=0.1 bet=5 s=1 lam=0.001 ite=20000...,21.300000,0.092491,0.000004,23.403639
176,ste=0.05 alp=0.1 bet=5 s=1 lam=0.005 ite=20000...,26.666667,0.155240,0.000005,26.225400
177,ste=0.05 alp=0.1 bet=10 s=1 lam=0.0005 ite=200...,23.066667,0.102841,0.000001,22.516795
178,ste=0.05 alp=0.1 bet=10 s=1 lam=0.001 ite=2000...,21.300000,0.092510,0.000001,22.266336


Unnamed: 0,leg,shd,err,acyc,time
0,ste=0.0005 alp=0.001 bet=0.1 s=1 lam=0.0005 it...,19.0,0.082342,9.141627e+00,335.130343
1,ste=0.0005 alp=0.001 bet=0.1 s=1 lam=0.001 ite...,17.0,0.072844,9.318603e+00,555.722431
2,ste=0.0005 alp=0.001 bet=0.1 s=1 lam=0.005 ite...,14.5,0.072080,1.031252e+01,505.608322
3,ste=0.0005 alp=0.001 bet=1 s=1 lam=0.0005 ite=...,27.5,0.112746,7.463773e-04,532.360593
4,ste=0.0005 alp=0.001 bet=1 s=1 lam=0.001 ite=2...,25.0,0.106827,6.979991e-04,387.918798
...,...,...,...,...,...
175,ste=0.05 alp=0.1 bet=5 s=1 lam=0.001 ite=20000...,18.0,0.074361,3.142821e-06,23.534863
176,ste=0.05 alp=0.1 bet=5 s=1 lam=0.005 ite=20000...,15.0,0.079073,5.056382e-06,20.380758
177,ste=0.05 alp=0.1 bet=10 s=1 lam=0.0005 ite=200...,19.0,0.085098,7.382117e-07,20.479945
178,ste=0.05 alp=0.1 bet=10 s=1 lam=0.001 ite=2000...,18.0,0.074366,7.717280e-07,20.288939
