# Inerpretation-Net Training

## Specitication of Experiment Settings

In [1]:
import sys
print(sys.version)


3.9.6 (default, Aug 18 2021, 19:38:01) 
[GCC 7.5.0]


In [2]:
#######################################################################################################################################
###################################################### CONFIG FILE ####################################################################
#######################################################################################################################################
sleep_time = 0 #minutes



config = {
    'function_family': {
        'maximum_depth': 4,
        'beta': 1,
        'decision_sparsity': 1,
        'fully_grown': True,   
        'dt_type': 'vanilla', #'vanilla', 'SDT'
    },
    'data': {
        'number_of_variables': 5, 
        'num_classes': 2,
        
        'function_generation_type': 'random_vanilla_decision_tree_trained', # 'make_classification' 'random_decision_tree' 'random_vanilla_decision_tree_trained'
        'objective': 'classification', # 'regression'
        
        'x_max': 1,
        'x_min': 0,
        'x_distrib': 'uniform', #'normal', 'uniform',       
                
        'lambda_dataset_size': 1000, #number of samples per function
        #'number_of_generated_datasets': 10000,
        
        'noise_injected_level': 0, 
        'noise_injected_type': 'flip_percentage', # '' 'normal' 'uniform' 'normal_range' 'uniform_range'
    }, 
    'lambda_net': {
        'epochs_lambda': 1000,
        'early_stopping_lambda': True, 
        'early_stopping_min_delta_lambda': 1e-2,
        'batch_lambda': 64,
        'dropout_lambda': 0,
        'lambda_network_layers': [64],
        'optimizer_lambda': 'adam',
        'loss_lambda': 'binary_crossentropy', #categorical_crossentropy
        
        'number_of_lambda_weights': None,
        
        'number_initializations_lambda': 1, 
        
        'number_of_trained_lambda_nets': 10000,
    },     
    
    'i_net': {
        'dense_layers': [2048],
        'convolution_layers': None,
        'lstm_layers': None,
        'dropout': [0],
        
        'optimizer': 'adam', #adam
        'learning_rate': 0.01,
        'loss': 'binary_crossentropy',
        'metrics': ['binary_accuracy'],
        
        'epochs': 200, 
        'early_stopping': True,
        'batch_size': 256,

        'interpretation_dataset_size': 10000,
                
        'test_size': 50, #Float for fraction, Int for number 0
        
        'function_representation_type': 2, # 1=standard representation; 2=sparse representation, 3=vanilla_dt

        'optimize_decision_function': True, #False
        'function_value_loss': True, #False
                      
        'data_reshape_version': None, #default to 2 options:(None, 0,1 2)
        
        'nas': True,
        'nas_type': 'SEQUENTIAL', #options:(None, 'SEQUENTIAL', 'CNN', 'LSTM', 'CNN-LSTM', 'CNN-LSTM-parallel')      
        'nas_trials': 10,
    },    
    
    'evaluation': {   
        #'inet_holdout_seed_evaluation': False,
            
        'random_evaluation_dataset_size': 500, 
        'per_network_optimization_dataset_size': 5000,

        'sklearn_dt_benchmark': False,
        'sdt_benchmark': False,
        
    },    
    
    'computation':{
        'load_model': False,
        
        'n_jobs': -3,
        'use_gpu': False,
        'gpu_numbers': '0',
        'RANDOM_SEED': 42,   
    }
}


## Imports

In [3]:
#######################################################################################################################################
########################################### IMPORT GLOBAL VARIABLES FROM CONFIG #######################################################
#######################################################################################################################################
globals().update(config['function_family'])
globals().update(config['data'])
globals().update(config['lambda_net'])
globals().update(config['i_net'])
globals().update(config['evaluation'])
globals().update(config['computation'])

In [4]:
#######################################################################################################################################
##################################################### IMPORT LIBRARIES ################################################################
#######################################################################################################################################
import warnings
warnings.filterwarnings('ignore')
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 

from itertools import product       
from tqdm import tqdm_notebook as tqdm
import pickle
import numpy as np
import pandas as pd
import scipy as sp
import timeit
import psutil

from functools import reduce
from more_itertools import random_product 
from sklearn.preprocessing import Normalizer

import sys
import shutil

import logging

#from prettytable import PrettyTable
#import colored
import math

import time
from datetime import datetime
from collections.abc import Iterable


from joblib import Parallel, delayed

from scipy.integrate import quad

from sklearn.model_selection import cross_val_score, train_test_split, StratifiedKFold, KFold, ParameterGrid, ParameterSampler
from sklearn.metrics import accuracy_score, log_loss, roc_auc_score, f1_score, mean_absolute_error, r2_score, log_loss
from sklearn.tree import DecisionTreeClassifier, plot_tree



#from similaritymeasures import frechet_dist, area_between_two_curves, dtw
import tensorflow as tf
#import tensorflow_addons as tfa
#import keras
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense, Dropout
from tensorflow.keras.utils import plot_model
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau


import tensorflow.keras.backend as K
from livelossplot import PlotLossesKerasTF
#from keras_tqdm import TQDMNotebookCallback

from matplotlib import pyplot as plt
import seaborn as sns


import random 



from IPython.display import Image
from IPython.display import display, Math, Latex, clear_output



In [5]:
tf.__version__

'2.5.1'

In [6]:
#######################################################################################################################################
########################################### IMPORT GLOBAL VARIABLES FROM CONFIG #######################################################
#######################################################################################################################################
globals().update(config['function_family'])
globals().update(config['data'])
globals().update(config['lambda_net'])
globals().update(config['evaluation'])
globals().update(config['computation'])

In [7]:
#######################################################################################################################################
################################################### VARIABLE ADJUSTMENTS ##############################################################
#######################################################################################################################################

config['i_net']['data_reshape_version'] = 2 if data_reshape_version == None and (convolution_layers != None or lstm_layers != None or (nas and nas_type != 'SEQUENTIAL')) else data_reshape_version

#######################################################################################################################################
###################################################### SET VARIABLES + DESIGN #########################################################
#######################################################################################################################################

#os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_numbers if use_gpu else ''
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
#os.environ['XLA_FLAGS'] =  '--xla_gpu_cuda_data_dir=/usr/lib/cuda-10.1'

logging.getLogger('tensorflow').disabled = True

