# Inerpretation-Net Training

## Specitication of Experiment Settings

In [1]:
import sys
print(sys.version)


3.9.6 (default, Aug 18 2021, 19:38:01) 
[GCC 7.5.0]


In [2]:
#######################################################################################################################################
###################################################### CONFIG FILE ####################################################################
#######################################################################################################################################
sleep_time = 0 #minutes



config = {
    'function_family': {
        'maximum_depth': 4,
        'beta': 1,
        'decision_sparsity': 1,
        'fully_grown': True,   
        'dt_type': 'vanilla', #'vanilla', 'SDT'
    },
    'data': {
        'number_of_variables': 5, 
        'num_classes': 2,
        
        'function_generation_type': 'make_classification_vanilla_decision_tree_trained', # 'make_classification' 'random_decision_tree' 'random_vanilla_decision_tree_trained'
        'objective': 'classification', # 'regression'
        
        'x_max': 1,
        'x_min': 0,
        'x_distrib': 'uniform', #'normal', 'uniform',       
                
        'lambda_dataset_size': 1000, #number of samples per function
        #'number_of_generated_datasets': 10000,
        
        'noise_injected_level': 0, 
        'noise_injected_type': 'flip_percentage', # '' 'normal' 'uniform' 'normal_range' 'uniform_range'
    }, 
    'lambda_net': {
        'epochs_lambda': 1000,
        'early_stopping_lambda': True, 
        'early_stopping_min_delta_lambda': 1e-2,
        'batch_lambda': 64,
        'dropout_lambda': 0,
        'lambda_network_layers': [256],
        'optimizer_lambda': 'adam',
        'loss_lambda': 'binary_crossentropy', #categorical_crossentropy
        
        'number_of_lambda_weights': None,
        
        'number_initializations_lambda': 1, 
        
        'number_of_trained_lambda_nets': 50000,
    },     
    
    'i_net': {
        'dense_layers': [2048],
        'convolution_layers': None,
        'lstm_layers': None,
        'dropout': [0],
        
        'optimizer': 'adam', #adam
        'learning_rate': 0.01,
        'loss': 'binary_crossentropy',
        'metrics': ['binary_accuracy'],
        
        'epochs': 200, 
        'early_stopping': True,
        'batch_size': 256,

        'interpretation_dataset_size': 10000,
                
        'test_size': 50, #Float for fraction, Int for number 0
        
        'function_representation_type': 2, # 1=standard representation; 2=sparse representation, 3=vanilla_dt

        'optimize_decision_function': True, #False
        'function_value_loss': True, #False
                      
        'data_reshape_version': None, #default to 2 options:(None, 0,1 2)
        
        'nas': True,
        'nas_type': 'SEQUENTIAL', #options:(None, 'SEQUENTIAL', 'CNN', 'LSTM', 'CNN-LSTM', 'CNN-LSTM-parallel')      
        'nas_trials': 25,
    },    
    
    'evaluation': {   
        #'inet_holdout_seed_evaluation': False,
            
        'random_evaluation_dataset_size': 500, 
        'per_network_optimization_dataset_size': 5000,

        'sklearn_dt_benchmark': False,
        'sdt_benchmark': False,
        
    },    
    
    'computation':{
        'load_model': False,
        
        'n_jobs': -3,
        'use_gpu': False,
        'gpu_numbers': '0',
        'RANDOM_SEED': 42,   
    }
}


## Imports

In [3]:
#######################################################################################################################################
########################################### IMPORT GLOBAL VARIABLES FROM CONFIG #######################################################
#######################################################################################################################################
globals().update(config['function_family'])
globals().update(config['data'])
globals().update(config['lambda_net'])
globals().update(config['i_net'])
globals().update(config['evaluation'])
globals().update(config['computation'])

In [4]:
#######################################################################################################################################
##################################################### IMPORT LIBRARIES ################################################################
#######################################################################################################################################
import warnings
warnings.filterwarnings('ignore')
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 

from itertools import product       
from tqdm import tqdm_notebook as tqdm
import pickle
import numpy as np
import pandas as pd
import scipy as sp
import timeit
import psutil

from functools import reduce
from more_itertools import random_product 
from sklearn.preprocessing import Normalizer

import sys
import shutil

import logging

#from prettytable import PrettyTable
#import colored
import math

import time
from datetime import datetime
from collections.abc import Iterable


from joblib import Parallel, delayed

from scipy.integrate import quad

from sklearn.model_selection import cross_val_score, train_test_split, StratifiedKFold, KFold, ParameterGrid, ParameterSampler
from sklearn.metrics import accuracy_score, log_loss, roc_auc_score, f1_score, mean_absolute_error, r2_score, log_loss
from sklearn.tree import DecisionTreeClassifier, plot_tree



#from similaritymeasures import frechet_dist, area_between_two_curves, dtw
import tensorflow as tf
#import tensorflow_addons as tfa
#import keras
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense, Dropout
from tensorflow.keras.utils import plot_model
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau


import tensorflow.keras.backend as K
from livelossplot import PlotLossesKerasTF
#from keras_tqdm import TQDMNotebookCallback

from matplotlib import pyplot as plt
import seaborn as sns


import random 



from IPython.display import Image
from IPython.display import display, Math, Latex, clear_output



In [5]:
tf.__version__

'2.5.1'

In [6]:
#######################################################################################################################################
########################################### IMPORT GLOBAL VARIABLES FROM CONFIG #######################################################
#######################################################################################################################################
globals().update(config['function_family'])
globals().update(config['data'])
globals().update(config['lambda_net'])
globals().update(config['evaluation'])
globals().update(config['computation'])

In [7]:
#######################################################################################################################################
################################################### VARIABLE ADJUSTMENTS ##############################################################
#######################################################################################################################################

config['i_net']['data_reshape_version'] = 2 if data_reshape_version == None and (convolution_layers != None or lstm_layers != None or (nas and nas_type != 'SEQUENTIAL')) else data_reshape_version

#######################################################################################################################################
###################################################### SET VARIABLES + DESIGN #########################################################
#######################################################################################################################################

#os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_numbers if use_gpu else ''
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
#os.environ['XLA_FLAGS'] =  '--xla_gpu_cuda_data_dir=/usr/lib/cuda-10.1'

