In [1]:
import numpy as np
from numpy import linalg as la
from time import perf_counter
import os
from joblib import Parallel, delayed

import itertools

from src.model import Nonneg_dagma, MetMulDagma
import src.utils as utils

PATH = './results/tuning/'
SEED = 10
N_CPUS = os.cpu_count() // 2
np.random.seed(SEED)


In [2]:
def cartesian_product(hyperparams):
    """
    Generate all combinations of hyperparameters.
    """
    param_names = list(hyperparams.keys())
    param_values = list(hyperparams.values())    
    param_combinations = [dict(zip(param_names, values)) for values in itertools.product(*param_values)]
    
    return param_combinations

args2str = lambda arguments: ''.join([f'{key[:3]}={val} ' for key, val in arguments.items()])

def print_best(key, metrics, args_combs, agg_funct='mean'):
    agg_metric = {key: getattr(np, agg_funct)(value, axis=0) for key, value in metrics.items()}
        
    idx = np.argmin(agg_metric[key])

    print(f'Combination with best {key} (agg: {agg_funct}):')
    print(args_combs[idx])
    print(f'shd: {agg_metric["shd"][idx]:.2f} | err: {agg_metric["err"][idx]:.4f} |' +
          f'acyc: {agg_metric["acyc"][idx]:.6f} | time: {agg_metric["time"][idx]:.2f}')

def run_grid_search_tuning(g, data_p, args_combs, model_const, norm_x, thr, verb=False):
    # Create data
    W_true, _, X = utils.simulate_sem(**data_p)
    norm_W_true = np.linalg.norm(W_true)
    W_true_bin = utils.to_bin(W_true, thr)
    X = X/np.linalg.norm(X, axis=1, keepdims=True) if norm_x else X
    
    fidelity = 1/data_p['n_samples']*la.norm(X - X @ W_true, 'fro')**2

    print(f'Graph {g+1}: Fidelity: {fidelity:.3f}')

    shd, err, acyc, runtime = [np.zeros((len(args_combs)))  for _ in range(4)]
    for i, args in enumerate(args_combs):        
        model = model_const()
        t_init = perf_counter()
        W_est = model.fit(X, **args)
        t_end = perf_counter()
    
        W_est_bin = utils.to_bin(W_est, thr)
        shd[i], _, _ = utils.count_accuracy(W_true_bin, W_est_bin)
        err[i] = utils.compute_norm_sq_err(W_true, W_est, norm_W_true)
        runtime[i] = t_end - t_init
        acyc[i] = model.dagness(W_est)

        if verb:
            text = args2str(args)
            print(f'\t- {text}: shd {shd[i]}  -  err: {err[i]:.3f}  -  acyc: {acyc[i]:.5g}  -  time: {runtime[i]:.3f}')
    
    return shd, err, acyc, runtime

## Experiment parameters

In [3]:
model_const = MetMulDagma

verb = False
thr = .2
n_dags = 30
norm_x = False
N = 100
data_params = {
    'n_nodes': N,
    'n_samples': 500, # 1000,
    'graph_type': 'er',
    'edges': 4*N,
    'edge_type': 'positive',
    'w_range': (.5, 1),
    'var': 1, # 1/np.sqrt(N),
}

# Stepsize
# beta vs iters out
# lambda

Hyperparams = {
    'stepsize': [2.5e-3],
    'alpha_0': [.01],
    'rho_0': [.05],
    'beta': [2],
    's': [1],
    'lamb': [1e-4],
    'iters_in': [1000, 5000, 10000, 20000, 30000],
    'iters_out': [5, 10, 15, 20, 30],
    'tol': [1e-6]
}

# Hyperparams = {
#     'stepsize': [5e-2],
#     'alpha': [.05, .075, .1, .25, .5, .75, 1],
#     's': [1],
#     'lamb': [5e-4, 1e-3, 5e-3],
#     'max_iters': [10000]
# }

print('CPUs employed:', N_CPUS)

# Get combination of hyperparams for grid search
args_combs = cartesian_product(Hyperparams)    

t_init = perf_counter()
results = Parallel(n_jobs=N_CPUS)(delayed(run_grid_search_tuning)
                  (g, data_params, args_combs, model_const, norm_x, thr, verb) for g in range(n_dags))
t_end = perf_counter()

shd, err, acyc, runtime = zip(*results)
metrics = {'shd': shd, 'err': err, 'acyc': acyc, 'time': runtime}

CPUs employed: 32


Graph 12: Fidelity: 100.457
Graph 3: Fidelity: 100.831
Graph 1: Fidelity: 100.348
Graph 15: Fidelity: 100.268
Graph 5: Fidelity: 99.977
Graph 14: Fidelity: 101.191
Graph 2: Fidelity: 99.826
Graph 9: Fidelity: 99.439
Graph 8: Fidelity: 100.416
Graph 16: Fidelity: 100.473
Graph 25: Fidelity: 99.381
Graph 24: Fidelity: 100.032
Graph 18: Fidelity: 100.441
Graph 4: Fidelity: 100.138
Graph 27: Fidelity: 99.785
Graph 17: Fidelity: 99.891
Graph 30: Fidelity: 100.242
Graph 19: Fidelity: 100.343
Graph 21: Fidelity: 100.051
Graph 6: Fidelity: 98.923
Graph 26: Fidelity: 100.504
Graph 13: Fidelity: 99.867
Graph 20: Fidelity: 98.595
Graph 7: Fidelity: 99.577
Graph 28: Fidelity: 99.543
Graph 29: Fidelity: 101.018
Graph 23: Fidelity: 100.425
Graph 22: Fidelity: 99.245
Graph 11: Fidelity: 99.705Graph 10: Fidelity: 98.078



