# Inerpretation-Net Training

## Specitication of Experiment Settings

In [1]:
import sys
print(sys.version)


3.9.6 (default, Aug 18 2021, 19:38:01) 
[GCC 7.5.0]


In [2]:
#######################################################################################################################################
###################################################### CONFIG FILE ####################################################################
#######################################################################################################################################
sleep_time = 0 #minutes



config = {
    'function_family': {
        'maximum_depth': 4,
        'beta': 1,
        'decision_sparsity': 1,
        'fully_grown': True,   
        'dt_type': 'SDT', #'vanilla', 'SDT'
    },
    'data': {
        'number_of_variables': 5, 
        'num_classes': 2,
        
        'function_generation_type': 'make_classification', # 'make_classification' 'random_decision_tree' 'random_vanilla_decision_tree_trained'
        'objective': 'classification', # 'regression'
        
        'x_max': 1,
        'x_min': 0,
        'x_distrib': 'uniform', #'normal', 'uniform',       
                
        'lambda_dataset_size': 1000, #number of samples per function
        #'number_of_generated_datasets': 10000,
        
        'noise_injected_level': 0, 
        'noise_injected_type': 'flip_percentage', # '' 'normal' 'uniform' 'normal_range' 'uniform_range'
    }, 
    'lambda_net': {
        'epochs_lambda': 1000,
        'early_stopping_lambda': True, 
        'early_stopping_min_delta_lambda': 1e-2,
        'batch_lambda': 64,
        'dropout_lambda': 0,
        'lambda_network_layers': [64],
        'optimizer_lambda': 'adam',
        'loss_lambda': 'binary_crossentropy', #categorical_crossentropy
        
        'number_of_lambda_weights': None,
        
        'number_initializations_lambda': 1, 
        
        'number_of_trained_lambda_nets': 10000,
    },     
    
    'i_net': {
        'dense_layers': [2048],
        'convolution_layers': None,
        'lstm_layers': None,
        'dropout': [0],
        
        'optimizer': 'adam', #adam
        'learning_rate': 0.01,
        'loss': 'binary_crossentropy',
        'metrics': ['binary_accuracy'],
        
        'epochs': 200, 
        'early_stopping': True,
        'batch_size': 256,

        'interpretation_dataset_size': 1000,
                
        'test_size': 50, #Float for fraction, Int for number 0
        
        'function_representation_type': 1, # 1=standard representation; 2=sparse representation, 3=vanilla_dt

        'optimize_decision_function': True, #False
        'function_value_loss': True, #False
                      
        'data_reshape_version': None, #default to 2 options:(None, 0,1 2)
        
        'nas': True,
        'nas_type': 'SEQUENTIAL', #options:(None, 'SEQUENTIAL', 'CNN', 'LSTM', 'CNN-LSTM', 'CNN-LSTM-parallel')      
        'nas_trials': 2,
    },    
    
    'evaluation': {   
        #'inet_holdout_seed_evaluation': False,
            
        'random_evaluation_dataset_size': 500, 
        'per_network_optimization_dataset_size': 5000,

        'sklearn_dt_benchmark': False,
        'sdt_benchmark': False,
        
    },    
    
    'computation':{
        'load_model': False,
        
        'n_jobs': -3,
        'use_gpu': False,
        'gpu_numbers': '0',
        'RANDOM_SEED': 42,   
    }
}


## Imports

In [3]:
#######################################################################################################################################
########################################### IMPORT GLOBAL VARIABLES FROM CONFIG #######################################################
#######################################################################################################################################
globals().update(config['function_family'])
globals().update(config['data'])
globals().update(config['lambda_net'])
globals().update(config['i_net'])
globals().update(config['evaluation'])
globals().update(config['computation'])

In [4]:
#######################################################################################################################################
##################################################### IMPORT LIBRARIES ################################################################
#######################################################################################################################################
import warnings
warnings.filterwarnings('ignore')
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 

from itertools import product       
from tqdm import tqdm_notebook as tqdm
import pickle
import numpy as np
import pandas as pd
import scipy as sp
import timeit
import psutil

from functools import reduce
from more_itertools import random_product 
from sklearn.preprocessing import Normalizer

import sys
import shutil

import logging

#from prettytable import PrettyTable
#import colored
import math

import time
from datetime import datetime
from collections.abc import Iterable


from joblib import Parallel, delayed

from scipy.integrate import quad

from sklearn.model_selection import cross_val_score, train_test_split, StratifiedKFold, KFold, ParameterGrid, ParameterSampler
from sklearn.metrics import accuracy_score, log_loss, roc_auc_score, f1_score, mean_absolute_error, r2_score, log_loss
from sklearn.tree import DecisionTreeClassifier, plot_tree



#from similaritymeasures import frechet_dist, area_between_two_curves, dtw
import tensorflow as tf
#import tensorflow_addons as tfa
#import keras
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense, Dropout
from tensorflow.keras.utils import plot_model
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau


import tensorflow.keras.backend as K
from livelossplot import PlotLossesKerasTF
#from keras_tqdm import TQDMNotebookCallback

from matplotlib import pyplot as plt
import seaborn as sns


import random 



from IPython.display import Image
from IPython.display import display, Math, Latex, clear_output



In [5]:
tf.__version__

'2.5.1'

In [6]:
#######################################################################################################################################
########################################### IMPORT GLOBAL VARIABLES FROM CONFIG #######################################################
#######################################################################################################################################
globals().update(config['function_family'])
globals().update(config['data'])
globals().update(config['lambda_net'])
globals().update(config['evaluation'])
globals().update(config['computation'])

In [7]:
#######################################################################################################################################
################################################### VARIABLE ADJUSTMENTS ##############################################################
#######################################################################################################################################

config['i_net']['data_reshape_version'] = 2 if data_reshape_version == None and (convolution_layers != None or lstm_layers != None or (nas and nas_type != 'SEQUENTIAL')) else data_reshape_version

#######################################################################################################################################
###################################################### SET VARIABLES + DESIGN #########################################################
#######################################################################################################################################

#os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_numbers if use_gpu else ''
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
#os.environ['XLA_FLAGS'] =  '--xla_gpu_cuda_data_dir=/usr/lib/cuda-10.1'

logging.getLogger('tensorflow').disabled = True

sns.set_style("darkgrid")
#np.set_printoptions(suppress=True)

random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
if int(tf.__version__[0]) >= 2:
    tf.random.set_seed(RANDOM_SEED)
else:
    tf.set_random_seed(RANDOM_SEED)
    
    
pd.set_option('display.float_format', lambda x: '%.3f' % x)
pd.set_option('display.max_columns', 200)
np.set_printoptions(threshold=200)
np.set_printoptions(suppress=True)



