In [None]:
import optuna
from optuna.visualization import plot_pareto_front, plot_optimization_history, plot_slice

import pandas

from params import lsh_test, hypercube_test, gnn_test, mrng_test, nsg_test

In [None]:
input_path = b'../MNIST/input.dat'
query_path = b'../MNIST/query.dat'

In [None]:
def objective_gnns(trial):
    param_dict = {'k': trial.suggest_int('k', 40, 100)}
    param_dict.update({'E': trial.suggest_int('E', 40, param_dict['k'])})
    param_dict.update({'R': trial.suggest_int('R', 1, 10)})

    print("Trial params", param_dict)

    average_time, maf = gnn_test(input_path, query_path, queries_num=100, **param_dict, N=5)

    return maf.value, average_time.value

In [None]:
%%time
for i in range(10):
    try:
        gnns_study = optuna.create_study(study_name='gnns', directions=['minimize', 'minimize'])
        gnns_study.optimize(objective_gnns, n_trials=50)
        print("-------------------- Best trials --------------------")
        trials = sorted(gnns_study.best_trials, key=lambda x: x.values)
        for trial in trials:
            print("Trial no. {}".format(trial.number))
            print(" Values = {}".format(trial.values))
            print(" Params = {}".format(trial.params))
        break
    except:
        print("Trial failed, trying again...")
        continue

In [None]:
plot_pareto_front(gnns_study, target_names=['maf', 'average_time'])

In [None]:
plot_optimization_history(gnns_study, target = lambda t: t.values[0], target_name = 'maf')

In [None]:
plot_optimization_history(gnns_study, target = lambda t: t.values[1], target_name = 'average_time')

In [None]:
plot_slice(gnns_study, target = lambda t: t.values[0], target_name = 'maf')

In [None]:
plot_slice(gnns_study, target = lambda t: t.values[1], target_name = 'average_time')