# Inerpretation-Net Training

## Specitication of Experiment Settings

In [1]:
import sys
print(sys.version)


3.9.6 (default, Aug 18 2021, 19:38:01) 
[GCC 7.5.0]


In [2]:
#######################################################################################################################################
###################################################### CONFIG FILE ####################################################################
#######################################################################################################################################
sleep_time = 0 #minutes



config = {
    'function_family': {
        'maximum_depth': 4,
        'beta': 1,
        'decision_sparsity': 1,
        'fully_grown': True,   
        'dt_type': 'vanilla', #'vanilla', 'SDT'
    },
    'data': {
        'number_of_variables': 5, 
        'num_classes': 2,
        
        'function_generation_type': 'make_classification', # 'make_classification' 'random_decision_tree' 'random_vanilla_decision_tree_trained'
        'objective': 'classification', # 'regression'
        
        'x_max': 1,
        'x_min': 0,
        'x_distrib': 'uniform', #'normal', 'uniform',       
                
        'lambda_dataset_size': 1000, #number of samples per function
        #'number_of_generated_datasets': 10000,
        
        'noise_injected_level': 0, 
        'noise_injected_type': 'flip_percentage', # '' 'normal' 'uniform' 'normal_range' 'uniform_range'
    }, 
    'lambda_net': {
        'epochs_lambda': 1000,
        'early_stopping_lambda': True, 
        'early_stopping_min_delta_lambda': 1e-2,
        'batch_lambda': 64,
        'dropout_lambda': 0,
        'lambda_network_layers': [64],
        'optimizer_lambda': 'adam',
        'loss_lambda': 'binary_crossentropy', #categorical_crossentropy
        
        'number_of_lambda_weights': None,
        
        'number_initializations_lambda': 1, 
        
        'number_of_trained_lambda_nets': 10000,
    },     
    
    'i_net': {
        'dense_layers': [2048],
        'convolution_layers': None,
        'lstm_layers': None,
        'dropout': [0],
        
        'optimizer': 'adam', #adam
        'learning_rate': 0.01,
        'loss': 'binary_crossentropy',
        'metrics': ['binary_accuracy'],
        
        'epochs': 200, 
        'early_stopping': True,
        'batch_size': 256,

        'interpretation_dataset_size': 10000,
                
        'test_size': 50, #Float for fraction, Int for number 0
        
        'function_representation_type': 2, # 1=standard representation; 2=sparse representation, 3=vanilla_dt

        'optimize_decision_function': True, #False
        'function_value_loss': True, #False
                      
        'data_reshape_version': None, #default to 2 options:(None, 0,1 2)
        
        'nas': False,
        'nas_type': 'SEQUENTIAL', #options:(None, 'SEQUENTIAL', 'CNN', 'LSTM', 'CNN-LSTM', 'CNN-LSTM-parallel')      
        'nas_trials': 100,
    },    
    
    'evaluation': {   
        #'inet_holdout_seed_evaluation': False,
            
        'random_evaluation_dataset_size': 500, 
        'per_network_optimization_dataset_size': 5000,

        'sklearn_dt_benchmark': False,
        'sdt_benchmark': False,
        
    },    
    
    'computation':{
        'load_model': False,
        
        'n_jobs': -3,
        'use_gpu': False,
        'gpu_numbers': '0',
        'RANDOM_SEED': 42,   
    }
}


## Imports

In [3]:
#######################################################################################################################################
########################################### IMPORT GLOBAL VARIABLES FROM CONFIG #######################################################
#######################################################################################################################################
globals().update(config['function_family'])
globals().update(config['data'])
globals().update(config['lambda_net'])
globals().update(config['i_net'])
globals().update(config['evaluation'])
globals().update(config['computation'])

In [4]:
#######################################################################################################################################
##################################################### IMPORT LIBRARIES ################################################################
#######################################################################################################################################
import warnings
warnings.filterwarnings('ignore')
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 

from itertools import product       
from tqdm import tqdm_notebook as tqdm
import pickle
import numpy as np
import pandas as pd
import scipy as sp
import timeit
import psutil

from functools import reduce
from more_itertools import random_product 
from sklearn.preprocessing import Normalizer

import sys
import shutil

import logging

#from prettytable import PrettyTable
#import colored
import math

import time
from datetime import datetime
from collections.abc import Iterable


from joblib import Parallel, delayed

from scipy.integrate import quad

from sklearn.model_selection import cross_val_score, train_test_split, StratifiedKFold, KFold, ParameterGrid, ParameterSampler
from sklearn.metrics import accuracy_score, log_loss, roc_auc_score, f1_score, mean_absolute_error, r2_score, log_loss
from sklearn.tree import DecisionTreeClassifier, plot_tree



#from similaritymeasures import frechet_dist, area_between_two_curves, dtw
import tensorflow as tf
#import tensorflow_addons as tfa
import keras
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense, Dropout
from tensorflow.keras.utils import plot_model
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau


import tensorflow.keras.backend as K
from livelossplot import PlotLossesKerasTF
#from keras_tqdm import TQDMNotebookCallback

from matplotlib import pyplot as plt
import seaborn as sns


import random 



from IPython.display import Image
from IPython.display import display, Math, Latex, clear_output



In [5]:
tf.__version__

'2.6.0'

In [6]:
#######################################################################################################################################
########################################### IMPORT GLOBAL VARIABLES FROM CONFIG #######################################################
#######################################################################################################################################
globals().update(config['function_family'])
globals().update(config['data'])
globals().update(config['lambda_net'])
globals().update(config['evaluation'])
globals().update(config['computation'])

In [7]:
#######################################################################################################################################
################################################### VARIABLE ADJUSTMENTS ##############################################################
#######################################################################################################################################

config['i_net']['data_reshape_version'] = 2 if data_reshape_version == None and (convolution_layers != None or lstm_layers != None or (nas and nas_type != 'SEQUENTIAL')) else data_reshape_version

#######################################################################################################################################
###################################################### SET VARIABLES + DESIGN #########################################################
#######################################################################################################################################

#os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_numbers if use_gpu else ''
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
#os.environ['XLA_FLAGS'] =  '--xla_gpu_cuda_data_dir=/usr/lib/cuda-10.1'

logging.getLogger('tensorflow').disabled = True

sns.set_style("darkgrid")
#np.set_printoptions(suppress=True)

random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
if int(tf.__version__[0]) >= 2:
    tf.random.set_seed(RANDOM_SEED)
else:
    tf.set_random_seed(RANDOM_SEED)
    
    
pd.set_option('display.float_format', lambda x: '%.3f' % x)
pd.set_option('display.max_columns', 200)
np.set_printoptions(threshold=200)
np.set_printoptions(suppress=True)



In [8]:
from utilities.InterpretationNet import *
from utilities.LambdaNet import *
from utilities.metrics import *
from utilities.utility_functions import *
from utilities.DecisionTree_BASIC import *

#######################################################################################################################################
####################################################### CONFIG ADJUSTMENTS ############################################################
#######################################################################################################################################

