In [1]:
from yahpo_train.cont_normalization import ContNormalization
from yahpo_train.model  import *
from yahpo_gym import cfg
from yahpo_train.metrics import *
from yahpo_gym.benchmarks import lcbench

import torch
import optuna
from optuna.integration.fastaiv2 import FastAIV2PruningCallback


# 1. Define an objective function to be maximized.
def objective(trial):

    # 2. Suggest values of the hyperparameters using a trial object.
    n_layers  = trial.suggest_int('n_layers', 1, 3)
    n_deep    = trial.suggest_categorical('n_deep', [128, 256, 512, 1024])
    arch      = trial.suggest_categorical('arch', ['cone', 'block'])
    lr = trial.suggest_float("lr", 1e-5, 1e-1, log=True)
    deeper = trial.suggest_categorical('deeper', [[512,256,128], []])
    
    layers = [n_deep for _ in range(n_layers)]
    if arch == 'cone':
        layers = [layers[i] * (2 ** -i) for i in range(n_layers)]
    f = FFSurrogateModel(dls, layers=layers, deeper = deeper)
    l = SurrogateTabularLearner(dls, f, loss_func=nn.MSELoss(reduction='mean'), metrics=nn.MSELoss)
    l.metrics = [AvgTfedMetric(mae),  AvgTfedMetric(r2), AvgTfedMetric(spearman)]
    l.add_cb(MixHandler)
    l.add_cb(EarlyStoppingCallback(patience=3))
    l.add_cb(FastAIV2PruningCallback(trial))

    l.fit_flat_cos(5, lr)
    for p in l.model.wide.parameters():
        p.requires_grad = False
    l.fit_flat_cos(5, lr)
    return l.final_record[1]

In [3]:
cfg = cfg("lcbench")
dls = dl_from_config(cfg, bs=2048)
study = optuna.create_study(direction='minimize')
study.optimize(objective, n_trials=100)