logging.getLogger('tensorflow').disabled = True

sns.set_style("darkgrid")
#np.set_printoptions(suppress=True)

random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
if int(tf.__version__[0]) >= 2:
    tf.random.set_seed(RANDOM_SEED)
else:
    tf.set_random_seed(RANDOM_SEED)
    
    
pd.set_option('display.float_format', lambda x: '%.3f' % x)
pd.set_option('display.max_columns', 200)
np.set_printoptions(threshold=200)
np.set_printoptions(suppress=True)



In [8]:
from utilities.InterpretationNet import *
from utilities.LambdaNet import *
from utilities.metrics import *
from utilities.utility_functions import *
from utilities.DecisionTree_BASIC import *

#######################################################################################################################################
####################################################### CONFIG ADJUSTMENTS ############################################################
#######################################################################################################################################

config['lambda_net']['number_of_lambda_weights'] = get_number_of_lambda_net_parameters(lambda_network_layers, number_of_variables, num_classes)
config['function_family']['basic_function_representation_length'] = (2 ** maximum_depth - 1) * number_of_variables + (2 ** maximum_depth - 1) + (2 ** maximum_depth) * num_classes
config['function_family']['function_representation_length'] = ( 
       ((2 ** maximum_depth - 1) * decision_sparsity) * 2 + (2 ** maximum_depth - 1) + (2 ** maximum_depth) * num_classes  if function_representation_type == 1 and dt_type == 'SDT'
  else (2 ** maximum_depth - 1) * decision_sparsity + (2 ** maximum_depth - 1) + ((2 ** maximum_depth - 1)  * decision_sparsity * number_of_variables) + (2 ** maximum_depth) * num_classes if function_representation_type == 2 and dt_type == 'SDT'
  else ((2 ** maximum_depth - 1) * decision_sparsity) * 2 + (2 ** maximum_depth)  if function_representation_type == 1 and dt_type == 'vanilla'
  else (2 ** maximum_depth - 1) * decision_sparsity + ((2 ** maximum_depth - 1)  * decision_sparsity * number_of_variables) + (2 ** maximum_depth) if function_representation_type == 2 and dt_type == 'vanilla'
  else None
                                                            )
#######################################################################################################################################
################################################## UPDATE VARIABLES ###################################################################
#######################################################################################################################################
globals().update(config['function_family'])
globals().update(config['data'])
globals().update(config['lambda_net'])
globals().update(config['i_net'])
globals().update(config['evaluation'])
globals().update(config['computation'])

#initialize_LambdaNet_config_from_curent_notebook(config)
#initialize_metrics_config_from_curent_notebook(config)
#initialize_utility_functions_config_from_curent_notebook(config)
#initialize_InterpretationNet_config_from_curent_notebook(config)


#######################################################################################################################################
###################################################### PATH + FOLDER CREATION #########################################################
#######################################################################################################################################
globals().update(generate_paths(config, path_type='interpretation_net'))
create_folders_inet(config)

#######################################################################################################################################
############################################################ SLEEP TIMER ##############################################################
#######################################################################################################################################
sleep_minutes(sleep_time)

In [9]:
print(path_identifier_interpretation_net)

print(path_identifier_lambda_net_data)


lNetSize1000_numLNets50000_var5_class2_make_classification_vanilla_decision_tree_trained_xMax1_xMin0_xDistuniform_depth4_beta1_decisionSpars1_fullyGrown/256_e1000ES0.01_b64_drop0_adam_binary_crossentropy_fixedInit1-seed42/inet_dense2048_drop0e200b256_adam
lNetSize1000_numLNets50000_var5_class2_make_classification_vanilla_decision_tree_trained_xMax1_xMin0_xDistuniform_depth4_beta1_decisionSpars1_fullyGrown/256_e1000ES0.01_b64_drop0_adam_binary_crossentropy_fixedInit1-seed42


In [10]:
print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU')))
print("Num XLA-GPUs Available: ", len(tf.config.experimental.list_physical_devices('XLA_GPU')))

Num GPUs Available:  0
Num XLA-GPUs Available:  0


## Load Data and Generate Datasets

In [11]:
def load_lambda_nets(config, no_noise=False, n_jobs=1):
    
    #def generate_lambda_net()
    
    #if psutil.virtual_memory().percent > 80:
        #raise SystemExit("Out of RAM!")
    
    if no_noise==True:
        config['noise_injected_level'] = 0
    path_dict = generate_paths(config, path_type='interpretation_net')        
        
    directory = './data/weights/' + 'weights_' + path_dict['path_identifier_lambda_net_data'] + '/'
    path_network_parameters = directory + 'weights' + '.txt'
    path_X_data = directory + 'X_test_lambda.txt'
    path_y_data = directory + 'y_test_lambda.txt'        
    
    network_parameters = pd.read_csv(path_network_parameters, sep=",", header=None)
    network_parameters = network_parameters.sort_values(by=0)
    if no_noise == False:
        network_parameters = network_parameters.sample(n=config['i_net']['interpretation_dataset_size'], random_state=config['computation']['RANDOM_SEED'])
    
    X_test_lambda = pd.read_csv(path_X_data, sep=",", header=None)
    X_test_lambda = X_test_lambda.sort_values(by=0)
    if no_noise == False:
        X_test_lambda = X_test_lambda.sample(n=config['i_net']['interpretation_dataset_size'], random_state=config['computation']['RANDOM_SEED'])
    
    y_test_lambda = pd.read_csv(path_y_data, sep=",", header=None)
    y_test_lambda = y_test_lambda.sort_values(by=0)
    if no_noise == False:
        y_test_lambda = y_test_lambda.sample(n=config['i_net']['interpretation_dataset_size'], random_state=config['computation']['RANDOM_SEED'])
        
        
    parallel = Parallel(n_jobs=n_jobs, verbose=3, backend='loky') #loky

    lambda_nets = parallel(delayed(LambdaNet)(network_parameters_row, 
                                              X_test_lambda_row, 
                                              y_test_lambda_row, 
                                              config) for network_parameters_row, X_test_lambda_row, y_test_lambda_row in zip(network_parameters.values, X_test_lambda.values, y_test_lambda.values))          
    del parallel
    
    base_model = generate_base_model(config)  
    
    def initialize_network_wrapper(config, lambda_net, base_model):
        lambda_net.initialize_network(config, base_model)
    
    parallel = Parallel(n_jobs=n_jobs, verbose=3, backend='sequential')
    _ = parallel(delayed(initialize_network_wrapper)(config, lambda_net, base_model) for lambda_net in lambda_nets)   
    del parallel
    
    def initialize_target_function_wrapper(config, lambda_net):
        lambda_net.initialize_target_function(config)
    
    parallel = Parallel(n_jobs=n_jobs, verbose=3, backend='sequential')
    _ = parallel(delayed(initialize_target_function_wrapper)(config, lambda_net) for lambda_net in lambda_nets)   
    del parallel
        
    
    #lambda_nets = [None] * network_parameters.shape[0]
    #for i, (network_parameters_row, X_test_lambda_row, y_test_lambda_row) in tqdm(enumerate(zip(network_parameters.values, X_test_lambda.values, y_test_lambda.values)), total=network_parameters.values.shape[0]):        
    #    lambda_net = LambdaNet(network_parameters_row, X_test_lambda_row, y_test_lambda_row, config)
    #    lambda_nets[i] = lambda_net
                
    lambda_net_dataset = LambdaNetDataset(lambda_nets)
        
    return lambda_net_dataset
    