config['lambda_net']['number_of_lambda_weights'] = get_number_of_lambda_net_parameters(lambda_network_layers, number_of_variables, num_classes)
config['function_family']['basic_function_representation_length'] = (2 ** maximum_depth - 1) * number_of_variables + (2 ** maximum_depth - 1) + (2 ** maximum_depth) * num_classes
config['function_family']['function_representation_length'] = ( 
       ((2 ** maximum_depth - 1) * decision_sparsity) * 2 + (2 ** maximum_depth - 1) + (2 ** maximum_depth) * num_classes  if function_representation_type == 1 and dt_type == 'SDT'
  else (2 ** maximum_depth - 1) * decision_sparsity + (2 ** maximum_depth - 1) + ((2 ** maximum_depth - 1)  * decision_sparsity * number_of_variables) + (2 ** maximum_depth) * num_classes if function_representation_type == 2 and dt_type == 'SDT'
  else ((2 ** maximum_depth - 1) * decision_sparsity) * 2 + (2 ** maximum_depth)  if function_representation_type == 1 and dt_type == 'vanilla'
  else (2 ** maximum_depth - 1) * decision_sparsity + ((2 ** maximum_depth - 1)  * decision_sparsity * number_of_variables) + (2 ** maximum_depth) if function_representation_type == 2 and dt_type == 'vanilla'
  else None
                                                            )
#######################################################################################################################################
################################################## UPDATE VARIABLES ###################################################################
#######################################################################################################################################
globals().update(config['function_family'])
globals().update(config['data'])
globals().update(config['lambda_net'])
globals().update(config['i_net'])
globals().update(config['evaluation'])
globals().update(config['computation'])

#initialize_LambdaNet_config_from_curent_notebook(config)
#initialize_metrics_config_from_curent_notebook(config)
#initialize_utility_functions_config_from_curent_notebook(config)
#initialize_InterpretationNet_config_from_curent_notebook(config)


#######################################################################################################################################
###################################################### PATH + FOLDER CREATION #########################################################
#######################################################################################################################################
globals().update(generate_paths(config, path_type='interpretation_net'))
create_folders_inet(config)

#######################################################################################################################################
############################################################ SLEEP TIMER ##############################################################
#######################################################################################################################################
sleep_minutes(sleep_time)

In [9]:
print(path_identifier_interpretation_net)

print(path_identifier_lambda_net_data)


lNetSize1000_numLNets10000_var5_class2_make_classification_xMax1_xMin0_xDistuniform_depth4_beta1_decisionSpars1_fullyGrown/64_e1000ES0.01_b64_drop0_adam_binary_crossentropy_fixedInit1-seed42/inet_dense2048_drop0e200b256_adam
lNetSize1000_numLNets10000_var5_class2_make_classification_xMax1_xMin0_xDistuniform_depth4_beta1_decisionSpars1_fullyGrown/64_e1000ES0.01_b64_drop0_adam_binary_crossentropy_fixedInit1-seed42


In [10]:
print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU')))
print("Num XLA-GPUs Available: ", len(tf.config.experimental.list_physical_devices('XLA_GPU')))

Num GPUs Available:  0
Num XLA-GPUs Available:  0


## Load Data and Generate Datasets

In [11]:
def load_lambda_nets(config, no_noise=False, n_jobs=1):
    
    #def generate_lambda_net()
    
    #if psutil.virtual_memory().percent > 80:
        #raise SystemExit("Out of RAM!")
    
    if no_noise==True:
        config['noise_injected_level'] = 0
    path_dict = generate_paths(config, path_type='interpretation_net')        
        
    directory = './data/weights/' + 'weights_' + path_dict['path_identifier_lambda_net_data'] + '/'
    path_network_parameters = directory + 'weights' + '.txt'
    path_X_data = directory + 'X_test_lambda.txt'
    path_y_data = directory + 'y_test_lambda.txt'        
    
    network_parameters = pd.read_csv(path_network_parameters, sep=",", header=None)
    network_parameters = network_parameters.sort_values(by=0)
    if no_noise == False:
        network_parameters = network_parameters.sample(n=config['i_net']['interpretation_dataset_size'], random_state=config['computation']['RANDOM_SEED'])
    
    X_test_lambda = pd.read_csv(path_X_data, sep=",", header=None)
    X_test_lambda = X_test_lambda.sort_values(by=0)
    if no_noise == False:
        X_test_lambda = X_test_lambda.sample(n=config['i_net']['interpretation_dataset_size'], random_state=config['computation']['RANDOM_SEED'])
    
    y_test_lambda = pd.read_csv(path_y_data, sep=",", header=None)
    y_test_lambda = y_test_lambda.sort_values(by=0)
    if no_noise == False:
        y_test_lambda = y_test_lambda.sample(n=config['i_net']['interpretation_dataset_size'], random_state=config['computation']['RANDOM_SEED'])
        
        
    parallel = Parallel(n_jobs=n_jobs, verbose=3, backend='loky') #loky

    lambda_nets = parallel(delayed(LambdaNet)(network_parameters_row, 
                                              X_test_lambda_row, 
                                              y_test_lambda_row, 
                                              config) for network_parameters_row, X_test_lambda_row, y_test_lambda_row in zip(network_parameters.values, X_test_lambda.values, y_test_lambda.values))          
    del parallel
    
    base_model = generate_base_model(config)  
    
    def initialize_network_wrapper(config, lambda_net, base_model):
        lambda_net.initialize_network(config, base_model)
    
    parallel = Parallel(n_jobs=n_jobs, verbose=3, backend='sequential')
    _ = parallel(delayed(initialize_network_wrapper)(config, lambda_net, base_model) for lambda_net in lambda_nets)   
    del parallel
    
    def initialize_target_function_wrapper(config, lambda_net):
        lambda_net.initialize_target_function(config)
    
    parallel = Parallel(n_jobs=n_jobs, verbose=3, backend='sequential')
    _ = parallel(delayed(initialize_target_function_wrapper)(config, lambda_net) for lambda_net in lambda_nets)   
    del parallel
        
    
    #lambda_nets = [None] * network_parameters.shape[0]
    #for i, (network_parameters_row, X_test_lambda_row, y_test_lambda_row) in tqdm(enumerate(zip(network_parameters.values, X_test_lambda.values, y_test_lambda.values)), total=network_parameters.values.shape[0]):        
    #    lambda_net = LambdaNet(network_parameters_row, X_test_lambda_row, y_test_lambda_row, config)
    #    lambda_nets[i] = lambda_net
                
    lambda_net_dataset = LambdaNetDataset(lambda_nets)
        
    return lambda_net_dataset
    

In [12]:
#LOAD DATA
if noise_injected_level > 0:
    lambda_net_dataset_training = load_lambda_nets(config, no_noise=True, n_jobs=n_jobs)
    lambda_net_dataset_evaluation = load_lambda_nets(config, n_jobs=n_jobs)

    lambda_net_dataset_train, lambda_net_dataset_valid = split_LambdaNetDataset(lambda_net_dataset_training, test_split=0.1)
    _, lambda_net_dataset_test = split_LambdaNetDataset(lambda_net_dataset_evaluation, test_split=test_size)
    