In [8]:
from utilities.InterpretationNet import *
from utilities.LambdaNet import *
from utilities.metrics import *
from utilities.utility_functions import *
from utilities.DecisionTree_BASIC import *

#######################################################################################################################################
####################################################### CONFIG ADJUSTMENTS ############################################################
#######################################################################################################################################

config['lambda_net']['number_of_lambda_weights'] = get_number_of_lambda_net_parameters(lambda_network_layers, number_of_variables, num_classes)
config['function_family']['basic_function_representation_length'] = (2 ** maximum_depth - 1) * number_of_variables + (2 ** maximum_depth - 1) + (2 ** maximum_depth) * num_classes
config['function_family']['function_representation_length'] = ( 
       ((2 ** maximum_depth - 1) * decision_sparsity) * 2 + (2 ** maximum_depth - 1) + (2 ** maximum_depth) * num_classes  if function_representation_type == 1 and dt_type == 'SDT'
  else (2 ** maximum_depth - 1) * decision_sparsity + (2 ** maximum_depth - 1) + ((2 ** maximum_depth - 1)  * decision_sparsity * number_of_variables) + (2 ** maximum_depth) * num_classes if function_representation_type == 2 and dt_type == 'SDT'
  else ((2 ** maximum_depth - 1) * decision_sparsity) * 2 + (2 ** maximum_depth)  if function_representation_type == 1 and dt_type == 'vanilla'
  else (2 ** maximum_depth - 1) * decision_sparsity + ((2 ** maximum_depth - 1)  * decision_sparsity * number_of_variables) + (2 ** maximum_depth) if function_representation_type == 2 and dt_type == 'vanilla'
  else None
                                                            )
#######################################################################################################################################
################################################## UPDATE VARIABLES ###################################################################
#######################################################################################################################################
globals().update(config['function_family'])
globals().update(config['data'])
globals().update(config['lambda_net'])
globals().update(config['i_net'])
globals().update(config['evaluation'])
globals().update(config['computation'])

#initialize_LambdaNet_config_from_curent_notebook(config)
#initialize_metrics_config_from_curent_notebook(config)
#initialize_utility_functions_config_from_curent_notebook(config)
#initialize_InterpretationNet_config_from_curent_notebook(config)


#######################################################################################################################################
###################################################### PATH + FOLDER CREATION #########################################################
#######################################################################################################################################
globals().update(generate_paths(config, path_type='interpretation_net'))
create_folders_inet(config)

#######################################################################################################################################
############################################################ SLEEP TIMER ##############################################################
#######################################################################################################################################
sleep_minutes(sleep_time)

In [9]:
print(path_identifier_interpretation_net)

print(path_identifier_lambda_net_data)


lNetSize1000_numLNets10000_var5_class2_make_classification_xMax1_xMin0_xDistuniform_depth4_beta1_decisionSpars1_fullyGrown/64_e1000ES0.01_b64_drop0_adam_binary_crossentropy_fixedInit1-seed42/inet_dense2048_drop0e200b256_adam
lNetSize1000_numLNets10000_var5_class2_make_classification_xMax1_xMin0_xDistuniform_depth4_beta1_decisionSpars1_fullyGrown/64_e1000ES0.01_b64_drop0_adam_binary_crossentropy_fixedInit1-seed42


In [10]:
print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU')))
print("Num XLA-GPUs Available: ", len(tf.config.experimental.list_physical_devices('XLA_GPU')))

Num GPUs Available:  0
Num XLA-GPUs Available:  0


## Load Data and Generate Datasets

In [11]:
def load_lambda_nets(config, no_noise=False, n_jobs=1):
    
    #def generate_lambda_net()
    
    #if psutil.virtual_memory().percent > 80:
        #raise SystemExit("Out of RAM!")
    
    if no_noise==True:
        config['noise_injected_level'] = 0
    path_dict = generate_paths(config, path_type='interpretation_net')        
        
    directory = './data/weights/' + 'weights_' + path_dict['path_identifier_lambda_net_data'] + '/'
    path_network_parameters = directory + 'weights' + '.txt'
    path_X_data = directory + 'X_test_lambda.txt'
    path_y_data = directory + 'y_test_lambda.txt'        
    
    network_parameters = pd.read_csv(path_network_parameters, sep=",", header=None)
    network_parameters = network_parameters.sort_values(by=0)
    if no_noise == False:
        network_parameters = network_parameters.sample(n=config['i_net']['interpretation_dataset_size'], random_state=config['computation']['RANDOM_SEED'])
    
    X_test_lambda = pd.read_csv(path_X_data, sep=",", header=None)
    X_test_lambda = X_test_lambda.sort_values(by=0)
    if no_noise == False:
        X_test_lambda = X_test_lambda.sample(n=config['i_net']['interpretation_dataset_size'], random_state=config['computation']['RANDOM_SEED'])
    
    y_test_lambda = pd.read_csv(path_y_data, sep=",", header=None)
    y_test_lambda = y_test_lambda.sort_values(by=0)
    if no_noise == False:
        y_test_lambda = y_test_lambda.sample(n=config['i_net']['interpretation_dataset_size'], random_state=config['computation']['RANDOM_SEED'])
        
        
    parallel = Parallel(n_jobs=n_jobs, verbose=3, backend='loky') #loky

    lambda_nets = parallel(delayed(LambdaNet)(network_parameters_row, 
                                              X_test_lambda_row, 
                                              y_test_lambda_row, 
                                              config) for network_parameters_row, X_test_lambda_row, y_test_lambda_row in zip(network_parameters.values, X_test_lambda.values, y_test_lambda.values))          
    del parallel
    
    base_model = generate_base_model(config)  
    
    def initialize_network_wrapper(config, lambda_net, base_model):
        lambda_net.initialize_network(config, base_model)
    
    parallel = Parallel(n_jobs=n_jobs, verbose=3, backend='sequential')
    _ = parallel(delayed(initialize_network_wrapper)(config, lambda_net, base_model) for lambda_net in lambda_nets)   
    del parallel
    
    def initialize_target_function_wrapper(config, lambda_net):
        lambda_net.initialize_target_function(config)
    
    parallel = Parallel(n_jobs=n_jobs, verbose=3, backend='sequential')
    _ = parallel(delayed(initialize_target_function_wrapper)(config, lambda_net) for lambda_net in lambda_nets)   
    del parallel
        
    
    #lambda_nets = [None] * network_parameters.shape[0]
    #for i, (network_parameters_row, X_test_lambda_row, y_test_lambda_row) in tqdm(enumerate(zip(network_parameters.values, X_test_lambda.values, y_test_lambda.values)), total=network_parameters.values.shape[0]):        
    #    lambda_net = LambdaNet(network_parameters_row, X_test_lambda_row, y_test_lambda_row, config)
    #    lambda_nets[i] = lambda_net
                
    lambda_net_dataset = LambdaNetDataset(lambda_nets)
        
    return lambda_net_dataset
    