In [12]:
#LOAD DATA
if noise_injected_level > 0:
    lambda_net_dataset_training = load_lambda_nets(config, no_noise=True, n_jobs=n_jobs)
    lambda_net_dataset_evaluation = load_lambda_nets(config, n_jobs=n_jobs)

    lambda_net_dataset_train, lambda_net_dataset_valid = split_LambdaNetDataset(lambda_net_dataset_training, test_split=0.1)
    _, lambda_net_dataset_test = split_LambdaNetDataset(lambda_net_dataset_evaluation, test_split=test_size)
    
else:
    lambda_net_dataset = load_lambda_nets(config, n_jobs=n_jobs)

    lambda_net_dataset_train_with_valid, lambda_net_dataset_test = split_LambdaNetDataset(lambda_net_dataset, test_split=test_size)
    lambda_net_dataset_train, lambda_net_dataset_valid = split_LambdaNetDataset(lambda_net_dataset_train_with_valid, test_split=0.1)

    

[Parallel(n_jobs=-3)]: Using backend LokyBackend with 14 concurrent workers.
[Parallel(n_jobs=-3)]: Done   4 tasks      | elapsed:    7.0s
[Parallel(n_jobs=-3)]: Done 116 tasks      | elapsed:    7.2s
[Parallel(n_jobs=-3)]: Done 4124 tasks      | elapsed:    8.6s
[Parallel(n_jobs=-3)]: Done 10000 out of 10000 | elapsed:   10.1s finished
[Parallel(n_jobs=-3)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=-3)]: Done   1 out of   1 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=-3)]: Done   2 out of   2 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=-3)]: Done 10000 out of 10000 | elapsed:  2.8min finished
[Parallel(n_jobs=-3)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=-3)]: Done   1 out of   1 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=-3)]: Done   2 out of   2 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=-3)]: Done 10000 out of 10000 | elapsed:   20.7s finished


## Data Inspection

In [13]:
lambda_net_dataset_train.shape

(8955, 1917)

In [14]:
lambda_net_dataset_valid.shape

(995, 1917)

In [15]:
lambda_net_dataset_test.shape

(50, 1917)

In [16]:
lambda_net_dataset_train.as_pandas(config).head()

