In [2]:
from fastai.tabular.all import *
from sklearn.model_selection import KFold, GroupKFold
from optuna.integration import FastAIPruningCallback
import optuna

In [3]:
STOCK_COUNT = 112
FEATURE_COUNT = 144#20


In [4]:
def fill_missing(train_df):
    all_times = train_df.time_id.unique()
    all_stocks = train_df.stock_id.unique()
    filled_df = train_df.copy()
    filled_df=filled_df.set_index(['time_id', 'stock_id'])
    new_index = pd.MultiIndex.from_product([all_times, all_stocks], names = ['time_id', 'stock_id'])
    filled_df = filled_df.reindex(new_index).reset_index()
    filled_df = filled_df.fillna(0)
    return filled_df

In [5]:
class MaskTfm(ItemTransform):
    do_tranform=0
    def mask(self, x, indices):
        x[torch.tensor(indices, device=x.device)] = 0
        return x
    
    def __call__(self, b, split_idx=None, **kwargs):
        self.do_transform = (split_idx == 0)
        return super().__call__(b, split_idx=split_idx, **kwargs)
    

    def encodes(self, x):
        if not self.do_transform: return x
        n = len(x[0])
        indices = np.random.choice(np.array(range(n)), n//10, replace=False)
        x = [self.mask(y, indices) for y in x]
        
        return x

class MyDataLoader(TabDataLoader):
    def __init__(self, dataset, bs=16, shuffle=False, after_batch=None, num_workers=0, **kwargs):
        if after_batch is None: after_batch = L(TransformBlock().batch_tfms)+ReadTabBatch(dataset) + [MaskTfm()]
        super().__init__(dataset, bs=bs, shuffle=shuffle, after_batch=after_batch, num_workers=num_workers, **kwargs)

    def shuffle_fn(self, idxs):
        idxs = np.array(idxs).reshape(-1,112)
        np.random.shuffle(idxs)
        return idxs.reshape(-1).tolist()

def get_dls(train_df, bs, trn_idx, val_idx):
    cont_nn,cat_nn = cont_cat_split(train_df, max_card=9000, dep_var='target')
    cat_nn=[x for x in cat_nn if not x in ['row_id', 'time_id']]
    
    procs_nn = [Categorify, Normalize]
    to_nn = TabularPandas(train_df, procs_nn, cat_nn, cont_nn, splits=[list(trn_idx), list(val_idx)], y_names='target')
    dls = to_nn.dataloaders(bs=112*100, shuffle=True, dl_type = MyDataLoader)
    dls.train_ds.split_idx=0
    dls.valid_ds.split_idx=1
    return dls

In [6]:
train_df = pd.read_csv('train_with_features_NO_ST.csv')

In [7]:
#train_df = pd.read_feather('train_24cols.feather')
train_df = fill_missing(train_df)
trn_idx0, val_idx0 = first(GroupKFold().split(train_df, groups = train_df.time_id))

dls0 = get_dls(train_df, 100, trn_idx0, val_idx0)

In [8]:
class TimeEncoding(nn.Module):
    def __init__(self, inp_size, bottleneck, p, multiplier):
        super().__init__()
        self.multiplier  = multiplier#nn.Parameter(torch.tensor(multiplier)) 
        self.initial_layers = LinBnDrop(inp_size, bottleneck, act=nn.ReLU(True), p=p, bn=False)
        
        self.concat_layers = nn.Sequential(
            nn.BatchNorm1d(bottleneck * STOCK_COUNT),
            nn.Linear(bottleneck * STOCK_COUNT, inp_size),
            nn.Tanh()
        )
        
    def forward(self, x):
        y = self.initial_layers(x)
        times = y.shape[0] // STOCK_COUNT
        y = y.view(times, -1)
        y = self.concat_layers(y)
   
        y = y.view(times,1,-1).expand(times,STOCK_COUNT,-1).contiguous().view(times*STOCK_COUNT, -1)
        
        return x + y * self.multiplier

class BN(nn.Module):
    def __init__(self, features):
        super().__init__()
        self.num_features = features
        self.bn = nn.BatchNorm1d(STOCK_COUNT * self.num_features)
    def forward(self, x):
        sh = x.shape
        x = x.view(-1, STOCK_COUNT * self.num_features)
        x = self.bn(x)
        return x.view(*sh)
    
class ParallelModel(nn.Module):
    def __init__(self, inp_size, emb_size, lin_sizes, ps, bottleneck, time_ps, multiplier ):
        super().__init__()
        
        self.stock_emb = nn.Parameter(torch.empty(STOCK_COUNT, emb_size))
        nn.init.normal_(self.stock_emb)
        
        lin_sizes = [inp_size+emb_size] + lin_sizes
        layers = []
        for n_in, n_out, p, time_p in zip(lin_sizes, lin_sizes[1:], ps, time_ps):
            layers.append(nn.Linear(n_in, n_out))
            layers.append(BN(n_out ))
            if p: layers.append(nn.Dropout(p))
            
            layers.append(nn.ReLU(True))
            
            layers.append(TimeEncoding(n_out, bottleneck, time_p, multiplier))
        layers.append(LinBnDrop(lin_sizes[-1], 1, bn=False))
        layers.append(SigmoidRange(0, .1))
        self.layers = nn.Sequential(*layers)
    
    
    def forward(self, x_cat, x_cont):
        times = x_cat.shape[0] // STOCK_COUNT
        s_e = self.stock_emb.repeat(times, 1)
        x = torch.cat([x_cont, s_e], dim=1)
        for l in self.layers.children():
            #print(x.shape, x.mean(), x.std())
            x = l(x)
        return x#self.layers(x)

In [9]:
def rmspe(preds, targs):
    mask = targs != 0
    targs, preds = torch.masked_select(targs, mask), torch.masked_select(preds, mask)
    x = (targs-preds)/targs
    res= (x**2).mean().sqrt()
    if torch.isnan(res): 
        print(targs)
        print(preds)
        raise Exception('fck loss is nan')
    return res

In [10]:
def train(trial, dls, save_as=None):
    inp_size = FEATURE_COUNT
    emb_size = trial.suggest_int('emb_size', 3, 30)
    max_sizes = [2000, 1000, 500]
    lin_sizes = [trial.suggest_int(f'lin_size{i}', 10, ms) for i, ms in enumerate(max_sizes)]
    ps = [0]+[trial.suggest_float(f'p{i}', 0, .8) for i in range(1,3)]
    
    bottleneck = trial.suggest_int('bottleneck', 5, 100)
    time_ps = [trial.suggest_float(f'time_p{i}', 0, .5) for i in range(3)]
    multiplier = trial.suggest_float('multiplier', .01, .5)
    lr = float(trial.suggest_float('lr', 1e-3, 1e-2))
    
    model = ParallelModel(inp_size, emb_size, lin_sizes, ps, bottleneck, time_ps, multiplier)
    #bx1, bx2, by = dls.one_batch()
    
    learn = Learner(dls,model = model, loss_func=rmspe, metrics=AccumMetric(rmspe), opt_func=ranger,
        cbs = FastAIPruningCallback(trial, 'rmspe')).to_fp16()
    # with learn.no_bar():
    #     with learn.no_logging():    
    learn.fit_flat_cos(50, lr)
    if save_as:
        learn.save(save_as)
    last5 = L(learn.recorder.values).itemgot(2)[-5:]
    return np.mean(last5)

def train_cross_valid(trial, dlss, save_as=None):
    res = 0
    for idx, dls in enumerate(dlss):
        v = train(trial, dls, save_as + str(idx) if save_as else None)
        print(f'fold {idx}: {v}')
        res +=v;
    return res/5

In [11]:
study = optuna.load_study('parallel_no_st2','sqlite:///optuna.db' )

In [None]:
# study = optuna.create_study(direction="minimize", study_name = 'parallel_no_st', storage='sqlite:///optuna.db', load_if_exists=True, pruner=optuna.pruners.NopPruner(), sampler=None)
# study.optimize(functools.partial(train, dls=dls))

In [12]:
trials = [t for t in study.trials if t.value is not None]

In [13]:
trials =sorted(trials,key = lambda x: x.value)

In [14]:
dlss = [get_dls(train_df,100, trn_idx, val_idx) for trn_idx, val_idx in GroupKFold().split(train_df, groups = train_df.time_id)]

In [15]:
best = study.best_trial

In [16]:
best.params

{'bottleneck': 10,
 'emb_size': 5,
 'lin_size0': 446,
 'lin_size1': 477,
 'lin_size2': 434,
 'lr': 0.006451237192100813,
 'multiplier': 0.12849162558259716,
 'p1': 0.6252926917957898,
 'p2': 0.02656541478851094,
 'time_p0': 0.45265688727038544,
 'time_p1': 0.02835514720594759,
 'time_p2': 0.21139460201890878}

In [25]:
# my_trial = optuna.create_trial(value=42, params=my_params, distributions=best.distributions)

# train_cross_valid(my_trial, dlss)

In [None]:
for i in range(10):
    trial = trials[i]
    r = train_cross_valid(trial, dlss)
    print('trial', i,':',r)

epoch,train_loss,valid_loss,rmspe,time
0,3.830726,3.38974,3.39314,00:01
1,1.708331,2.188046,2.19702,00:00
2,1.046502,0.756532,0.757254,00:00
3,0.723091,0.630178,0.630705,00:00
4,0.538146,0.482312,0.482688,00:00
5,0.443042,0.460234,0.461542,00:00
6,0.384866,0.351852,0.352232,00:00
7,0.337456,0.311869,0.313243,00:00
8,0.309291,0.268633,0.269134,00:00
9,0.293854,0.234379,0.234426,00:00


fold 0: 0.2134610116481781


epoch,train_loss,valid_loss,rmspe,time
0,3.755021,2.770694,2.774179,00:00
1,1.692757,1.858882,1.866384,00:00
2,1.014256,1.115581,1.121314,00:00
3,0.715814,0.790698,0.794458,00:00
4,0.548291,0.565959,0.568125,00:00
5,0.449375,0.435739,0.435926,00:00
6,0.392237,0.340529,0.341839,00:00
7,0.355597,0.432239,0.432906,00:00
8,0.336967,0.28599,0.286502,00:00
9,0.313269,0.287177,0.287278,00:00


fold 1: 0.21612372398376464


epoch,train_loss,valid_loss,rmspe,time
0,3.637094,2.54869,2.550514,00:00
1,1.667799,1.662484,1.672782,00:00
2,1.022148,2.022485,2.068091,00:00
3,0.706548,1.182131,1.250339,00:00
4,0.534368,0.988939,1.06086,00:00
5,0.4642,0.543143,0.549194,00:00
6,0.401239,0.436284,0.437945,00:00
7,0.364758,0.365745,0.368335,00:00
8,0.333554,0.396239,0.397723,00:00
9,0.308899,0.28412,0.285711,00:00


fold 2: 0.21918282210826873


epoch,train_loss,valid_loss,rmspe,time
0,4.167779,2.951963,2.953215,00:00
1,1.813349,1.350996,1.353384,00:00
2,1.057302,1.685941,1.697217,00:00
3,0.714992,0.561894,0.562516,00:00
4,0.530336,0.661941,0.664235,00:00
5,0.425604,0.518899,0.521306,00:00
6,0.366111,0.31529,0.315466,00:00
7,0.336019,0.400675,0.401696,00:00
8,0.314793,0.273982,0.274206,00:00
9,0.299656,0.28875,0.28888,00:00


fold 3: 0.21495648920536042


epoch,train_loss,valid_loss,rmspe,time
0,3.367785,2.336866,2.339811,00:00
1,1.5361,1.715944,1.724469,00:00
2,0.936715,0.940078,0.947508,00:00
3,0.651596,0.674625,0.682147,00:00
4,0.507052,0.642757,0.647584,00:00
5,0.42709,0.403375,0.404805,00:00
6,0.374426,0.405833,0.40631,00:00
7,0.335665,0.372722,0.373054,00:00
8,0.315352,0.45809,0.459369,00:00
9,0.302478,0.28171,0.28213,00:00


fold 4: 0.21493141651153563
trial 0 : 0.2157310926914215


epoch,train_loss,valid_loss,rmspe,time
0,4.314839,2.66197,2.662731,00:00
1,1.843563,1.426043,1.43011,00:00
2,1.096597,0.8542,0.856444,00:00
3,0.719867,0.367348,0.367869,00:00
4,0.532112,0.385206,0.385393,00:00
5,0.436047,0.397018,0.400019,00:00
6,0.376292,0.357081,0.357535,00:00
7,0.331563,0.262782,0.263255,00:00
8,0.308612,0.382136,0.382995,00:00
9,0.302036,0.302299,0.302407,00:00