In [12]:
#LOAD DATA
if noise_injected_level > 0:
    lambda_net_dataset_training = load_lambda_nets(config, no_noise=True, n_jobs=n_jobs)
    lambda_net_dataset_evaluation = load_lambda_nets(config, n_jobs=n_jobs)

    lambda_net_dataset_train, lambda_net_dataset_valid = split_LambdaNetDataset(lambda_net_dataset_training, test_split=0.1)
    _, lambda_net_dataset_test = split_LambdaNetDataset(lambda_net_dataset_evaluation, test_split=test_size)
    
else:
    lambda_net_dataset = load_lambda_nets(config, n_jobs=n_jobs)

    lambda_net_dataset_train_with_valid, lambda_net_dataset_test = split_LambdaNetDataset(lambda_net_dataset, test_split=test_size)
    lambda_net_dataset_train, lambda_net_dataset_valid = split_LambdaNetDataset(lambda_net_dataset_train_with_valid, test_split=0.1)

    

[Parallel(n_jobs=-3)]: Using backend LokyBackend with 14 concurrent workers.
[Parallel(n_jobs=-3)]: Done   4 tasks      | elapsed:    6.8s
[Parallel(n_jobs=-3)]: Done 116 tasks      | elapsed:    7.0s
[Parallel(n_jobs=-3)]: Done 1000 out of 1000 | elapsed:    7.4s finished
[Parallel(n_jobs=-3)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=-3)]: Done   1 out of   1 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=-3)]: Done   2 out of   2 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=-3)]: Done 1000 out of 1000 | elapsed:   20.0s finished
[Parallel(n_jobs=-3)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=-3)]: Done   1 out of   1 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=-3)]: Done   2 out of   2 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=-3)]: Done 1000 out of 1000 | elapsed:    2.0s finished


## Data Inspection

In [13]:
lambda_net_dataset_train.shape

(855, 573)

In [14]:
lambda_net_dataset_valid.shape

(95, 573)

In [15]:
lambda_net_dataset_test.shape

(50, 573)

In [16]:
lambda_net_dataset_train.as_pandas(config).head()

Unnamed: 0,index,seed,f0v0,f0v1,f0v2,f0v3,f0v4,f1v0,f1v1,f1v2,f1v3,f1v4,f2v0,f2v1,f2v2,f2v3,f2v4,f3v0,f3v1,f3v2,f3v3,f3v4,f4v0,f4v1,f4v2,f4v3,f4v4,f5v0,f5v1,f5v2,f5v3,f5v4,f6v0,f6v1,f6v2,f6v3,f6v4,f7v0,f7v1,f7v2,f7v3,f7v4,f8v0,f8v1,f8v2,f8v3,f8v4,f9v0,f9v1,f9v2,f9v3,f9v4,f10v0,f10v1,f10v2,f10v3,f10v4,f11v0,f11v1,f11v2,f11v3,f11v4,f12v0,f12v1,f12v2,f12v3,f12v4,f13v0,f13v1,f13v2,f13v3,f13v4,f14v0,f14v1,f14v2,f14v3,f14v4,b0,b1,b2,b3,b4,b5,b6,b7,b8,b9,b10,b11,b12,b13,b14,lp0c0,lp0c1,lp1c0,lp1c1,lp2c0,lp2c1,lp3c0,lp3c1,...,wb_349,wb_350,wb_351,wb_352,wb_353,wb_354,wb_355,wb_356,wb_357,wb_358,wb_359,wb_360,wb_361,wb_362,wb_363,wb_364,wb_365,wb_366,wb_367,wb_368,wb_369,wb_370,wb_371,wb_372,wb_373,wb_374,wb_375,wb_376,wb_377,wb_378,wb_379,wb_380,wb_381,wb_382,wb_383,wb_384,wb_385,wb_386,wb_387,wb_388,wb_389,wb_390,wb_391,wb_392,wb_393,wb_394,wb_395,wb_396,wb_397,wb_398,wb_399,wb_400,wb_401,wb_402,wb_403,wb_404,wb_405,wb_406,wb_407,wb_408,wb_409,wb_410,wb_411,wb_412,wb_413,wb_414,wb_415,wb_416,wb_417,wb_418,wb_419,wb_420,wb_421,wb_422,wb_423,wb_424,wb_425,wb_426,wb_427,wb_428,wb_429,wb_430,wb_431,wb_432,wb_433,wb_434,wb_435,wb_436,wb_437,wb_438,wb_439,wb_440,wb_441,wb_442,wb_443,wb_444,wb_445,wb_446,wb_447,wb_448
9644,9644.0,42,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,-0.037,0.177,-0.052,-0.073,-0.125,0.255,-0.091,0.0,-0.076,-0.016,0.0,-0.175,-0.029,-0.074,-0.062,-0.016,0.0,-0.176,0.222,-0.066,0.223,-0.099,0.0,0.0,0.0,0.218,-0.094,-0.173,0.235,0.112,-0.12,0.0,-0.03,-0.028,-0.135,-0.424,0.588,0.618,-0.281,-1.371,0.052,0.155,-0.672,-0.7,-0.272,0.422,-0.641,-0.528,-0.154,0.142,0.163,0.022,-0.34,0.322,-0.934,-0.237,-0.624,0.559,0.498,0.481,0.734,-0.646,0.765,-0.203,0.011,-0.662,0.182,0.947,0.503,-0.8,0.592,0.025,0.658,0.569,-0.294,0.363,0.096,0.173,0.162,0.232,0.187,0.606,-1.441,0.682,-0.727,0.159,-0.187,-0.257,-0.159,-0.69,0.141,0.469,-0.716,-0.153,0.922,-0.13,0.212,-0.192,0.434,-0.127
36,36.0,42,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.079,0.118,0.003,0.093,-0.063,0.0,-0.021,0.0,0.024,0.101,0.0,-0.136,0.015,0.027,0.027,0.19,0.064,-0.04,0.0,-0.035,0.087,0.016,0.0,0.0,0.0,0.124,-0.052,-0.113,-0.025,0.104,0.002,0.0,0.062,0.116,-0.074,-0.566,-0.778,0.535,-0.288,-0.256,0.89,1.28,-0.27,-0.194,-0.272,0.539,-0.136,0.718,-0.154,0.758,1.289,0.839,-0.929,-0.448,-0.354,-0.245,-0.081,-0.8,0.464,0.326,0.105,-0.614,1.157,-0.214,0.854,-0.626,0.997,1.155,0.655,-0.113,0.163,0.025,0.411,0.959,-0.294,0.451,0.575,0.986,0.966,1.389,1.412,0.615,-0.298,1.041,-0.576,0.642,-0.187,-0.257,-0.159,-0.735,0.669,0.346,-0.155,-0.556,0.682,-0.13,0.97,-0.936,0.575,-0.093
1568,1568.0,42,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,-0.036,0.261,-0.013,-0.376,-0.121,0.286,-0.32,0.0,-0.289,-0.112,0.0,-0.201,0.237,-0.105,-0.089,-0.026,0.0,-0.146,0.211,0.149,0.011,-0.412,0.0,0.0,0.0,-0.028,-0.143,-0.197,0.255,0.222,-0.042,0.0,-0.038,0.058,-0.176,-1.695,1.853,0.421,-1.814,-1.88,2.177,0.155,-0.391,-1.801,-0.272,0.617,-1.764,-1.41,-0.154,0.474,0.163,0.015,-1.118,-0.145,-0.527,-1.894,-1.814,2.071,0.352,0.364,0.011,-0.429,0.146,-1.923,-0.021,-0.45,1.097,2.533,0.395,-0.379,2.038,0.025,1.941,2.047,-0.294,0.286,2.948,0.895,0.928,0.23,0.187,0.402,-1.899,2.454,-1.249,1.942,-0.187,-0.257,-0.159,-0.032,0.236,0.294,-0.328,-0.323,0.179,-0.13,0.205,-2.26,0.499,-0.22
6850,6850.0,42,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.307,-0.186,-0.036,0.26,0.074,-0.15,-0.031,0.0,-0.032,-0.079,0.0,0.193,0.201,0.193,0.255,0.354,0.0,-0.056,-0.105,-0.059,-0.033,0.194,0.0,0.0,0.0,0.007,0.284,0.041,-0.155,-0.162,0.255,-0.007,0.345,-0.049,0.123,-0.543,0.624,0.485,-1.128,-0.884,0.696,0.155,-0.156,-0.188,-0.272,0.521,-0.39,-0.379,-0.154,0.388,0.163,0.022,-0.654,-0.297,-0.609,-0.244,-0.068,-0.507,-0.31,0.104,0.64,-0.506,0.16,-0.194,0.777,-0.525,0.188,1.027,0.042,-0.745,-0.022,0.025,0.395,0.025,-0.294,0.127,0.69,0.752,0.777,0.947,0.187,0.182,-1.001,0.124,-0.274,0.622,-0.187,-0.257,-0.159,-0.633,0.571,0.012,-0.748,-0.384,1.023,-0.124,0.723,-0.179,0.211,0.142
1862,1862.0,42,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,-0.051,-0.034,-0.067,-0.027,0.05,-0.019,0.091,0.0,-0.102,0.174,0.0,0.049,-0.033,-0.082,-0.067,-0.039,-0.008,0.034,-0.025,0.155,-0.001,-0.106,0.0,0.0,0.0,-0.005,-0.004,0.043,-0.027,0.009,-0.055,0.0,-0.051,-0.037,0.004,0.638,-1.019,0.783,-0.272,-0.234,0.805,0.155,-0.673,-1.09,-0.272,0.198,-0.101,-0.876,-0.154,0.493,0.163,0.017,-0.176,-0.546,-0.159,-0.733,-0.946,-1.015,0.736,0.517,0.005,-1.134,0.119,-1.093,0.006,-0.068,0.174,0.25,0.742,-0.096,0.828,0.025,0.242,1.14,-0.294,0.608,0.092,0.167,0.158,0.21,0.182,0.744,-0.279,1.136,-1.494,0.191,-0.187,-0.257,-0.159,-1.17,0.822,0.657,-1.36,-0.15,0.171,-0.13,0.189,-0.187,0.384,-0.0