Unnamed: 0,index,seed,f0v0,f0v1,f0v2,f0v3,f0v4,f1v0,f1v1,f1v2,f1v3,f1v4,f2v0,f2v1,f2v2,f2v3,f2v4,f3v0,f3v1,f3v2,f3v3,f3v4,f4v0,f4v1,f4v2,f4v3,f4v4,f5v0,f5v1,f5v2,f5v3,f5v4,f6v0,f6v1,f6v2,f6v3,f6v4,f7v0,f7v1,f7v2,f7v3,f7v4,f8v0,f8v1,f8v2,f8v3,f8v4,f9v0,f9v1,f9v2,f9v3,f9v4,f10v0,f10v1,f10v2,f10v3,f10v4,f11v0,f11v1,f11v2,f11v3,f11v4,f12v0,f12v1,f12v2,f12v3,f12v4,f13v0,f13v1,f13v2,f13v3,f13v4,f14v0,f14v1,f14v2,f14v3,f14v4,b0,b1,b2,b3,b4,b5,b6,b7,b8,b9,b10,b11,b12,b13,b14,lp0c0,lp0c1,lp1c0,lp1c1,lp2c0,lp2c1,lp3c0,lp3c1,...,wb_1693,wb_1694,wb_1695,wb_1696,wb_1697,wb_1698,wb_1699,wb_1700,wb_1701,wb_1702,wb_1703,wb_1704,wb_1705,wb_1706,wb_1707,wb_1708,wb_1709,wb_1710,wb_1711,wb_1712,wb_1713,wb_1714,wb_1715,wb_1716,wb_1717,wb_1718,wb_1719,wb_1720,wb_1721,wb_1722,wb_1723,wb_1724,wb_1725,wb_1726,wb_1727,wb_1728,wb_1729,wb_1730,wb_1731,wb_1732,wb_1733,wb_1734,wb_1735,wb_1736,wb_1737,wb_1738,wb_1739,wb_1740,wb_1741,wb_1742,wb_1743,wb_1744,wb_1745,wb_1746,wb_1747,wb_1748,wb_1749,wb_1750,wb_1751,wb_1752,wb_1753,wb_1754,wb_1755,wb_1756,wb_1757,wb_1758,wb_1759,wb_1760,wb_1761,wb_1762,wb_1763,wb_1764,wb_1765,wb_1766,wb_1767,wb_1768,wb_1769,wb_1770,wb_1771,wb_1772,wb_1773,wb_1774,wb_1775,wb_1776,wb_1777,wb_1778,wb_1779,wb_1780,wb_1781,wb_1782,wb_1783,wb_1784,wb_1785,wb_1786,wb_1787,wb_1788,wb_1789,wb_1790,wb_1791,wb_1792
27989,27989.0,42,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,-0.64,0.0,-0.586,-0.081,-0.203,-0.142,0.07,0.669,0.909,-0.087,-0.225,0.061,-0.067,-0.231,-0.483,0.036,-0.647,-0.029,-0.114,-0.121,-0.484,0.743,-0.146,-0.518,-0.126,-0.295,0.42,0.621,-0.152,0.561,-0.042,-0.005,-0.09,-0.001,-0.014,0.682,-0.091,-0.616,0.509,0.418,0.501,-0.151,0.764,0.508,-0.362,-0.527,0.54,0.568,-0.106,0.682,-0.001,0.499,0.36,0.633,-0.333,1.106,1.094,0.882,0.674,0.56,-0.51,-0.716,-0.628,0.111,0.45,0.815,0.014,0.72,0.646,0.481,0.777,-0.529,-0.053,0.589,-0.15,0.577,0.351,0.07,0.041,-0.147,-0.503,-0.012,0.504,0.016,0.028,0.647,-0.501,-0.414,-0.138,-0.07,-0.114,0.545,0.005,0.679,-0.108,0.074,-0.086,-0.649,0.563,-0.188
5046,5046.0,42,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,-0.39,0.0,-0.065,-0.081,-0.213,-0.142,0.357,0.316,0.362,-0.359,-0.292,0.282,-0.067,-0.307,-0.289,0.296,-0.331,-0.029,-0.114,-0.38,-0.252,0.315,-0.146,-0.242,-0.126,-0.285,0.314,0.265,-0.152,0.362,-0.042,-0.005,-0.09,0.261,-0.014,0.229,-0.091,-0.288,0.281,0.101,0.264,-0.216,-0.191,0.265,-0.181,-0.283,0.35,0.346,-0.106,0.363,-0.001,0.223,-0.156,0.335,-0.188,0.091,0.377,0.252,0.318,0.35,-0.33,0.002,-0.361,0.362,0.287,-0.291,0.003,-0.194,0.072,0.087,0.323,-0.311,-0.053,0.328,-0.15,0.305,-0.128,0.321,0.041,-0.147,-0.301,-0.012,0.216,0.016,0.026,0.293,-0.318,-0.246,-0.138,-0.07,-0.114,0.292,0.337,0.338,-0.108,0.307,-0.094,-0.369,0.136,-0.048
13661,13661.0,42,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,-0.137,0.0,-0.065,-0.076,-0.263,-0.142,0.07,0.012,0.591,-0.528,-0.431,0.482,-0.063,-0.436,-0.377,0.472,-0.421,-0.029,-0.106,-0.312,0.403,0.132,-0.146,0.385,-0.126,-0.373,0.43,-0.017,-0.152,0.395,-0.042,-0.005,-0.09,0.484,-0.014,0.142,-0.091,-0.04,0.428,0.237,-0.332,-0.362,-0.455,0.235,-0.358,0.545,0.386,0.732,-0.106,0.174,-0.001,0.264,-0.34,0.513,-0.359,0.096,0.126,0.495,0.147,0.407,-0.313,-0.496,-0.109,0.477,0.422,-0.379,0.014,-0.329,-0.102,0.094,0.554,-0.608,-0.053,0.117,-0.15,0.067,-0.376,0.517,0.041,-0.147,-0.399,-0.012,0.226,0.016,0.031,0.563,-0.391,-0.165,-0.132,-0.07,-0.114,0.227,0.01,0.515,-0.756,0.456,-0.66,-0.108,0.491,0.114
7568,7568.0,42,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,-0.137,0.0,-0.065,-0.081,-0.201,-0.142,0.07,0.014,0.428,-0.337,-0.314,0.43,-0.067,-0.352,-0.353,0.458,-0.424,-0.029,-0.107,-0.404,-0.261,0.122,-0.146,-0.14,-0.126,-0.344,0.403,0.297,-0.152,0.525,-0.042,-0.005,-0.387,0.479,-0.014,0.217,-0.26,-0.035,0.376,0.035,0.472,-0.249,-0.234,0.158,-0.131,-0.267,0.448,0.534,-0.106,0.095,-0.001,0.036,-0.19,0.129,-0.154,0.429,0.46,0.351,0.499,0.3,-0.328,-0.205,-0.106,0.461,0.383,-0.34,0.014,0.348,0.172,0.097,0.062,-0.048,-0.053,0.506,-0.15,0.509,-0.09,0.07,0.041,-0.147,-0.372,-0.012,0.138,0.016,0.385,0.4,-0.336,-0.235,-0.138,-0.07,-0.114,0.459,0.01,0.546,-0.39,0.471,-0.308,-0.148,0.398,-0.008
44361,44361.0,42,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,-0.853,-0.001,-0.065,-0.081,-0.103,-0.142,0.07,0.403,0.51,-0.465,-0.599,0.436,-0.067,-0.814,-0.595,0.38,-0.734,-0.029,-0.114,-0.121,0.358,0.115,-0.146,0.395,-0.126,-0.672,0.451,0.229,-0.152,0.108,-0.042,-0.005,-0.083,-0.001,-0.01,0.399,-0.091,-0.435,0.411,0.37,-0.354,0.186,0.336,0.389,-0.332,-0.002,0.441,0.143,-0.106,0.354,-0.001,0.39,-0.17,0.461,-0.273,0.101,0.135,0.467,0.305,0.46,-0.971,0.364,-1.109,0.484,0.408,-0.527,0.014,-0.411,0.32,0.329,0.48,-1.0,-0.053,0.181,-0.15,0.041,-0.274,0.5,0.041,-0.147,-0.9,-0.012,0.384,0.405,0.039,0.413,-0.64,0.48,-0.138,-0.07,-0.114,0.237,0.01,0.478,-0.562,0.39,-0.073,-1.182,0.654,-0.165


In [17]:
lambda_net_dataset_valid.as_pandas(config).head()

