In [1]:
import time
from itertools import product

import pandas as pd
import numpy as np

from sklearn.preprocessing import StandardScaler
from sklearn.utils import shuffle

import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader


In [2]:
# Set seeds
torch.manual_seed(0)
np.random.seed(0)

In [3]:
synthetic_calls_path = '../data/binom_synthetic_calls.csv'
synthetic_puts_path = '../data/binom_synthetic_puts.csv'

In [4]:
def reduce_mem_usage(df):
    """ iterate through all the columns of a dataframe and modify the data type
        to reduce memory usage.        
    """    
    for col in df.columns:
        col_type = df[col].dtype
        
        if col_type != object:
            c_min = df[col].min()
            c_max = df[col].max()
            if str(col_type)[:3] == 'int':
                if c_min > np.iinfo(np.int8).min and c_max < np.iinfo(np.int8).max:
                    df[col] = df[col].astype(np.int8)
                elif c_min > np.iinfo(np.int16).min and c_max < np.iinfo(np.int16).max:
                    df[col] = df[col].astype(np.int16)
                elif c_min > np.iinfo(np.int32).min and c_max < np.iinfo(np.int32).max:
                    df[col] = df[col].astype(np.int32)
                elif c_min > np.iinfo(np.int64).min and c_max < np.iinfo(np.int64).max:
                    df[col] = df[col].astype(np.int64)  
            else:
                if c_min > np.finfo(np.float16).min and c_max < np.finfo(np.float16).max:
                    df[col] = df[col].astype(np.float16)
                elif c_min > np.finfo(np.float32).min and c_max < np.finfo(np.float32).max:
                    df[col] = df[col].astype(np.float32)
                else:
                    df[col] = df[col].astype(np.float64)
        else:
            df[col] = df[col].astype('category')
    
    return df

In [5]:
synthetic_calls = pd.read_csv(synthetic_calls_path, index_col=0)
synthetic_puts = pd.read_csv(synthetic_puts_path, index_col=0)

synthetic_calls = reduce_mem_usage(synthetic_calls)
synthetic_puts = reduce_mem_usage(synthetic_puts)

In [6]:
synthetic_options = pd.concat([synthetic_calls, synthetic_puts], axis=0)
synthetic_options = shuffle(synthetic_options, random_state=0)
synthetic_options = synthetic_options.reset_index()
synthetic_options = synthetic_options.drop('index', axis=1)

In [7]:
synthetic_options.head(5)

Unnamed: 0,Price,Strike,Type,Vol,Interest Rate,Time to Expiration,Option Price
0,100,97.0,C,0.799805,0.070007,0.600098,27.125
1,100,101.0,P,0.5,0.099976,0.600098,12.671875
2,100,90.0,P,0.300049,0.059998,0.899902,4.707031
3,100,104.0,P,0.899902,0.090027,0.899902,30.390625
4,100,144.0,C,1.0,0.099976,0.899902,28.078125


# Preprocessing

In [8]:
synthetic_options = pd.get_dummies(synthetic_options, prefix='', prefix_sep='')

In [9]:
input_sc = StandardScaler()
output_sc = StandardScaler()
input_data = input_sc.fit_transform(synthetic_options.drop('Option Price', axis=1))
output_data = output_sc.fit_transform(synthetic_options['Option Price'].values.reshape(-1, 1))

train_size = 0.9
last_train_idx = int(np.round(len(input_data) * train_size))

X_train = input_data[0:last_train_idx]
X_test = input_data[last_train_idx:]

y_train = output_data[0:last_train_idx]
y_test = output_data[last_train_idx:]

In [10]:
X_train = Variable(torch.Tensor(X_train))
X_test = Variable(torch.Tensor(X_test))

y_train = Variable(torch.Tensor(y_train))
y_test = Variable(torch.Tensor(y_test))

 # Model

In [11]:
CUDA = torch.cuda.is_available()
device = 'cuda:0' if CUDA else 'cpu'

In [12]:
class ResBlock(nn.Module):

  def __init__(self, module):
    super(ResBlock, self).__init__()
    self.module = module

  def forward(self, x):
    return self.module(x) + x

In [13]:
class HiddenLayer(nn.Module):

  def __init__(self, layer_size, act_fn):
      super(HiddenLayer, self).__init__()
      
      if act_fn == 'ReLU':
        self.layer = nn.Sequential(
          nn.Linear(layer_size, layer_size),
          nn.ReLU())
      elif act_fn == 'LeakyReLU':
        self.layer = nn.Sequential(
          nn.Linear(layer_size, layer_size),
          nn.LeakyReLU())
      elif act_fn == 'ELU':
        self.layer = nn.Sequential(
          nn.Linear(layer_size, layer_size),
          nn.ELU())
    
  def forward(self, x):
    return self.layer(x)

