In [1]:
import numpy as np
from numpy import linalg as la
from time import perf_counter
import os
from joblib import Parallel, delayed

import itertools

from src.model import Nonneg_dagma, MetMulDagma
import src.utils as utils

PATH = './results/tuning/'
PATH_SACHS = './datasets/sachs/'
SEED = 10
N_CPUS = os.cpu_count() // 2
np.random.seed(SEED)

DATASET = "SACHS"  # SYNTH, SACHS

In [2]:
def cartesian_product(hyperparams):
    """
    Generate all combinations of hyperparameters.
    """
    param_names = list(hyperparams.keys())
    param_values = list(hyperparams.values())    
    param_combinations = [dict(zip(param_names, values)) for values in itertools.product(*param_values)]
    
    return param_combinations

args2str = lambda arguments: ''.join([f'{key[:3]}={val} ' for key, val in arguments.items()])

def print_best(key, metrics, args_combs, agg_funct='mean'):
    agg_metric = {key: getattr(np, agg_funct)(value, axis=0) for key, value in metrics.items()}
        
    idx = np.argmin(agg_metric[key])

    print(f'Combination with best {key} (agg: {agg_funct}):')
    print(args_combs[idx])
    print(f'shd: {agg_metric["shd"][idx]:.2f} | err: {agg_metric["err"][idx]:.4f} |' +
          f'acyc: {agg_metric["acyc"][idx]:.6f} | time: {agg_metric["time"][idx]:.2f}')



def run_grid_search_tuning(g, data_p, args_combs, model_const, std_x, thr, verb=False):
    # Create data
    if DATASET == "SACHS":
        W_true = np.load(PATH_SACHS + "sachs_A_matrix.npy")
        X = np.load(PATH_SACHS + "sachs_X.npy")
    else:
        W_true, _, X = utils.simulate_sem(**data_p)
        
    # X = X/np.linalg.norm(X, axis=1, keepdims=True) if std_x else X
    X = utils.standarize(X) if std_x else X
    norm_W_true = np.linalg.norm(W_true)
    W_true_bin = utils.to_bin(W_true, thr)
    
    fidelity = 1/data_p['n_samples']*la.norm(X - X @ W_true, 'fro')**2

    print(f'Graph {g+1}: Fidelity: {fidelity:.3f}')

    shd, err, acyc, runtime = [np.zeros((len(args_combs)))  for _ in range(4)]
    for i, args in enumerate(args_combs):        
        model = model_const()
        t_init = perf_counter()
        W_est = model.fit(X, **args)
        t_end = perf_counter()
    
        W_est_bin = utils.to_bin(W_est, thr)
        shd[i], _, _ = utils.count_accuracy(W_true_bin, W_est_bin)
        err[i] = utils.compute_norm_sq_err(W_true, W_est, norm_W_true)
        runtime[i] = t_end - t_init
        acyc[i] = model.dagness(W_est)

        if verb:
            text = args2str(args)
            print(f'\t- {text}: shd {shd[i]}  -  err: {err[i]:.3f}  -  acyc: {acyc[i]:.5g}  -  time: {runtime[i]:.3f}')
    
    return shd, err, acyc, runtime

## Experiment parameters

In [None]:
model_const = MetMulDagma

model_args = {
    'primal_opt': 'fista',  # 'adam', 'fista', 'pgd
    'acyclicity': 'logdet',
    'restart': True,  # Only used in FISTA
}

verb = True
thr = .2
n_dags = 30 if DATASET != "SACHS" else 1
std_x = True
N = 100  
data_params = {
    'n_nodes': N,
    'n_samples': 500, # 1000,
    'graph_type': 'er',
    'edges': 4*N,
    'edge_type': 'positive',
    'w_range': (.5, 1),
    'var': 1, # 1/np.sqrt(N),
}

# Stepsize
# beta vs iters out
# lambda

Hyperparams = {
    'stepsize': [1e-5],
    'alpha_0': [.1],
    'rho_0': [.001],
    'beta': [1.5],
    's': [1],
    'lamb': [1e-2],
    'iters_in': [30000],
    'iters_out': [25],
    'tol': [1e-6],
    'Sigma': [.01, 1],
}

# Hyperparams = {
#     'stepsize': [5e-2],
#     'alpha': [.05, .075, .1, .25, .5, .75, 1],
#     's': [1],
#     'lamb': [5e-4, 1e-3, 5e-3],
#     'max_iters': [10000]
# }

if DATASET == "SACHS":
    N_CPUS = 1