Unnamed: 0,index,seed,f0v0,f0v1,f0v2,f0v3,f0v4,f1v0,f1v1,f1v2,f1v3,f1v4,f2v0,f2v1,f2v2,f2v3,f2v4,f3v0,f3v1,f3v2,f3v3,f3v4,f4v0,f4v1,f4v2,f4v3,f4v4,f5v0,f5v1,f5v2,f5v3,f5v4,f6v0,f6v1,f6v2,f6v3,f6v4,f7v0,f7v1,f7v2,f7v3,f7v4,f8v0,f8v1,f8v2,f8v3,f8v4,f9v0,f9v1,f9v2,f9v3,f9v4,f10v0,f10v1,f10v2,f10v3,f10v4,f11v0,f11v1,f11v2,f11v3,f11v4,f12v0,f12v1,f12v2,f12v3,f12v4,f13v0,f13v1,f13v2,f13v3,f13v4,f14v0,f14v1,f14v2,f14v3,f14v4,b0,b1,b2,b3,b4,b5,b6,b7,b8,b9,b10,b11,b12,b13,b14,lp0c0,lp0c1,lp1c0,lp1c1,lp2c0,lp2c1,lp3c0,lp3c1,...,wb_1693,wb_1694,wb_1695,wb_1696,wb_1697,wb_1698,wb_1699,wb_1700,wb_1701,wb_1702,wb_1703,wb_1704,wb_1705,wb_1706,wb_1707,wb_1708,wb_1709,wb_1710,wb_1711,wb_1712,wb_1713,wb_1714,wb_1715,wb_1716,wb_1717,wb_1718,wb_1719,wb_1720,wb_1721,wb_1722,wb_1723,wb_1724,wb_1725,wb_1726,wb_1727,wb_1728,wb_1729,wb_1730,wb_1731,wb_1732,wb_1733,wb_1734,wb_1735,wb_1736,wb_1737,wb_1738,wb_1739,wb_1740,wb_1741,wb_1742,wb_1743,wb_1744,wb_1745,wb_1746,wb_1747,wb_1748,wb_1749,wb_1750,wb_1751,wb_1752,wb_1753,wb_1754,wb_1755,wb_1756,wb_1757,wb_1758,wb_1759,wb_1760,wb_1761,wb_1762,wb_1763,wb_1764,wb_1765,wb_1766,wb_1767,wb_1768,wb_1769,wb_1770,wb_1771,wb_1772,wb_1773,wb_1774,wb_1775,wb_1776,wb_1777,wb_1778,wb_1779,wb_1780,wb_1781,wb_1782,wb_1783,wb_1784,wb_1785,wb_1786,wb_1787,wb_1788,wb_1789,wb_1790,wb_1791,wb_1792
42575,42575.0,42,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,-0.106,0.0,-0.065,-0.35,-0.179,-0.142,0.069,0.268,0.119,-0.088,-0.137,0.074,-0.067,-0.064,-0.181,0.151,-0.062,-0.029,-0.373,-0.415,-0.24,0.38,-0.146,-0.262,-0.126,0.036,0.067,0.328,-0.152,0.278,-0.042,-0.005,-0.09,-0.001,-0.014,0.271,-0.091,-0.277,0.065,0.073,0.238,-0.125,-0.268,0.146,-0.198,-0.289,0.106,0.076,-0.106,0.341,-0.001,0.129,-0.126,0.322,-0.232,0.348,0.386,-0.241,0.346,0.158,-0.314,-0.28,-0.358,0.132,0.084,0.265,0.002,0.268,0.227,0.012,0.344,-0.215,-0.053,0.362,-0.15,0.355,-0.218,0.07,0.041,-0.147,-0.119,-0.012,0.141,0.016,0.302,0.255,-0.266,-0.222,-0.138,-0.07,-0.114,0.224,0.01,0.345,-0.108,0.073,-0.376,-0.378,0.219,0.004
41631,41631.0,42,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,-0.131,-1.293,-0.065,-0.077,-0.218,-0.142,0.07,0.608,0.694,-0.801,-0.332,0.204,-0.062,-0.308,-0.444,1.813,-0.077,-0.029,-0.108,-1.233,-0.968,0.832,-0.14,-0.671,-0.126,-0.356,1.144,0.502,-0.152,0.103,-0.042,-0.005,-0.09,-0.001,-0.014,0.571,-0.091,-0.41,0.851,0.081,-1.212,-0.148,-0.212,0.173,-0.607,-1.033,1.152,0.096,-0.106,0.683,-0.001,0.361,-0.177,0.631,-0.473,1.466,1.02,-1.053,0.367,0.938,-0.633,-0.089,-0.459,1.266,0.159,-1.301,1.654,-1.111,0.055,-0.212,0.913,-0.05,-0.053,0.184,-0.144,0.118,-0.278,0.07,0.041,-0.147,-0.396,-0.012,0.13,0.016,1.393,0.505,-0.418,-0.703,-0.135,-0.07,-0.114,0.108,0.01,1.198,-0.108,0.791,-1.172,-0.218,1.276,0.112
41067,41067.0,42,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,-0.458,-0.395,-0.065,-0.354,-0.206,-0.142,0.07,0.343,0.48,-0.474,-0.375,0.444,-0.067,-0.391,-0.367,0.03,-0.453,-0.029,-0.408,-0.444,-0.172,0.414,-0.146,-0.277,-0.126,-0.342,0.054,0.332,-0.152,0.462,-0.042,-0.005,-0.09,0.005,-0.383,0.328,-0.51,-0.372,0.242,0.225,0.383,-0.18,-0.348,0.316,-0.265,0.391,0.461,0.444,-0.106,0.382,-0.001,0.327,-0.192,0.446,-0.282,0.39,0.448,0.384,0.381,0.488,-0.34,-0.33,-0.388,0.518,0.265,0.281,0.014,0.195,0.232,0.087,0.453,-0.359,-0.053,0.343,-0.15,0.34,-0.286,0.096,0.041,-0.147,-0.367,-0.012,0.277,0.016,0.027,0.45,-0.354,-0.218,-0.477,-0.07,-0.114,0.301,0.01,0.437,-0.601,0.442,-0.444,-0.405,0.379,-0.067
20399,20399.0,42,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,-1.311,0.0,-1.198,-0.709,-0.1,-0.142,0.07,0.496,0.114,-1.275,-0.104,0.203,-0.885,-0.024,-1.227,0.601,-1.289,-0.029,-0.114,-1.356,-0.695,0.536,-0.146,-0.487,-0.126,0.061,0.068,0.543,-0.152,0.121,-0.042,-0.005,-0.09,-0.77,-0.014,0.488,-1.037,-0.664,0.2,0.279,-0.843,-0.027,-0.266,0.326,-0.293,-0.783,0.316,0.104,-0.106,0.511,-0.001,0.34,-0.031,0.408,-0.257,0.108,0.138,-0.508,0.465,0.35,-0.451,-0.38,-0.532,0.303,0.205,0.669,0.014,0.567,0.524,0.305,0.061,-1.134,-0.053,0.331,-0.15,0.328,-0.23,0.07,0.041,-1.0,-1.054,-0.012,0.326,0.016,0.037,0.24,-0.737,-0.281,-0.138,-0.07,-0.114,0.273,0.004,0.5,-0.108,0.374,-1.077,-0.578,0.106,-0.14
30801,30801.0,42,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,-0.137,1.156,-0.065,-0.081,-0.133,-0.142,0.065,0.548,0.884,-0.085,-0.043,0.92,-0.067,-0.059,-0.173,0.959,-0.076,-0.029,-0.114,-0.127,0.729,0.337,-0.146,-0.385,-0.126,-0.026,0.771,0.318,-0.152,0.967,-0.042,-0.005,-0.09,-0.001,-0.014,0.291,-0.091,-0.033,0.184,0.2,1.041,0.119,-0.178,0.258,-0.246,-0.008,0.664,1.118,-0.106,0.387,-0.001,0.442,0.16,0.694,-0.234,0.083,0.115,0.271,0.323,0.652,-0.617,-0.198,-0.482,0.965,0.173,0.342,1.127,0.281,0.259,0.192,0.959,-0.055,-0.053,0.305,-0.15,0.562,-0.17,0.063,0.041,-0.147,-0.061,-0.012,0.23,0.016,1.129,0.967,-1.015,-0.323,-0.138,-0.07,-0.114,0.259,0.01,0.341,-0.108,0.936,-0.082,-0.467,0.963,-0.165