In [14]:
class Net(nn.Module):

  def __init__(self, input_size, output_size, hidden_size, num_layers, act_fn):
    super(Net, self).__init__()
    self.input_size = input_size
    self.output_size = output_size
    self.hidden_size = hidden_size

    if act_fn == 'ReLU':
      self.initial_layer = nn.Sequential(
          nn.Linear(self.input_size, self.hidden_size),
          nn.ReLU())
    elif act_fn == 'LeakyReLU':
      self.initial_layer = nn.Sequential(
          nn.Linear(self.input_size, self.hidden_size),
          nn.LeakyReLU())
    elif act_fn == 'ELU':
      self.initial_layer = nn.Sequential(
          nn.Linear(self.input_size, self.hidden_size),
          nn.ELU())

    self.hidden_layers_list = []

    for i in range(num_layers // 2):
      self.hidden_layers_list.append(
          ResBlock(
            nn.Sequential(
                HiddenLayer(self.hidden_size, act_fn),
                HiddenLayer(self.hidden_size, act_fn)
            )
        )
      )

    self.hidden_layers = nn.Sequential(*self.hidden_layers_list)

    self.net = nn.Sequential(
        self.initial_layer,
        self.hidden_layers,
        nn.Linear(self.hidden_size, self.output_size)
    )
  
  def forward(self, x):
    return self.net(x)

In [15]:
def init_weights(m, init_m: str):

  @torch.no_grad()
  def init_uniform(m):
    if isinstance(m, nn.Linear):
      torch.nn.init.uniform_(m.weight)
      m.bias.data.fill_(0.01)

  @torch.no_grad()
  def init_normal(m):
    if isinstance(m, nn.Linear):
      torch.nn.init.normal_(m.weight)
      m.bias.data.fill_(0.01)

  @torch.no_grad()
  def init_xuniform(m):
    if isinstance(m, nn.Linear):
      torch.nn.init.xavier_uniform_(m.weight)
      m.bias.data.fill_(0.01)

  @torch.no_grad()
  def init_xnormal(m):
    if isinstance(m, nn.Linear):
      torch.nn.init.xavier_normal_(m.weight)
      m.bias.data.fill_(0.01)

  if init_m == 'uniform':
    m.apply(init_uniform)
  elif init_m == 'normal':
    m.apply(init_normal)
  elif init_m == 'xaiver uniform':
    m.apply(init_xuniform)
  elif init_m == 'xavier normal':
    m.apply(init_xnormal)

## Hyperparameter options

In [16]:
hidden_size = [200, 400, 600]
n_layers = [4, 6, 8]
act_fun = ['ReLU', 'LeakyReLU', 'ELU']
init_methods = ['xavier uniform', 'xavier normal']
epochs = 25
n_folds = 5

In [17]:
cv_params = list(product(hidden_size,
                         n_layers,
                         act_fun,
                         init_methods))
n_cv_params = len(cv_params)
n_cv_params

54

In [18]:
sample_proportion = 0.6
sample_size = int(sample_proportion * n_cv_params)

cv_param_sample = np.random.choice(list(range(n_cv_params)),
                                     size=int(sample_size),
                                     replace=False)
cv_params_ = [cv_params[i] for i in cv_param_sample]
print('# CV parameters:', len(cv_params_))

# CV parameters: 32


# Training

In [19]:
input_size = 7
output_size = 1
batch_size = 1208
lr = 1e-4

loss_fn = nn.MSELoss()

In [20]:
class OptDataset(Dataset):

  def __init__(self, X, y):
    self.X = X
    self.y = y

  def __getitem__(self, idx):
    return self.X[idx], self.y[idx]

  def __len__(self):
    return len(self.X)

In [21]:
def evaluate(model, loss_fn, X_val, y_val):
  model.eval()
  with torch.no_grad():
    out = model(X_val)
    loss = loss_fn(out, y_val)
    return loss.item()

In [22]:
crossval_result = pd.DataFrame(columns=
                               ['hidden_size',
                                'n_layers',
                                'act_fun',
                                'init_methods',
                                'mean_val_result',
                                'std_val_result'])

In [23]:
def cross_validation(
  cv_params_,
  epochs,
  n_folds,
  batch_size,
  X,
  y,
  loss_fn
):
  crossval_result = pd.DataFrame(columns=
                               ['hidden_size',
                                'n_layers',
                                'act_fun',
                                'init_methods',
                                'mean_val_result',
                                'std_val_result'])

  for h_size, n_layers, act, init_m in cv_params_:
    model = Net(input_size, output_size, h_size, n_layers, act).to(device)
    init_weights(model, init_m)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    c_val_size = len(X) // n_folds

    validation_losses = []
    cv_res = {
        'hidden_size': h_size,
        'n_layers': n_layers,
        'act_fun': act,
        'init_methods': init_m
    }

    print('Model: ', cv_res)

    for i in range(n_folds):
      model.train()

      X_train = torch.cat((X[0:c_val_size*i], X[c_val_size*(i+1):]))
      y_train = torch.cat((y[0:c_val_size*i], y[c_val_size*(i+1):]))
      X_val = X[c_val_size*i:c_val_size*(i+1)]
      y_val = y[c_val_size*i:c_val_size*(i+1)]

      for epoch in range(epochs):

        for batch, batch_labels in DataLoader(OptDataset(
            X_train, 
            y_train), 
            batch_size=batch_size):
          out = model(batch.to(device))
          optimizer.zero_grad()

          loss = loss_fn(out, batch_labels.to(device))
          loss.backward()
          optimizer.step()

        validation_losses.append(evaluate(model, 
                                          loss_fn, 
                                          X_val.to(device), 
                                          y_val.to(device)))
        print('fold: ', i + 1 ,', epoch: ', epoch + 1, ', val loss: ', validation_losses[-1])

    validation_losses = np.array(validation_losses)
    cv_res['mean_val_result'] = validation_losses.mean()
    cv_res['std_val_result'] = validation_losses.std()
    print('Model results: ', cv_res, '\n')
    crossval_result = crossval_result.append(cv_res, ignore_index=True)


  return crossval_result

In [24]:
crossval_results = cross_validation(cv_params_, 
                                    epochs, 
                                    n_folds, 
                                    batch_size,
                                    X_train,
                                    y_train,
                                    loss_fn)

Model:  {'hidden_size': 600, 'n_layers': 8, 'act_fun': 'ELU', 'init_methods': 'xavier normal'}
fold:  1 , epoch:  1 , val loss:  0.004184442572295666
fold:  1 , epoch:  2 , val loss:  0.001227384665980935
fold:  1 , epoch:  3 , val loss:  0.0015790387988090515
fold:  1 , epoch:  4 , val loss:  0.00038043083623051643
fold:  1 , epoch:  5 , val loss:  0.0003624866367317736
fold:  1 , epoch:  6 , val loss:  0.00028280678088776767
fold:  1 , epoch:  7 , val loss:  0.0010647728340700269
fold:  1 , epoch:  8 , val loss:  0.00032244017347693443
fold:  1 , epoch:  9 , val loss:  0.00016348068311344832
fold:  1 , epoch:  10 , val loss:  0.0005539879784919322
fold:  1 , epoch:  11 , val loss:  0.00019334722310304642
fold:  1 , epoch:  12 , val loss:  0.0009476420818828046
fold:  1 , epoch:  13 , val loss:  0.0010139707010239363
fold:  1 , epoch:  14 , val loss:  0.0016273148357868195
fold:  1 , epoch:  15 , val loss:  0.00017711175314616412
fold:  1 , epoch:  16 , val loss:  0.000102713216620031

fold:  1 , epoch:  13 , val loss:  6.318851956166327e-05
fold:  1 , epoch:  14 , val loss:  2.0704970665974542e-05
fold:  1 , epoch:  15 , val loss:  3.3405063732061535e-05
fold:  1 , epoch:  16 , val loss:  7.820063183316961e-05
fold:  1 , epoch:  17 , val loss:  0.00025096317403949797
fold:  1 , epoch:  18 , val loss:  2.0367167962831445e-05
fold:  1 , epoch:  19 , val loss:  2.4869646949809976e-05
fold:  1 , epoch:  20 , val loss:  8.727824024390429e-05
fold:  1 , epoch:  21 , val loss:  2.6011586669483222e-05
fold:  1 , epoch:  22 , val loss:  0.00011127620382467285
fold:  1 , epoch:  23 , val loss:  2.6818608603207394e-05
fold:  1 , epoch:  24 , val loss:  1.7316519006271847e-05
fold:  1 , epoch:  25 , val loss:  5.472290649777278e-05
fold:  2 , epoch:  1 , val loss:  0.00010275805834680796
fold:  2 , epoch:  2 , val loss:  1.7212822058354504e-05
fold:  2 , epoch:  3 , val loss:  1.416673057974549e-05
fold:  2 , epoch:  4 , val loss:  1.8451924916007556e-05
fold:  2 , epoch:  5 , 

fold:  2 , epoch:  2 , val loss:  1.159082239610143e-05
fold:  2 , epoch:  3 , val loss:  9.692519597592764e-06
fold:  2 , epoch:  4 , val loss:  1.0515496796870138e-05
fold:  2 , epoch:  5 , val loss:  9.716713975649327e-06
fold:  2 , epoch:  6 , val loss:  6.135208877822151e-06
fold:  2 , epoch:  7 , val loss:  1.2552953194244765e-05
fold:  2 , epoch:  8 , val loss:  5.644363409373909e-05
fold:  2 , epoch:  9 , val loss:  2.062359453702811e-05
fold:  2 , epoch:  10 , val loss:  2.7827649319078773e-05
fold:  2 , epoch:  11 , val loss:  1.8319253285881132e-05
fold:  2 , epoch:  12 , val loss:  1.9400720702833496e-05
fold:  2 , epoch:  13 , val loss:  5.754057383455802e-06
fold:  2 , epoch:  14 , val loss:  7.725527211732697e-06
fold:  2 , epoch:  15 , val loss:  4.020398046122864e-05
fold:  2 , epoch:  16 , val loss:  9.869249879557174e-06
fold:  2 , epoch:  17 , val loss:  1.835109833336901e-05
fold:  2 , epoch:  18 , val loss:  6.290745204751147e-06
fold:  2 , epoch:  19 , val loss: 

fold:  2 , epoch:  16 , val loss:  1.6725447494536638e-05
fold:  2 , epoch:  17 , val loss:  1.1329660082992632e-05
fold:  2 , epoch:  18 , val loss:  3.343671414768323e-05
fold:  2 , epoch:  19 , val loss:  3.340365583426319e-05
fold:  2 , epoch:  20 , val loss:  8.479610187350772e-06
fold:  2 , epoch:  21 , val loss:  1.0786008715513162e-05
fold:  2 , epoch:  22 , val loss:  7.143617040128447e-06
fold:  2 , epoch:  23 , val loss:  2.3277048967429437e-05
fold:  2 , epoch:  24 , val loss:  1.9494404114084318e-05
fold:  2 , epoch:  25 , val loss:  1.0860469046747312e-05
fold:  3 , epoch:  1 , val loss:  6.672677045571618e-06
fold:  3 , epoch:  2 , val loss:  6.508530987048289e-06
fold:  3 , epoch:  3 , val loss:  9.940004019881599e-06
fold:  3 , epoch:  4 , val loss:  7.879998520365916e-06
fold:  3 , epoch:  5 , val loss:  8.648913535580505e-06
fold:  3 , epoch:  6 , val loss:  1.5004959095676895e-05
fold:  3 , epoch:  7 , val loss:  2.4938532078522258e-05
fold:  3 , epoch:  8 , val los

fold:  3 , epoch:  5 , val loss:  0.00021887135517317802
fold:  3 , epoch:  6 , val loss:  0.0001276827388210222
fold:  3 , epoch:  7 , val loss:  6.152045534690842e-05
fold:  3 , epoch:  8 , val loss:  7.441697380272672e-05
fold:  3 , epoch:  9 , val loss:  0.00010634979844326153
fold:  3 , epoch:  10 , val loss:  5.4751300922362134e-05
fold:  3 , epoch:  11 , val loss:  1.6675321603543125e-05
fold:  3 , epoch:  12 , val loss:  2.3750064428895712e-05
fold:  3 , epoch:  13 , val loss:  4.908817572868429e-05
fold:  3 , epoch:  14 , val loss:  2.0656982087530196e-05
fold:  3 , epoch:  15 , val loss:  4.756159978569485e-05
fold:  3 , epoch:  16 , val loss:  2.458396556903608e-05
fold:  3 , epoch:  17 , val loss:  3.485723209450953e-05
fold:  3 , epoch:  18 , val loss:  4.467489634407684e-05
fold:  3 , epoch:  19 , val loss:  3.063111216761172e-05
fold:  3 , epoch:  20 , val loss:  2.5302209905930795e-05
fold:  3 , epoch:  21 , val loss:  2.7519768991624005e-05
fold:  3 , epoch:  22 , val 

fold:  3 , epoch:  19 , val loss:  1.054645690601319e-05
fold:  3 , epoch:  20 , val loss:  9.741471330926288e-06
fold:  3 , epoch:  21 , val loss:  8.728995453566313e-06
fold:  3 , epoch:  22 , val loss:  8.512499334756285e-06
fold:  3 , epoch:  23 , val loss:  9.638598385208752e-06
fold:  3 , epoch:  24 , val loss:  3.2957603252725676e-05
fold:  3 , epoch:  25 , val loss:  9.17117631615838e-06
fold:  4 , epoch:  1 , val loss:  6.948313966859132e-06
fold:  4 , epoch:  2 , val loss:  8.255613465735223e-06
fold:  4 , epoch:  3 , val loss:  7.786987225699704e-06
fold:  4 , epoch:  4 , val loss:  8.5743222371093e-06
fold:  4 , epoch:  5 , val loss:  8.824002179608215e-06
fold:  4 , epoch:  6 , val loss:  9.222632797900587e-06
fold:  4 , epoch:  7 , val loss:  8.323160727741197e-06
fold:  4 , epoch:  8 , val loss:  7.308913609449519e-06
fold:  4 , epoch:  9 , val loss:  8.946257366915233e-06
fold:  4 , epoch:  10 , val loss:  3.214102616766468e-05
fold:  4 , epoch:  11 , val loss:  6.79136

fold:  4 , epoch:  8 , val loss:  4.962236744177062e-06
fold:  4 , epoch:  9 , val loss:  4.263534265191993e-06
fold:  4 , epoch:  10 , val loss:  3.1434187803824898e-06
fold:  4 , epoch:  11 , val loss:  2.0694886188721284e-05
fold:  4 , epoch:  12 , val loss:  1.12796269604587e-05
fold:  4 , epoch:  13 , val loss:  5.274873274174752e-06
fold:  4 , epoch:  14 , val loss:  3.880604708683677e-05
fold:  4 , epoch:  15 , val loss:  6.736616342095658e-06
fold:  4 , epoch:  16 , val loss:  6.484168352471897e-06
fold:  4 , epoch:  17 , val loss:  4.3173098674742505e-05
fold:  4 , epoch:  18 , val loss:  3.2605199521640316e-05
fold:  4 , epoch:  19 , val loss:  4.706342679128284e-06
fold:  4 , epoch:  20 , val loss:  6.751640739821596e-06
fold:  4 , epoch:  21 , val loss:  5.84919580433052e-05
fold:  4 , epoch:  22 , val loss:  3.232081144233234e-05
fold:  4 , epoch:  23 , val loss:  6.577080966962967e-06
fold:  4 , epoch:  24 , val loss:  5.165461516298819e-06
fold:  4 , epoch:  25 , val los

fold:  4 , epoch:  22 , val loss:  3.1508282063441584e-06
fold:  4 , epoch:  23 , val loss:  4.208840891806176e-06
fold:  4 , epoch:  24 , val loss:  4.543131581158377e-06
fold:  4 , epoch:  25 , val loss:  3.4350057831034064e-05
fold:  5 , epoch:  1 , val loss:  1.4781789104745258e-05
fold:  5 , epoch:  2 , val loss:  3.1155821034190012e-06
fold:  5 , epoch:  3 , val loss:  2.667716898940853e-06
fold:  5 , epoch:  4 , val loss:  1.7985674276133068e-05
fold:  5 , epoch:  5 , val loss:  4.739789437735453e-05
fold:  5 , epoch:  6 , val loss:  9.635716196498834e-06
fold:  5 , epoch:  7 , val loss:  3.1576203127769986e-06
fold:  5 , epoch:  8 , val loss:  4.037908183818217e-06
fold:  5 , epoch:  9 , val loss:  5.042232260166202e-06
fold:  5 , epoch:  10 , val loss:  2.558826508902712e-06
fold:  5 , epoch:  11 , val loss:  4.374978743726388e-06
fold:  5 , epoch:  12 , val loss:  2.5801960873650387e-05
fold:  5 , epoch:  13 , val loss:  1.4856833331577946e-05
fold:  5 , epoch:  14 , val loss

fold:  5 , epoch:  11 , val loss:  5.747675459133461e-06
fold:  5 , epoch:  12 , val loss:  1.1212934623472393e-05
fold:  5 , epoch:  13 , val loss:  2.371923255850561e-05
fold:  5 , epoch:  14 , val loss:  7.329771324293688e-05
fold:  5 , epoch:  15 , val loss:  6.556339940289035e-05
fold:  5 , epoch:  16 , val loss:  4.383475607028231e-05
fold:  5 , epoch:  17 , val loss:  1.5804809663677588e-05
fold:  5 , epoch:  18 , val loss:  2.445751306368038e-05
fold:  5 , epoch:  19 , val loss:  1.4445254237216432e-05
fold:  5 , epoch:  20 , val loss:  2.782216370178503e-06
fold:  5 , epoch:  21 , val loss:  9.455184226681013e-06
fold:  5 , epoch:  22 , val loss:  1.2977646292711142e-05
fold:  5 , epoch:  23 , val loss:  3.397380396563676e-06
fold:  5 , epoch:  24 , val loss:  4.1011358007381205e-06
fold:  5 , epoch:  25 , val loss:  7.584265858895378e-06
Model results:  {'hidden_size': 600, 'n_layers': 6, 'act_fun': 'LeakyReLU', 'init_methods': 'xavier normal', 'mean_val_result': 4.3718253771

fold:  5 , epoch:  25 , val loss:  3.580029442673549e-06
Model results:  {'hidden_size': 400, 'n_layers': 8, 'act_fun': 'ReLU', 'init_methods': 'xavier uniform', 'mean_val_result': 2.617532322256011e-05, 'std_val_result': 3.987268616473944e-05} 

Model:  {'hidden_size': 200, 'n_layers': 4, 'act_fun': 'ELU', 'init_methods': 'xavier uniform'}
fold:  1 , epoch:  1 , val loss:  0.00830057729035616
fold:  1 , epoch:  2 , val loss:  0.0032745273783802986
fold:  1 , epoch:  3 , val loss:  0.0017530195182189345
fold:  1 , epoch:  4 , val loss:  0.0010464494116604328
fold:  1 , epoch:  5 , val loss:  0.0007053331355564296
fold:  1 , epoch:  6 , val loss:  0.00047634379006922245
fold:  1 , epoch:  7 , val loss:  0.00041956763016059995
fold:  1 , epoch:  8 , val loss:  0.00028877155273221433
fold:  1 , epoch:  9 , val loss:  0.00024333709734492004
fold:  1 , epoch:  10 , val loss:  0.00025341869331896305
fold:  1 , epoch:  11 , val loss:  0.00022344828175846487
fold:  1 , epoch:  12 , val loss:  

fold:  1 , epoch:  9 , val loss:  0.00023952421906869859
fold:  1 , epoch:  10 , val loss:  0.0002454384812153876
fold:  1 , epoch:  11 , val loss:  0.00025011395337060094
fold:  1 , epoch:  12 , val loss:  0.00026483225519768894
fold:  1 , epoch:  13 , val loss:  0.00023413004237227142
fold:  1 , epoch:  14 , val loss:  0.00013625918654724956
fold:  1 , epoch:  15 , val loss:  0.0002724182268138975
fold:  1 , epoch:  16 , val loss:  0.0002368777641095221
fold:  1 , epoch:  17 , val loss:  0.0001664396550040692
fold:  1 , epoch:  18 , val loss:  0.00017473161278758198
fold:  1 , epoch:  19 , val loss:  0.0001421313063474372
fold:  1 , epoch:  20 , val loss:  0.00012182283535366878
fold:  1 , epoch:  21 , val loss:  0.00021040748106315732
fold:  1 , epoch:  22 , val loss:  0.00018481227743905038
fold:  1 , epoch:  23 , val loss:  0.00012539404269773513
fold:  1 , epoch:  24 , val loss:  0.00013991040759719908
fold:  1 , epoch:  25 , val loss:  0.0001597641676198691
fold:  2 , epoch:  1 

fold:  1 , epoch:  23 , val loss:  0.00013427971862256527
fold:  1 , epoch:  24 , val loss:  0.0001673876104177907
fold:  1 , epoch:  25 , val loss:  0.00029119194368831813
fold:  2 , epoch:  1 , val loss:  3.7555590097326785e-05
fold:  2 , epoch:  2 , val loss:  0.0001124343543779105
fold:  2 , epoch:  3 , val loss:  0.00017180520808324218
fold:  2 , epoch:  4 , val loss:  0.0001275026152143255
fold:  2 , epoch:  5 , val loss:  0.00013632816262543201
fold:  2 , epoch:  6 , val loss:  0.00012437274563126266
fold:  2 , epoch:  7 , val loss:  0.00011826420814031735
fold:  2 , epoch:  8 , val loss:  0.00023573494399897754
fold:  2 , epoch:  9 , val loss:  5.354562745196745e-05
fold:  2 , epoch:  10 , val loss:  2.5078443286474794e-05
fold:  2 , epoch:  11 , val loss:  8.980787970358506e-05
fold:  2 , epoch:  12 , val loss:  5.262682316242717e-05
fold:  2 , epoch:  13 , val loss:  0.0002149976062355563
fold:  2 , epoch:  14 , val loss:  3.9073147490853444e-05
fold:  2 , epoch:  15 , val lo

fold:  2 , epoch:  12 , val loss:  0.00017973179637920111
fold:  2 , epoch:  13 , val loss:  5.9207395679550245e-05
fold:  2 , epoch:  14 , val loss:  0.00015433531370945275
fold:  2 , epoch:  15 , val loss:  5.722710193367675e-05
fold:  2 , epoch:  16 , val loss:  5.2902356401318684e-05
fold:  2 , epoch:  17 , val loss:  4.740234726341441e-05
fold:  2 , epoch:  18 , val loss:  0.00018079709843732417
fold:  2 , epoch:  19 , val loss:  0.00010950999421766028
fold:  2 , epoch:  20 , val loss:  0.0001370843092445284
fold:  2 , epoch:  21 , val loss:  3.332293272251263e-05
fold:  2 , epoch:  22 , val loss:  6.670098809991032e-05
fold:  2 , epoch:  23 , val loss:  4.238612382323481e-05
fold:  2 , epoch:  24 , val loss:  6.765355647075921e-05
fold:  2 , epoch:  25 , val loss:  6.855765968794003e-05
fold:  3 , epoch:  1 , val loss:  6.259299698285758e-05
fold:  3 , epoch:  2 , val loss:  2.1465757527039386e-05
fold:  3 , epoch:  3 , val loss:  7.571392779937014e-05
fold:  3 , epoch:  4 , val 

fold:  3 , epoch:  1 , val loss:  1.0150904017791618e-05
fold:  3 , epoch:  2 , val loss:  4.059540151502006e-05
fold:  3 , epoch:  3 , val loss:  3.3699361665640026e-05
fold:  3 , epoch:  4 , val loss:  1.8704316971707158e-05
fold:  3 , epoch:  5 , val loss:  4.092376912012696e-05
fold:  3 , epoch:  6 , val loss:  5.6026415222731885e-06
fold:  3 , epoch:  7 , val loss:  1.354144205834018e-05
fold:  3 , epoch:  8 , val loss:  1.7387634215992875e-05
fold:  3 , epoch:  9 , val loss:  8.323827387357596e-06
fold:  3 , epoch:  10 , val loss:  6.188461156853009e-06
fold:  3 , epoch:  11 , val loss:  4.691140020440798e-06
fold:  3 , epoch:  12 , val loss:  1.3194739040045533e-05
fold:  3 , epoch:  13 , val loss:  1.9939030607929453e-05
fold:  3 , epoch:  14 , val loss:  2.373227289353963e-05
fold:  3 , epoch:  15 , val loss:  1.3277843208925333e-05
fold:  3 , epoch:  16 , val loss:  2.188275902881287e-05
fold:  3 , epoch:  17 , val loss:  1.804665407689754e-05
fold:  3 , epoch:  18 , val loss

fold:  3 , epoch:  15 , val loss:  6.663750536972657e-05
fold:  3 , epoch:  16 , val loss:  4.853371501667425e-05
fold:  3 , epoch:  17 , val loss:  1.0949336683552247e-05
fold:  3 , epoch:  18 , val loss:  5.626115125778597e-06
fold:  3 , epoch:  19 , val loss:  3.455574915278703e-05
fold:  3 , epoch:  20 , val loss:  1.3836936886946205e-05
fold:  3 , epoch:  21 , val loss:  8.279305802716408e-06
fold:  3 , epoch:  22 , val loss:  9.172630598186515e-06
fold:  3 , epoch:  23 , val loss:  2.270833465445321e-05
fold:  3 , epoch:  24 , val loss:  1.4623971765104216e-05
fold:  3 , epoch:  25 , val loss:  2.4132654289132915e-05
fold:  4 , epoch:  1 , val loss:  3.0328636057674885e-05
fold:  4 , epoch:  2 , val loss:  6.533285613841144e-06
fold:  4 , epoch:  3 , val loss:  5.041093572799582e-06
fold:  4 , epoch:  4 , val loss:  2.3534410502179526e-05
fold:  4 , epoch:  5 , val loss:  2.4929922801675275e-05
fold:  4 , epoch:  6 , val loss:  4.838784207095159e-06
fold:  4 , epoch:  7 , val los

fold:  4 , epoch:  4 , val loss:  1.3518754713004455e-05
fold:  4 , epoch:  5 , val loss:  5.77936725676409e-06
fold:  4 , epoch:  6 , val loss:  1.2299838999751955e-05
fold:  4 , epoch:  7 , val loss:  1.9339360733283684e-05
fold:  4 , epoch:  8 , val loss:  4.5496020902646706e-05
fold:  4 , epoch:  9 , val loss:  5.64782658329932e-06
fold:  4 , epoch:  10 , val loss:  4.127661213715328e-06
fold:  4 , epoch:  11 , val loss:  6.146889518277021e-06
fold:  4 , epoch:  12 , val loss:  1.4421576452150475e-05
fold:  4 , epoch:  13 , val loss:  1.1083021490776446e-05
fold:  4 , epoch:  14 , val loss:  4.89795820612926e-06
fold:  4 , epoch:  15 , val loss:  3.798671968979761e-05
fold:  4 , epoch:  16 , val loss:  3.99235614167992e-06
fold:  4 , epoch:  17 , val loss:  4.6999077312648296e-05
fold:  4 , epoch:  18 , val loss:  5.294113179843407e-06
fold:  4 , epoch:  19 , val loss:  5.8929990700562485e-06
fold:  4 , epoch:  20 , val loss:  4.442273166205268e-06
fold:  4 , epoch:  21 , val loss:

fold:  4 , epoch:  18 , val loss:  3.703924812725745e-05
fold:  4 , epoch:  19 , val loss:  1.6483792933286168e-05
fold:  4 , epoch:  20 , val loss:  7.866682608437259e-06
fold:  4 , epoch:  21 , val loss:  1.4514896065520588e-05
fold:  4 , epoch:  22 , val loss:  1.3749947356700432e-05
fold:  4 , epoch:  23 , val loss:  1.0378079423389863e-05
fold:  4 , epoch:  24 , val loss:  9.250067705579568e-06
fold:  4 , epoch:  25 , val loss:  1.1877657925651874e-05
fold:  5 , epoch:  1 , val loss:  1.735805744829122e-05
fold:  5 , epoch:  2 , val loss:  1.139953383244574e-05
fold:  5 , epoch:  3 , val loss:  9.003314517030958e-06
fold:  5 , epoch:  4 , val loss:  6.756386483175447e-06
fold:  5 , epoch:  5 , val loss:  7.0452329055115115e-06
fold:  5 , epoch:  6 , val loss:  8.076985068328213e-06
fold:  5 , epoch:  7 , val loss:  1.557153882458806e-05
fold:  5 , epoch:  8 , val loss:  1.2008764315396547e-05
fold:  5 , epoch:  9 , val loss:  9.800472980714403e-06
fold:  5 , epoch:  10 , val loss:

fold:  5 , epoch:  7 , val loss:  4.640969109459547e-06
fold:  5 , epoch:  8 , val loss:  4.595111022354104e-06
fold:  5 , epoch:  9 , val loss:  1.8835507944459096e-05
fold:  5 , epoch:  10 , val loss:  6.270097401284147e-06
fold:  5 , epoch:  11 , val loss:  6.2564618019678164e-06
fold:  5 , epoch:  12 , val loss:  2.16861444641836e-05
fold:  5 , epoch:  13 , val loss:  5.949753813183634e-06
fold:  5 , epoch:  14 , val loss:  4.944524334860034e-06
fold:  5 , epoch:  15 , val loss:  8.112140494631603e-05
fold:  5 , epoch:  16 , val loss:  7.065831596264616e-05
fold:  5 , epoch:  17 , val loss:  4.106108917767415e-06
fold:  5 , epoch:  18 , val loss:  6.158736596262315e-06
fold:  5 , epoch:  19 , val loss:  7.190014457592042e-06
fold:  5 , epoch:  20 , val loss:  1.047793102770811e-05
fold:  5 , epoch:  21 , val loss:  4.847199306823313e-05
fold:  5 , epoch:  22 , val loss:  5.028551640862133e-06
fold:  5 , epoch:  23 , val loss:  6.163247599033639e-05
fold:  5 , epoch:  24 , val loss:

fold:  5 , epoch:  21 , val loss:  1.5562098269583657e-05
fold:  5 , epoch:  22 , val loss:  8.302020432893187e-06
fold:  5 , epoch:  23 , val loss:  6.151723027869593e-06
fold:  5 , epoch:  24 , val loss:  4.9623657105257735e-05
fold:  5 , epoch:  25 , val loss:  2.834973747667391e-05
Model results:  {'hidden_size': 400, 'n_layers': 6, 'act_fun': 'LeakyReLU', 'init_methods': 'xavier normal', 'mean_val_result': 3.3759756281142474e-05, 'std_val_result': 5.0020378084945505e-05} 

Model:  {'hidden_size': 400, 'n_layers': 8, 'act_fun': 'ELU', 'init_methods': 'xavier normal'}
fold:  1 , epoch:  1 , val loss:  0.0015917542623355985
fold:  1 , epoch:  2 , val loss:  0.0005563193699344993
fold:  1 , epoch:  3 , val loss:  0.0003825286985374987
fold:  1 , epoch:  4 , val loss:  0.00027356354985386133
fold:  1 , epoch:  5 , val loss:  0.00018227938562631607
fold:  1 , epoch:  6 , val loss:  0.000424072437454015
fold:  1 , epoch:  7 , val loss:  0.003267091466113925
fold:  1 , epoch:  8 , val los

fold:  1 , epoch:  5 , val loss:  0.000120751210488379
fold:  1 , epoch:  6 , val loss:  5.0785969506250694e-05
fold:  1 , epoch:  7 , val loss:  4.1924522520275787e-05
fold:  1 , epoch:  8 , val loss:  2.522069371480029e-05
fold:  1 , epoch:  9 , val loss:  7.866224041208625e-05
fold:  1 , epoch:  10 , val loss:  3.6710946005769074e-05
fold:  1 , epoch:  11 , val loss:  0.0003915123816113919
fold:  1 , epoch:  12 , val loss:  2.9533979613916017e-05
fold:  1 , epoch:  13 , val loss:  8.494159555993974e-05
fold:  1 , epoch:  14 , val loss:  0.00011854091280838475
fold:  1 , epoch:  15 , val loss:  2.777911868179217e-05
fold:  1 , epoch:  16 , val loss:  5.338281698641367e-05
fold:  1 , epoch:  17 , val loss:  2.9934028134448454e-05
fold:  1 , epoch:  18 , val loss:  1.9110157154500484e-05
fold:  1 , epoch:  19 , val loss:  2.0902354663121514e-05
fold:  1 , epoch:  20 , val loss:  5.140591019880958e-05
fold:  1 , epoch:  21 , val loss:  4.419676406541839e-05
fold:  1 , epoch:  22 , val l

fold:  1 , epoch:  19 , val loss:  1.601023723196704e-05
fold:  1 , epoch:  20 , val loss:  1.4632341844844632e-05
fold:  1 , epoch:  21 , val loss:  1.7366759493597783e-05
fold:  1 , epoch:  22 , val loss:  2.783955460472498e-05
fold:  1 , epoch:  23 , val loss:  1.9506540411384776e-05
fold:  1 , epoch:  24 , val loss:  2.0179870261927135e-05
fold:  1 , epoch:  25 , val loss:  1.2945416528964415e-05
fold:  2 , epoch:  1 , val loss:  1.3896257769374643e-05
fold:  2 , epoch:  2 , val loss:  1.536909803689923e-05
fold:  2 , epoch:  3 , val loss:  1.3058765944151673e-05
fold:  2 , epoch:  4 , val loss:  1.05908356999862e-05
fold:  2 , epoch:  5 , val loss:  1.0512671906326432e-05
fold:  2 , epoch:  6 , val loss:  1.5011098184913862e-05
fold:  2 , epoch:  7 , val loss:  1.816241092456039e-05
fold:  2 , epoch:  8 , val loss:  9.669908649811987e-06
fold:  2 , epoch:  9 , val loss:  9.547574336465914e-06
fold:  2 , epoch:  10 , val loss:  9.8920918389922e-06
fold:  2 , epoch:  11 , val loss: 

fold:  2 , epoch:  8 , val loss:  2.193654108850751e-05
fold:  2 , epoch:  9 , val loss:  3.263246253482066e-05
fold:  2 , epoch:  10 , val loss:  0.0001368016382912174
fold:  2 , epoch:  11 , val loss:  3.435228791204281e-05
fold:  2 , epoch:  12 , val loss:  4.556195926852524e-05
fold:  2 , epoch:  13 , val loss:  3.079640373471193e-05
fold:  2 , epoch:  14 , val loss:  2.648264307936188e-05
fold:  2 , epoch:  15 , val loss:  1.154280107584782e-05
fold:  2 , epoch:  16 , val loss:  3.979731991421431e-05
fold:  2 , epoch:  17 , val loss:  2.0016113921883516e-05
fold:  2 , epoch:  18 , val loss:  2.0485775166889653e-05
fold:  2 , epoch:  19 , val loss:  3.160139749525115e-05
fold:  2 , epoch:  20 , val loss:  6.27650078968145e-05
fold:  2 , epoch:  21 , val loss:  5.783419692306779e-05
fold:  2 , epoch:  22 , val loss:  1.0325024959456641e-05
fold:  2 , epoch:  23 , val loss:  3.284899867139757e-05
fold:  2 , epoch:  24 , val loss:  2.5148687200271524e-05
fold:  2 , epoch:  25 , val lo

fold:  2 , epoch:  22 , val loss:  1.280612195841968e-05
fold:  2 , epoch:  23 , val loss:  9.154409781331196e-06
fold:  2 , epoch:  24 , val loss:  0.00024645408848300576
fold:  2 , epoch:  25 , val loss:  2.70529963017907e-05
fold:  3 , epoch:  1 , val loss:  1.2643630725506227e-05
fold:  3 , epoch:  2 , val loss:  5.25939030922018e-05
fold:  3 , epoch:  3 , val loss:  1.0130000191566069e-05
fold:  3 , epoch:  4 , val loss:  1.1279546015430242e-05
fold:  3 , epoch:  5 , val loss:  3.207470217603259e-05
fold:  3 , epoch:  6 , val loss:  3.215340984752402e-05
fold:  3 , epoch:  7 , val loss:  4.898651604889892e-05
fold:  3 , epoch:  8 , val loss:  3.414863385842182e-05
fold:  3 , epoch:  9 , val loss:  6.669776485068724e-05
fold:  3 , epoch:  10 , val loss:  1.6042353308876045e-05
fold:  3 , epoch:  11 , val loss:  1.368981884297682e-05
fold:  3 , epoch:  12 , val loss:  0.00011881098907906562
fold:  3 , epoch:  13 , val loss:  1.1313275535940193e-05
fold:  3 , epoch:  14 , val loss:  

fold:  3 , epoch:  11 , val loss:  3.146986273350194e-05
fold:  3 , epoch:  12 , val loss:  1.4691865544591565e-05
fold:  3 , epoch:  13 , val loss:  5.664492346113548e-05
fold:  3 , epoch:  14 , val loss:  0.00012884302122984082
fold:  3 , epoch:  15 , val loss:  1.2547247933980543e-05
fold:  3 , epoch:  16 , val loss:  2.5638193619670346e-05
fold:  3 , epoch:  17 , val loss:  1.3617508557217661e-05
fold:  3 , epoch:  18 , val loss:  1.1544309018063359e-05
fold:  3 , epoch:  19 , val loss:  2.1275562176015228e-05
fold:  3 , epoch:  20 , val loss:  1.611373409104999e-05
fold:  3 , epoch:  21 , val loss:  1.431594409950776e-05
fold:  3 , epoch:  22 , val loss:  1.6287227481370792e-05
fold:  3 , epoch:  23 , val loss:  1.4558628208760638e-05
fold:  3 , epoch:  24 , val loss:  9.323383892478887e-06
fold:  3 , epoch:  25 , val loss:  1.645852353249211e-05
fold:  4 , epoch:  1 , val loss:  1.1346380233590025e-05
fold:  4 , epoch:  2 , val loss:  9.728254553920124e-06
fold:  4 , epoch:  3 , 

fold:  3 , epoch:  25 , val loss:  4.187091326457448e-05
fold:  4 , epoch:  1 , val loss:  3.85563907912001e-05
fold:  4 , epoch:  2 , val loss:  4.8259898903779685e-05
fold:  4 , epoch:  3 , val loss:  3.580267002689652e-05
fold:  4 , epoch:  4 , val loss:  4.7057470510480925e-05
fold:  4 , epoch:  5 , val loss:  2.7675498131429777e-05
fold:  4 , epoch:  6 , val loss:  3.0850449547870085e-05
fold:  4 , epoch:  7 , val loss:  5.180664447834715e-05
fold:  4 , epoch:  8 , val loss:  4.343307591625489e-05
fold:  4 , epoch:  9 , val loss:  4.856322630075738e-05
fold:  4 , epoch:  10 , val loss:  4.093089228263125e-05
fold:  4 , epoch:  11 , val loss:  2.479388058418408e-05
fold:  4 , epoch:  12 , val loss:  2.6433379389345646e-05
fold:  4 , epoch:  13 , val loss:  2.9633980375365354e-05
fold:  4 , epoch:  14 , val loss:  4.960068326909095e-05
fold:  4 , epoch:  15 , val loss:  2.0436302293092012e-05
fold:  4 , epoch:  16 , val loss:  2.3494811102864332e-05
fold:  4 , epoch:  17 , val loss:

fold:  4 , epoch:  14 , val loss:  5.334925663191825e-05
fold:  4 , epoch:  15 , val loss:  1.3055963790975511e-05
fold:  4 , epoch:  16 , val loss:  3.177545659127645e-05
fold:  4 , epoch:  17 , val loss:  2.1771613319288008e-05
fold:  4 , epoch:  18 , val loss:  1.5791856640134938e-05
fold:  4 , epoch:  19 , val loss:  1.3307469089340884e-05
fold:  4 , epoch:  20 , val loss:  0.00011406256089685485
fold:  4 , epoch:  21 , val loss:  4.772954707732424e-05
fold:  4 , epoch:  22 , val loss:  8.146541222231463e-05
fold:  4 , epoch:  23 , val loss:  6.579519686056301e-05
fold:  4 , epoch:  24 , val loss:  2.3045982743497007e-05
fold:  4 , epoch:  25 , val loss:  8.44261194288265e-06
fold:  5 , epoch:  1 , val loss:  4.1428640543017536e-05
fold:  5 , epoch:  2 , val loss:  1.9503420844557695e-05
fold:  5 , epoch:  3 , val loss:  9.160751687886659e-06
fold:  5 , epoch:  4 , val loss:  1.6469864931423217e-05
fold:  5 , epoch:  5 , val loss:  1.878477087302599e-05
fold:  5 , epoch:  6 , val l

fold:  5 , epoch:  3 , val loss:  1.3437806046567857e-05
fold:  5 , epoch:  4 , val loss:  5.999774657539092e-06
fold:  5 , epoch:  5 , val loss:  6.815914275648538e-06
fold:  5 , epoch:  6 , val loss:  2.0813336959690787e-05
fold:  5 , epoch:  7 , val loss:  1.04509426819277e-05
fold:  5 , epoch:  8 , val loss:  6.6681118369160686e-06
fold:  5 , epoch:  9 , val loss:  5.583111033047317e-06
fold:  5 , epoch:  10 , val loss:  9.3320086307358e-05
fold:  5 , epoch:  11 , val loss:  6.388183464878239e-06
fold:  5 , epoch:  12 , val loss:  8.890546268958133e-06
fold:  5 , epoch:  13 , val loss:  6.518696409330005e-06
fold:  5 , epoch:  14 , val loss:  5.4360829381039366e-05
fold:  5 , epoch:  15 , val loss:  8.393461939704139e-06
fold:  5 , epoch:  16 , val loss:  8.440250894636847e-06
fold:  5 , epoch:  17 , val loss:  8.241205250669736e-06
fold:  5 , epoch:  18 , val loss:  1.7108835891122e-05
fold:  5 , epoch:  19 , val loss:  4.7179546527331695e-06
fold:  5 , epoch:  20 , val loss:  8.7

fold:  5 , epoch:  17 , val loss:  4.8409343435196206e-05
fold:  5 , epoch:  18 , val loss:  2.7483396479510702e-05
fold:  5 , epoch:  19 , val loss:  4.6788868530711625e-06
fold:  5 , epoch:  20 , val loss:  4.404319952300284e-06
fold:  5 , epoch:  21 , val loss:  5.827518634760054e-06
fold:  5 , epoch:  22 , val loss:  4.20197920902865e-06
fold:  5 , epoch:  23 , val loss:  2.9480603188858368e-05
fold:  5 , epoch:  24 , val loss:  4.2083620428456925e-06
fold:  5 , epoch:  25 , val loss:  3.571508204913698e-05
Model results:  {'hidden_size': 600, 'n_layers': 6, 'act_fun': 'ELU', 'init_methods': 'xavier uniform', 'mean_val_result': 9.108797466251417e-05, 'std_val_result': 0.000209640278954211} 

Model:  {'hidden_size': 400, 'n_layers': 4, 'act_fun': 'LeakyReLU', 'init_methods': 'xavier uniform'}
fold:  1 , epoch:  1 , val loss:  0.000314035831252113
fold:  1 , epoch:  2 , val loss:  0.00013095929170958698
fold:  1 , epoch:  3 , val loss:  0.00011370726133463904
fold:  1 , epoch:  4 , v

fold:  1 , epoch:  1 , val loss:  0.0021956481505185366
fold:  1 , epoch:  2 , val loss:  0.0007669304613955319
fold:  1 , epoch:  3 , val loss:  0.0011890748282894492
fold:  1 , epoch:  4 , val loss:  0.0012491738889366388
fold:  1 , epoch:  5 , val loss:  0.0005461494438350201
fold:  1 , epoch:  6 , val loss:  0.00019976316252723336
fold:  1 , epoch:  7 , val loss:  0.0009781289845705032
fold:  1 , epoch:  8 , val loss:  0.0001903671509353444
fold:  1 , epoch:  9 , val loss:  0.0001526595588074997
fold:  1 , epoch:  10 , val loss:  0.0002948542241938412
fold:  1 , epoch:  11 , val loss:  0.0001709515490802005
fold:  1 , epoch:  12 , val loss:  0.00010046294482890517
fold:  1 , epoch:  13 , val loss:  0.00036627304507419467
fold:  1 , epoch:  14 , val loss:  0.00012706322013400495
fold:  1 , epoch:  15 , val loss:  0.00014690094394609332
fold:  1 , epoch:  16 , val loss:  0.0004330354568082839
fold:  1 , epoch:  17 , val loss:  0.0006296837236732244
fold:  1 , epoch:  18 , val loss:  

In [25]:
crossval_results

Unnamed: 0,hidden_size,n_layers,act_fun,init_methods,mean_val_result,std_val_result
0,600,8,ELU,xavier normal,0.000228,0.000452
1,400,8,LeakyReLU,xavier normal,3.9e-05,5.9e-05
2,600,8,ReLU,xavier uniform,2.8e-05,5.9e-05
3,400,6,LeakyReLU,xavier uniform,2.4e-05,3e-05
4,200,6,ELU,xavier normal,0.000115,0.000409
5,200,4,LeakyReLU,xavier uniform,3.9e-05,0.000126
6,400,8,LeakyReLU,xavier uniform,2.2e-05,2.9e-05
7,600,6,ReLU,xavier uniform,2.4e-05,2.9e-05
8,600,6,LeakyReLU,xavier normal,4.4e-05,6.8e-05
9,400,8,ReLU,xavier uniform,2.6e-05,4e-05


In [27]:
hyperparam_res = crossval_results.groupby(by=['hidden_size', 'n_layers', 'act_fun', 'init_methods'])

In [28]:
hyperparam_res.apply(lambda x : x.nsmallest(3, 'mean_val_result')).drop(['hidden_size', 'n_layers', 'act_fun', 'init_methods'], axis=1)

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,mean_val_result,std_val_result
hidden_size,n_layers,act_fun,init_methods,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
200,4,ELU,xavier normal,26,0.000194,0.000812
200,4,ELU,xavier uniform,10,0.000192,0.000805
200,4,LeakyReLU,xavier uniform,5,3.9e-05,0.000126
200,6,ELU,xavier normal,4,0.000115,0.000409
200,6,ELU,xavier uniform,11,0.000159,0.000417
200,6,ReLU,xavier normal,17,7e-05,0.00016
200,8,ELU,xavier uniform,28,0.000125,0.000468
200,8,LeakyReLU,xavier normal,25,5.2e-05,9.5e-05
200,8,LeakyReLU,xavier uniform,18,2.9e-05,6.1e-05
400,4,ELU,xavier uniform,13,0.000124,0.000326


In [26]:
crossval_results.to_csv('../results/hyperparams_optimization_res.csv')

## Top 5 Models

- `{'n_hidden': 400, 'n_layers': 8, 'act_fun': 'LeakyReLU', 'init_method': 'xavier uniform'}`
- `{'n_hidden': 400, 'n_layers': 4, 'act_fun': 'ReLU', 'init_method': 'xavier uniform'}`
- `{'n_hidden': 400, 'n_layers': 4, 'act_fun': 'LeakyReLU', 'init_method': 'xavier uniform'}`
- `{'n_hidden': 400, 'n_layers': 6, 'act_fun': 'LeakyReLU', 'init_method': 'xavier uniform'}`
- `{'n_hidden': 400, 'n_layers': 8, 'act_fun': 'LeakyReLU', 'init_method': 'xavier uniform'}`