else:
    lambda_net_dataset = load_lambda_nets(config, n_jobs=n_jobs)

    lambda_net_dataset_train_with_valid, lambda_net_dataset_test = split_LambdaNetDataset(lambda_net_dataset, test_split=test_size)
    lambda_net_dataset_train, lambda_net_dataset_valid = split_LambdaNetDataset(lambda_net_dataset_train_with_valid, test_split=0.1)

    

[Parallel(n_jobs=-3)]: Using backend LokyBackend with 14 concurrent workers.
[Parallel(n_jobs=-3)]: Done   4 tasks      | elapsed:    9.2s
[Parallel(n_jobs=-3)]: Done 170 tasks      | elapsed:    9.3s
[Parallel(n_jobs=-3)]: Done 5006 tasks      | elapsed:   10.9s
[Parallel(n_jobs=-3)]: Done 10000 out of 10000 | elapsed:   12.6s finished
[Parallel(n_jobs=-3)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=-3)]: Done   1 out of   1 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=-3)]: Done   2 out of   2 | elapsed:    0.4s remaining:    0.0s
[Parallel(n_jobs=-3)]: Done 10000 out of 10000 | elapsed:  3.5min finished
[Parallel(n_jobs=-3)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=-3)]: Done   1 out of   1 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=-3)]: Done   2 out of   2 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=-3)]: Done 10000 out of 10000 | elapsed:  1.3min finished


## Data Inspection

In [13]:
lambda_net_dataset_train.shape

(8955, 573)

In [14]:
lambda_net_dataset_valid.shape

(995, 573)

In [15]:
lambda_net_dataset_test.shape

(50, 573)

In [16]:
lambda_net_dataset_train.as_pandas(config).head()

Unnamed: 0,index,seed,f0v0,f0v1,f0v2,f0v3,f0v4,f1v0,f1v1,f1v2,f1v3,f1v4,f2v0,f2v1,f2v2,f2v3,f2v4,f3v0,f3v1,f3v2,f3v3,f3v4,f4v0,f4v1,f4v2,f4v3,f4v4,f5v0,f5v1,f5v2,f5v3,f5v4,f6v0,f6v1,f6v2,f6v3,f6v4,f7v0,f7v1,f7v2,f7v3,f7v4,f8v0,f8v1,f8v2,f8v3,f8v4,f9v0,f9v1,f9v2,f9v3,f9v4,f10v0,f10v1,f10v2,f10v3,f10v4,f11v0,f11v1,f11v2,f11v3,f11v4,f12v0,f12v1,f12v2,f12v3,f12v4,f13v0,f13v1,f13v2,f13v3,f13v4,f14v0,f14v1,f14v2,f14v3,f14v4,b0,b1,b2,b3,b4,b5,b6,b7,b8,b9,b10,b11,b12,b13,b14,lp0c0,lp0c1,lp1c0,lp1c1,lp2c0,lp2c1,lp3c0,lp3c1,...,wb_349,wb_350,wb_351,wb_352,wb_353,wb_354,wb_355,wb_356,wb_357,wb_358,wb_359,wb_360,wb_361,wb_362,wb_363,wb_364,wb_365,wb_366,wb_367,wb_368,wb_369,wb_370,wb_371,wb_372,wb_373,wb_374,wb_375,wb_376,wb_377,wb_378,wb_379,wb_380,wb_381,wb_382,wb_383,wb_384,wb_385,wb_386,wb_387,wb_388,wb_389,wb_390,wb_391,wb_392,wb_393,wb_394,wb_395,wb_396,wb_397,wb_398,wb_399,wb_400,wb_401,wb_402,wb_403,wb_404,wb_405,wb_406,wb_407,wb_408,wb_409,wb_410,wb_411,wb_412,wb_413,wb_414,wb_415,wb_416,wb_417,wb_418,wb_419,wb_420,wb_421,wb_422,wb_423,wb_424,wb_425,wb_426,wb_427,wb_428,wb_429,wb_430,wb_431,wb_432,wb_433,wb_434,wb_435,wb_436,wb_437,wb_438,wb_439,wb_440,wb_441,wb_442,wb_443,wb_444,wb_445,wb_446,wb_447,wb_448
6671,6671.0,42,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.163,0.318,-0.236,0.261,-0.274,-0.005,-0.258,0.0,-0.167,-0.332,0.0,-0.332,-0.34,-0.303,-0.071,-0.105,0.0,-0.277,0.0,-0.387,0.341,-0.226,0.0,0.0,0.0,0.36,-0.23,-0.344,0.335,0.252,-0.039,0.0,-0.133,0.309,-0.262,-0.452,1.161,1.263,-0.292,-0.256,1.396,0.155,-0.896,-1.175,-0.272,0.836,-1.525,-0.45,-0.154,0.248,0.163,0.011,-1.814,0.27,-0.68,-1.444,-1.472,1.217,0.358,0.538,1.292,-0.639,0.16,-0.227,-1.073,-0.627,0.784,2.061,0.387,-0.106,1.006,0.025,1.257,0.849,-0.294,0.276,1.431,0.87,0.156,0.838,0.187,0.742,-0.298,1.202,-0.958,0.555,-0.187,-0.257,-0.159,-0.562,1.305,0.419,-0.606,-0.432,0.181,-0.13,0.876,-1.761,0.612,-0.224
3274,3274.0,42,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,-0.182,0.325,-0.067,-0.025,-0.244,-0.023,-0.225,0.0,-0.104,-0.022,0.0,-0.363,-0.038,-0.2,-0.064,-0.018,0.0,-0.234,-0.01,-0.158,-0.035,-0.25,0.0,0.0,0.0,0.302,-0.262,-0.315,0.182,0.244,-0.034,0.0,-0.056,0.311,-0.251,0.328,-0.528,1.127,-0.269,-0.249,0.048,0.155,-0.443,-0.205,-0.272,0.2,-0.144,-1.061,-0.154,0.7,0.163,0.022,-2.433,-0.263,-2.267,-0.262,-1.877,-0.645,0.594,0.507,0.022,-2.113,1.385,-0.227,0.939,-0.405,0.171,0.254,0.626,-0.094,0.586,0.025,0.232,1.664,-0.294,0.529,0.089,0.631,0.156,0.232,0.187,0.591,-0.29,0.697,-0.264,0.403,-0.187,-0.257,-0.159,-2.013,1.244,0.522,-1.923,-0.27,1.117,-0.13,0.189,-2.251,0.568,-0.25
3095,3095.0,42,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,-0.044,0.192,-0.13,-0.026,-0.165,-0.007,-0.168,0.0,-0.117,-0.109,0.0,-0.209,-0.034,-0.123,-0.075,-0.012,0.0,-0.153,0.017,-0.106,0.263,-0.027,0.0,0.0,0.0,0.02,-0.094,-0.207,0.087,0.164,-0.033,0.0,-0.054,-0.049,-0.169,0.063,-0.78,0.327,-1.036,-0.229,0.053,0.155,-0.755,-0.899,-0.272,0.262,-0.798,-0.577,-0.154,0.267,0.163,0.022,-0.228,0.13,-0.89,-0.978,-0.756,-0.808,0.434,0.467,-0.685,-0.591,0.16,-0.919,0.018,-0.457,0.732,0.252,0.568,-0.107,0.575,0.025,0.451,0.698,-0.294,0.258,0.089,0.407,0.118,0.237,0.187,0.65,-0.849,0.737,-0.989,-0.12,-0.187,-0.257,-0.159,-0.459,0.141,0.356,-0.653,-0.399,0.182,-0.13,0.196,-0.18,0.556,-0.146
8379,8379.0,42,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.07,0.124,-0.026,0.06,-0.081,-0.023,-0.105,0.0,-0.075,0.006,0.0,-0.068,0.018,-0.011,-0.009,-0.03,0.0,-0.075,-0.009,0.003,0.107,-0.081,0.0,0.0,0.0,0.164,-0.12,-0.06,0.14,0.109,-0.049,0.0,0.058,0.163,-0.103,-0.409,-0.441,0.088,-0.272,-0.243,0.55,0.155,-0.553,-0.751,-0.272,0.455,-0.494,0.367,-0.154,0.486,0.163,0.529,-0.61,-0.248,-0.574,-0.692,-0.081,-0.447,-0.102,0.272,0.016,-0.464,0.139,-0.21,0.524,-0.493,0.703,0.735,0.415,-0.096,0.342,0.025,0.353,0.578,-0.294,0.032,0.641,0.68,0.673,0.295,0.187,0.556,-0.292,0.644,-0.724,0.523,-0.187,-0.257,-0.159,-0.476,0.346,0.022,-0.349,-0.509,0.173,-0.13,0.706,-0.661,0.559,-0.095
3043,3043.0,42,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,-0.042,-0.136,0.171,-0.031,0.148,0.0,0.213,0.0,0.216,0.254,0.0,0.03,0.215,0.009,0.008,0.0,0.0,0.121,-0.031,0.209,-0.054,-0.094,0.0,0.0,0.0,-0.028,0.135,0.14,-0.139,-0.111,0.246,0.0,0.221,-0.05,0.103,-0.536,0.636,0.539,-0.927,-0.22,0.063,0.155,-0.659,-0.76,-0.272,0.464,-0.732,-0.561,-0.154,0.024,0.163,0.022,-0.119,-0.318,-0.964,-0.826,-0.648,0.699,0.236,0.308,0.677,-0.587,0.16,-0.838,0.01,-0.59,0.825,0.25,0.488,-0.113,0.557,0.025,0.799,0.708,-0.294,0.017,0.834,0.488,0.47,0.246,0.187,0.698,-0.278,0.764,-0.83,0.092,-0.187,-0.257,-0.159,-0.03,0.581,0.194,-0.91,-0.59,0.809,-0.13,0.884,-0.155,0.41,0.09