In [17]:
lambda_net_dataset_valid.as_pandas(config).head()

Unnamed: 0,index,seed,f0v0,f0v1,f0v2,f0v3,f0v4,f1v0,f1v1,f1v2,f1v3,f1v4,f2v0,f2v1,f2v2,f2v3,f2v4,f3v0,f3v1,f3v2,f3v3,f3v4,f4v0,f4v1,f4v2,f4v3,f4v4,f5v0,f5v1,f5v2,f5v3,f5v4,f6v0,f6v1,f6v2,f6v3,f6v4,f7v0,f7v1,f7v2,f7v3,f7v4,f8v0,f8v1,f8v2,f8v3,f8v4,f9v0,f9v1,f9v2,f9v3,f9v4,f10v0,f10v1,f10v2,f10v3,f10v4,f11v0,f11v1,f11v2,f11v3,f11v4,f12v0,f12v1,f12v2,f12v3,f12v4,f13v0,f13v1,f13v2,f13v3,f13v4,f14v0,f14v1,f14v2,f14v3,f14v4,b0,b1,b2,b3,b4,b5,b6,b7,b8,b9,b10,b11,b12,b13,b14,lp0c0,lp0c1,lp1c0,lp1c1,lp2c0,lp2c1,lp3c0,lp3c1,...,wb_349,wb_350,wb_351,wb_352,wb_353,wb_354,wb_355,wb_356,wb_357,wb_358,wb_359,wb_360,wb_361,wb_362,wb_363,wb_364,wb_365,wb_366,wb_367,wb_368,wb_369,wb_370,wb_371,wb_372,wb_373,wb_374,wb_375,wb_376,wb_377,wb_378,wb_379,wb_380,wb_381,wb_382,wb_383,wb_384,wb_385,wb_386,wb_387,wb_388,wb_389,wb_390,wb_391,wb_392,wb_393,wb_394,wb_395,wb_396,wb_397,wb_398,wb_399,wb_400,wb_401,wb_402,wb_403,wb_404,wb_405,wb_406,wb_407,wb_408,wb_409,wb_410,wb_411,wb_412,wb_413,wb_414,wb_415,wb_416,wb_417,wb_418,wb_419,wb_420,wb_421,wb_422,wb_423,wb_424,wb_425,wb_426,wb_427,wb_428,wb_429,wb_430,wb_431,wb_432,wb_433,wb_434,wb_435,wb_436,wb_437,wb_438,wb_439,wb_440,wb_441,wb_442,wb_443,wb_444,wb_445,wb_446,wb_447,wb_448
3501,3501.0,42,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.024,-0.067,-0.053,-0.037,0.029,-0.02,-0.257,0.0,-0.039,-0.074,0.0,0.071,0.169,-0.082,-0.063,0.0,0.0,0.046,0.07,-0.059,-0.041,0.086,0.0,0.0,0.0,-0.093,0.147,0.054,0.022,-0.019,-0.02,0.316,0.161,-0.109,0.024,-0.593,0.6,0.64,-1.495,-0.794,0.751,0.155,-0.145,-1.62,-0.272,0.602,-0.664,-0.58,-0.154,0.609,0.163,-1.449,-0.824,-0.407,-1.171,-1.676,-1.003,-0.737,-0.474,0.195,0.638,-0.083,0.16,-1.286,-0.544,-0.809,0.179,0.246,0.303,-0.794,0.533,0.025,0.598,0.001,-0.294,0.422,1.261,0.165,0.148,0.246,0.187,0.375,-0.789,0.124,-0.263,0.62,-0.187,-0.257,-0.159,-0.972,0.69,0.328,-0.676,-0.369,0.188,-1.209,0.958,-1.222,0.265,0.035
5748,5748.0,42,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.013,0.262,-0.057,-0.029,-0.103,0.28,-0.102,0.0,-0.096,-0.024,0.0,-0.174,-0.029,-0.062,-0.05,-0.027,-0.005,-0.101,0.0,-0.067,-0.023,-0.106,0.0,0.0,0.0,0.307,-0.091,-0.164,0.333,0.118,-0.038,0.0,-0.037,0.09,-0.141,-0.294,-0.764,0.549,-0.286,-1.435,0.905,0.15,-0.974,-0.93,-0.272,0.39,-0.136,0.768,-0.154,0.732,0.163,0.022,-0.156,0.615,-1.058,-0.87,-0.212,-0.841,0.684,0.686,0.12,-0.882,0.155,-0.215,0.886,-0.833,0.177,0.251,0.691,-1.021,0.798,0.025,0.235,0.868,-0.294,0.591,0.095,0.176,0.166,0.223,0.184,0.498,-0.298,0.112,-0.276,0.709,-0.187,-0.257,-0.159,-0.822,0.137,0.631,-1.024,-0.106,0.179,-0.13,0.2,-1.016,0.722,-0.169
8103,8103.0,42,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,-0.038,0.165,-0.068,0.023,-0.124,-0.005,-0.113,0.0,-0.071,-0.092,0.0,-0.103,-0.025,-0.054,-0.048,-0.002,0.039,-0.111,0.213,-0.105,0.107,-0.077,0.0,0.0,0.0,0.193,-0.067,-0.07,0.078,0.15,-0.104,0.0,0.018,0.2,-0.138,-0.421,-0.588,0.081,-1.511,-1.032,0.919,0.819,-0.59,-0.207,-0.272,0.403,-0.859,0.637,-0.154,0.644,0.821,0.022,-0.729,-0.302,-0.611,-0.267,-0.09,-0.604,-0.13,0.277,0.644,-0.482,1.003,-0.227,0.744,-0.526,0.812,0.866,0.502,-0.108,0.294,0.025,0.263,0.765,-0.294,0.07,0.81,0.833,0.835,0.77,0.844,0.615,-1.318,0.775,-1.311,0.784,-0.187,-0.257,-0.159,-0.51,0.255,0.025,-1.223,-0.584,0.918,-0.13,0.865,-0.733,0.586,-0.124
7032,7032.0,42,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,-0.013,0.146,0.007,0.008,-0.131,0.198,-0.107,0.0,-0.11,-0.029,0.0,-0.168,-0.04,-0.087,-0.084,-0.056,-0.005,-0.13,0.214,-0.084,0.213,-0.099,0.0,0.0,0.0,-0.035,-0.138,-0.172,0.111,0.168,-0.003,0.0,-0.042,-0.047,-0.146,-0.291,-0.598,0.486,-1.049,-1.004,0.016,1.089,-0.161,-0.98,-0.272,0.321,-0.657,-0.492,-0.154,0.526,0.163,0.903,-0.626,-0.315,-0.54,-1.022,-0.588,-0.551,0.366,0.246,-0.609,-0.205,0.149,-0.766,0.832,-0.176,0.332,1.028,0.514,-0.911,0.099,0.025,0.231,0.091,-0.294,0.378,0.709,0.66,0.724,0.204,0.182,0.57,-1.05,0.544,-1.06,0.282,-0.187,-0.257,-0.159,-0.027,0.734,0.317,-0.766,-0.455,0.557,-0.13,0.914,-0.177,0.362,-0.144
500,500.0,42,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,-0.368,0.488,-0.396,-0.013,-0.526,0.267,0.381,0.0,-0.47,-0.07,0.0,-0.528,-0.376,-0.365,-0.364,0.0,0.0,-0.424,0.0,-0.492,0.285,-0.402,0.0,0.0,0.0,0.316,-0.379,-0.528,0.15,0.438,-0.38,0.0,-0.29,0.446,-0.428,-1.298,-0.939,1.303,-1.718,-2.003,1.534,0.155,-2.374,-2.441,-0.272,1.478,-1.704,-1.473,-0.154,0.888,0.163,0.022,-1.926,-0.412,-0.861,-1.054,-0.904,-1.834,0.838,0.674,1.0,-0.565,0.145,-1.091,0.841,-0.807,1.098,0.264,1.423,-2.859,-1.602,0.025,3.614,0.024,-0.294,0.749,0.967,0.93,0.925,0.246,0.187,1.087,-0.298,2.648,-2.873,1.591,-0.187,-0.257,-0.159,-2.521,0.874,0.751,-2.03,-0.66,1.041,-0.13,0.784,-1.668,0.908,-0.372


