In [1]:
from fastai.tabular.all import *
from sklearn.model_selection import KFold, GroupKFold
from optuna.integration import FastAIPruningCallback
import optuna

In [2]:
STOCK_COUNT = 112
FEATURE_COUNT = 240#20


In [3]:
def fill_missing(train_df):
    all_times = train_df.time_id.unique()
    all_stocks = train_df.stock_id.unique()
    filled_df = train_df.copy()
    filled_df=filled_df.set_index(['time_id', 'stock_id'])
    new_index = pd.MultiIndex.from_product([all_times, all_stocks], names = ['time_id', 'stock_id'])
    filled_df = filled_df.reindex(new_index).reset_index()
    filled_df = filled_df.fillna(0)
    return filled_df

In [4]:
class MaskTfm(ItemTransform):
    do_tranform=0
    def mask(self, x, indices):
        x[torch.tensor(indices, device=x.device)] = 0
        return x
    
    def __call__(self, b, split_idx=None, **kwargs):
        self.do_transform = (split_idx == 0)
        return super().__call__(b, split_idx=split_idx, **kwargs)
    

    def encodes(self, x):
        if not self.do_transform: return x
        n = len(x[0])
        indices = np.random.choice(np.array(range(n)), n//10, replace=False)
        x = [self.mask(y, indices) for y in x]
        
        return x

class MyDataLoader(TabDataLoader):

        
    def shuffle_fn(self, idxs):
        idxs = np.array(idxs).reshape(-1,112)
        np.random.shuffle(idxs)
        return idxs.reshape(-1).tolist()
    
    def __init__(self, dataset, bs=16, shuffle=False, after_batch=None, num_workers=0, **kwargs):
        if after_batch is None: after_batch = L(TransformBlock().batch_tfms)+ReadTabBatch(dataset) + [MaskTfm()]
        super().__init__(dataset, bs=bs, shuffle=shuffle, after_batch=after_batch, num_workers=num_workers, **kwargs)

    



def get_dls(train_df, bs, trn_idx, val_idx):
    cont_nn,cat_nn = cont_cat_split(train_df, max_card=9000, dep_var='target')
    cat_nn=[x for x in cat_nn if not x in ['row_id', 'time_id']]
    
    procs_nn = [Categorify, Normalize]
    to_nn = TabularPandas(train_df, procs_nn, cat_nn, cont_nn, splits=[list(trn_idx), list(val_idx)], y_names='target')
    dls = to_nn.dataloaders(bs=112*100, shuffle=True, dl_type = MyDataLoader)
    dls.train_ds.split_idx=0
    dls.valid_ds.split_idx=1
    return dls

In [5]:
train_df = pd.read_csv('train_with_features.csv')

In [6]:
#train_df = pd.read_feather('train_24cols.feather')
train_df = fill_missing(train_df)
trn_idx0, val_idx0 = first(GroupKFold().split(train_df, groups = train_df.time_id))

In [7]:
dls = get_dls(train_df, 100, trn_idx0, val_idx0)

In [19]:
class TimeEncoding(nn.Module):
    def __init__(self, inp_size, bottleneck, p, multiplier):
        super().__init__()
        self.multiplier  = multiplier#nn.Parameter(torch.tensor(multiplier)) 
        self.initial_layers = LinBnDrop(inp_size, bottleneck, act=nn.ReLU(True), p=p, bn=False)
        
        self.concat_layers = nn.Sequential(
            nn.Linear(bottleneck * STOCK_COUNT, inp_size),
            nn.Tanh()
        )
        
    def forward(self, x):
        y = self.initial_layers(x)
        times = y.shape[0] // STOCK_COUNT
        y = y.view(times, -1)
        y = self.concat_layers(y)
   
        y = y.view(times,1,-1).expand(times,STOCK_COUNT,-1).contiguous().view(times*STOCK_COUNT, -1)
        
        return x + y * self.multiplier

class BN(nn.Module):
    def __init__(self, features):
        super().__init__()
        self.num_features = features
        self.bn = nn.BatchNorm1d(STOCK_COUNT * self.num_features)
    def forward(self, x):
        sh = x.shape
        x = x.view(-1, STOCK_COUNT * self.num_features)
        x = self.bn(x)
        return x.view(*sh)
    
class ParallelModel(nn.Module):
    def __init__(self, inp_size, emb_size, lin_sizes, ps, bottleneck, time_ps, multiplier ):
        super().__init__()
        
        self.stock_emb = nn.Parameter(torch.empty(STOCK_COUNT, emb_size))
        nn.init.normal_(self.stock_emb)
        
        lin_sizes = [inp_size+emb_size] + lin_sizes
        layers = []
        for n_in, n_out, p, time_p in zip(lin_sizes, lin_sizes[1:], ps, time_ps):
            layers.append(BN(n_in ))
            if p: layers.append(nn.Dropout(p))
            layers.append(nn.Linear(n_in, n_out))
            layers.append(nn.ReLU(True))
            
            layers.append(TimeEncoding(n_out, bottleneck, time_p, multiplier))
        layers.append(LinBnDrop(lin_sizes[-1], 1, bn=False))
        layers.append(SigmoidRange(0, .1))
        self.layers = nn.Sequential(*layers)
    
    
    def forward(self, x_cat, x_cont):
        times = x_cat.shape[0] // STOCK_COUNT
        s_e = self.stock_emb.repeat(times, 1)
        x = torch.cat([x_cont, s_e], dim=1)
        mx = x.max()
        for l in self.layers.children():
            #print(x.shape, x.mean(), x.std())
            x = l(x)
            mx = x.max()
        if mx > 10:print(mx)
        return x#self.layers(x)

In [9]:

def rmspe(preds, targs):
    mask = targs != 0
    targs, preds = torch.masked_select(targs, mask), torch.masked_select(preds, mask)
    x = (targs-preds)/targs
    res = (x**2).mean().sqrt()
    if torch.isnan(res): 
        print(targs)
        print(preds)
        raise Exception('fck loss is nan')
    return res

In [10]:
def train(trial, dls, save_as = None):
    inp_size = FEATURE_COUNT
    emb_size = trial.suggest_int('emb_size', 3, 30)
    max_sizes = [2000, 1000, 500]
    lin_sizes = [trial.suggest_int(f'lin_size{i}', 10, ms) for i, ms in enumerate(max_sizes)]
    ps = [0]+[trial.suggest_float(f'p{i}', 0, .8) for i in range(1,3)]
    
    bottleneck = trial.suggest_int('bottleneck', 5, 100)
    time_ps = [trial.suggest_float(f'time_p{i}', 0, .5) for i in range(3)]
    multiplier = trial.suggest_float('multiplier', .01, .5)
    lr = float(trial.suggest_float('lr', 1e-3, 1e-2))
    
    model = ParallelModel(inp_size, emb_size, lin_sizes, ps, bottleneck, time_ps, multiplier)
    #bx1, bx2, by = dls.one_batch()
    
    learn = Learner(dls,model = model, loss_func=rmspe, metrics=AccumMetric(rmspe), opt_func=ranger).to_fp16()
    # with learn.no_bar():
    #     with learn.no_logging():    
    learn.fit_flat_cos(50, lr)
    #with learn.no_logging(): 
    if save_as:
        learn.save(save_as)
    return L(learn.recorder.values).itemgot(2)[-1]

In [11]:
def train_cross_valid(trial, dlss, save_as=None):
    res = 0
    for idx, dls in enumerate(dlss):
        v = train(trial, dls, save_as + str(idx) if save_as else None)
        print(f'fold {idx}: {v}')
        res +=v;
    return res/5

In [None]:
study = optuna.load_study('parallel4','sqlite:///optuna.db' )

In [None]:
study = optuna.create_study(direction="minimize", study_name = 'parallel_fixed_mult', storage='sqlite:///optuna.db', load_if_exists=True, pruner=optuna.pruners.NopPruner(), sampler=None)
study.optimize(functools.partial(train, dls=dls))

[32m[I 2021-09-20 21:17:06,409][0m Using an existing study with name 'parallel_fixed_mult' instead of creating a new one.[0m


epoch,train_loss,valid_loss,rmspe,time
0,19.878239,20.14529,20.158358,00:00
1,18.255226,16.074665,16.084969,00:00
2,15.925702,12.088929,12.096025,00:00
3,13.254589,8.724669,8.728976,00:00
4,10.691133,6.529365,6.532411,00:00
5,8.431808,4.772106,4.774443,00:00
6,6.540692,3.407113,3.408715,00:00
7,4.987409,2.347702,2.348614,00:00
8,3.728291,1.565269,1.565905,00:00
9,2.733438,0.997414,0.997912,00:00


[32m[I 2021-09-20 21:17:50,406][0m Trial 45 finished with value: 0.25317320227622986 and parameters: {'emb_size': 16, 'lin_size0': 827, 'lin_size1': 338, 'lin_size2': 356, 'p1': 0.35934634877626714, 'p2': 0.21500230229686917, 'bottleneck': 60, 'time_p0': 0.17240412789238566, 'time_p1': 0.2722105194052167, 'time_p2': 0.26333635094558483, 'multiplier': 0.4637178722284131, 'lr': 0.007690704414846785}. Best is trial 29 with value: 0.2165665328502655.[0m


epoch,train_loss,valid_loss,rmspe,time
0,20.839403,20.400064,20.412752,00:00
1,18.643988,14.968852,14.977967,00:00
2,15.449907,9.412355,9.417967,00:00
3,11.891209,5.279018,5.282197,00:00
4,8.630199,2.907056,2.909145,00:00
5,6.016031,1.54857,1.550181,00:00
6,4.070316,0.809631,0.813076,00:00
7,2.691597,0.424722,0.426809,00:00
8,1.745035,0.333833,0.33415,00:00
9,1.138657,0.382207,0.382249,00:00


[32m[I 2021-09-20 21:18:37,414][0m Trial 46 finished with value: 0.2678394019603729 and parameters: {'emb_size': 19, 'lin_size0': 673, 'lin_size1': 885, 'lin_size2': 453, 'p1': 0.7191184773219781, 'p2': 0.17295930968778012, 'bottleneck': 48, 'time_p0': 0.11390485055951897, 'time_p1': 0.16906075624925493, 'time_p2': 0.16939931523115287, 'multiplier': 0.40404629316755425, 'lr': 0.009513754112396525}. Best is trial 29 with value: 0.2165665328502655.[0m


epoch,train_loss,valid_loss,rmspe,time
0,22.809326,22.943882,22.957985,00:00
1,21.674484,20.065699,20.078032,00:00
2,19.927855,16.951544,16.96138,00:00
3,17.737959,13.662332,13.669275,00:00
4,15.280061,10.562997,10.567408,00:00
5,12.82242,8.192679,8.195509,00:00
6,10.535361,6.219559,6.221547,00:00
7,8.527883,4.809021,4.810439,00:00
8,6.827046,3.599938,3.601285,00:00
9,5.39408,2.630313,2.63132,00:00


[32m[I 2021-09-20 21:19:23,740][0m Trial 47 finished with value: 0.2724132835865021 and parameters: {'emb_size': 17, 'lin_size0': 1125, 'lin_size1': 558, 'lin_size2': 414, 'p1': 0.5921363259214671, 'p2': 0.12084170749749651, 'bottleneck': 31, 'time_p0': 0.04280118548692129, 'time_p1': 0.0050064473706220025, 'time_p2': 0.13079233155633627, 'multiplier': 0.4450155488616887, 'lr': 0.005044381389254371}. Best is trial 29 with value: 0.2165665328502655.[0m


epoch,train_loss,valid_loss,rmspe,time
0,19.718704,20.500626,20.512833,00:01
1,18.086752,16.195118,16.20372,00:01
2,15.676429,11.791406,11.796894,00:00
3,12.900021,8.18495,8.188195,00:01
4,10.195558,5.646598,5.648945,00:01
5,7.81638,3.775404,3.777476,00:01
6,5.877832,2.630944,2.63295,00:01
7,4.34467,1.722127,1.724125,00:01
8,3.144967,1.008434,1.010063,00:01
9,2.237191,0.59994,0.601281,00:01


[32m[I 2021-09-20 21:20:17,145][0m Trial 48 finished with value: 0.2630521357059479 and parameters: {'emb_size': 12, 'lin_size0': 1322, 'lin_size1': 945, 'lin_size2': 478, 'p1': 0.4345555359591736, 'p2': 0.42792129488901914, 'bottleneck': 52, 'time_p0': 0.2723476003695576, 'time_p1': 0.03506061724065335, 'time_p2': 0.23864726906188757, 'multiplier': 0.2895732805051772, 'lr': 0.006860001286141467}. Best is trial 29 with value: 0.2165665328502655.[0m


epoch,train_loss,valid_loss,rmspe,time
0,21.306438,21.806517,21.818867,00:01
1,19.804508,17.692661,17.701429,00:01
2,17.599781,14.029181,14.035334,00:00
3,15.005515,10.627084,10.631156,00:00
4,12.356673,7.925675,7.928259,00:00
5,9.895537,5.783321,5.785051,00:00
6,7.792716,4.300923,4.302387,00:01
7,6.037657,3.142462,3.143619,00:00
8,4.613082,2.253572,2.254498,00:00
9,3.471844,1.531634,1.532306,00:00


[32m[I 2021-09-20 21:21:07,127][0m Trial 49 finished with value: 0.24860641360282898 and parameters: {'emb_size': 21, 'lin_size0': 899, 'lin_size1': 739, 'lin_size2': 443, 'p1': 0.1843509412036595, 'p2': 0.23782447264043438, 'bottleneck': 97, 'time_p0': 0.13382416674276515, 'time_p1': 0.32249407645055816, 'time_p2': 0.07481871788386384, 'multiplier': 0.25369383726412037, 'lr': 0.005862847558864709}. Best is trial 29 with value: 0.2165665328502655.[0m


epoch,train_loss,valid_loss,rmspe,time
0,23.575188,22.583326,22.5965,00:00
1,22.54394,20.468033,20.479424,00:00
2,20.89352,17.471384,17.480482,00:00
3,18.777391,14.202074,14.208923,00:00
4,16.291752,10.931805,10.93626,00:00
5,13.699503,8.080691,8.083328,00:00
6,11.202935,5.900263,5.901887,00:00
7,8.941507,4.211949,4.213162,00:00
8,7.054009,3.07325,3.074394,00:00
9,5.505979,2.208753,2.210006,00:00


[32m[I 2021-09-20 21:21:52,511][0m Trial 50 finished with value: 0.3342914283275604 and parameters: {'emb_size': 25, 'lin_size0': 831, 'lin_size1': 998, 'lin_size2': 172, 'p1': 0.5161936617256797, 'p2': 0.47628937713627195, 'bottleneck': 60, 'time_p0': 0.1111345283019777, 'time_p1': 0.27886870778199113, 'time_p2': 0.4825650035006274, 'multiplier': 0.4877697095960559, 'lr': 0.004202863927135916}. Best is trial 29 with value: 0.2165665328502655.[0m


epoch,train_loss,valid_loss,rmspe,time
0,22.53953,21.747013,21.760855,00:00
1,21.302094,19.054951,19.067574,00:00
2,19.365877,15.515023,15.525336,00:00
3,16.897173,11.794358,11.801994,00:00
4,14.145802,8.478539,8.483878,00:00
5,11.40941,5.891302,5.895002,00:00
6,8.9188,4.035565,4.038164,00:00
7,6.810081,2.772004,2.77417,00:00
8,5.134624,1.909817,1.912771,00:00
9,3.823649,1.304908,1.312235,00:00


[32m[I 2021-09-20 21:22:36,987][0m Trial 51 finished with value: 0.2836240231990814 and parameters: {'emb_size': 28, 'lin_size0': 454, 'lin_size1': 945, 'lin_size2': 262, 'p1': 0.6574116862972329, 'p2': 0.3146119546831467, 'bottleneck': 47, 'time_p0': 0.09201201608392304, 'time_p1': 0.08231888606937007, 'time_p2': 0.4433184111409438, 'multiplier': 0.39130414971690003, 'lr': 0.0056781957485626386}. Best is trial 29 with value: 0.2165665328502655.[0m


epoch,train_loss,valid_loss,rmspe,time
0,23.258972,20.916361,20.93075,00:00
1,21.289572,16.657579,16.668776,00:00
2,18.549192,12.305298,12.312658,00:00
3,15.475727,8.991248,8.995938,00:00
4,12.445647,6.35391,6.357471,00:00
5,9.800851,4.573267,4.577315,00:00
6,7.591658,3.243666,3.247605,00:00
7,5.791998,2.117573,2.120838,00:00
8,4.352562,1.331925,1.334434,00:00
9,3.223237,0.779481,0.781605,00:00


[32m[I 2021-09-20 21:23:22,266][0m Trial 52 finished with value: 0.37783271074295044 and parameters: {'emb_size': 13, 'lin_size0': 744, 'lin_size1': 890, 'lin_size2': 308, 'p1': 0.09536407736270935, 'p2': 0.7156585098505288, 'bottleneck': 70, 'time_p0': 0.16396824850571362, 'time_p1': 0.3658716280457267, 'time_p2': 0.32811177679459547, 'multiplier': 0.42065983138777563, 'lr': 0.007510933512385047}. Best is trial 29 with value: 0.2165665328502655.[0m


epoch,train_loss,valid_loss,rmspe,time
0,21.79014,21.80792,21.819788,00:00
1,20.683893,18.89045,18.899755,00:00
2,18.909611,15.464463,15.471371,00:00
3,16.637886,11.884442,11.889147,00:00
4,14.050561,8.577096,8.579968,00:00
5,11.451515,6.044156,6.045952,00:00
6,9.047281,4.129146,4.130404,00:00
7,7.037058,2.989029,2.99017,00:00
8,5.398673,2.125164,2.126393,00:00
9,4.094447,1.445328,1.446343,00:00


[32m[I 2021-09-20 21:24:07,160][0m Trial 53 finished with value: 0.2691757082939148 and parameters: {'emb_size': 23, 'lin_size0': 1012, 'lin_size1': 477, 'lin_size2': 240, 'p1': 0.7419138466225085, 'p2': 0.013128984701536173, 'bottleneck': 40, 'time_p0': 0.02316632408366215, 'time_p1': 0.11956556347114164, 'time_p2': 0.4035722858046365, 'multiplier': 0.37163082047232077, 'lr': 0.006505675028260034}. Best is trial 29 with value: 0.2165665328502655.[0m


epoch,train_loss,valid_loss,rmspe,time
0,19.203896,19.829252,19.841318,00:00
1,18.410013,17.51992,17.530144,00:00
2,17.07762,14.831208,14.839536,00:00
3,15.352813,11.843651,11.849974,00:00
4,13.361501,8.973585,8.97814,00:00
5,11.238249,6.447828,6.450891,00:00
6,9.200076,4.519145,4.521229,00:00
7,7.364558,3.190921,3.192427,00:00
8,5.785406,2.205986,2.207276,00:00
9,4.510122,1.519836,1.521037,00:00


[32m[I 2021-09-20 21:24:51,164][0m Trial 54 finished with value: 0.309998482465744 and parameters: {'emb_size': 26, 'lin_size0': 885, 'lin_size1': 803, 'lin_size2': 96, 'p1': 0.7984102816245183, 'p2': 0.36280051862729434, 'bottleneck': 52, 'time_p0': 0.061597800967191044, 'time_p1': 0.3671938557841434, 'time_p2': 0.4307138609048356, 'multiplier': 0.4338207043020119, 'lr': 0.0037838644845701293}. Best is trial 29 with value: 0.2165665328502655.[0m


epoch,train_loss,valid_loss,rmspe,time
0,19.362846,20.931213,20.942129,00:00
1,18.884644,19.549639,19.55929,00:00
2,18.218353,18.225933,18.234667,00:00
3,17.382048,16.693632,16.701183,00:00
4,16.343306,15.144014,15.150579,00:00
5,15.178044,13.53038,13.535854,00:00
6,13.939257,11.926387,11.930909,00:00
7,12.690105,10.516117,10.519927,00:00
8,11.474007,9.176434,9.17946,00:00
9,10.285743,7.931108,7.933548,00:00


[32m[I 2021-09-20 21:25:36,135][0m Trial 55 finished with value: 0.3822156488895416 and parameters: {'emb_size': 30, 'lin_size0': 581, 'lin_size1': 863, 'lin_size2': 291, 'p1': 0.458048155577863, 'p2': 0.5253616799456929, 'bottleneck': 61, 'time_p0': 0.18798302877665182, 'time_p1': 0.4026604971954169, 'time_p2': 0.3054270857838978, 'multiplier': 0.4708299438957329, 'lr': 0.0018184020857105678}. Best is trial 29 with value: 0.2165665328502655.[0m


epoch,train_loss,valid_loss,rmspe,time
0,20.571285,22.080591,22.093542,00:00
1,20.386475,21.279659,21.292183,00:00
2,20.038794,20.548223,20.560179,00:00
3,19.514912,19.72225,19.733572,00:00
4,18.930195,18.835651,18.846323,00:00
5,18.272673,17.907301,17.917187,00:00
6,17.509317,16.876741,16.8859,00:00
7,16.684719,15.891022,15.899531,00:00
8,15.826885,14.86725,14.874978,00:00
9,14.974992,13.871886,13.879042,00:00


[32m[I 2021-09-20 21:26:23,163][0m Trial 56 finished with value: 0.4632316827774048 and parameters: {'emb_size': 18, 'lin_size0': 1276, 'lin_size1': 930, 'lin_size2': 397, 'p1': 0.4900837127527627, 'p2': 0.29049810049273994, 'bottleneck': 26, 'time_p0': 0.21131075963045032, 'time_p1': 0.31261235624746164, 'time_p2': 0.20735678889280698, 'multiplier': 0.07282600271478731, 'lr': 0.001150168201774729}. Best is trial 29 with value: 0.2165665328502655.[0m


epoch,train_loss,valid_loss,rmspe,time
0,22.820364,22.838587,22.852852,00:00
1,21.874153,20.577509,20.590719,00:00
2,20.274975,17.601454,17.612844,00:00
3,18.21953,14.389405,14.398832,00:00
4,15.889342,11.271331,11.27902,00:00
5,13.424427,8.551717,8.557808,00:00
6,11.016035,6.254649,6.259499,00:00
7,8.848495,4.58092,4.584893,00:00
8,6.977038,3.321599,3.32494,00:00
9,5.452108,2.441018,2.443864,00:00


[32m[I 2021-09-20 21:27:14,778][0m Trial 57 finished with value: 0.2715389132499695 and parameters: {'emb_size': 20, 'lin_size0': 1485, 'lin_size1': 964, 'lin_size2': 369, 'p1': 0.6585121630517039, 'p2': 0.24091168422337902, 'bottleneck': 56, 'time_p0': 0.1282211172815797, 'time_p1': 0.4337256904918618, 'time_p2': 0.3832172589380621, 'multiplier': 0.16691761360436397, 'lr': 0.004961366378624931}. Best is trial 29 with value: 0.2165665328502655.[0m


epoch,train_loss,valid_loss,rmspe,time
0,22.009304,21.688379,21.700315,00:00
1,20.118233,17.373568,17.38151,00:00
2,17.248966,12.096671,12.100978,00:00
3,13.845164,7.562442,7.564411,00:00
4,10.503949,4.619419,4.620802,00:00
5,7.677067,2.80217,2.803628,00:00
6,5.479935,1.704244,1.705788,00:00
7,3.840217,1.003849,1.009574,00:00
8,2.652594,0.602486,0.624953,00:00
9,1.821402,0.418845,0.468841,00:00


[32m[I 2021-09-20 21:28:02,684][0m Trial 58 finished with value: 0.2645501494407654 and parameters: {'emb_size': 13, 'lin_size0': 1851, 'lin_size1': 732, 'lin_size2': 188, 'p1': 0.5907441887897013, 'p2': 0.6329806086021034, 'bottleneck': 16, 'time_p0': 0.09884350298058606, 'time_p1': 0.4660225676763554, 'time_p2': 0.045746231884200905, 'multiplier': 0.39526169240686276, 'lr': 0.009983608608709433}. Best is trial 29 with value: 0.2165665328502655.[0m


epoch,train_loss,valid_loss,rmspe,time
0,21.034758,20.725283,20.737177,00:00
1,19.266027,16.495857,16.504892,00:00
2,16.673658,12.069489,12.075942,00:00
3,13.66043,8.325815,8.330361,00:00
4,10.733887,5.869061,5.872566,00:00
5,8.184299,3.958754,3.9616,00:00
6,6.091196,2.523832,2.526009,00:00
7,4.424753,1.640634,1.642362,00:00
8,3.147445,1.011583,1.012895,00:00
9,2.184031,0.575358,0.576233,00:00


[32m[I 2021-09-20 21:28:47,376][0m Trial 59 finished with value: 0.2421894669532776 and parameters: {'emb_size': 11, 'lin_size0': 690, 'lin_size1': 664, 'lin_size2': 324, 'p1': 0.3330718169244043, 'p2': 0.1433948401485488, 'bottleneck': 37, 'time_p0': 0.2500492364466964, 'time_p1': 0.23509327304250577, 'time_p2': 0.09533759448502918, 'multiplier': 0.2176939614912403, 'lr': 0.008013556235117963}. Best is trial 29 with value: 0.2165665328502655.[0m


epoch,train_loss,valid_loss,rmspe,time
0,19.835304,19.816589,19.828522,00:01
1,18.116741,14.983673,14.99126,00:01
2,15.480129,10.667566,10.672028,00:00
3,12.532777,7.396301,7.399299,00:00
4,9.802831,5.138919,5.141248,00:00
5,7.518901,3.626578,3.628469,00:01
6,5.651479,2.441394,2.442759,00:00
7,4.151475,1.524528,1.526322,00:01
8,2.993428,0.969565,0.97089,00:01
9,2.123839,0.571788,0.572663,00:01


[32m[I 2021-09-20 21:29:39,372][0m Trial 60 finished with value: 0.25201669335365295 and parameters: {'emb_size': 7, 'lin_size0': 1979, 'lin_size1': 817, 'lin_size2': 90, 'p1': 0.13943617902383382, 'p2': 0.45633833138586505, 'bottleneck': 31, 'time_p0': 0.036753050825039116, 'time_p1': 0.3871676172627552, 'time_p2': 0.43449575069596846, 'multiplier': 0.12707698676283627, 'lr': 0.007188045816207476}. Best is trial 29 with value: 0.2165665328502655.[0m


epoch,train_loss,valid_loss,rmspe,time
0,23.977039,23.059855,23.07217,00:01
1,22.484955,20.020494,20.030115,00:01
2,20.296961,16.27426,16.281239,00:01
3,17.52644,12.278036,12.282391,00:00
4,14.552212,8.964683,8.967113,00:01
5,11.719074,6.476048,6.478084,00:01
6,9.305873,4.88971,4.891857,00:00
7,7.327226,3.648948,3.65105,00:01
8,5.715187,2.66057,2.662784,00:00
9,4.424444,1.933456,1.935465,00:01


[32m[I 2021-09-20 21:30:30,859][0m Trial 61 finished with value: 0.2702171504497528 and parameters: {'emb_size': 6, 'lin_size0': 1903, 'lin_size1': 762, 'lin_size2': 37, 'p1': 0.027492879585290453, 'p2': 0.6985832561272438, 'bottleneck': 44, 'time_p0': 0.06250059833693417, 'time_p1': 0.4393649463357375, 'time_p2': 0.4598717137306366, 'multiplier': 0.034038893098809794, 'lr': 0.006947412339976444}. Best is trial 29 with value: 0.2165665328502655.[0m


epoch,train_loss,valid_loss,rmspe,time
0,24.078104,21.972111,21.985374,00:00
1,22.075809,17.757524,17.767897,00:00
2,19.188778,13.210363,13.217951,00:00
3,15.818932,9.318147,9.323964,00:00
4,12.565232,6.493186,6.497771,00:00
5,9.709909,4.450469,4.453999,00:00
6,7.36908,3.037438,3.039894,00:00
7,5.516152,2.062227,2.064,00:00
8,4.060573,1.277718,1.279217,00:00
9,2.945995,0.785535,0.787406,00:00


[32m[I 2021-09-20 21:31:17,053][0m Trial 62 finished with value: 0.32688260078430176 and parameters: {'emb_size': 4, 'lin_size0': 983, 'lin_size1': 912, 'lin_size2': 484, 'p1': 0.09573724985365739, 'p2': 0.5674935135850591, 'bottleneck': 25, 'time_p0': 0.012724622777861266, 'time_p1': 0.45094562330382015, 'time_p2': 0.4957013078556345, 'multiplier': 0.06520788239647694, 'lr': 0.008705531973121338}. Best is trial 29 with value: 0.2165665328502655.[0m


epoch,train_loss,valid_loss,rmspe,time
0,25.346838,23.296949,23.310688,00:00
1,24.061539,20.913549,20.925735,00:00
2,21.99659,17.314028,17.323584,00:01
3,19.329067,13.583177,13.590782,00:00
4,16.338596,10.174198,10.179688,00:00
5,13.405506,7.641473,7.645399,00:00
6,10.710112,5.635963,5.638783,00:00
7,8.442627,4.188441,4.19056,00:00
8,6.578631,3.150747,3.152698,00:00
9,5.063357,2.238152,2.239584,00:00


[32m[I 2021-09-20 21:32:08,765][0m Trial 63 finished with value: 0.23963971436023712 and parameters: {'emb_size': 8, 'lin_size0': 1567, 'lin_size1': 859, 'lin_size2': 155, 'p1': 0.41142039139074305, 'p2': 0.17828410937176847, 'bottleneck': 67, 'time_p0': 0.14565402088873378, 'time_p1': 0.06567665353168159, 'time_p2': 0.4557258119658645, 'multiplier': 0.012822953893257022, 'lr': 0.006448536048539333}. Best is trial 29 with value: 0.2165665328502655.[0m


epoch,train_loss,valid_loss,rmspe,time
0,23.63982,21.306147,21.319874,00:00
1,21.53471,17.323538,17.336329,00:00
2,18.550116,12.826759,12.836513,00:00
3,15.194683,9.077552,9.084474,00:00
4,12.023356,6.544971,6.550508,00:00
5,9.364533,4.769939,4.774793,00:00
6,7.222773,3.549955,3.55437,00:00
7,5.51021,2.579037,2.582383,00:00
8,4.160512,1.853408,1.856128,00:00
9,3.107505,1.256181,1.257998,00:00


[32m[I 2021-09-20 21:32:56,838][0m Trial 64 finished with value: 0.23522697389125824 and parameters: {'emb_size': 25, 'lin_size0': 1730, 'lin_size1': 921, 'lin_size2': 13, 'p1': 0.23758205613297914, 'p2': 0.33062303425288586, 'bottleneck': 16, 'time_p0': 0.08948142126824823, 'time_p1': 0.34332668676029, 'time_p2': 0.412309755072905, 'multiplier': 0.32615593431642786, 'lr': 0.005983246912274402}. Best is trial 29 with value: 0.2165665328502655.[0m


epoch,train_loss,valid_loss,rmspe,time
0,22.218935,20.985792,20.997952,00:00
1,20.51178,17.316114,17.325611,00:00
2,17.774763,12.456568,12.462506,00:00
3,14.363826,8.297196,8.300386,00:00
4,11.077156,5.635948,5.637991,00:00
5,8.305234,3.868632,3.870281,00:00
6,6.131367,2.720305,2.721686,00:00
7,4.453885,1.73309,1.734155,00:00
8,3.18035,1.154371,1.155227,00:00
9,2.229755,0.725093,0.725895,00:00


[32m[I 2021-09-20 21:33:43,725][0m Trial 65 finished with value: 0.23099711537361145 and parameters: {'emb_size': 23, 'lin_size0': 1114, 'lin_size1': 969, 'lin_size2': 36, 'p1': 0.18992132592534025, 'p2': 0.2607414907400571, 'bottleneck': 40, 'time_p0': 0.17429846391725926, 'time_p1': 0.488497433525554, 'time_p2': 0.3541424517801241, 'multiplier': 0.14373780325628993, 'lr': 0.007379202889930904}. Best is trial 29 with value: 0.2165665328502655.[0m


epoch,train_loss,valid_loss,rmspe,time
0,18.780708,20.442921,20.456163,00:00
1,18.091103,18.158175,18.169937,00:00
2,16.960707,15.906333,15.916214,00:00
3,15.534282,13.670047,13.678346,00:00
4,13.956864,11.569945,11.576828,00:00
5,12.321887,9.778275,9.783921,00:00
6,10.686412,7.97436,7.978732,00:00
7,9.162226,6.560084,6.563512,00:00
8,7.808372,5.51598,5.518723,00:00
9,6.601364,4.533773,4.536195,00:00


[32m[I 2021-09-20 21:34:31,031][0m Trial 66 finished with value: 0.30533432960510254 and parameters: {'emb_size': 16, 'lin_size0': 1177, 'lin_size1': 837, 'lin_size2': 419, 'p1': 0.2866907529615898, 'p2': 0.3818277209775137, 'bottleneck': 29, 'time_p0': 0.02996412404565882, 'time_p1': 0.4194065456143467, 'time_p2': 0.4830386093285268, 'multiplier': 0.44900190789376015, 'lr': 0.0026712037238123165}. Best is trial 29 with value: 0.2165665328502655.[0m


epoch,train_loss,valid_loss,rmspe,time
0,25.573689,21.648575,21.661737,00:00
1,23.826593,17.812937,17.823883,00:00
2,21.197159,13.305284,13.313343,00:00
3,17.877975,9.087379,9.092745,00:00
4,14.373614,5.75266,5.756198,00:01
5,11.16175,3.641381,3.643894,00:00
6,8.405613,2.152505,2.155138,00:00
7,6.274446,1.291684,1.296909,00:00
8,4.61023,0.790736,0.806107,00:00
9,3.373658,0.51939,0.554894,00:00


[32m[I 2021-09-20 21:35:21,697][0m Trial 67 finished with value: 0.4441341459751129 and parameters: {'emb_size': 3, 'lin_size0': 1434, 'lin_size1': 873, 'lin_size2': 459, 'p1': 0.7682666182095848, 'p2': 0.7805629274941042, 'bottleneck': 51, 'time_p0': 0.05769258129899968, 'time_p1': 0.49648709553830006, 'time_p2': 0.13993255998723544, 'multiplier': 0.3646130772777791, 'lr': 0.009591651459937579}. Best is trial 29 with value: 0.2165665328502655.[0m


epoch,train_loss,valid_loss,rmspe,time
0,23.563665,23.32382,23.337648,00:00
1,22.00292,19.644571,19.656681,00:00
2,19.607014,15.169299,15.178997,00:00
3,16.590729,10.670786,10.67802,00:00
4,13.345427,6.92441,6.930401,00:00
5,10.239308,4.339926,4.344551,00:00
6,7.624036,2.815578,2.819466,00:00
7,5.567365,1.77106,1.774696,00:00
8,3.999683,1.114303,1.120651,00:00
9,2.841275,0.700968,0.718524,00:00


[32m[I 2021-09-20 21:36:09,956][0m Trial 68 finished with value: 0.28229624032974243 and parameters: {'emb_size': 21, 'lin_size0': 1034, 'lin_size1': 999, 'lin_size2': 374, 'p1': 0.6869467936386081, 'p2': 0.4215981181678515, 'bottleneck': 44, 'time_p0': 0.1244042944460651, 'time_p1': 0.3888040690758089, 'time_p2': 0.41927218280251105, 'multiplier': 0.0942686580460671, 'lr': 0.009119673089370598}. Best is trial 29 with value: 0.2165665328502655.[0m


epoch,train_loss,valid_loss,rmspe,time
0,22.544031,20.842684,20.85433,00:00
1,19.470615,14.570962,14.577313,00:00
2,15.588395,9.602162,9.605283,00:00
3,11.905502,6.275138,6.277641,00:00
4,8.813865,4.137307,4.139039,00:00
5,6.28394,2.463076,2.464658,00:00
6,4.336133,1.375065,1.376637,00:00
7,2.910421,0.740256,0.743449,00:00
8,1.908007,0.426296,0.431745,00:00
9,1.236368,0.304072,0.306914,00:00


[32m[I 2021-09-20 21:36:58,757][0m Trial 69 finished with value: 0.249508798122406 and parameters: {'emb_size': 14, 'lin_size0': 1608, 'lin_size1': 757, 'lin_size2': 126, 'p1': 0.03867447696991147, 'p2': 0.21599836519530877, 'bottleneck': 34, 'time_p0': 0.07457624509561464, 'time_p1': 0.1530173311229674, 'time_p2': 0.17239285964670986, 'multiplier': 0.4284300273005639, 'lr': 0.008021842086017387}. Best is trial 29 with value: 0.2165665328502655.[0m


epoch,train_loss,valid_loss,rmspe,time
0,21.309282,20.990568,21.003408,00:00
1,19.582909,17.007355,17.017729,00:00
2,16.903507,12.508971,12.516782,00:00
3,13.792915,8.788277,8.793698,00:00
4,10.788732,6.145353,6.149113,00:00
5,8.216824,4.23641,4.239372,00:00
6,6.109562,2.8487,2.850686,00:00
7,4.438314,1.872426,1.873722,00:00
8,3.166567,1.244431,1.245361,00:00
9,2.225968,0.802325,0.802968,00:00


[32m[I 2021-09-20 21:37:44,858][0m Trial 70 finished with value: 0.22928987443447113 and parameters: {'emb_size': 23, 'lin_size0': 1108, 'lin_size1': 936, 'lin_size2': 29, 'p1': 0.20296675911148818, 'p2': 0.26120106074909955, 'bottleneck': 38, 'time_p0': 0.23411922578253758, 'time_p1': 0.4679781919934001, 'time_p2': 0.3543650084912808, 'multiplier': 0.058298543741355124, 'lr': 0.007358575227602423}. Best is trial 29 with value: 0.2165665328502655.[0m


epoch,train_loss,valid_loss,rmspe,time
0,20.733707,19.816072,19.826887,00:00
1,18.890039,15.75553,15.762752,00:00
2,16.262896,11.546837,11.551209,00:00
3,13.265852,8.15892,8.161892,00:00
4,10.440565,5.738366,5.740596,00:00
5,8.033134,4.103488,4.105716,00:00
6,6.063903,2.85987,2.861841,00:00


In [14]:
trials = [t for t in study.trials if t.value is not None]

In [15]:
trials =sorted(trials,key = lambda x: x.value)

In [16]:
dlss = [get_dls(train_df,100, trn_idx, val_idx) for trn_idx, val_idx in GroupKFold().split(train_df, groups = train_df.time_id)]

In [17]:
best = study.best_trial

In [21]:
best.params

{'bottleneck': 41,
 'emb_size': 11,
 'lin_size0': 764,
 'lin_size1': 842,
 'lin_size2': 75,
 'lr': 0.006267668552837353,
 'multiplier': 0.18618690920965522,
 'p1': 0.7149990749997216,
 'p2': 0.29950526254816967,
 'time_p0': 0.34702106813252764,
 'time_p1': 0.07144546307698028,
 'time_p2': 0.12147426684839502}

In [25]:
# my_trial = optuna.create_trial(value=42, params=my_params, distributions=best.distributions)

# train_cross_valid(my_trial, dlss)

In [None]:
for i in range(10):
    trial = trials[i]
    r = train_cross_valid(trial, dlss)
    print('trial', i,':',r)