KeyboardInterrupt: 

In [None]:
print_best('shd', metrics, args_combs)
print_best('err', metrics, args_combs)
print()
print_best('shd', metrics, args_combs, agg_funct='median')
print_best('err', metrics, args_combs, agg_funct='median')


Combination with best shd (agg: mean):
{'stepsize': 0.0025, 'alpha_0': 0.01, 'rho_0': 0.05, 'beta': 2, 's': 1, 'lamb': 0.0001, 'iters_in': 30000, 'iters_out': 20, 'tol': 1e-06}
shd: 44.37 | err: 0.1003 |acyc: 0.000205 | time: 255.46
Combination with best err (agg: mean):
{'stepsize': 0.0025, 'alpha_0': 0.01, 'rho_0': 0.05, 'beta': 2, 's': 1, 'lamb': 0.0001, 'iters_in': 30000, 'iters_out': 20, 'tol': 1e-06}
shd: 44.37 | err: 0.1003 |acyc: 0.000205 | time: 255.46

Combination with best shd (agg: median):
{'stepsize': 0.0025, 'alpha_0': 0.01, 'rho_0': 0.05, 'beta': 1.5, 's': 1, 'lamb': 0.0001, 'iters_in': 30000, 'iters_out': 20, 'tol': 1e-06}
shd: 29.50 | err: 0.0650 |acyc: 0.000415 | time: 295.55
Combination with best err (agg: median):
{'stepsize': 0.0025, 'alpha_0': 0.01, 'rho_0': 0.05, 'beta': 1, 's': 1, 'lamb': 0.0001, 'iters_in': 30000, 'iters_out': 20, 'tol': 1e-06}
shd: 30.00 | err: 0.0624 |acyc: 0.011316 | time: 323.13


In [None]:
leg = [args2str(args) for args in args_combs]
utils.display_results(leg, metrics, agg='mean', file_name=f'{PATH}tuning_mean')
utils.display_results(leg, metrics, agg='median', file_name=f'{PATH}tuning_med')

Unnamed: 0,leg,shd,err,acyc,time
0,ste=0.0025 alp=0.01 rho=0.05 bet=1 s=1 lam=0.0...,93.966667,0.215904,0.3561422,164.513633
1,ste=0.0025 alp=0.01 rho=0.05 bet=1 s=1 lam=0.0...,69.933333,0.15923,0.107849,227.545298
2,ste=0.0025 alp=0.01 rho=0.05 bet=1 s=1 lam=0.0...,59.233333,0.138979,0.03711267,277.112196
3,ste=0.0025 alp=0.01 rho=0.05 bet=1 s=1 lam=0.0...,59.2,0.138586,0.03455077,350.743448
4,ste=0.0025 alp=0.01 rho=0.05 bet=1.5 s=1 lam=0...,87.966667,0.20321,0.03748139,198.18125
5,ste=0.0025 alp=0.01 rho=0.05 bet=1.5 s=1 lam=0...,67.4,0.15477,0.02209993,243.984569
6,ste=0.0025 alp=0.01 rho=0.05 bet=1.5 s=1 lam=0...,60.333333,0.143963,0.02160927,299.663983
7,ste=0.0025 alp=0.01 rho=0.05 bet=1.5 s=1 lam=0...,54.333333,0.129121,0.004275546,319.818126
8,ste=0.0025 alp=0.01 rho=0.05 bet=2 s=1 lam=0.0...,79.1,0.178392,0.02011577,190.969188
9,ste=0.0025 alp=0.01 rho=0.05 bet=2 s=1 lam=0.0...,64.6,0.153522,0.02059942,235.871609


DataFrame saved to ./results/tuning/tuning_mean.cs


Unnamed: 0,leg,shd,err,acyc,time
0,ste=0.0025 alp=0.01 rho=0.05 bet=1 s=1 lam=0.0...,69.0,0.121028,0.1062515,178.351156
1,ste=0.0025 alp=0.01 rho=0.05 bet=1 s=1 lam=0.0...,36.0,0.077939,0.03417634,253.208249
2,ste=0.0025 alp=0.01 rho=0.05 bet=1 s=1 lam=0.0...,30.5,0.062362,0.01152436,261.658481
3,ste=0.0025 alp=0.01 rho=0.05 bet=1 s=1 lam=0.0...,30.0,0.062361,0.01131592,323.12754
4,ste=0.0025 alp=0.01 rho=0.05 bet=1.5 s=1 lam=0...,54.5,0.099164,0.01561773,212.733377
5,ste=0.0025 alp=0.01 rho=0.05 bet=1.5 s=1 lam=0...,32.5,0.072637,0.0006398465,272.510115
6,ste=0.0025 alp=0.01 rho=0.05 bet=1.5 s=1 lam=0...,30.0,0.072636,0.0005325537,268.192793
7,ste=0.0025 alp=0.01 rho=0.05 bet=1.5 s=1 lam=0...,29.5,0.065016,0.0004152796,295.547803
8,ste=0.0025 alp=0.01 rho=0.05 bet=2 s=1 lam=0.0...,52.5,0.095892,0.001284519,196.488165
9,ste=0.0025 alp=0.01 rho=0.05 bet=2 s=1 lam=0.0...,32.5,0.072713,0.00044443,261.179911


DataFrame saved to ./results/tuning/tuning_med.cs