In [18]:
lambda_net_dataset_test.as_pandas(config).head()

Unnamed: 0,index,seed,f0v0,f0v1,f0v2,f0v3,f0v4,f1v0,f1v1,f1v2,f1v3,f1v4,f2v0,f2v1,f2v2,f2v3,f2v4,f3v0,f3v1,f3v2,f3v3,f3v4,f4v0,f4v1,f4v2,f4v3,f4v4,f5v0,f5v1,f5v2,f5v3,f5v4,f6v0,f6v1,f6v2,f6v3,f6v4,f7v0,f7v1,f7v2,f7v3,f7v4,f8v0,f8v1,f8v2,f8v3,f8v4,f9v0,f9v1,f9v2,f9v3,f9v4,f10v0,f10v1,f10v2,f10v3,f10v4,f11v0,f11v1,f11v2,f11v3,f11v4,f12v0,f12v1,f12v2,f12v3,f12v4,f13v0,f13v1,f13v2,f13v3,f13v4,f14v0,f14v1,f14v2,f14v3,f14v4,b0,b1,b2,b3,b4,b5,b6,b7,b8,b9,b10,b11,b12,b13,b14,lp0c0,lp0c1,lp1c0,lp1c1,lp2c0,lp2c1,lp3c0,lp3c1,...,wb_349,wb_350,wb_351,wb_352,wb_353,wb_354,wb_355,wb_356,wb_357,wb_358,wb_359,wb_360,wb_361,wb_362,wb_363,wb_364,wb_365,wb_366,wb_367,wb_368,wb_369,wb_370,wb_371,wb_372,wb_373,wb_374,wb_375,wb_376,wb_377,wb_378,wb_379,wb_380,wb_381,wb_382,wb_383,wb_384,wb_385,wb_386,wb_387,wb_388,wb_389,wb_390,wb_391,wb_392,wb_393,wb_394,wb_395,wb_396,wb_397,wb_398,wb_399,wb_400,wb_401,wb_402,wb_403,wb_404,wb_405,wb_406,wb_407,wb_408,wb_409,wb_410,wb_411,wb_412,wb_413,wb_414,wb_415,wb_416,wb_417,wb_418,wb_419,wb_420,wb_421,wb_422,wb_423,wb_424,wb_425,wb_426,wb_427,wb_428,wb_429,wb_430,wb_431,wb_432,wb_433,wb_434,wb_435,wb_436,wb_437,wb_438,wb_439,wb_440,wb_441,wb_442,wb_443,wb_444,wb_445,wb_446,wb_447,wb_448
5206,5206.0,42,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.201,0.092,0.346,0.405,-0.076,0.0,-0.075,0.0,0.085,0.016,0.0,-0.224,0.42,0.294,0.34,0.0,0.0,-0.071,0.0,-0.022,-0.01,-0.097,0.0,0.0,0.0,0.176,-0.087,-0.232,0.213,0.098,0.406,0.0,0.465,0.19,-0.083,-1.087,-1.802,0.872,-0.292,-2.394,0.042,0.155,-0.7,-1.886,-0.272,0.735,-2.322,0.952,-0.154,0.603,0.163,0.022,-0.753,0.341,-1.086,-2.366,-2.06,-2.039,0.918,0.647,0.021,-0.838,2.015,-0.22,1.228,-0.706,2.014,1.944,0.306,-0.113,0.837,0.025,1.179,0.793,-0.294,0.42,2.297,1.684,1.887,0.246,0.187,0.447,-0.298,0.519,-0.275,0.824,-0.187,-0.257,-0.159,-1.921,0.124,0.759,-2.046,-0.197,2.259,-0.13,2.198,-0.86,0.665,-0.099
2771,2771.0,42,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.014,0.049,0.09,0.111,-0.034,0.067,0.025,0.169,0.045,0.027,0.0,-0.057,0.186,0.076,0.094,0.255,0.213,0.031,0.0,0.061,0.057,0.031,0.0,0.0,-0.04,-0.117,0.074,-0.042,0.066,0.04,0.187,0.193,0.247,-0.112,-0.029,-0.277,0.925,0.861,-1.845,-0.25,0.392,1.295,-0.418,-0.349,-0.272,0.761,-1.492,-0.053,-0.154,0.087,1.238,0.022,-0.188,0.13,-0.463,-1.709,-1.483,0.333,0.161,0.289,0.98,-0.576,1.219,-1.641,0.978,-0.339,1.198,1.165,0.016,-0.6,0.358,0.379,0.882,0.923,-0.294,0.049,1.395,1.236,1.225,1.227,1.349,0.856,-0.298,1.199,-0.547,0.402,-0.187,-0.257,-1.295,-1.21,0.763,0.15,-0.538,-0.19,1.283,-1.615,1.509,-1.66,0.341,-0.037
5928,5928.0,42,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.04,-0.005,0.132,-0.205,-0.095,-0.433,-0.047,0.0,-0.099,0.265,0.0,-0.155,0.201,-0.195,-0.203,0.0,0.0,-0.09,-0.375,0.286,0.634,-0.295,0.0,0.0,0.0,-0.022,-0.103,-0.054,-0.422,0.043,-0.033,0.375,-0.084,-0.029,-0.058,-0.66,-0.586,1.546,-2.526,-0.788,1.236,0.155,-0.152,-2.017,-0.272,1.087,-1.901,-0.395,-0.154,0.713,0.163,0.022,-0.671,-0.332,-0.831,-1.884,-1.558,-0.559,-0.595,0.294,1.847,-1.392,2.635,-1.713,0.706,-0.67,2.338,1.975,0.427,-1.795,0.316,0.025,0.227,1.933,-0.294,0.588,2.494,1.268,1.302,0.246,0.187,0.492,-2.43,1.916,-2.215,1.292,-0.187,-0.257,-0.159,-0.885,0.96,0.29,-1.648,-0.366,0.179,-1.563,0.915,-0.189,0.454,-0.058
103,103.0,42,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,-0.233,0.18,0.211,-0.107,-0.174,0.25,-0.106,0.0,-0.094,-0.071,0.0,-0.203,-0.126,-0.209,-0.203,0.0,0.0,-0.162,0.255,-0.055,0.376,-0.209,0.0,0.0,0.0,0.328,-0.123,-0.186,0.313,0.17,-0.123,0.0,-0.078,0.297,-0.16,-0.697,-1.44,0.369,-1.134,-1.222,0.003,0.155,-1.587,-1.239,-0.272,0.405,-1.15,-0.805,-0.154,0.362,0.163,0.022,-0.556,-0.178,-1.557,-0.921,-0.571,-1.304,0.148,0.207,0.372,-1.638,0.16,-0.94,1.618,-0.244,1.541,1.32,0.388,-0.972,0.049,0.025,0.452,0.037,-0.294,0.21,0.646,1.497,1.387,0.246,0.187,0.513,-1.228,0.124,-1.838,0.674,-0.187,-0.257,-0.159,-1.577,0.496,0.173,-1.419,-0.372,0.756,-0.13,0.598,-1.49,0.342,-0.157
4367,4367.0,42,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.243,0.153,-0.053,-0.024,-0.067,-0.115,-0.124,0.0,-0.029,-0.023,0.0,-0.125,0.134,-0.039,-0.034,-0.035,0.0,-0.08,0.134,-0.069,-0.036,-0.047,0.0,0.0,0.0,0.215,-0.018,-0.117,-0.094,0.088,-0.262,0.0,-0.053,0.173,-0.089,-0.552,-2.574,0.57,-2.53,-0.865,0.636,0.155,-0.153,-0.2,-0.272,0.642,-0.132,-0.218,-0.154,0.426,0.163,0.022,-2.661,-0.142,-0.745,-0.258,-0.072,-2.657,-0.21,0.285,1.974,-0.725,0.154,-0.214,2.039,-0.574,0.176,0.254,0.059,-2.287,1.419,0.025,0.558,1.594,-0.294,0.325,1.58,0.376,0.497,0.222,0.187,0.204,-1.58,0.118,-0.27,0.68,-0.187,-0.257,-0.159,-0.736,0.621,0.291,-2.149,-0.181,2.158,-0.13,2.176,-0.855,0.325,-0.086