sns.set_style("darkgrid")
#np.set_printoptions(suppress=True)

random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
if int(tf.__version__[0]) >= 2:
    tf.random.set_seed(RANDOM_SEED)
else:
    tf.set_random_seed(RANDOM_SEED)
    
    
pd.set_option('display.float_format', lambda x: '%.3f' % x)
pd.set_option('display.max_columns', 200)
np.set_printoptions(threshold=200)
np.set_printoptions(suppress=True)



In [8]:
from utilities.InterpretationNet import *
from utilities.LambdaNet import *
from utilities.metrics import *
from utilities.utility_functions import *
from utilities.DecisionTree_BASIC import *

#######################################################################################################################################
####################################################### CONFIG ADJUSTMENTS ############################################################
#######################################################################################################################################

config['lambda_net']['number_of_lambda_weights'] = get_number_of_lambda_net_parameters(lambda_network_layers, number_of_variables, num_classes)
config['function_family']['basic_function_representation_length'] = (2 ** maximum_depth - 1) * number_of_variables + (2 ** maximum_depth - 1) + (2 ** maximum_depth) * num_classes
config['function_family']['function_representation_length'] = ( 
       ((2 ** maximum_depth - 1) * decision_sparsity) * 2 + (2 ** maximum_depth - 1) + (2 ** maximum_depth) * num_classes  if function_representation_type == 1 and dt_type == 'SDT'
  else (2 ** maximum_depth - 1) * decision_sparsity + (2 ** maximum_depth - 1) + ((2 ** maximum_depth - 1)  * decision_sparsity * number_of_variables) + (2 ** maximum_depth) * num_classes if function_representation_type == 2 and dt_type == 'SDT'
  else ((2 ** maximum_depth - 1) * decision_sparsity) * 2 + (2 ** maximum_depth)  if function_representation_type == 1 and dt_type == 'vanilla'
  else (2 ** maximum_depth - 1) * decision_sparsity + ((2 ** maximum_depth - 1)  * decision_sparsity * number_of_variables) + (2 ** maximum_depth) if function_representation_type == 2 and dt_type == 'vanilla'
  else None
                                                            )
#######################################################################################################################################
################################################## UPDATE VARIABLES ###################################################################
#######################################################################################################################################
globals().update(config['function_family'])
globals().update(config['data'])
globals().update(config['lambda_net'])
globals().update(config['i_net'])
globals().update(config['evaluation'])
globals().update(config['computation'])

#initialize_LambdaNet_config_from_curent_notebook(config)
#initialize_metrics_config_from_curent_notebook(config)
#initialize_utility_functions_config_from_curent_notebook(config)
#initialize_InterpretationNet_config_from_curent_notebook(config)


#######################################################################################################################################
###################################################### PATH + FOLDER CREATION #########################################################
#######################################################################################################################################
globals().update(generate_paths(config, path_type='interpretation_net'))
create_folders_inet(config)

#######################################################################################################################################
############################################################ SLEEP TIMER ##############################################################
#######################################################################################################################################
sleep_minutes(sleep_time)

In [9]:
print(path_identifier_interpretation_net)

print(path_identifier_lambda_net_data)


lNetSize1000_numLNets10000_var5_class2_random_vanilla_decision_tree_trained_xMax1_xMin0_xDistuniform_depth4_beta1_decisionSpars1_fullyGrown/64_e1000ES0.01_b64_drop0_adam_binary_crossentropy_fixedInit1-seed42/inet_dense2048_drop0e200b256_adam
lNetSize1000_numLNets10000_var5_class2_random_vanilla_decision_tree_trained_xMax1_xMin0_xDistuniform_depth4_beta1_decisionSpars1_fullyGrown/64_e1000ES0.01_b64_drop0_adam_binary_crossentropy_fixedInit1-seed42


In [10]:
print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU')))
print("Num XLA-GPUs Available: ", len(tf.config.experimental.list_physical_devices('XLA_GPU')))

Num GPUs Available:  0
Num XLA-GPUs Available:  0


## Load Data and Generate Datasets

In [11]:
def load_lambda_nets(config, no_noise=False, n_jobs=1):
    
    #def generate_lambda_net()
    
    #if psutil.virtual_memory().percent > 80:
        #raise SystemExit("Out of RAM!")
    
    if no_noise==True:
        config['noise_injected_level'] = 0
    path_dict = generate_paths(config, path_type='interpretation_net')        
        
    directory = './data/weights/' + 'weights_' + path_dict['path_identifier_lambda_net_data'] + '/'
    path_network_parameters = directory + 'weights' + '.txt'
    path_X_data = directory + 'X_test_lambda.txt'
    path_y_data = directory + 'y_test_lambda.txt'        
    
    network_parameters = pd.read_csv(path_network_parameters, sep=",", header=None)
    network_parameters = network_parameters.sort_values(by=0)
    if no_noise == False:
        network_parameters = network_parameters.sample(n=config['i_net']['interpretation_dataset_size'], random_state=config['computation']['RANDOM_SEED'])
    
    X_test_lambda = pd.read_csv(path_X_data, sep=",", header=None)
    X_test_lambda = X_test_lambda.sort_values(by=0)
    if no_noise == False:
        X_test_lambda = X_test_lambda.sample(n=config['i_net']['interpretation_dataset_size'], random_state=config['computation']['RANDOM_SEED'])
    
    y_test_lambda = pd.read_csv(path_y_data, sep=",", header=None)
    y_test_lambda = y_test_lambda.sort_values(by=0)
    if no_noise == False:
        y_test_lambda = y_test_lambda.sample(n=config['i_net']['interpretation_dataset_size'], random_state=config['computation']['RANDOM_SEED'])
        
        
    parallel = Parallel(n_jobs=n_jobs, verbose=3, backend='loky') #loky

    lambda_nets = parallel(delayed(LambdaNet)(network_parameters_row, 
                                              X_test_lambda_row, 
                                              y_test_lambda_row, 
                                              config) for network_parameters_row, X_test_lambda_row, y_test_lambda_row in zip(network_parameters.values, X_test_lambda.values, y_test_lambda.values))          
    del parallel
    
    base_model = generate_base_model(config)  
    
    def initialize_network_wrapper(config, lambda_net, base_model):
        lambda_net.initialize_network(config, base_model)
    
    parallel = Parallel(n_jobs=n_jobs, verbose=3, backend='sequential')
    _ = parallel(delayed(initialize_network_wrapper)(config, lambda_net, base_model) for lambda_net in lambda_nets)   
    del parallel
    
    def initialize_target_function_wrapper(config, lambda_net):
        lambda_net.initialize_target_function(config)
    
    parallel = Parallel(n_jobs=n_jobs, verbose=3, backend='sequential')
    _ = parallel(delayed(initialize_target_function_wrapper)(config, lambda_net) for lambda_net in lambda_nets)   
    del parallel
        
    
    #lambda_nets = [None] * network_parameters.shape[0]
    #for i, (network_parameters_row, X_test_lambda_row, y_test_lambda_row) in tqdm(enumerate(zip(network_parameters.values, X_test_lambda.values, y_test_lambda.values)), total=network_parameters.values.shape[0]):        
    #    lambda_net = LambdaNet(network_parameters_row, X_test_lambda_row, y_test_lambda_row, config)
    #    lambda_nets[i] = lambda_net
                
    lambda_net_dataset = LambdaNetDataset(lambda_nets)
        
    return lambda_net_dataset
    