In [18]:
lambda_net_dataset_test.as_pandas(config).head()

Unnamed: 0,index,seed,f0v0,f0v1,f0v2,f0v3,f0v4,f1v0,f1v1,f1v2,f1v3,f1v4,f2v0,f2v1,f2v2,f2v3,f2v4,f3v0,f3v1,f3v2,f3v3,f3v4,f4v0,f4v1,f4v2,f4v3,f4v4,f5v0,f5v1,f5v2,f5v3,f5v4,f6v0,f6v1,f6v2,f6v3,f6v4,f7v0,f7v1,f7v2,f7v3,f7v4,f8v0,f8v1,f8v2,f8v3,f8v4,f9v0,f9v1,f9v2,f9v3,f9v4,f10v0,f10v1,f10v2,f10v3,f10v4,f11v0,f11v1,f11v2,f11v3,f11v4,f12v0,f12v1,f12v2,f12v3,f12v4,f13v0,f13v1,f13v2,f13v3,f13v4,f14v0,f14v1,f14v2,f14v3,f14v4,b0,b1,b2,b3,b4,b5,b6,b7,b8,b9,b10,b11,b12,b13,b14,lp0c0,lp0c1,lp1c0,lp1c1,lp2c0,lp2c1,lp3c0,lp3c1,...,wb_1693,wb_1694,wb_1695,wb_1696,wb_1697,wb_1698,wb_1699,wb_1700,wb_1701,wb_1702,wb_1703,wb_1704,wb_1705,wb_1706,wb_1707,wb_1708,wb_1709,wb_1710,wb_1711,wb_1712,wb_1713,wb_1714,wb_1715,wb_1716,wb_1717,wb_1718,wb_1719,wb_1720,wb_1721,wb_1722,wb_1723,wb_1724,wb_1725,wb_1726,wb_1727,wb_1728,wb_1729,wb_1730,wb_1731,wb_1732,wb_1733,wb_1734,wb_1735,wb_1736,wb_1737,wb_1738,wb_1739,wb_1740,wb_1741,wb_1742,wb_1743,wb_1744,wb_1745,wb_1746,wb_1747,wb_1748,wb_1749,wb_1750,wb_1751,wb_1752,wb_1753,wb_1754,wb_1755,wb_1756,wb_1757,wb_1758,wb_1759,wb_1760,wb_1761,wb_1762,wb_1763,wb_1764,wb_1765,wb_1766,wb_1767,wb_1768,wb_1769,wb_1770,wb_1771,wb_1772,wb_1773,wb_1774,wb_1775,wb_1776,wb_1777,wb_1778,wb_1779,wb_1780,wb_1781,wb_1782,wb_1783,wb_1784,wb_1785,wb_1786,wb_1787,wb_1788,wb_1789,wb_1790,wb_1791,wb_1792
25056,25056.0,42,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,-0.126,0.0,-0.065,-0.081,-0.185,-0.142,0.65,0.691,0.868,-0.422,-0.334,0.331,-0.067,-0.983,-0.029,0.312,-0.069,-0.029,-0.11,-0.439,0.507,0.707,-0.146,-0.321,-0.126,-0.952,0.719,0.183,-0.152,1.044,-0.042,-0.005,-0.09,-0.001,-0.014,0.175,-0.091,-0.432,0.051,0.1,1.287,-0.171,-0.304,0.105,-0.403,-1.38,1.016,1.241,-0.106,0.669,-0.001,0.117,-0.187,0.805,-0.282,1.025,0.632,0.855,0.195,0.713,-0.199,-0.268,-0.373,1.053,0.1,-1.27,0.014,-0.56,0.132,-0.12,0.877,-0.046,-0.053,0.098,-0.15,0.916,-0.339,0.07,0.041,-0.147,-0.671,-0.012,0.101,0.016,1.045,0.805,-0.152,-0.054,-0.138,-0.07,-0.114,-0.046,0.01,1.095,-0.108,0.318,-0.088,-0.442,0.827,0.05
30334,30334.0,42,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,-0.583,-1.226,-0.557,-0.07,-0.2,-0.142,0.07,0.009,0.412,-1.641,-0.376,0.771,-0.067,-0.458,-0.45,0.027,-0.507,-0.029,-0.105,-0.128,-0.909,1.068,-0.146,-0.117,-0.126,-0.408,0.656,0.874,-0.152,0.32,-0.042,-0.005,-0.09,-1.536,-0.014,0.19,-0.091,-1.143,0.42,0.182,-1.295,-0.215,0.123,0.338,-0.227,-1.516,0.535,0.074,-0.106,1.177,-0.001,0.336,-0.167,0.234,-0.165,0.09,0.124,1.153,1.051,0.476,-0.592,0.014,-0.104,0.562,0.416,-0.38,0.004,-0.363,-0.146,0.165,0.697,-0.543,-0.053,0.893,-0.15,0.044,-0.13,0.062,0.041,-0.147,-0.476,-0.012,0.279,0.016,0.029,0.247,-0.468,-0.294,-0.138,-0.07,-0.114,0.365,0.01,0.53,-0.108,0.843,-0.085,-0.115,0.087,0.005
17962,17962.0,42,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,-0.11,1.008,-0.065,-0.984,-0.046,-0.142,0.07,0.015,0.877,-0.085,-0.101,0.54,-0.067,-1.16,-0.987,1.155,-1.091,-0.029,-1.078,-1.326,-0.193,0.332,-0.142,-0.766,-0.126,-1.05,0.146,0.215,-0.152,0.273,-0.042,-0.005,-0.09,-0.309,-0.014,0.214,-0.091,-0.019,0.249,0.137,-0.955,-0.194,-0.227,0.133,0.282,-0.272,0.129,0.174,-0.106,0.091,-0.001,0.072,-0.13,0.76,-0.087,1.179,1.258,0.333,0.261,0.785,-1.072,-0.236,-1.115,0.119,0.24,-1.127,0.941,0.402,0.234,0.176,0.065,-1.098,-0.053,0.257,-0.143,0.24,0.051,0.063,0.041,-0.147,-1.116,-0.012,0.156,0.016,0.857,0.133,-0.955,-0.068,-0.138,-0.07,-0.114,0.238,0.01,0.082,-0.108,0.817,-0.087,-1.057,0.563,-0.04
39588,39588.0,42,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,-0.137,0.313,-0.065,-0.081,-0.191,-0.142,0.066,0.013,0.111,-0.089,-0.148,0.077,-0.067,-0.216,-0.421,0.023,-0.474,-0.029,-0.107,-0.429,-0.272,0.309,-0.146,-0.299,-0.126,-0.204,0.353,0.356,-0.152,0.405,-0.042,-0.005,-0.09,-0.337,-0.014,0.318,-0.091,-0.033,0.322,0.096,0.3,-0.197,-0.213,0.281,-0.182,-0.314,0.431,0.415,-0.106,0.422,-0.001,0.205,-0.137,0.102,-0.188,0.1,0.126,-0.004,0.373,0.301,-0.389,-0.211,-0.511,0.122,0.289,0.317,0.281,0.263,0.251,0.217,0.059,-0.403,-0.053,0.384,-0.15,0.341,0.05,0.07,0.041,-0.147,-0.369,-0.012,0.248,0.016,0.034,0.093,-0.404,-0.288,-0.138,-0.07,-0.114,0.35,0.009,0.421,-0.108,0.059,-0.085,-0.099,0.103,-0.074
34107,34107.0,42,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,-0.123,-0.003,-0.056,-0.081,-0.067,-0.142,0.07,1.1,0.106,-1.418,-0.224,0.429,-0.067,-0.194,-0.813,0.028,-0.061,-0.029,-1.337,-1.489,0.323,0.12,-0.146,-0.123,-0.126,-0.222,0.961,0.061,-0.152,0.111,-0.042,-0.005,-0.09,-0.001,-0.014,0.229,-0.091,-0.946,0.83,0.269,-0.32,-0.04,-0.211,0.214,-0.195,-1.633,0.9,0.093,-0.106,0.121,-0.001,0.256,-0.135,0.253,-0.168,0.102,0.352,0.465,0.063,0.243,-0.45,-0.166,-1.235,0.997,0.659,1.304,0.009,-0.776,0.219,0.215,0.057,-0.038,-0.053,0.085,-0.15,0.044,-0.211,0.063,0.041,-0.147,0.144,-0.012,0.233,0.016,0.028,0.385,-0.564,0.206,-1.466,-0.07,-0.114,0.805,0.01,0.077,-0.108,0.438,-1.423,-1.341,0.239,-0.095