## Interpretation Network Training

In [19]:
#%load_ext autoreload

In [20]:
#%autoreload 2
((X_valid, y_valid), 
 (X_test, y_test),
 history,

 model) = interpretation_net_training(
                                      lambda_net_dataset_train, 
                                      lambda_net_dataset_valid, 
                                      lambda_net_dataset_test,
                                      config,
                                      #callback_names=['plot_losses']
                                     )



Trial 2 Complete [00h 00m 44s]
val_loss: 0.6931469440460205

Best val_loss So Far: 0.645199716091156
Total elapsed time: 00h 02m 11s
[<keras_tuner.engine.trial.Trial object at 0x7f7915763bb0>, <keras_tuner.engine.trial.Trial object at 0x7f78fc6b5bb0>]
[<keras_tuner.engine.trial.Trial object at 0x7f7915763bb0>, <keras_tuner.engine.trial.Trial object at 0x7f78fc6b5bb0>]
Training Time: 0:02:15
---------------------------------------------------------------------------------------------------------------------------
------------------------------------------------------ LOADING MODELS -----------------------------------------------------


JSONDecodeError: Expecting value: line 1 column 1 (char 0)

In [None]:
model.summary()

In [None]:
y_test_inet_vanilla_dt_list = []
y_test_distilled_sklearn_vanilla_dt_list = []