In [12]:
#LOAD DATA
if noise_injected_level > 0:
    lambda_net_dataset_training = load_lambda_nets(config, no_noise=True, n_jobs=n_jobs)
    lambda_net_dataset_evaluation = load_lambda_nets(config, n_jobs=n_jobs)

    lambda_net_dataset_train, lambda_net_dataset_valid = split_LambdaNetDataset(lambda_net_dataset_training, test_split=0.1)
    _, lambda_net_dataset_test = split_LambdaNetDataset(lambda_net_dataset_evaluation, test_split=test_size)
    
else:
    lambda_net_dataset = load_lambda_nets(config, n_jobs=n_jobs)

    lambda_net_dataset_train_with_valid, lambda_net_dataset_test = split_LambdaNetDataset(lambda_net_dataset, test_split=test_size)
    lambda_net_dataset_train, lambda_net_dataset_valid = split_LambdaNetDataset(lambda_net_dataset_train_with_valid, test_split=0.1)

    

[Parallel(n_jobs=-3)]: Using backend LokyBackend with 14 concurrent workers.
[Parallel(n_jobs=-3)]: Done   4 tasks      | elapsed:    6.1s
[Parallel(n_jobs=-3)]: Done 170 tasks      | elapsed:    6.2s
[Parallel(n_jobs=-3)]: Done 5006 tasks      | elapsed:    7.8s
[Parallel(n_jobs=-3)]: Done 10000 out of 10000 | elapsed:    9.4s finished
[Parallel(n_jobs=-3)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=-3)]: Done   1 out of   1 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=-3)]: Done   2 out of   2 | elapsed:    0.1s remaining:    0.0s
[Parallel(n_jobs=-3)]: Done 10000 out of 10000 | elapsed:  3.0min finished
[Parallel(n_jobs=-3)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=-3)]: Done   1 out of   1 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=-3)]: Done   2 out of   2 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=-3)]: Done 10000 out of 10000 | elapsed:   30.9s finished


## Data Inspection

In [13]:
lambda_net_dataset_train.shape

(8955, 573)

In [14]:
lambda_net_dataset_valid.shape

(995, 573)

In [15]:
lambda_net_dataset_test.shape

(50, 573)

In [16]:
lambda_net_dataset_train.as_pandas(config).head()