[32m[I 2021-08-12 11:38:12,216][0m A new study created in memory with name: no-name-62b6af1f-a724-4e47-8085-b08b61176269[0m


epoch,train_loss,valid_loss,mae,r2,spearman,time
0,0.022399,0.019245,[1.40643355e+02 1.34281317e+01 3.91302942e-01 1.23962153e-01  4.36931771e-01 1.22713331e-01],[0.15744615 0.43429933 0.25592702 0.46654063 0.28806383 0.46977858],[0.73354542 0.69009034 0.83506665 0.7119077 0.88356738 0.70535249],04:16
1,0.013862,0.01103,[1.31204213e+02 1.05068523e+01 2.78542223e-01 9.27069401e-02  2.22293473e-01 9.24314804e-02],[0.24895111 0.62877236 0.66180222 0.68801271 0.8073185 0.68382155],[0.79203882 0.77338779 0.86503713 0.81504889 0.92656959 0.81305056],04:08
2,0.011522,0.009282,[1.03520335e+02 9.82039807e+00 2.43373796e-01 8.43127077e-02  1.35222575e-01 8.38608280e-02],[0.47364465 0.65623211 0.7501447 0.72300803 0.92421855 0.72031693],[0.90373312 0.79226543 0.86592398 0.83486678 0.94135502 0.83460848],04:18
3,0.010446,0.008539,[75.00131715 9.47184332 0.22462388 0.08075463 0.11377404 0.08037957],[0.68115686 0.67557616 0.77495835 0.73839116 0.93835496 0.73551077],[0.94913241 0.80723046 0.86843364 0.84404635 0.94605866 0.84335457],04:09
4,0.010076,0.008287,[70.42020312 9.28482137 0.22013967 0.07911245 0.11044473 0.07879953],[0.71085691 0.6852334 0.78240295 0.74578997 0.94270384 0.74269256],[0.95393611 0.81489114 0.87015136 0.84840776 0.94764034 0.84730275],04:17


epoch,train_loss,valid_loss,mae,r2,spearman,time
0,0.009408,0.007758,[64.66685467 8.91664661 0.21822952 0.07615364 0.10871433 0.07597436],[0.73938738 0.70785778 0.79093749 0.76037945 0.95111729 0.75700043],[0.96036976 0.83254131 0.87448686 0.85772256 0.95004001 0.85575812],04:06
1,0.008934,0.007317,[62.44399364 8.58324631 0.2195159 0.07339 0.1095516 0.07317902],[0.75235786 0.72336053 0.788864 0.77213943 0.95652892 0.76955664],[0.96254593 0.84365824 0.87928766 0.86561484 0.95232303 0.86399351],04:08
2,0.008447,0.006874,[61.69196691 8.25214416 0.21577335 0.07013595 0.10745432 0.06978664],[0.75539139 0.7385986 0.78946323 0.78646212 0.96343413 0.78473939],[0.9630975 0.85231903 0.88647604 0.87417647 0.95437977 0.87288891],04:06
3,0.007968,0.006397,[62.18083644 7.89579314 0.20981561 0.06647126 0.10044162 0.06618808],[0.75280396 0.75462366 0.79070261 0.80281039 0.96897453 0.80108297],[0.96331334 0.86083378 0.89484966 0.88242733 0.95648689 0.8813064 ],04:08
4,0.007798,0.006236,[62.59675591 7.72857057 0.20613156 0.06490057 0.09586275 0.06464858],[0.75215865 0.75941802 0.79590708 0.80770296 0.97123269 0.80596336],[0.96296847 0.8638748 0.89745451 0.88548527 0.95676948 0.88451679],04:05


[32m[I 2021-08-12 12:20:26,657][0m Trial 0 finished with value: 0.006235813722014427 and parameters: {'n_layers': 2, 'n_deep': 512, 'arch': 'block', 'lr': 1.1246190762031614e-05, 'deeper': [512, 256, 128]}. Best is trial 0 with value: 0.006235813722014427.[0m


epoch,train_loss,valid_loss,mae,r2,spearman,time


> [1;32mc:\users\flo\documents\yahpo_gym\yahpo_train\yahpo_train\metrics.py[0m(16)[0;36maccumulate[1;34m()[0m
[1;32m     14 [1;33m        [1;32mif[0m [0mtorch[0m[1;33m.[0m[0many[0m[1;33m([0m[0mtorch[0m[1;33m.[0m[0misnan[0m[1;33m([0m[0mlearn[0m[1;33m.[0m[0mtfpred[0m[1;33m)[0m[1;33m)[0m[1;33m:[0m[1;33m[0m[1;33m[0m[0m
[0m[1;32m     15 [1;33m            [1;32mimport[0m [0mpdb[0m[1;33m;[0m [0mpdb[0m[1;33m.[0m[0mset_trace[0m[1;33m([0m[1;33m)[0m[1;33m[0m[1;33m[0m[0m
[0m[1;32m---> 16 [1;33m        [0mself[0m[1;33m.[0m[0mtotal[0m [1;33m+=[0m [0mlearn[0m[1;33m.[0m[0mto_detach[0m[1;33m([0m[0mself[0m[1;33m.[0m[0mfunc[0m[1;33m([0m[1;33m*[0m[0mlearn[0m[1;33m.[0m[0mtfyb[0m[1;33m,[0m [0mlearn[0m[1;33m.[0m[0mtfpred[0m[1;33m)[0m[1;33m)[0m[1;33m*[0m[0mbs[0m[1;33m[0m[1;33m[0m[0m
[0m[1;32m     17 [1;33m        [0mself[0m[1;33m.[0m[0mcount[0m [1;33m+=[0m [0mbs[0m[1;33m[0m

[33m[W 2021-08-12 14:22:04,107][0m Trial 1 failed because of the following error: BdbQuit()
Traceback (most recent call last):
  File "C:\Users\flo\AppData\Local\r-miniconda\envs\yahpo\lib\site-packages\optuna\study\_optimize.py", line 213, in _run_trial
    value_or_values = func(trial)
  File "<ipython-input-1-4fbd0e67cc29>", line 32, in objective
    l.fit_flat_cos(5, lr)
  File "C:\Users\flo\AppData\Local\r-miniconda\envs\yahpo\lib\site-packages\fastai\callback\schedule.py", line 136, in fit_flat_cos
    self.fit(n_epoch, cbs=ParamScheduler(scheds)+L(cbs), reset_opt=reset_opt, wd=wd)
  File "C:\Users\flo\AppData\Local\r-miniconda\envs\yahpo\lib\site-packages\fastai\learner.py", line 221, in fit
    self._with_events(self._do_fit, 'fit', CancelFitException, self._end_cleanup)
  File "C:\Users\flo\AppData\Local\r-miniconda\envs\yahpo\lib\site-packages\fastai\learner.py", line 163, in _with_events
    try: self(f'before_{event_type}');  f()
  File "C:\Users\flo\AppData\Local\r-minic

BdbQuit: 