binary_crossentropy_distilled_sklearn_vanilla_dt_list =[]
accuracy_distilled_sklearn_vanilla_dt_list = []
f1_score_distilled_sklearn_vanilla_dt_list = []

binary_crossentropy_inet_vanilla_dt_list =[]
accuracy_inet_vanilla_dt_list = []
f1_score_inet_vanilla_dt_list = []

#inet_metric_function_list = []

number = lambda_net_dataset_test.y_test_lambda_array.shape[0]#10

for lambda_net_parameters, lambda_net, X_test_lambda, y_test_lambda in tqdm(zip(lambda_net_dataset_test.network_parameters_array[:number], lambda_net_dataset_test.network_list[:number], lambda_net_dataset_test.X_test_lambda_array[:number], lambda_net_dataset_test.y_test_lambda_array[:number]), total=lambda_net_dataset_test.y_test_lambda_array[:number].shape[0]):
    dt_inet = model.predict(np.array([lambda_net_parameters]))[0]
    
    X_data_random = generate_random_data_points_custom(config['data']['x_min'], config['data']['x_max'], config['evaluation']['random_evaluation_dataset_size'], config['data']['number_of_variables'])
    y_data_random_lambda_pred = lambda_net.predict(X_data_random)
    y_data_random_lambda_pred = np.round(y_data_random_lambda_pred).astype(np.int64)
    
    dt_sklearn_distilled = DecisionTreeClassifier(max_depth=config['function_family']['maximum_depth'])
    dt_sklearn_distilled.fit(X_data_random, y_data_random_lambda_pred)
    
    
    
    y_test_inet_vanilla_dt  = calculate_function_value_from_vanilla_decision_tree_parameters_wrapper(X_test_lambda, config)(dt_inet).numpy()
    y_test_distilled_sklearn_vanilla_dt = dt_sklearn_distilled.predict(X_test_lambda)
    
    y_test_lambda_pred = lambda_net.predict(X_test_lambda)
    y_test_lambda_pred = np.round(y_test_lambda_pred)
    
    
    #random_model = generate_base_model(config)        
    #random_network_parameters = random_model.get_weights()
    #network_parameters_structure = [network_parameter.shape for network_parameter in random_network_parameters]         
    
    #function_true_placeholder = np.array([0 for i in range(basic_function_representation_length)])
    #function_true_with_network_parameters = np.concatenate([function_true_placeholder, lambda_net_parameters])
    #inet_metric_function = inet_decision_function_fv_metric_wrapper(X_test_lambda, 
    #                                                                 random_model, 
    #                                                                 network_parameters_structure, 
    #                                                                 config, 
    #                                                                 'binary_accuracy')(np.array([function_true_with_network_parameters]), 
    #                                                                                     np.array([dt_inet]))    
    #inet_metric_function_list.append(inet_metric_function)
    
    
    binary_crossentropy_distilled_sklearn_vanilla_dt = log_loss(y_test_lambda_pred, y_test_distilled_sklearn_vanilla_dt)
    accuracy_distilled_sklearn_vanilla_dt = accuracy_score(y_test_lambda_pred, np.round(y_test_distilled_sklearn_vanilla_dt))
    f1_score_distilled_sklearn_vanilla_dt = f1_score(y_test_lambda_pred, np.round(y_test_distilled_sklearn_vanilla_dt))
    
    binary_crossentropy_inet_vanilla_dt = log_loss(y_test_lambda_pred, y_test_inet_vanilla_dt)
    accuracy_inet_vanilla_dt = accuracy_score(y_test_lambda_pred, np.round(y_test_inet_vanilla_dt))
    f1_score_inet_vanilla_dt = f1_score(y_test_lambda_pred, np.round(y_test_inet_vanilla_dt))
    
    
    y_test_inet_vanilla_dt_list.append(y_test_inet_vanilla_dt)
    y_test_distilled_sklearn_vanilla_dt_list.append(y_test_distilled_sklearn_vanilla_dt)    

    binary_crossentropy_distilled_sklearn_vanilla_dt_list.append(np.nan_to_num(binary_crossentropy_distilled_sklearn_vanilla_dt))
    accuracy_distilled_sklearn_vanilla_dt_list.append(np.nan_to_num(accuracy_distilled_sklearn_vanilla_dt))
    f1_score_distilled_sklearn_vanilla_dt_list.append(np.nan_to_num(f1_score_distilled_sklearn_vanilla_dt))

    binary_crossentropy_inet_vanilla_dt_list.append(np.nan_to_num(binary_crossentropy_inet_vanilla_dt))
    accuracy_inet_vanilla_dt_list.append(np.nan_to_num(accuracy_inet_vanilla_dt))
    f1_score_inet_vanilla_dt_list.append(np.nan_to_num(f1_score_inet_vanilla_dt))
    
y_test_inet_vanilla_dt_list = np.array(y_test_inet_vanilla_dt_list)
y_test_distilled_sklearn_vanilla_dt_list = np.array(y_test_distilled_sklearn_vanilla_dt_list)

binary_crossentropy_distilled_sklearn_vanilla_dt_list = np.array(binary_crossentropy_distilled_sklearn_vanilla_dt_list)
accuracy_distilled_sklearn_vanilla_dt_list = np.array(accuracy_distilled_sklearn_vanilla_dt_list)
f1_score_distilled_sklearn_vanilla_dt_list = np.array(f1_score_distilled_sklearn_vanilla_dt_list)