In [17]:
lambda_net_dataset_valid.as_pandas(config).head()

Unnamed: 0,index,seed,f0v0,f0v1,f0v2,f0v3,f0v4,f1v0,f1v1,f1v2,f1v3,f1v4,f2v0,f2v1,f2v2,f2v3,f2v4,f3v0,f3v1,f3v2,f3v3,f3v4,f4v0,f4v1,f4v2,f4v3,f4v4,f5v0,f5v1,f5v2,f5v3,f5v4,f6v0,f6v1,f6v2,f6v3,f6v4,f7v0,f7v1,f7v2,f7v3,f7v4,f8v0,f8v1,f8v2,f8v3,f8v4,f9v0,f9v1,f9v2,f9v3,f9v4,f10v0,f10v1,f10v2,f10v3,f10v4,f11v0,f11v1,f11v2,f11v3,f11v4,f12v0,f12v1,f12v2,f12v3,f12v4,f13v0,f13v1,f13v2,f13v3,f13v4,f14v0,f14v1,f14v2,f14v3,f14v4,b0,b1,b2,b3,b4,b5,b6,b7,b8,b9,b10,b11,b12,b13,b14,lp0c0,lp0c1,lp1c0,lp1c1,lp2c0,lp2c1,lp3c0,lp3c1,...,wb_349,wb_350,wb_351,wb_352,wb_353,wb_354,wb_355,wb_356,wb_357,wb_358,wb_359,wb_360,wb_361,wb_362,wb_363,wb_364,wb_365,wb_366,wb_367,wb_368,wb_369,wb_370,wb_371,wb_372,wb_373,wb_374,wb_375,wb_376,wb_377,wb_378,wb_379,wb_380,wb_381,wb_382,wb_383,wb_384,wb_385,wb_386,wb_387,wb_388,wb_389,wb_390,wb_391,wb_392,wb_393,wb_394,wb_395,wb_396,wb_397,wb_398,wb_399,wb_400,wb_401,wb_402,wb_403,wb_404,wb_405,wb_406,wb_407,wb_408,wb_409,wb_410,wb_411,wb_412,wb_413,wb_414,wb_415,wb_416,wb_417,wb_418,wb_419,wb_420,wb_421,wb_422,wb_423,wb_424,wb_425,wb_426,wb_427,wb_428,wb_429,wb_430,wb_431,wb_432,wb_433,wb_434,wb_435,wb_436,wb_437,wb_438,wb_439,wb_440,wb_441,wb_442,wb_443,wb_444,wb_445,wb_446,wb_447,wb_448
3466,3466.0,42,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,-0.04,0.049,-0.094,-0.023,-0.072,0.116,-0.074,0.0,-0.011,-0.073,0.0,-0.08,-0.006,-0.101,-0.088,-0.0,-0.014,-0.077,0.187,-0.057,0.115,-0.062,0.0,0.0,0.0,0.061,-0.047,-0.08,0.063,0.076,-0.009,0.0,-0.004,-0.061,-0.07,-0.342,-0.666,0.399,-0.738,-0.587,0.038,0.155,-0.386,-0.632,-0.272,0.49,-0.464,-0.294,-0.154,0.069,0.163,0.022,-0.11,-0.04,-0.653,-0.651,-0.449,0.276,0.053,0.237,0.396,-0.567,0.16,-0.227,0.029,-0.427,0.471,0.352,0.179,-0.7,0.282,0.025,0.597,0.368,-0.294,0.096,0.802,0.27,0.453,0.46,0.603,0.402,-0.675,0.123,-0.792,0.212,-0.187,-0.257,-0.159,-0.692,0.468,0.135,-0.722,-0.299,0.584,-0.13,0.876,-0.139,0.312,-0.07
689,689.0,42,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.179,0.118,0.124,0.174,-0.24,0.071,-0.098,0.0,0.002,0.169,0.0,-0.231,0.128,0.085,0.116,0.0,0.0,-0.155,0.126,0.065,0.01,0.142,0.0,0.0,0.0,0.272,-0.116,-0.205,0.17,0.105,-0.024,-0.009,0.117,-0.021,-0.108,-0.277,-1.763,0.908,-1.894,-0.256,1.857,0.155,-0.166,-1.77,-0.272,0.4,-1.718,-0.498,-0.154,0.253,0.163,0.008,-0.232,-0.024,-0.278,-1.803,-1.708,-1.803,-0.344,0.285,2.116,-0.237,0.16,-1.833,1.831,-0.207,1.92,1.87,0.522,-1.633,0.339,0.025,0.303,1.905,-0.294,0.465,2.045,1.829,1.873,0.246,0.187,0.399,-1.962,1.869,-1.558,1.645,-0.187,-0.257,-0.159,-1.79,0.732,0.457,-0.4,-0.148,0.19,-0.124,2.108,-0.197,0.232,-0.117
4148,4148.0,42,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,-0.042,0.025,0.097,-0.028,0.021,-0.012,0.062,0.0,0.057,0.079,0.0,0.03,0.134,-0.093,-0.061,-0.02,-0.009,0.054,0.0,0.123,0.084,-0.108,0.0,0.0,0.0,-0.028,0.031,0.033,-0.045,-0.083,0.144,0.142,-0.054,-0.065,0.018,0.042,-0.697,0.474,-1.008,-0.251,0.061,0.15,-0.765,-0.769,-0.272,0.24,-0.807,-0.562,-0.154,0.283,0.163,0.005,-0.958,-0.268,-0.156,-0.868,-0.647,-0.678,0.518,0.482,0.663,-0.68,0.787,-0.814,0.02,-0.638,0.775,0.251,0.656,-0.102,0.66,0.025,0.539,0.751,-0.294,0.321,0.711,0.173,0.155,0.233,0.181,0.735,-0.298,0.735,-0.862,0.113,-0.187,-0.257,-0.159,-0.031,0.293,0.418,-0.137,-0.695,0.72,-0.793,0.189,-0.169,0.564,0.012
2815,2815.0,42,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,-0.088,0.276,-0.06,-0.018,-0.213,-0.009,-0.179,0.0,-0.104,-0.135,0.0,-0.261,-0.034,-0.144,-0.139,-0.029,0.0,-0.182,0.0,-0.052,0.201,-0.163,0.0,0.0,0.0,0.303,-0.193,-0.242,0.269,0.16,-0.051,0.0,-0.047,0.293,-0.198,-0.501,-0.646,0.511,-0.292,-0.239,0.702,0.155,-0.156,-0.21,-0.272,0.493,-0.156,0.688,-0.154,0.636,0.163,0.022,-0.837,-0.247,-0.944,-0.272,-0.097,-0.728,0.56,0.558,0.05,-0.879,0.141,-0.227,0.755,-0.655,0.173,0.259,0.559,-0.107,0.68,0.025,0.223,0.882,-0.294,0.478,0.09,0.736,0.706,0.228,0.187,0.467,-0.298,0.119,-0.902,0.658,-0.187,-0.257,-0.159,-0.733,0.429,0.503,-0.909,-0.154,0.172,-0.13,0.197,-0.82,0.627,-0.191
5185,5185.0,42,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.358,-0.134,0.025,-0.109,0.238,-0.016,-0.219,0.0,-0.246,-0.114,0.0,0.069,0.509,0.263,0.266,-0.086,0.0,0.224,-0.012,0.006,-0.073,0.191,0.0,0.0,-0.005,-0.471,-0.074,0.18,-0.585,-0.112,-0.038,0.0,0.535,0.401,0.163,-0.24,-0.793,0.06,-0.268,-0.251,0.047,0.155,-0.494,-2.302,-0.272,0.955,-2.252,-0.295,-0.154,0.472,0.163,0.022,-2.153,-0.276,-2.572,-2.231,-1.792,-0.532,0.356,0.274,0.032,-0.456,1.832,-2.392,0.728,-0.481,1.571,1.992,0.5,-0.094,1.591,0.025,1.785,1.519,-0.294,0.044,2.156,0.592,0.584,2.002,0.187,0.502,-0.289,1.6,-0.246,0.29,-0.187,-0.257,-0.154,-2.226,0.15,0.175,-2.336,-0.431,0.185,-0.13,2.365,-2.393,0.423,0.1