## Interpretation Network Training

In [19]:
#%load_ext autoreload

In [20]:
#%autoreload 2
((X_valid, y_valid), 
 (X_test, y_test),
 history,

 model) = interpretation_net_training(
                                      lambda_net_dataset_train, 
                                      lambda_net_dataset_valid, 
                                      lambda_net_dataset_test,
                                      config,
                                      #callback_names=['plot_losses']
                                     )



Trial 25 Complete [00h 12m 50s]
val_loss: 0.5219460725784302

Best val_loss So Far: 0.49273085594177246
Total elapsed time: 04h 15m 50s
Training Time: 4:16:10
---------------------------------------------------------------------------------------------------------------------------
------------------------------------------------------ LOADING MODELS -----------------------------------------------------
Loading Time: 0:00:01


In [21]:
model.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 1793)]       0                                            
__________________________________________________________________________________________________
cast_to_float32 (CastToFloat32) (None, 1793)         0           input_1[0][0]                    
__________________________________________________________________________________________________
dense (Dense)                   (None, 64)           114816      cast_to_float32[0][0]            
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 64)           256         dense[0][0]                      
______________________________________________________________________________________________

In [22]:
if nas:
    for trial in history_list[-1]: 
        print(trial.summary())

NameError: name 'history_list' is not defined

In [None]:
y_test_inet_dt_list = []
y_test_distilled_sklearn_dt_list = []

binary_crossentropy_distilled_sklearn_dt_list =[]
accuracy_distilled_sklearn_dt_list = []
f1_score_distilled_sklearn_dt_list = []