Unnamed: 0,index,seed,f0v0,f0v1,f0v2,f0v3,f0v4,f1v0,f1v1,f1v2,f1v3,f1v4,f2v0,f2v1,f2v2,f2v3,f2v4,f3v0,f3v1,f3v2,f3v3,f3v4,f4v0,f4v1,f4v2,f4v3,f4v4,f5v0,f5v1,f5v2,f5v3,f5v4,f6v0,f6v1,f6v2,f6v3,f6v4,f7v0,f7v1,f7v2,f7v3,f7v4,f8v0,f8v1,f8v2,f8v3,f8v4,f9v0,f9v1,f9v2,f9v3,f9v4,f10v0,f10v1,f10v2,f10v3,f10v4,f11v0,f11v1,f11v2,f11v3,f11v4,f12v0,f12v1,f12v2,f12v3,f12v4,f13v0,f13v1,f13v2,f13v3,f13v4,f14v0,f14v1,f14v2,f14v3,f14v4,b0,b1,b2,b3,b4,b5,b6,b7,b8,b9,b10,b11,b12,b13,b14,lp0c0,lp0c1,lp1c0,lp1c1,lp2c0,lp2c1,lp3c0,lp3c1,...,wb_349,wb_350,wb_351,wb_352,wb_353,wb_354,wb_355,wb_356,wb_357,wb_358,wb_359,wb_360,wb_361,wb_362,wb_363,wb_364,wb_365,wb_366,wb_367,wb_368,wb_369,wb_370,wb_371,wb_372,wb_373,wb_374,wb_375,wb_376,wb_377,wb_378,wb_379,wb_380,wb_381,wb_382,wb_383,wb_384,wb_385,wb_386,wb_387,wb_388,wb_389,wb_390,wb_391,wb_392,wb_393,wb_394,wb_395,wb_396,wb_397,wb_398,wb_399,wb_400,wb_401,wb_402,wb_403,wb_404,wb_405,wb_406,wb_407,wb_408,wb_409,wb_410,wb_411,wb_412,wb_413,wb_414,wb_415,wb_416,wb_417,wb_418,wb_419,wb_420,wb_421,wb_422,wb_423,wb_424,wb_425,wb_426,wb_427,wb_428,wb_429,wb_430,wb_431,wb_432,wb_433,wb_434,wb_435,wb_436,wb_437,wb_438,wb_439,wb_440,wb_441,wb_442,wb_443,wb_444,wb_445,wb_446,wb_447,wb_448
6671,6671.0,42,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.282,0.413,-0.349,-0.277,-0.356,0.496,-0.382,-0.008,-0.382,-0.37,0.0,-0.427,-0.385,-0.342,-0.348,-0.359,-0.061,-0.336,0.463,-0.348,0.398,-0.34,0.0,0.0,0.518,0.481,-0.387,-0.421,0.435,0.324,-0.377,0.524,-0.303,0.403,-0.328,-0.508,-0.59,0.684,-0.795,-0.778,0.661,0.113,-0.609,-0.665,-0.272,0.773,-0.625,-0.56,-0.861,0.519,0.134,-0.716,-0.652,-0.158,-0.676,-0.706,-0.576,-0.605,0.441,0.545,0.781,-0.61,0.778,-0.73,-0.414,-0.625,0.739,0.732,0.557,-0.735,0.598,0.018,0.79,0.565,-0.294,0.372,0.716,0.69,0.703,0.889,0.134,0.708,-0.832,0.651,-0.716,0.581,-0.187,-0.257,-0.877,-0.679,0.701,0.426,-0.623,-0.497,0.824,-0.778,0.621,-0.763,0.684,-0.272
3274,3274.0,42,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,-0.043,-0.071,-0.004,-0.038,0.027,-0.072,-0.061,0.454,0.099,-0.064,0.0,-0.087,0.339,-0.087,-0.062,0.376,0.388,0.059,-0.038,0.018,-0.029,-0.099,0.0,0.0,-0.028,-0.048,0.242,-0.059,-0.045,-0.044,0.32,-0.031,0.359,-0.099,-0.116,-0.423,-0.08,0.45,-0.852,-0.827,0.085,1.065,-0.313,-0.662,-0.272,0.327,-0.575,-0.394,-0.154,0.05,1.031,-0.709,-0.405,-0.245,-0.367,-0.717,-0.516,-0.406,-0.19,0.073,0.64,-0.318,0.825,-0.748,-0.011,-0.37,0.21,0.262,0.116,-0.686,0.045,0.786,0.531,0.015,-0.294,-0.048,0.787,0.15,0.136,0.925,0.963,0.343,-0.897,0.203,-0.684,0.112,-0.187,-0.257,-0.136,-0.112,0.633,-0.033,-0.498,-0.443,0.822,-0.711,0.87,-0.106,0.207,0.022
3095,3095.0,42,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,-0.284,0.086,-0.073,0.008,-0.037,0.122,-0.057,0.385,-0.027,-0.287,0.0,0.126,0.333,-0.049,-0.029,0.376,0.383,0.104,-0.014,-0.004,0.127,-0.138,0.0,0.0,-0.021,0.373,0.271,0.056,0.147,0.065,0.323,0.012,0.353,0.072,-0.153,-0.305,0.37,0.766,-0.516,-0.473,0.932,1.236,-0.325,-0.431,-0.272,0.343,-0.327,-0.198,-0.147,0.163,0.855,1.325,-0.274,-0.144,-0.322,-0.459,-0.313,-0.241,-0.123,0.126,0.875,-0.324,1.093,-0.467,1.153,-0.31,0.348,0.405,0.208,-0.459,0.394,1.189,0.395,0.933,-0.294,-0.098,1.148,0.192,0.189,1.197,1.226,0.56,-0.546,0.397,-0.485,0.485,-0.187,-0.257,-0.142,0.97,0.934,-0.044,-0.397,-0.335,1.088,-0.4,1.189,-0.668,0.406,-0.073
8379,8379.0,42,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.007,0.281,-0.054,-0.021,-0.072,-0.029,-0.052,-0.014,-0.074,-0.007,0.0,-0.115,-0.062,-0.038,-0.042,-0.062,-0.022,-0.079,-0.043,-0.024,0.001,-0.021,0.0,0.0,0.405,0.339,-0.121,-0.102,0.041,-0.114,-0.068,-0.04,-0.093,0.335,-0.089,-0.077,-0.696,0.145,-0.252,-0.093,0.928,0.139,-0.9,-0.175,-0.272,0.501,-0.12,0.735,-0.988,0.71,0.361,0.854,-0.944,0.56,-0.964,-0.235,-0.057,0.849,0.605,0.673,0.004,-0.828,0.119,-0.179,0.766,-1.006,0.936,1.158,0.664,-0.088,0.725,0.013,0.276,0.701,-0.294,0.561,0.073,0.892,0.935,0.198,0.172,0.596,-0.266,0.806,-0.317,0.744,-0.187,-0.257,-1.018,-0.769,0.081,0.583,-0.256,-0.153,0.164,-0.099,0.173,-0.92,0.804,-0.081
3043,3043.0,42,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.31,-0.005,0.191,0.281,0.116,-0.026,-0.07,-0.008,-0.107,0.262,0.0,-0.103,0.275,0.203,0.201,-0.078,0.324,0.087,-0.058,0.249,-0.001,-0.118,0.0,0.0,-0.016,-0.026,-0.097,-0.078,-0.08,-0.003,0.25,0.007,0.291,0.022,0.013,-0.144,-0.387,0.071,-0.253,-0.232,0.032,0.696,-0.54,-0.585,-0.272,0.2,-0.504,-0.325,-0.149,0.267,0.724,0.561,-0.543,-0.249,-0.427,-0.644,-0.464,-0.387,-0.11,0.143,0.001,-0.382,0.734,-0.695,0.525,-0.423,0.665,0.811,0.419,-0.093,0.039,0.018,0.228,0.532,-0.294,-0.134,0.625,0.64,0.631,0.184,0.744,0.443,-0.258,0.606,-0.533,0.1,-0.187,-0.257,-0.147,-0.071,0.13,-0.087,-0.114,-0.496,0.733,-0.6,0.728,-0.587,0.379,-0.016


In [17]:
lambda_net_dataset_valid.as_pandas(config).head()