print('CPUs employed:', N_CPUS)
print('Looking hyperparameters for dataset', DATASET)
# Get combination of hyperparams for grid search
args_combs = cartesian_product(Hyperparams)    

t_init = perf_counter()
results = Parallel(n_jobs=N_CPUS)(delayed(run_grid_search_tuning)
                  (g, data_params, args_combs, model_const, std_x, thr, verb) for g in range(n_dags))
t_end = perf_counter()

shd, err, acyc, runtime = zip(*results)
metrics = {'shd': shd, 'err': err, 'acyc': acyc, 'time': runtime}

CPUs employed: 1
Looking hyperparameters for dataset SACHS
Graph 1: Fidelity: 43.797
	- ste=1e-05 alp=0.1 rho=0.001 bet=1.5 s=1 lam=0.01 ite=30000 ite=25 tol=1e-06 Sig=0.01 : shd 13.0  -  err: 1.239  -  acyc: 7.1545e-07  -  time: 25.754
	- ste=1e-05 alp=0.1 rho=0.001 bet=1.5 s=1 lam=0.01 ite=30000 ite=25 tol=1e-06 Sig=1 : shd 14.0  -  err: 1.125  -  acyc: 0.00027289  -  time: 41.579


In [4]:
print_best('shd', metrics, args_combs)
print_best('err', metrics, args_combs)
print()
print_best('shd', metrics, args_combs, agg_funct='median')
print_best('err', metrics, args_combs, agg_funct='median')


Combination with best shd (agg: mean):
{'stepsize': 1e-05, 'alpha_0': 0.1, 'rho_0': 0.001, 'beta': 1.5, 's': 1, 'lamb': 0.01, 'iters_in': 30000, 'iters_out': 25, 'tol': 1e-06, 'Sigma': 0.01}
shd: 13.00 | err: 1.2390 |acyc: 0.000001 | time: 25.75
Combination with best err (agg: mean):
{'stepsize': 1e-05, 'alpha_0': 0.1, 'rho_0': 0.001, 'beta': 1.5, 's': 1, 'lamb': 0.01, 'iters_in': 30000, 'iters_out': 25, 'tol': 1e-06, 'Sigma': 1}
shd: 14.00 | err: 1.1254 |acyc: 0.000273 | time: 41.58

Combination with best shd (agg: median):
{'stepsize': 1e-05, 'alpha_0': 0.1, 'rho_0': 0.001, 'beta': 1.5, 's': 1, 'lamb': 0.01, 'iters_in': 30000, 'iters_out': 25, 'tol': 1e-06, 'Sigma': 0.01}
shd: 13.00 | err: 1.2390 |acyc: 0.000001 | time: 25.75
Combination with best err (agg: median):
{'stepsize': 1e-05, 'alpha_0': 0.1, 'rho_0': 0.001, 'beta': 1.5, 's': 1, 'lamb': 0.01, 'iters_in': 30000, 'iters_out': 25, 'tol': 1e-06, 'Sigma': 1}
shd: 14.00 | err: 1.1254 |acyc: 0.000273 | time: 41.58


In [5]:
leg = [args2str(args) for args in args_combs]
utils.display_results(leg, metrics, agg='mean', file_name=f'{PATH}tuning_mean')
utils.display_results(leg, metrics, agg='median', file_name=f'{PATH}tuning_med')

Unnamed: 0,leg,shd,err,acyc,time
0,ste=1e-05 alp=0.1 rho=0.001 bet=1.5 s=1 lam=0....,13.0000 ± 0.0000,1.2390 ± 0.0000,0.0000 ± 0.0000,25.7536 ± 0.0000
1,ste=1e-05 alp=0.1 rho=0.001 bet=1.5 s=1 lam=0....,14.0000 ± 0.0000,1.1254 ± 0.0000,0.0003 ± 0.0000,41.5788 ± 0.0000


DataFrame saved to ./results/tuning/tuning_mean.csv


Unnamed: 0,leg,shd,err,acyc,time
0,ste=1e-05 alp=0.1 rho=0.001 bet=1.5 s=1 lam=0....,13.0000 ± 0.0000,1.2390 ± 0.0000,0.0000 ± 0.0000,25.7536 ± 0.0000
1,ste=1e-05 alp=0.1 rho=0.001 bet=1.5 s=1 lam=0....,14.0000 ± 0.0000,1.1254 ± 0.0000,0.0003 ± 0.0000,41.5788 ± 0.0000


DataFrame saved to ./results/tuning/tuning_med.csv