In [18]:
lambda_net_dataset_test.as_pandas(config).head()

Unnamed: 0,index,seed,f0v0,f0v1,f0v2,f0v3,f0v4,f1v0,f1v1,f1v2,f1v3,f1v4,f2v0,f2v1,f2v2,f2v3,f2v4,f3v0,f3v1,f3v2,f3v3,f3v4,f4v0,f4v1,f4v2,f4v3,f4v4,f5v0,f5v1,f5v2,f5v3,f5v4,f6v0,f6v1,f6v2,f6v3,f6v4,f7v0,f7v1,f7v2,f7v3,f7v4,f8v0,f8v1,f8v2,f8v3,f8v4,f9v0,f9v1,f9v2,f9v3,f9v4,f10v0,f10v1,f10v2,f10v3,f10v4,f11v0,f11v1,f11v2,f11v3,f11v4,f12v0,f12v1,f12v2,f12v3,f12v4,f13v0,f13v1,f13v2,f13v3,f13v4,f14v0,f14v1,f14v2,f14v3,f14v4,b0,b1,b2,b3,b4,b5,b6,b7,b8,b9,b10,b11,b12,b13,b14,lp0c0,lp0c1,lp1c0,lp1c1,lp2c0,lp2c1,lp3c0,lp3c1,...,wb_349,wb_350,wb_351,wb_352,wb_353,wb_354,wb_355,wb_356,wb_357,wb_358,wb_359,wb_360,wb_361,wb_362,wb_363,wb_364,wb_365,wb_366,wb_367,wb_368,wb_369,wb_370,wb_371,wb_372,wb_373,wb_374,wb_375,wb_376,wb_377,wb_378,wb_379,wb_380,wb_381,wb_382,wb_383,wb_384,wb_385,wb_386,wb_387,wb_388,wb_389,wb_390,wb_391,wb_392,wb_393,wb_394,wb_395,wb_396,wb_397,wb_398,wb_399,wb_400,wb_401,wb_402,wb_403,wb_404,wb_405,wb_406,wb_407,wb_408,wb_409,wb_410,wb_411,wb_412,wb_413,wb_414,wb_415,wb_416,wb_417,wb_418,wb_419,wb_420,wb_421,wb_422,wb_423,wb_424,wb_425,wb_426,wb_427,wb_428,wb_429,wb_430,wb_431,wb_432,wb_433,wb_434,wb_435,wb_436,wb_437,wb_438,wb_439,wb_440,wb_441,wb_442,wb_443,wb_444,wb_445,wb_446,wb_447,wb_448
7217,7217.0,42,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.038,0.294,-0.205,-0.046,-0.222,0.322,-0.202,0.0,-0.134,-0.184,0.0,-0.297,-0.035,-0.099,-0.073,0.0,0.0,-0.228,0.313,-0.057,0.238,-0.201,0.0,0.0,0.0,0.328,-0.259,-0.289,0.263,0.15,-0.245,-0.009,-0.047,0.296,-0.227,-0.612,0.626,0.722,-0.926,-0.867,0.057,0.155,-0.811,-0.916,-0.272,0.79,-0.088,-0.6,-0.154,0.241,0.163,0.022,-1.177,0.524,-0.787,-1.099,-0.877,0.664,0.494,0.612,0.718,-0.68,0.16,-0.187,0.739,-0.73,0.713,0.536,0.322,-0.716,0.687,0.025,0.82,0.646,-0.294,0.417,0.091,0.142,0.155,0.246,0.187,0.629,-0.947,0.148,-0.925,0.678,-0.187,-0.257,-0.159,-0.645,0.837,0.526,-0.688,-0.145,1.033,-0.123,0.187,-0.933,0.621,-0.192
8291,8291.0,42,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,-0.046,-0.332,-0.07,0.226,0.337,-0.011,-0.565,0.0,0.185,0.346,0.0,0.286,-0.032,0.326,-0.071,0.422,0.0,0.363,-0.022,0.286,-0.016,0.116,0.0,0.0,0.0,-0.02,0.226,-0.206,-0.331,-0.449,0.445,0.0,-0.024,-0.031,0.252,-0.646,1.345,1.43,-0.284,-0.224,1.48,0.155,-0.154,-1.427,-0.272,1.528,-1.021,-0.81,-0.154,0.202,0.163,0.022,-2.223,-0.472,-0.162,-1.183,-1.029,-1.492,-0.57,0.16,1.37,-0.703,0.149,-0.227,0.028,-0.77,0.203,2.005,0.375,-0.049,-1.864,0.025,1.733,1.469,-0.294,0.14,0.093,0.337,0.145,1.627,0.187,0.776,-0.283,1.393,-0.266,1.079,-0.187,-0.257,-0.159,-0.032,1.455,-0.133,-0.67,-1.106,1.545,-0.13,0.226,-0.183,0.258,0.286
4607,4607.0,42,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,-0.048,-0.188,0.317,0.153,0.258,-0.005,-0.12,0.0,0.112,0.012,0.0,-0.3,-0.03,0.307,-0.045,0.0,0.0,0.27,-0.01,0.356,-0.116,-0.08,0.0,0.0,0.0,-0.009,0.367,0.255,-0.022,-0.209,-0.034,0.0,-0.04,-0.014,0.19,-0.146,-1.13,-0.655,-1.849,-2.077,0.035,0.155,-0.601,-0.921,-0.272,0.214,-1.914,-1.07,-0.154,0.224,0.163,0.022,-1.679,-0.23,-1.782,-1.429,-1.541,-1.308,0.346,0.388,2.273,-0.533,0.16,-1.815,0.02,-0.534,0.746,1.289,0.497,-1.812,1.144,0.025,0.541,0.941,-0.294,-0.247,0.092,0.57,0.343,0.246,0.187,0.594,-2.104,0.695,-0.691,-0.12,-0.187,-0.257,-0.159,-1.745,1.079,0.202,-1.693,-0.76,0.164,-0.13,0.194,-1.985,0.522,0.189
5114,5114.0,42,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.467,-0.321,0.152,0.437,0.288,0.079,0.206,0.0,0.237,0.245,0.0,-0.374,-0.027,0.348,-0.059,0.3,0.0,0.084,0.068,0.146,-0.178,0.249,0.0,0.0,0.0,-0.346,-0.081,0.29,-0.222,-0.265,0.197,0.0,-0.034,-0.353,0.26,-0.516,-1.202,0.077,-1.528,-0.222,1.21,0.155,-0.751,-0.724,-0.272,0.197,-2.037,-0.467,-0.154,0.249,0.163,0.022,-1.882,-0.28,-0.902,-0.88,-1.925,-0.886,0.274,0.428,2.247,-0.784,1.669,-2.11,0.725,-0.822,1.616,0.985,0.387,-1.765,1.297,0.025,1.208,0.975,-0.294,-0.41,0.096,0.691,0.159,1.512,0.187,1.742,-2.068,1.553,-0.834,1.857,-0.187,-0.257,-0.159,-0.977,0.149,0.181,-0.76,-0.636,1.931,-0.13,0.212,-1.548,0.579,0.218
1859,1859.0,42,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.213,-0.003,0.134,0.166,0.079,-0.02,-0.072,0.0,-0.082,0.184,0.0,-0.054,0.197,0.119,0.124,0.0,0.0,0.076,-0.068,0.183,-0.015,-0.098,0.0,0.0,0.0,-0.03,-0.079,-0.048,-0.052,0.002,-0.026,-0.03,0.209,-0.063,0.003,0.018,-0.642,0.077,-0.851,-0.237,0.575,0.155,-0.786,-0.792,-0.272,0.203,-0.734,-0.518,-0.147,0.571,0.163,0.755,-0.774,-0.399,-0.14,-0.873,-0.657,-0.633,0.358,0.239,0.026,-0.072,0.16,-0.471,0.711,-0.513,0.9,1.062,0.658,-0.099,0.038,0.025,0.248,0.743,-0.294,-0.121,0.814,0.849,0.859,0.246,0.187,0.703,-0.841,0.82,-1.016,0.126,-0.187,-0.257,-0.159,0.015,0.147,-0.093,-0.124,-0.647,0.189,-0.101,0.939,-0.167,0.526,-0.008


