In [None]:
from yahpo_train.cont_normalization import ContNormalization
from yahpo_train.model  import *
from yahpo_gym import cfg
from yahpo_train.metrics import *
from yahpo_gym.benchmarks import lcbench

import torch
import optuna
from optuna.integration.fastaiv2 import FastAIV2PruningCallback


# 1. Define an objective function to be maximized.
def objective(trial):

    # 2. Suggest values of the hyperparameters using a trial object.
    n_layers  = trial.suggest_int('n_layers', 1, 3)
    n_deep    = trial.suggest_categorical('n_deep', [128, 256, 512, 1024])
    arch      = trial.suggest_categorical('arch', ['cone', 'block'])
    lr = trial.suggest_float("lr", 1e-5, 1e-1, log=True)
    deeper = trial.suggest_categorical('deeper', [[512,256,128], []])
    
    layers = [n_deep for _ in range(n_layers)]
    if arch == 'cone':
        layers = [layers[i] * (2 ** -i) for i in range(n_layers)]
    f = FFSurrogateModel(dls, layers=layers, deeper = deeper)
    l = SurrogateTabularLearner(dls, f, loss_func=nn.MSELoss(reduction='mean'), metrics=nn.MSELoss)
    l.metrics = [AvgTfedMetric(mae),  AvgTfedMetric(r2), AvgTfedMetric(spearman)]
    l.add_cb(MixHandler)
    l.add_cb(EarlyStoppingCallback(patience=3))
    l.add_cb(FastAIV2PruningCallback(trial))

    l.fit_flat_cos(5, lr)
    for p in l.model.wide.parameters():
        p.requires_grad = False
    l.fit_flat_cos(5, lr)
    l.final_record[1]
    return accuracy

# 3. Create a study object and optimize the objective function.
study = optuna.create_study(direction='minimize')
study.optimize(objective, n_trials=100)