Unnamed: 0,index,seed,f0v0,f0v1,f0v2,f0v3,f0v4,f1v0,f1v1,f1v2,f1v3,f1v4,f2v0,f2v1,f2v2,f2v3,f2v4,f3v0,f3v1,f3v2,f3v3,f3v4,f4v0,f4v1,f4v2,f4v3,f4v4,f5v0,f5v1,f5v2,f5v3,f5v4,f6v0,f6v1,f6v2,f6v3,f6v4,f7v0,f7v1,f7v2,f7v3,f7v4,f8v0,f8v1,f8v2,f8v3,f8v4,f9v0,f9v1,f9v2,f9v3,f9v4,f10v0,f10v1,f10v2,f10v3,f10v4,f11v0,f11v1,f11v2,f11v3,f11v4,f12v0,f12v1,f12v2,f12v3,f12v4,f13v0,f13v1,f13v2,f13v3,f13v4,f14v0,f14v1,f14v2,f14v3,f14v4,b0,b1,b2,b3,b4,b5,b6,b7,b8,b9,b10,b11,b12,b13,b14,lp0c0,lp0c1,lp1c0,lp1c1,lp2c0,lp2c1,lp3c0,lp3c1,...,wb_349,wb_350,wb_351,wb_352,wb_353,wb_354,wb_355,wb_356,wb_357,wb_358,wb_359,wb_360,wb_361,wb_362,wb_363,wb_364,wb_365,wb_366,wb_367,wb_368,wb_369,wb_370,wb_371,wb_372,wb_373,wb_374,wb_375,wb_376,wb_377,wb_378,wb_379,wb_380,wb_381,wb_382,wb_383,wb_384,wb_385,wb_386,wb_387,wb_388,wb_389,wb_390,wb_391,wb_392,wb_393,wb_394,wb_395,wb_396,wb_397,wb_398,wb_399,wb_400,wb_401,wb_402,wb_403,wb_404,wb_405,wb_406,wb_407,wb_408,wb_409,wb_410,wb_411,wb_412,wb_413,wb_414,wb_415,wb_416,wb_417,wb_418,wb_419,wb_420,wb_421,wb_422,wb_423,wb_424,wb_425,wb_426,wb_427,wb_428,wb_429,wb_430,wb_431,wb_432,wb_433,wb_434,wb_435,wb_436,wb_437,wb_438,wb_439,wb_440,wb_441,wb_442,wb_443,wb_444,wb_445,wb_446,wb_447,wb_448
3466,3466.0,42,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,-0.07,0.12,-0.015,-0.117,-0.05,-0.056,-0.012,-0.005,0.008,0.051,0.0,-0.079,-0.054,-0.099,-0.106,-0.064,-0.045,-0.055,0.01,0.034,0.21,-0.136,0.0,0.0,0.05,0.028,-0.134,-0.127,-0.0,0.04,0.002,0.285,-0.134,-0.074,-0.095,-0.113,-0.228,0.219,-0.613,-0.203,0.524,0.468,-0.417,-0.462,-0.272,0.276,-0.573,-0.202,-0.482,0.132,0.771,0.323,-0.136,0.353,-0.238,-0.506,-0.341,-0.383,0.368,0.421,0.255,-0.335,0.753,-0.635,0.096,-0.283,0.477,0.822,0.349,-0.376,0.49,0.02,0.645,0.483,-0.294,0.168,0.118,0.18,0.428,0.183,0.491,0.419,-0.366,0.512,-0.546,0.139,-0.187,-0.257,-0.515,-0.349,0.37,0.287,-0.384,-0.157,0.449,-0.568,0.774,-0.335,0.453,-0.074
689,689.0,42,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.096,0.061,0.024,0.101,-0.013,0.101,0.01,0.0,0.059,0.023,0.0,-0.016,0.06,0.045,0.05,0.125,0.125,-0.018,0.128,-0.026,0.054,0.055,0.0,0.0,0.145,0.087,0.04,-0.015,0.092,0.056,0.048,-0.032,0.103,0.096,-0.017,-0.274,-0.126,0.21,-0.41,-0.447,0.255,0.369,-0.321,-0.21,-0.272,0.388,-0.105,0.156,-0.368,0.249,0.405,0.238,-0.358,-0.097,-0.34,-0.249,-0.058,-0.124,-0.012,0.193,0.105,-0.259,0.199,-0.18,0.24,-0.273,0.341,0.462,0.133,-0.288,0.132,0.025,0.389,0.109,-0.294,0.052,0.283,0.365,0.353,0.436,0.409,0.29,-0.467,0.101,-0.397,0.328,-0.187,-0.257,-0.36,-0.223,0.32,0.043,-0.325,-0.346,0.321,-0.091,0.412,-0.393,0.363,-0.016
4148,4148.0,42,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.541,0.322,-0.165,-0.068,-0.188,-0.031,-0.176,0.0,-0.32,-0.125,0.0,-0.437,-0.049,-0.094,-0.079,-0.049,-0.033,-0.416,-0.07,-0.136,0.294,-0.102,0.0,0.0,0.482,-0.416,-0.1,-0.297,-0.038,0.079,-0.467,0.463,-0.07,-0.083,-0.186,0.055,-0.579,0.978,-0.236,-0.183,0.032,0.131,-1.097,-1.02,-0.272,0.506,-0.691,-0.506,-0.151,0.294,0.137,1.704,-0.135,0.065,-0.127,-1.129,-0.608,-0.56,0.475,0.515,1.056,-0.705,0.111,-0.824,1.33,-0.658,0.622,0.208,0.557,-0.093,0.56,0.025,0.888,0.591,-0.294,0.428,0.081,0.155,0.148,0.211,0.164,0.936,-0.244,0.654,-1.21,0.117,-0.187,-0.257,-1.002,1.037,0.129,0.452,-0.065,-0.554,1.293,-0.794,0.176,-0.162,0.543,-0.182
2815,2815.0,42,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,-0.273,0.265,-0.061,0.007,-0.223,0.372,-0.183,-0.012,-0.084,-0.147,0.0,-0.295,-0.051,-0.263,-0.272,0.208,0.26,-0.088,-0.043,-0.069,0.33,-0.255,0.0,0.0,0.446,0.353,-0.098,-0.254,0.312,0.184,0.071,-0.057,-0.081,0.354,-0.155,-0.338,-0.512,0.059,-0.243,-0.733,0.686,0.974,-0.653,-0.265,-0.272,0.339,-0.101,-0.055,-0.842,0.55,0.585,0.77,-0.597,-0.071,-0.656,-0.182,0.621,-0.434,0.226,0.438,0.003,-0.565,0.971,-0.152,0.754,-0.533,0.183,0.451,0.347,-0.717,0.529,0.016,0.328,0.62,-0.294,0.184,0.078,0.721,0.721,0.801,1.002,0.19,-0.24,0.063,-0.789,0.606,-0.187,-0.257,-0.894,-0.579,0.13,0.271,-0.611,-0.31,0.472,-0.065,0.175,-0.737,0.559,-0.162
5185,5185.0,42,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.09,0.099,0.022,0.102,-0.046,0.158,-0.068,-0.008,-0.098,0.047,0.0,-0.034,0.049,0.008,0.011,-0.058,0.12,-0.023,0.141,0.051,0.116,-0.077,0.0,0.0,0.179,0.099,-0.074,-0.047,0.085,0.088,0.06,0.179,0.101,0.131,-0.055,-0.23,-0.259,0.064,-0.599,-0.546,0.023,0.495,-0.407,-0.461,-0.272,0.181,-0.371,-0.178,-0.544,0.181,0.485,0.36,-0.413,-0.139,-0.377,-0.521,-0.325,-0.247,-0.037,0.148,0.003,-0.273,0.454,-0.498,0.367,-0.294,0.484,0.584,0.262,-0.475,0.039,0.016,0.219,0.302,-0.294,-0.034,0.451,0.486,0.47,0.184,0.531,0.368,-0.625,0.447,-0.53,0.092,-0.187,-0.257,-0.585,-0.259,0.134,-0.016,-0.388,-0.394,0.5,-0.454,0.563,-0.457,0.334,-0.078