## Interpretation Network Training

In [19]:
#%load_ext autoreload

In [20]:
#%autoreload 2
((X_valid, y_valid), 
 (X_test, y_test),
 history,

 model) = interpretation_net_training(
                                      lambda_net_dataset_train, 
                                      lambda_net_dataset_valid, 
                                      lambda_net_dataset_test,
                                      config,
                                      #callback_names=['plot_losses']
                                     )



----------------------------------------------- TRAINING INTERPRETATION NET -----------------------------------------------
Epoch 1/200
Epoch 2/200
Epoch 3/200
Epoch 4/200
Epoch 5/200
Epoch 6/200
Epoch 7/200
Epoch 8/200
Epoch 9/200
Epoch 10/200
Epoch 11/200
Epoch 12/200
Epoch 13/200
Epoch 14/200
Epoch 15/200
Epoch 16/200
Epoch 17/200
Epoch 18/200
Epoch 19/200
Epoch 20/200
Epoch 21/200
Epoch 22/200
Epoch 23/200
Epoch 24/200
Epoch 25/200
Epoch 26/200
Epoch 27/200
Epoch 28/200
Epoch 29/200
Epoch 30/200
Epoch 31/200
Epoch 32/200
Epoch 33/200
Epoch 34/200
Epoch 35/200
Epoch 36/200
Epoch 37/200
Epoch 38/200
Epoch 39/200
Epoch 40/200
Epoch 41/200
Epoch 42/200
Epoch 43/200
Epoch 44/200
Epoch 45/200
Epoch 46/200
Epoch 47/200
Epoch 48/200
Epoch 49/200
Epoch 50/200
Epoch 51/200
Epoch 52/200
Epoch 53/200
Epoch 54/200
Epoch 55/200
Epoch 56/200
Epoch 57/200
Epoch 58/200
Epoch 59/200
Epoch 60/200
Epoch 61/200
Epoch 62/200
Epoch 63/200
Epoch 64/200
Training Time: 0:17:34
------------------------------