binary_crossentropy_inet_vanilla_dt_list = np.array(binary_crossentropy_inet_vanilla_dt_list)
accuracy_inet_vanilla_dt_list = np.array(accuracy_inet_vanilla_dt_list)
f1_score_inet_vanilla_dt_list = np.array(f1_score_inet_vanilla_dt_list)

inet_metric_function_list = np.array(inet_metric_function_list)
    

    
print('Binary Crossentropy:\t\t', np.round(np.mean(binary_crossentropy_distilled_sklearn_vanilla_dt_list), 3), '(Sklearn DT)' , '\t', np.round(np.mean(binary_crossentropy_inet_vanilla_dt_list), 3), '(I-Net DT)')
print('Accuracy:\t\t', np.round(np.mean(accuracy_distilled_sklearn_vanilla_dt_list), 3), '(Sklearn DT)' , '\t', np.round(np.mean(accuracy_inet_vanilla_dt_list), 3), '(I-Net DT)')
print('F1 Score:\t\t', np.round(np.mean(f1_score_distilled_sklearn_vanilla_dt_list), 3), '(Sklearn DT)' , '\t', np.round(np.mean(f1_score_inet_vanilla_dt_list), 3), '(I-Net DT)')
      

    array([ 3.7088814 ,  4.0115128 ,  5.1971545 ,  2.956121  ,  2.171942  ,
           -4.5159836 , -1.5440631 , -4.720496  , -8.691709  , -2.4547508 ,
            2.353416  ,  3.989924  ,  3.0750039 ,  2.6971464 ,  1.2753053 ,
            2.6641812 ,  2.9581091 ,  3.4725027 ,  5.3815365 ,  3.3167508 ,
            7.243947  ,  5.4603324 ,  6.0733647 ,  7.6625285 ,  5.032399  ,
            2.1799219 ,  3.8169808 ,  1.3696196 ,  3.826566  ,  3.2445428 ,
            3.546689  ,  2.413791  ,  3.192084  ,  0.43668652,  2.9079697 ,
            8.753475  ,  5.2893    ,  5.943552  ,  6.9104233 ,  7.7552915 ,
            5.706509  ,  4.3613763 ,  6.476285  ,  4.839129  ,  3.3252523 ,
           -1.9945749 ,  5.6768656 , -0.56589913, -5.339801  ,  0.3734794 ,
           -0.35990655,  0.14992855, -0.9953262 ,  0.9631527 ,  0.10870038,
            1.4074339 ,  1.5604519 ,  1.3444741 , -0.28689137, -1.3038124 ,
           -3.4533072 , -2.8390214 , -2.9115138 , -3.8113365 , -2.3977513 ,
           -0.6165216 , -0.4210023 , -1.4207656 , -0.56437695, -0.9168651 ,
           -3.467764  , -2.9678626 , -3.183811  ,  0.6534001 , -1.4332838 ,
            6.6093574 , -5.3566527 ,  3.5091665 ,  4.1031146 ,  6.8739233 ,
            1.7489419 ,  2.6479259 ,  8.11766   ,  4.0887527 ,  0.778695  ,
            0.0240611 , -0.33409914, -2.215649  , -2.0906925 , -1.306479  ,
           11.136937  , -1.7391894 , -3.99088   , -8.747282  ,  6.145732  ,
           -2.2608323 , -5.7336845 ,  0.44919446,  1.390956  , -2.143065  ,
           -3.352171  ,  1.288973  , -0.67188215, -1.359532  ,  0.16739012,
           -0.6902978 , -1.3088157 , -1.4101    , -0.9355023 ,  0.30086473,
           -0.763403  , -2.4984457 , -4.9625816 , -0.3064472 , -2.854495  ,
           -1.0730515 , -0.64003015,  1.1362306 ,  1.5821602 , -3.9027815 ,
           -3.6285934 ,  3.9028134 ], dtype=float32)

    (<tf.Tensor: shape=(15, 5), dtype=float32, numpy=
     array([[ 3.7088814 ,  4.0115128 ,  5.1971545 ,  2.956121  ,  2.171942  ],
            [-4.5159836 , -1.5440631 , -4.720496  , -8.691709  , -2.4547508 ],
            [ 2.353416  ,  3.989924  ,  3.0750039 ,  2.6971464 ,  1.2753053 ],
            [ 2.6641812 ,  2.9581091 ,  3.4725027 ,  5.3815365 ,  3.3167508 ],
            [ 7.243947  ,  5.4603324 ,  6.0733647 ,  7.6625285 ,  5.032399  ],
            [ 2.1799219 ,  3.8169808 ,  1.3696196 ,  3.826566  ,  3.2445428 ],
            [ 3.546689  ,  2.413791  ,  3.192084  ,  0.43668652,  2.9079697 ],
            [ 8.753475  ,  5.2893    ,  5.943552  ,  6.9104233 ,  7.7552915 ],
            [ 5.706509  ,  4.3613763 ,  6.476285  ,  4.839129  ,  3.3252523 ],
            [-1.9945749 ,  5.6768656 , -0.56589913, -5.339801  ,  0.3734794 ],
            [-0.35990655,  0.14992855, -0.9953262 ,  0.9631527 ,  0.10870038],
            [ 1.4074339 ,  1.5604519 ,  1.3444741 , -0.28689137, -1.3038124 ],
            [-3.4533072 , -2.8390214 , -2.9115138 , -3.8113365 , -2.3977513 ],
            [-0.6165216 , -0.4210023 , -1.4207656 , -0.56437695, -0.9168651 ],
            [-3.467764  , -2.9678626 , -3.183811  ,  0.6534001 , -1.4332838 ]],
           dtype=float32)>,
     array([ 6.6093574 , -5.3566527 ,  3.5091665 ,  4.1031146 ,  6.8739233 ,
             1.7489419 ,  2.6479259 ,  8.11766   ,  4.0887527 ,  0.778695  ,
             0.0240611 , -0.33409914, -2.215649  , -2.0906925 , -1.306479  ],
           dtype=float32),
     <tf.Tensor: shape=(2, 16), dtype=float32, numpy=
     array([[11.136937  , -3.99088   ,  6.145732  , -5.7336845 ,  1.390956  ,
             -3.352171  , -0.67188215,  0.16739012, -1.3088157 , -0.9355023 ,
             -0.763403  , -4.9625816 , -2.854495  , -0.64003015,  1.5821602 ,
             -3.6285934 ],
            [-1.7391894 , -8.747282  , -2.2608323 ,  0.44919446, -2.143065  ,
              1.288973  , -1.359532  , -0.6902978 , -1.4101    ,  0.30086473,
             -2.4984457 , -0.3064472 , -1.0730515 ,  1.1362306 , -3.9027815 ,
              3.9028134 ]], dtype=float32)>)

In [None]:
dt_inet

In [None]:
get_shaped_parameters_for_decision_tree(dt_inet, config)