In [18]:
lambda_net_dataset_test.as_pandas(config).head()

Unnamed: 0,index,seed,f0v0,f0v1,f0v2,f0v3,f0v4,f1v0,f1v1,f1v2,f1v3,f1v4,f2v0,f2v1,f2v2,f2v3,f2v4,f3v0,f3v1,f3v2,f3v3,f3v4,f4v0,f4v1,f4v2,f4v3,f4v4,f5v0,f5v1,f5v2,f5v3,f5v4,f6v0,f6v1,f6v2,f6v3,f6v4,f7v0,f7v1,f7v2,f7v3,f7v4,f8v0,f8v1,f8v2,f8v3,f8v4,f9v0,f9v1,f9v2,f9v3,f9v4,f10v0,f10v1,f10v2,f10v3,f10v4,f11v0,f11v1,f11v2,f11v3,f11v4,f12v0,f12v1,f12v2,f12v3,f12v4,f13v0,f13v1,f13v2,f13v3,f13v4,f14v0,f14v1,f14v2,f14v3,f14v4,b0,b1,b2,b3,b4,b5,b6,b7,b8,b9,b10,b11,b12,b13,b14,lp0c0,lp0c1,lp1c0,lp1c1,lp2c0,lp2c1,lp3c0,lp3c1,...,wb_349,wb_350,wb_351,wb_352,wb_353,wb_354,wb_355,wb_356,wb_357,wb_358,wb_359,wb_360,wb_361,wb_362,wb_363,wb_364,wb_365,wb_366,wb_367,wb_368,wb_369,wb_370,wb_371,wb_372,wb_373,wb_374,wb_375,wb_376,wb_377,wb_378,wb_379,wb_380,wb_381,wb_382,wb_383,wb_384,wb_385,wb_386,wb_387,wb_388,wb_389,wb_390,wb_391,wb_392,wb_393,wb_394,wb_395,wb_396,wb_397,wb_398,wb_399,wb_400,wb_401,wb_402,wb_403,wb_404,wb_405,wb_406,wb_407,wb_408,wb_409,wb_410,wb_411,wb_412,wb_413,wb_414,wb_415,wb_416,wb_417,wb_418,wb_419,wb_420,wb_421,wb_422,wb_423,wb_424,wb_425,wb_426,wb_427,wb_428,wb_429,wb_430,wb_431,wb_432,wb_433,wb_434,wb_435,wb_436,wb_437,wb_438,wb_439,wb_440,wb_441,wb_442,wb_443,wb_444,wb_445,wb_446,wb_447,wb_448
7217,7217.0,42,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.135,0.29,0.053,0.108,-0.066,0.138,0.015,0.0,0.061,0.106,0.0,-0.039,0.164,0.03,0.028,-0.007,0.147,-0.064,0.44,0.084,0.385,-0.043,0.0,0.0,0.0,0.144,-0.057,-0.046,0.096,0.053,0.11,0.455,0.084,0.239,-0.036,-0.151,-0.368,0.308,-1.217,-1.039,0.298,0.827,-0.485,-1.083,-0.272,0.299,-0.784,-0.132,-1.047,0.157,0.734,0.753,-0.632,0.088,-0.679,-1.114,-0.342,-0.296,0.168,0.223,-0.432,-0.399,0.467,-1.047,0.649,-0.673,0.127,0.584,0.1,-0.26,0.22,0.025,0.306,0.275,-0.294,0.16,0.575,0.092,0.101,0.192,0.705,0.111,-1.22,0.246,-0.972,0.309,-0.187,-0.257,-0.159,-0.318,0.285,0.173,-0.311,-0.08,0.378,-1.237,0.357,-0.735,0.257,-0.057
8291,8291.0,42,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,-0.051,0.145,-0.069,-0.047,-0.06,0.221,-0.025,0.07,0.023,-0.055,0.0,-0.093,-0.047,-0.099,-0.089,0.072,-0.032,-0.06,0.24,-0.073,0.162,-0.034,0.0,0.0,0.0,0.218,-0.043,-0.086,0.201,-0.118,-0.066,-0.001,-0.074,0.132,-0.069,-0.379,0.319,0.413,-0.626,-0.593,0.406,0.135,-0.137,-0.372,-0.272,0.533,-0.304,-0.271,-0.145,0.175,0.245,-0.009,-0.136,0.262,-0.49,-0.234,-0.089,0.336,0.228,0.359,0.361,-0.409,0.168,-0.434,-0.007,-0.409,0.166,0.247,0.101,-0.45,0.37,0.42,0.546,-0.016,-0.294,0.287,0.08,0.154,0.142,0.544,0.156,0.269,-0.638,0.112,-0.574,0.459,-0.187,-0.257,-0.159,-0.381,0.469,0.307,-0.465,-0.114,0.159,-0.137,0.177,-0.423,0.331,-0.07
4607,4607.0,42,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.081,0.359,-0.083,-0.072,-0.198,0.39,-0.178,0.0,-0.143,0.095,0.0,-0.217,-0.052,-0.108,-0.099,-0.142,-0.028,-0.173,0.418,-0.098,0.248,-0.148,0.0,0.0,0.0,0.415,-0.114,-0.214,0.386,-0.127,-0.071,0.444,-0.082,-0.052,-0.185,-0.587,0.507,0.6,-0.83,-0.803,1.187,0.136,-0.213,-0.897,-0.272,1.165,-0.669,-0.534,-0.154,0.441,0.143,-0.646,-1.141,0.463,-0.136,-0.266,-0.514,0.565,0.487,0.537,0.567,-0.613,0.121,-0.76,0.502,-0.664,0.154,0.223,0.319,-0.635,0.608,0.025,1.317,-0.653,-0.294,0.502,0.082,0.141,0.13,1.667,0.168,0.5,-0.865,0.097,-1.151,0.622,-0.187,-0.257,-0.159,-0.589,0.651,0.517,-0.626,-0.156,0.156,-0.837,0.17,-0.51,0.45,-0.167
5114,5114.0,42,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.359,-0.041,0.297,0.434,0.128,-0.028,-0.05,-0.008,-0.072,0.31,0.0,-0.132,0.385,0.272,0.301,-0.064,0.455,0.12,0.192,0.281,-0.046,-0.109,0.0,0.0,-0.032,-0.059,-0.098,-0.084,0.121,-0.026,0.42,-0.049,0.456,-0.027,0.026,-0.188,-0.504,0.07,-0.234,-1.488,0.032,1.173,-0.495,-0.606,-0.272,0.199,-0.527,-0.38,-0.154,0.176,1.197,0.928,-0.493,-0.246,-0.354,-0.635,-0.539,-0.487,-0.131,0.132,0.007,-0.34,1.197,-0.781,0.751,-0.366,1.057,1.39,0.349,-0.089,0.076,0.018,0.234,0.69,-0.294,-0.112,0.779,0.889,0.957,0.195,1.266,0.446,-1.527,0.702,-0.518,0.109,-0.187,-0.257,-0.138,-0.288,0.128,-0.06,0.206,-0.438,1.451,-0.839,1.347,-0.487,0.352,0.003
1859,1859.0,42,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.16,-0.047,0.064,0.594,0.034,-0.03,-0.341,-0.022,-0.136,0.048,0.0,-0.122,0.122,0.08,0.066,-0.071,0.617,-0.006,-0.06,0.101,-0.08,-0.145,0.0,0.0,-0.005,-0.016,0.249,-0.274,-0.083,0.232,0.384,-0.048,0.584,0.319,-0.031,-0.071,-0.469,-0.22,-0.242,-0.188,0.021,2.209,-2.276,-0.494,-0.272,0.268,-0.082,0.524,-0.146,0.335,2.203,0.337,-0.553,1.205,-0.137,-0.555,-2.218,-0.524,0.653,0.746,-2.081,-0.063,1.863,-0.151,0.295,-0.086,0.4,2.278,0.362,-0.088,1.469,0.007,0.212,0.367,-0.294,0.593,0.447,0.363,0.361,0.197,2.285,0.344,-0.255,0.377,-0.238,0.391,-0.187,-0.257,-0.154,1.269,1.056,0.972,-0.112,-0.45,1.965,-0.094,2.116,-0.521,0.364,-0.163