<Figure size 432x288 with 0 Axes>

In [21]:
model.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input (InputLayer)              [(None, 449)]        0                                            
__________________________________________________________________________________________________
hidden1_2048 (Dense)            (None, 2048)         921600      input[0][0]                      
__________________________________________________________________________________________________
activation1_relu (Activation)   (None, 2048)         0           hidden1_2048[0][0]               
__________________________________________________________________________________________________
output_coeff_15 (Dense)         (None, 15)           30735       activation1_relu[0][0]           
______________________________________________________________________________________________

In [22]:
y_test_inet_vanilla_dt_list = []
y_test_distilled_sklearn_vanilla_dt_list = []

binary_crossentropy_distilled_sklearn_vanilla_dt_list =[]
accuracy_distilled_sklearn_vanilla_dt_list = []
f1_score_distilled_sklearn_vanilla_dt_list = []

binary_crossentropy_inet_vanilla_dt_list =[]
accuracy_inet_vanilla_dt_list = []
f1_score_inet_vanilla_dt_list = []

#inet_metric_function_list = []

number = lambda_net_dataset_test.y_test_lambda_array.shape[0]#10

for lambda_net_parameters, lambda_net, X_test_lambda, y_test_lambda in tqdm(zip(lambda_net_dataset_test.network_parameters_array[:number], lambda_net_dataset_test.network_list[:number], lambda_net_dataset_test.X_test_lambda_array[:number], lambda_net_dataset_test.y_test_lambda_array[:number]), total=lambda_net_dataset_test.y_test_lambda_array[:number].shape[0]):
    dt_inet = model.predict(np.array([lambda_net_parameters]))[0]
    
    X_data_random = generate_random_data_points_custom(config['data']['x_min'], config['data']['x_max'], config['evaluation']['random_evaluation_dataset_size'], config['data']['number_of_variables'])
    y_data_random_lambda_pred = lambda_net.predict(X_data_random)
    y_data_random_lambda_pred = np.round(y_data_random_lambda_pred).astype(np.int64)
    
    dt_sklearn_distilled = DecisionTreeClassifier(max_depth=config['function_family']['maximum_depth'])
    dt_sklearn_distilled.fit(X_data_random, y_data_random_lambda_pred)
    
    
    
    y_test_inet_vanilla_dt  = calculate_function_value_from_vanilla_decision_tree_parameters_wrapper(X_test_lambda, config)(dt_inet).numpy()
    y_test_distilled_sklearn_vanilla_dt = dt_sklearn_distilled.predict(X_test_lambda)
    
    y_test_lambda_pred = lambda_net.predict(X_test_lambda)
    y_test_lambda_pred = np.round(y_test_lambda_pred)
    
    
    #random_model = generate_base_model(config)        
    #random_network_parameters = random_model.get_weights()
    #network_parameters_structure = [network_parameter.shape for network_parameter in random_network_parameters]         
    
    #function_true_placeholder = np.array([0 for i in range(basic_function_representation_length)])
    #function_true_with_network_parameters = np.concatenate([function_true_placeholder, lambda_net_parameters])
    #inet_metric_function = inet_decision_function_fv_metric_wrapper(X_test_lambda, 
    #                                                                 random_model, 
    #                                                                 network_parameters_structure, 
    #                                                                 config, 
    #                                                                 'binary_accuracy')(np.array([function_true_with_network_parameters]), 
    #                                                                                     np.array([dt_inet]))    
    #inet_metric_function_list.append(inet_metric_function)
    
    
    binary_crossentropy_distilled_sklearn_vanilla_dt = log_loss(y_test_lambda_pred, y_test_distilled_sklearn_vanilla_dt)
    accuracy_distilled_sklearn_vanilla_dt = accuracy_score(y_test_lambda_pred, np.round(y_test_distilled_sklearn_vanilla_dt))
    f1_score_distilled_sklearn_vanilla_dt = f1_score(y_test_lambda_pred, np.round(y_test_distilled_sklearn_vanilla_dt))
    
    binary_crossentropy_inet_vanilla_dt = log_loss(y_test_lambda_pred, y_test_inet_vanilla_dt)
    accuracy_inet_vanilla_dt = accuracy_score(y_test_lambda_pred, np.round(y_test_inet_vanilla_dt))
    f1_score_inet_vanilla_dt = f1_score(y_test_lambda_pred, np.round(y_test_inet_vanilla_dt))
    
    
    y_test_inet_vanilla_dt_list.append(y_test_inet_vanilla_dt)
    y_test_distilled_sklearn_vanilla_dt_list.append(y_test_distilled_sklearn_vanilla_dt)    

    binary_crossentropy_distilled_sklearn_vanilla_dt_list.append(np.nan_to_num(binary_crossentropy_distilled_sklearn_vanilla_dt))
    accuracy_distilled_sklearn_vanilla_dt_list.append(np.nan_to_num(accuracy_distilled_sklearn_vanilla_dt))
    f1_score_distilled_sklearn_vanilla_dt_list.append(np.nan_to_num(f1_score_distilled_sklearn_vanilla_dt))

    binary_crossentropy_inet_vanilla_dt_list.append(np.nan_to_num(binary_crossentropy_inet_vanilla_dt))
    accuracy_inet_vanilla_dt_list.append(np.nan_to_num(accuracy_inet_vanilla_dt))
    f1_score_inet_vanilla_dt_list.append(np.nan_to_num(f1_score_inet_vanilla_dt))
    
y_test_inet_vanilla_dt_list = np.array(y_test_inet_vanilla_dt_list)
y_test_distilled_sklearn_vanilla_dt_list = np.array(y_test_distilled_sklearn_vanilla_dt_list)

binary_crossentropy_distilled_sklearn_vanilla_dt_list = np.array(binary_crossentropy_distilled_sklearn_vanilla_dt_list)
accuracy_distilled_sklearn_vanilla_dt_list = np.array(accuracy_distilled_sklearn_vanilla_dt_list)
f1_score_distilled_sklearn_vanilla_dt_list = np.array(f1_score_distilled_sklearn_vanilla_dt_list)

binary_crossentropy_inet_vanilla_dt_list = np.array(binary_crossentropy_inet_vanilla_dt_list)
accuracy_inet_vanilla_dt_list = np.array(accuracy_inet_vanilla_dt_list)
f1_score_inet_vanilla_dt_list = np.array(f1_score_inet_vanilla_dt_list)