binary_crossentropy_inet_dt_list =[]
accuracy_inet_dt_list = []
f1_score_inet_dt_list = []

#inet_metric_function_list = []

number = lambda_net_dataset_test.y_test_lambda_array.shape[0]#10

for lambda_net_parameters, lambda_net, X_test_lambda, y_test_lambda in tqdm(zip(lambda_net_dataset_test.network_parameters_array[:number], lambda_net_dataset_test.network_list[:number], lambda_net_dataset_test.X_test_lambda_array[:number], lambda_net_dataset_test.y_test_lambda_array[:number]), total=lambda_net_dataset_test.y_test_lambda_array[:number].shape[0]):
    dt_inet = model.predict(np.array([lambda_net_parameters]))[0]
    if nas:
        dt_inet = dt_inet[:function_representation_length]

    
    X_data_random = generate_random_data_points_custom(config['data']['x_min'], config['data']['x_max'], config['evaluation']['random_evaluation_dataset_size'], config['data']['number_of_variables'])
    y_data_random_lambda_pred = lambda_net.predict(X_data_random)
    y_data_random_lambda_pred = np.round(y_data_random_lambda_pred).astype(np.int64)
    
    dt_sklearn_distilled = DecisionTreeClassifier(max_depth=config['function_family']['maximum_depth'])
    dt_sklearn_distilled.fit(X_data_random, y_data_random_lambda_pred)
    
    
    if dt_type == 'SDT':
        y_test_inet_dt  = calculate_function_value_from_decision_tree_parameters_wrapper(X_test_lambda, config)(dt_inet).numpy()
    elif dt_type == 'vanilla':
        y_test_inet_dt  = calculate_function_value_from_vanilla_decision_tree_parameters_wrapper(X_test_lambda, config)(dt_inet).numpy()
    y_test_distilled_sklearn_dt = dt_sklearn_distilled.predict(X_test_lambda)
    
    y_test_lambda_pred = lambda_net.predict(X_test_lambda)
    y_test_lambda_pred = np.round(y_test_lambda_pred)
    
    
    #random_model = generate_base_model(config)        
    #random_network_parameters = random_model.get_weights()
    #network_parameters_structure = [network_parameter.shape for network_parameter in random_network_parameters]         
    
    #function_true_placeholder = np.array([0 for i in range(basic_function_representation_length)])
    #function_true_with_network_parameters = np.concatenate([function_true_placeholder, lambda_net_parameters])
    #inet_metric_function = inet_decision_function_fv_metric_wrapper(X_test_lambda, 
    #                                                                 random_model, 
    #                                                                 network_parameters_structure, 
    #                                                                 config, 
    #                                                                 'binary_accuracy')(np.array([function_true_with_network_parameters]), 
    #                                                                                     np.array([dt_inet]))    
    #inet_metric_function_list.append(inet_metric_function)
    
    
    binary_crossentropy_distilled_sklearn_dt = log_loss(y_test_lambda_pred, y_test_distilled_sklearn_dt)
    accuracy_distilled_sklearn_dt = accuracy_score(y_test_lambda_pred, np.round(y_test_distilled_sklearn_dt))
    f1_score_distilled_sklearn_dt = f1_score(y_test_lambda_pred, np.round(y_test_distilled_sklearn_dt))
    
    binary_crossentropy_inet_dt = log_loss(y_test_lambda_pred, y_test_inet_dt)
    accuracy_inet_dt = accuracy_score(y_test_lambda_pred, np.round(y_test_inet_dt))
    f1_score_inet_dt = f1_score(y_test_lambda_pred, np.round(y_test_inet_dt))
    
    
    y_test_inet_dt_list.append(y_test_inet_dt)
    y_test_distilled_sklearn_dt_list.append(y_test_distilled_sklearn_dt)    

    binary_crossentropy_distilled_sklearn_dt_list.append(np.nan_to_num(binary_crossentropy_distilled_sklearn_dt))
    accuracy_distilled_sklearn_dt_list.append(np.nan_to_num(accuracy_distilled_sklearn_dt))
    f1_score_distilled_sklearn_dt_list.append(np.nan_to_num(f1_score_distilled_sklearn_dt))

    binary_crossentropy_inet_dt_list.append(np.nan_to_num(binary_crossentropy_inet_dt))
    accuracy_inet_dt_list.append(np.nan_to_num(accuracy_inet_dt))
    f1_score_inet_dt_list.append(np.nan_to_num(f1_score_inet_dt))
    
y_test_inet_dt_list = np.array(y_test_inet_dt_list)
y_test_distilled_sklearn_dt_list = np.array(y_test_distilled_sklearn_dt_list)

binary_crossentropy_distilled_sklearn_dt_list = np.array(binary_crossentropy_distilled_sklearn_dt_list)
accuracy_distilled_sklearn_dt_list = np.array(accuracy_distilled_sklearn_dt_list)
f1_score_distilled_sklearn_dt_list = np.array(f1_score_distilled_sklearn_dt_list)

binary_crossentropy_inet_dt_list = np.array(binary_crossentropy_inet_dt_list)
accuracy_inet_dt_list = np.array(accuracy_inet_dt_list)
f1_score_inet_dt_list = np.array(f1_score_inet_dt_list)    

    
print('Binary Crossentropy:\t\t', np.round(np.mean(binary_crossentropy_distilled_sklearn_dt_list), 3), '(Sklearn DT)' , '\t', np.round(np.mean(binary_crossentropy_inet_dt_list), 3), '(I-Net DT)')
print('Accuracy:\t\t', np.round(np.mean(accuracy_distilled_sklearn_dt_list), 3), '(Sklearn DT)' , '\t', np.round(np.mean(accuracy_inet_dt_list), 3), '(I-Net DT)')
print('F1 Score:\t\t', np.round(np.mean(f1_score_distilled_sklearn_dt_list), 3), '(Sklearn DT)' , '\t', np.round(np.mean(f1_score_inet_dt_list), 3), '(I-Net DT)')
      

In [None]:
dt_inet.shape

In [None]:
get_shaped_parameters_for_decision_tree(dt_inet, config)