## Interpretation Network Training

In [19]:
#%load_ext autoreload

In [25]:
#%autoreload 2
((X_valid, y_valid), 
 (X_test, y_test),
 history,

 model) = interpretation_net_training(
                                      lambda_net_dataset_train, 
                                      lambda_net_dataset_valid, 
                                      lambda_net_dataset_test,
                                      config,
                                      #callback_names=['plot_losses']
                                     )



Trial 10 Complete [00h 26m 23s]
val_loss: 0.5649985671043396

Best val_loss So Far: 0.489719420671463
Total elapsed time: 02h 09m 21s
Training Time: 2:09:49
---------------------------------------------------------------------------------------------------------------------------
------------------------------------------------------ LOADING MODELS -----------------------------------------------------
Loading Time: 0:00:02


In [26]:
model.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 449)]        0                                            
__________________________________________________________________________________________________
cast_to_float32 (CastToFloat32) (None, 449)          0           input_1[0][0]                    
__________________________________________________________________________________________________
dense (Dense)                   (None, 512)          230400      cast_to_float32[0][0]            
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 512)          2048        dense[0][0]                      
______________________________________________________________________________________________

In [None]:
if nas:
    for trial in history_list[-1]: 
        print(trial.summary())

In [30]:
y_test_inet_dt_list = []
y_test_distilled_sklearn_dt_list = []

binary_crossentropy_distilled_sklearn_dt_list =[]
accuracy_distilled_sklearn_dt_list = []
f1_score_distilled_sklearn_dt_list = []

binary_crossentropy_inet_dt_list =[]
accuracy_inet_dt_list = []
f1_score_inet_dt_list = []

#inet_metric_function_list = []

number = lambda_net_dataset_test.y_test_lambda_array.shape[0]#10