#inet_metric_function_list = np.array(inet_metric_function_list)
    

    
print('Binary Crossentropy:\t\t', np.round(np.mean(binary_crossentropy_distilled_sklearn_vanilla_dt_list), 3), '(Sklearn DT)' , '\t', np.round(np.mean(binary_crossentropy_inet_vanilla_dt_list), 3), '(I-Net DT)')
print('Accuracy:\t\t', np.round(np.mean(accuracy_distilled_sklearn_vanilla_dt_list), 3), '(Sklearn DT)' , '\t', np.round(np.mean(accuracy_inet_vanilla_dt_list), 3), '(I-Net DT)')
print('F1 Score:\t\t', np.round(np.mean(f1_score_distilled_sklearn_vanilla_dt_list), 3), '(Sklearn DT)' , '\t', np.round(np.mean(f1_score_inet_vanilla_dt_list), 3), '(I-Net DT)')
      

  0%|          | 0/50 [00:00<?, ?it/s]

Binary Crossentropy:		 6.634 (Sklearn DT) 	 0.58 (I-Net DT)
Accuracy:		 0.808 (Sklearn DT) 	 0.694 (I-Net DT)
F1 Score:		 0.813 (Sklearn DT) 	 0.689 (I-Net DT)


In [23]:
def get_shaped_parameters_for_decision_tree(flat_parameters, config):
    
    input_dim = config['data']['number_of_variables']
    output_dim = config['data']['num_classes']
    internal_node_num_ = 2 ** config['function_family']['maximum_depth'] - 1 
    leaf_node_num_ = 2 ** config['function_family']['maximum_depth']

    split_values_num_params = internal_node_num_ * config['function_family']['decision_sparsity']
    split_index_num_params = config['data']['number_of_variables'] *  config['function_family']['decision_sparsity'] * internal_node_num_
    leaf_classes_num_params = leaf_node_num_ #* config['data']['num_classes']

    split_values = flat_parameters[:split_values_num_params]
    split_values_list_by_internal_node = tf.split(split_values, internal_node_num_)

    split_index_array = flat_parameters[split_values_num_params:split_values_num_params+split_index_num_params]    
    split_index_list_by_internal_node = tf.split(split_index_array, internal_node_num_)
    split_index_list_by_internal_node_by_decision_sparsity = []
    for tensor in split_index_list_by_internal_node:
        split_tensor = tf.split(tensor, config['function_family']['decision_sparsity'])
        split_index_list_by_internal_node_by_decision_sparsity.append(split_tensor)
    split_index_list_by_internal_node_by_decision_sparsity_argmax = tf.split(tf.argmax(split_index_list_by_internal_node_by_decision_sparsity, axis=2), internal_node_num_)
    split_index_list_by_internal_node_by_decision_sparsity_argmax_new = []
    for tensor in split_index_list_by_internal_node_by_decision_sparsity_argmax:
        tensor_squeeze = tf.squeeze(tensor, axis=0)
        split_index_list_by_internal_node_by_decision_sparsity_argmax_new.append(tensor_squeeze)
    split_index_list_by_internal_node_by_decision_sparsity_argmax = split_index_list_by_internal_node_by_decision_sparsity_argmax_new    
    dense_tensor_list = []
    for indices_node, values_node in zip(split_index_list_by_internal_node_by_decision_sparsity_argmax,  split_values_list_by_internal_node):
        sparse_tensor = tf.sparse.SparseTensor(indices=tf.expand_dims(indices_node, axis=1), values=values_node, dense_shape=[input_dim])
        dense_tensor = tf.sparse.to_dense(sparse_tensor)
        dense_tensor_list.append(dense_tensor) 
    splits = tf.stack(dense_tensor_list)

    leaf_classes_array = flat_parameters[split_values_num_params+split_index_num_params:]  
    split_index_list_by_leaf_node = tf.split(leaf_classes_array, leaf_node_num_)
    #leaf_classes_list = []
    #for tensor in split_index_list_by_leaf_node:
        #argmax = tf.argmax(tensor)
        #argsort = tf.argsort(tensor, direction='DESCENDING')
        #leaf_classes_list.append(argsort[0])
        #leaf_classes_list.append(argsort[1])

    leaf_classes = tf.squeeze(tf.stack(split_index_list_by_leaf_node))#tf.stack(leaf_classes_list)
    return splits, leaf_classes



In [24]:
dt_inet

array([0.45284158, 0.32821262, 0.5029834 , 0.46723327, 0.50334644,
       0.5107664 , 0.4788803 , 0.5286971 , 0.5330286 , 0.44671458,
       0.5951074 , 0.6098474 , 0.5804453 , 0.5890433 , 0.6047059 ,
       0.21950963, 0.1907012 , 0.10660789, 0.14148287, 0.3416984 ,
       0.22137895, 0.23314933, 0.271722  , 0.10512332, 0.16862637,
       0.09619734, 0.17953692, 0.2249763 , 0.29107037, 0.20821905,
       0.19006312, 0.22789945, 0.29852223, 0.16333182, 0.12018336,
       0.1530848 , 0.10679807, 0.22248009, 0.33663225, 0.18100473,
       0.22383659, 0.1375861 , 0.17533286, 0.33283713, 0.13040733,
       0.34083217, 0.16291459, 0.17132531, 0.21951392, 0.10541398,
       0.31181073, 0.0980735 , 0.17653432, 0.20887682, 0.2047046 ,
       0.1521662 , 0.32716733, 0.10945123, 0.29110935, 0.1201059 ,
       0.2363556 , 0.16630751, 0.2725188 , 0.14974466, 0.17507344,
       0.08592523, 0.15281053, 0.25172088, 0.25335568, 0.25618774,
       0.18832608, 0.13647796, 0.2585782 , 0.24039364, 0.17622

In [25]:
get_shaped_parameters_for_decision_tree(dt_inet, config)

(<tf.Tensor: shape=(15, 5), dtype=float32, numpy=
 array([[0.        , 0.        , 0.        , 0.        , 0.45284158],
        [0.        , 0.        , 0.32821262, 0.        , 0.        ],
        [0.        , 0.        , 0.        , 0.5029834 , 0.        ],
        [0.        , 0.        , 0.46723327, 0.        , 0.        ],
        [0.        , 0.        , 0.        , 0.50334644, 0.        ],
        [0.        , 0.        , 0.        , 0.5107664 , 0.        ],
        [0.4788803 , 0.        , 0.        , 0.        , 0.        ],
        [0.5286971 , 0.        , 0.        , 0.        , 0.        ],
        [0.        , 0.5330286 , 0.        , 0.        , 0.        ],
        [0.        , 0.        , 0.44671458, 0.        , 0.        ],
        [0.        , 0.        , 0.        , 0.        , 0.5951074 ],
        [0.        , 0.        , 0.6098474 , 0.        , 0.        ],
        [0.        , 0.        , 0.        , 0.        , 0.5804453 ],
        [0.        , 0.        , 0.     