for lambda_net_parameters, lambda_net, X_test_lambda, y_test_lambda in tqdm(zip(lambda_net_dataset_test.network_parameters_array[:number], lambda_net_dataset_test.network_list[:number], lambda_net_dataset_test.X_test_lambda_array[:number], lambda_net_dataset_test.y_test_lambda_array[:number]), total=lambda_net_dataset_test.y_test_lambda_array[:number].shape[0]):
    dt_inet = model.predict(np.array([lambda_net_parameters]))[0]
    if nas:
        dt_inet = dt_inet[:function_representation_length]

    
    X_data_random = generate_random_data_points_custom(config['data']['x_min'], config['data']['x_max'], config['evaluation']['random_evaluation_dataset_size'], config['data']['number_of_variables'])
    y_data_random_lambda_pred = lambda_net.predict(X_data_random)
    y_data_random_lambda_pred = np.round(y_data_random_lambda_pred).astype(np.int64)
    
    dt_sklearn_distilled = DecisionTreeClassifier(max_depth=config['function_family']['maximum_depth'])
    dt_sklearn_distilled.fit(X_data_random, y_data_random_lambda_pred)
    
    
    if dt_type == 'SDT':
        y_test_inet_dt  = calculate_function_value_from_decision_tree_parameters_wrapper(X_test_lambda, config)(dt_inet).numpy()
    elif dt_type == 'vanilla':
        y_test_inet_dt  = calculate_function_value_from_vanilla_decision_tree_parameters_wrapper(X_test_lambda, config)(dt_inet).numpy()
    y_test_distilled_sklearn_dt = dt_sklearn_distilled.predict(X_test_lambda)
    
    y_test_lambda_pred = lambda_net.predict(X_test_lambda)
    y_test_lambda_pred = np.round(y_test_lambda_pred)
    
    binary_crossentropy_distilled_sklearn_dt = log_loss(y_test_lambda_pred, y_test_distilled_sklearn_dt)
    accuracy_distilled_sklearn_dt = accuracy_score(y_test_lambda_pred, np.round(y_test_distilled_sklearn_dt))
    f1_score_distilled_sklearn_dt = f1_score(y_test_lambda_pred, np.round(y_test_distilled_sklearn_dt))
    
    binary_crossentropy_inet_dt = log_loss(y_test_lambda_pred, y_test_inet_dt)
    accuracy_inet_dt = accuracy_score(y_test_lambda_pred, np.round(y_test_inet_dt))
    f1_score_inet_dt = f1_score(y_test_lambda_pred, np.round(y_test_inet_dt))
    
    
    y_test_inet_dt_list.append(y_test_inet_dt)
    y_test_distilled_sklearn_dt_list.append(y_test_distilled_sklearn_dt)    

    binary_crossentropy_distilled_sklearn_dt_list.append(np.nan_to_num(binary_crossentropy_distilled_sklearn_dt))
    accuracy_distilled_sklearn_dt_list.append(np.nan_to_num(accuracy_distilled_sklearn_dt))
    f1_score_distilled_sklearn_dt_list.append(np.nan_to_num(f1_score_distilled_sklearn_dt))

    binary_crossentropy_inet_dt_list.append(np.nan_to_num(binary_crossentropy_inet_dt))
    accuracy_inet_dt_list.append(np.nan_to_num(accuracy_inet_dt))
    f1_score_inet_dt_list.append(np.nan_to_num(f1_score_inet_dt))
    
y_test_inet_dt_list = np.array(y_test_inet_dt_list)
y_test_distilled_sklearn_dt_list = np.array(y_test_distilled_sklearn_dt_list)

binary_crossentropy_distilled_sklearn_dt_list = np.array(binary_crossentropy_distilled_sklearn_dt_list)
accuracy_distilled_sklearn_dt_list = np.array(accuracy_distilled_sklearn_dt_list)
f1_score_distilled_sklearn_dt_list = np.array(f1_score_distilled_sklearn_dt_list)

binary_crossentropy_inet_dt_list = np.array(binary_crossentropy_inet_dt_list)
accuracy_inet_dt_list = np.array(accuracy_inet_dt_list)
f1_score_inet_dt_list = np.array(f1_score_inet_dt_list)    

    
print('Binary Crossentropy:\t\t', np.round(np.mean(binary_crossentropy_distilled_sklearn_dt_list), 3), '(Sklearn DT)' , '\t', np.round(np.mean(binary_crossentropy_inet_dt_list), 3), '(I-Net DT)')
print('Accuracy:\t\t', np.round(np.mean(accuracy_distilled_sklearn_dt_list), 3), '(Sklearn DT)' , '\t', np.round(np.mean(accuracy_inet_dt_list), 3), '(I-Net DT)')
print('F1 Score:\t\t', np.round(np.mean(f1_score_distilled_sklearn_dt_list), 3), '(Sklearn DT)' , '\t', np.round(np.mean(f1_score_inet_dt_list), 3), '(I-Net DT)')
      

  0%|          | 0/50 [00:00<?, ?it/s]

Binary Crossentropy:		 2.357 (Sklearn DT) 	 0.494 (I-Net DT)
Accuracy:		 0.932 (Sklearn DT) 	 0.742 (I-Net DT)
F1 Score:		 0.928 (Sklearn DT) 	 0.722 (I-Net DT)


In [31]:
dt_inet

array([0.49746728, 0.50673634, 0.50530845, 0.47406688, 0.51658875,
       0.4707519 , 0.48858735, 0.48731878, 0.4944548 , 0.49451634,
       0.5213153 , 0.535468  , 0.50523174, 0.515518  , 0.48073077,
       0.21596205, 0.20765604, 0.19602226, 0.1849192 , 0.19544046,
       0.20178193, 0.21388663, 0.19310056, 0.18829891, 0.20293203,
       0.19134349, 0.19843073, 0.19121753, 0.20719475, 0.21181346,
       0.19892852, 0.17564636, 0.19636513, 0.20807502, 0.22098495,
       0.19388397, 0.20222628, 0.1996802 , 0.1865783 , 0.2176312 ,
       0.1875158 , 0.20455629, 0.22917129, 0.20001702, 0.17873959,
       0.20704377, 0.18870327, 0.18478656, 0.2103651 , 0.20910135,
       0.18761574, 0.22929119, 0.20454343, 0.19199656, 0.18655305,
       0.22081643, 0.17805547, 0.2173234 , 0.17730759, 0.20649706,
       0.23124366, 0.1776883 , 0.18093817, 0.20386213, 0.2062677 ,
       0.21015465, 0.20074128, 0.18030185, 0.23034449, 0.17845777,
       0.20929116, 0.1773132 , 0.20590067, 0.20858929, 0.19890

In [32]:
get_shaped_parameters_for_decision_tree(dt_inet, config)

(<tf.Tensor: shape=(15, 5), dtype=float32, numpy=
 array([[0.49746728, 0.        , 0.        , 0.        , 0.        ],
        [0.        , 0.50673634, 0.        , 0.        , 0.        ],
        [0.        , 0.        , 0.        , 0.        , 0.50530845],
        [0.        , 0.        , 0.        , 0.        , 0.47406688],
        [0.        , 0.        , 0.        , 0.        , 0.51658875],
        [0.        , 0.        , 0.4707519 , 0.        , 0.        ],
        [0.        , 0.        , 0.        , 0.48858735, 0.        ],
        [0.        , 0.48731878, 0.        , 0.        , 0.        ],
        [0.4944548 , 0.        , 0.        , 0.        , 0.        ],
        [0.49451634, 0.        , 0.        , 0.        , 0.        ],
        [0.        , 0.        , 0.        , 0.5213153 , 0.        ],
        [0.535468  , 0.        , 0.        , 0.        , 0.        ],
        [0.50523174, 0.        , 0.        , 0.        , 0.        ],
        [0.        , 0.        , 0.51551