# Inerpretation-Net Training

# Experiment 1: I-Net Performance for Different Algebras and Complexities
# Experiment 2: I-Net Performance Comparison for λ-Nets with Different Training Levels
# Experiment 3: I-Net Performance Comparison Different Training Data Sizes

## Specitication of Experiment Settings

In [1]:
#######################################################################################################################################
###################################################### CONFIG FILE ####################################################################
#######################################################################################################################################
sleep_time = 0 #minutes


config = {
    'data': {
        'd': 5, #degree
        'n': 1, #number of variables
        'sparsity': None,
        'sample_sparsity': None,
        'x_max': 1,
        'x_min': 0,
        'x_distrib': 'uniform', #'normal', 'uniform', 'beta', 'Gamma', 'laplace'
        'a_max': 1,
        'a_min': -1,
        'lambda_nets_total': 10000,
        'noise': 0.1,
        'noise_distrib': 'normal', #'normal', 'uniform'
        
        'same_training_all_lambda_nets': False,

        'fixed_seed_lambda_training': True,
        'fixed_initialization_lambda_training': False,
        'number_different_lambda_trainings': 1,
    },
    'lambda_net': {
        'epochs_lambda': 1000,
        'early_stopping_lambda': True,  #if early stopping is used, multi_epoch_analysis is deactivated
        'early_stopping_min_delta_lambda': 1e-4,
        'batch_lambda': 64,
        'dropout': 0,
        'lambda_network_layers': [5*'sparsity'],
        'optimizer_lambda': 'adam',
        'loss_lambda': 'mae',
        'number_of_lambda_weights': None,
        'lambda_dataset_size': 5000,
    },
    'i_net': {
        'optimizer': 'custom',#adam
        'inet_loss': 'mae',
        'inet_metrics': ['r2'],
        'dropout': 0.25,
        'dropout_output': 0,
        'epochs': 500,
        'early_stopping': True,
        'batch_size': 512,
        'dense_layers': [512, 1024],
        'convolution_layers': None,
        'lstm_layers': None,
        'interpretation_dataset_size': 10000,
                
        'interpretation_net_output_monomials': 3, #(None, int)
        'interpretation_net_output_shape': None, #calculated automatically later
        
        'evaluate_with_real_function': False,
        'consider_labels_training': False,
                      
        'data_reshape_version': None, #default to 2 options:(None, 0,1 2)
        'nas': False,
        'nas_type': 'SEQUENTIAL', #options:(None, 'SEQUENTIAL', 'CNN', 'LSTM', 'CNN-LSTM', 'CNN-LSTM-parallel')      
        'nas_trials': 100,
    },
    'evaluation': {   
        'inet_holdout_seed_evaluation': False,
        
        #set if multi_epoch_analysis should be performed
        'multi_epoch_analysis': True,
        'each_epochs_save_lambda': 20,
        'epoch_start': 0, #use to skip first epochs in multi_epoch_analysis
        
        #set if samples analysis should be performed
        'samples_list': None,#[100, 500, 750, 1000, 2500, 5000, 7500, 10000, 15000, 20000, 25000, 28125] 
       
        'random_evaluation_dataset_size': 500,
    },
    'computation':{
        'n_jobs': 5,
        'use_gpu': False,
        'gpu_numbers': '0',
        'RANDOM_SEED': 42,   
    }
}

## Imports

In [2]:
#######################################################################################################################################
########################################### IMPORT GLOBAL VARIABLES FROM CONFIG #######################################################
#######################################################################################################################################
globals().update(config['data'])
globals().update(config['lambda_net'])
globals().update(config['i_net'])
globals().update(config['evaluation'])
globals().update(config['computation'])

In [3]:
#######################################################################################################################################
##################################################### IMPORT LIBRARIES ################################################################
#######################################################################################################################################
from itertools import product       
from tqdm import tqdm_notebook as tqdm
import pickle
import numpy as np
import pandas as pd
import scipy as sp
import timeit
import psutil

from functools import reduce
from more_itertools import random_product 
from sklearn.preprocessing import Normalizer

import sys
import os
import shutil

import logging

from prettytable import PrettyTable
import colored
import math

import time
from datetime import datetime
from collections.abc import Iterable


from joblib import Parallel, delayed

from scipy.integrate import quad

from sklearn.model_selection import cross_val_score, train_test_split, StratifiedKFold, KFold
from sklearn.metrics import accuracy_score, log_loss, roc_auc_score, f1_score, mean_absolute_error, r2_score
from similaritymeasures import frechet_dist, area_between_two_curves, dtw
import keras
from keras.models import Sequential
from keras.layers.core import Dense, Dropout
from keras.utils import plot_model
from IPython.display import Image

import keras.backend as K
from livelossplot import PlotLossesKerasTF
from keras_tqdm import TQDMNotebookCallback
from keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau

from matplotlib import pyplot as plt
import seaborn as sns


import tensorflow as tf
import random 


import warnings


from IPython.display import display, Math, Latex, clear_output


In [4]:
sys.path.append('..')

from utilities.InterpretationNet import *
from utilities.LambdaNet import *
from utilities.metrics import *
from utilities.utility_functions import *

In [5]:
#######################################################################################################################################
################################################### VARIABLE ADJUSTMENTS ##############################################################
#######################################################################################################################################
variables = 'abcdefghijklmnopqrstuvwxyz'[:n]

n_jobs = min((epochs_lambda//each_epochs_save_lambda+1, n_jobs)) if multi_epoch_analysis else min(len(samples_list), n_jobs) if samples_list!=None else 1

multi_epoch_analysis = False if early_stopping_lambda else multi_epoch_analysis #deactivate multi_epoch_analysis if early stopping is used

each_epochs_save_lambda = each_epochs_save_lambda if multi_epoch_analysis else epochs_lambda
epochs_save_range_lambda = range(epoch_start//each_epochs_save_lambda, epochs_lambda//each_epochs_save_lambda) if each_epochs_save_lambda == 1 else range(epoch_start//each_epochs_save_lambda, epochs_lambda//each_epochs_save_lambda+1) if multi_epoch_analysis else range(1,2)

data_reshape_version = 2 if data_reshape_version == None and (convolution_layers != None or lstm_layers != None or (nas and nas_type != 'SEQUENTIAL')) else data_reshape_version
#######################################################################################################################################
###################################################### SET VARIABLES + DESIGN #########################################################
#######################################################################################################################################

#os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_numbers if use_gpu else ''

logging.getLogger('tensorflow').disabled = True

sns.set_style("darkgrid")
#np.set_printoptions(suppress=True)

random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
if int(tf.__version__[0]) >= 2:
    tf.random.set_seed(RANDOM_SEED)
else:
    tf.set_random_seed(RANDOM_SEED)
    
    
pd.set_option('display.float_format', lambda x: '%.3f' % x)
pd.set_option('display.max_columns', 500)

warnings.filterwarnings('ignore')

In [6]:
#######################################################################################################################################
####################################################### CONFIG ADJUSTMENTS ############################################################
#######################################################################################################################################
config['evaluation']['multi_epoch_analysis'] = multi_epoch_analysis
config['evaluation']['each_epochs_save_lambda'] = each_epochs_save_lambda
config['i_net']['data_reshape_version'] = data_reshape_version

config['data']['sparsity'] = nCr(config['data']['n']+config['data']['d'], config['data']['d'])
config['data']['sample_sparsity'] = config['data']['sparsity'] if config['data']['sample_sparsity'] == None else config['data']['sample_sparsity']

config['i_net']['interpretation_net_output_shape'] = config['data']['sparsity'] if config['i_net']['interpretation_net_output_monomials'] is None else config['data']['sparsity']*config['i_net']['interpretation_net_output_monomials']+config['i_net']['interpretation_net_output_monomials']


transformed_layers = []
for layer in config['lambda_net']['lambda_network_layers']:
    if type(layer) == str:
        transformed_layers.append(layer.count('sparsity')*config['data']['sparsity'])
    else:
        transformed_layers.append(layer)
config['lambda_net']['lambda_network_layers'] = transformed_layers

layers_with_input_output = list(flatten([[config['data']['n']], config['lambda_net']['lambda_network_layers'], [1]]))
number_of_lambda_weights = 0
for i in range(len(layers_with_input_output)-1):
    number_of_lambda_weights += (layers_with_input_output[i]+1)*layers_with_input_output[i+1]  
config['lambda_net']['number_of_lambda_weights'] = number_of_lambda_weights
    
#######################################################################################################################################
################################################## UPDATE VARIABLES ###################################################################
#######################################################################################################################################
globals().update(config['data'])
globals().update(config['lambda_net'])
globals().update(config['i_net'])
globals().update(config['evaluation'])
globals().update(config['computation'])



initialize_LambdaNet_config_from_curent_notebook(config)
initialize_metrics_config_from_curent_notebook(config)
initialize_utility_functions_config_from_curent_notebook(config)
initialize_InterpretationNet_config_from_curent_notebook(config)


#######################################################################################################################################
###################################################### PATH + FOLDER CREATION #########################################################
#######################################################################################################################################
globals().update(generate_paths(path_type='interpretation_net'))
create_folders_inet()

#######################################################################################################################################
############################################################ SLEEP TIMER ##############################################################
#######################################################################################################################################
sleep_minutes(sleep_time)

In [7]:
print(path_identifier_interpretation_net_data)

print(path_identifier_lambda_net_data)


inet_dense512-1024-output_21_drop0.25e500b512_custom/lnets_10000_30-1000e_ES0.0001_64b_adam_mae_train_5000_diffX_1-FixSeed_42/var_1_d_5_sparsity_6_amin_-1_amax_1_xdist_uniform_noise_normal_0.1
lnets_10000_30-1000e_ES0.0001_64b_adam_mae_train_5000_diffX_1-FixSeed_42/var_1_d_5_sparsity_6_amin_-1_amax_1_xmin_0_xmax_1_xdist_uniform_noise_normal_0.1


In [8]:
print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU')))
print("Num XLA-GPUs Available: ", len(tf.config.experimental.list_physical_devices('XLA_GPU')))

Num GPUs Available:  0
Num XLA-GPUs Available:  0


## Utility functions

### Generate List of Monomial Identifiers

In [9]:
list_of_monomial_identifiers_extended = []
for i in tqdm(range((d+1)**n)):    
    monomial_identifier = dec_to_base(i, base = (d+1)).zfill(n) 
    list_of_monomial_identifiers_extended.append(monomial_identifier)

print('List length: ' + str(len(list_of_monomial_identifiers_extended)))
print('Number of monomials in a polynomial with ' + str(n) + ' variables and degree ' + str(d) + ': ' + str(nCr(n+d, d)))
print('Sparsity: ' + str(sparsity))
print(list_of_monomial_identifiers_extended)

list_of_monomial_identifiers = []
for monomial_identifier in tqdm(list_of_monomial_identifiers_extended):
    monomial_identifier_values = list(map(int, list(monomial_identifier)))
    if sum(monomial_identifier_values) <= d:
        list_of_monomial_identifiers.append(monomial_identifier)

print('List length: ' + str(len(list_of_monomial_identifiers)))
print('Number of monomials in a polynomial with ' + str(n) + ' variables and degree ' + str(d) + ': ' + str(nCr(n+d, d)))
print('Sparsity: ' + str(sparsity))
print(list_of_monomial_identifiers)


layers_with_input_output = list(flatten([[n], lambda_network_layers, [1]]))
number_of_lambda_weights = 0
for i in range(len(layers_with_input_output)-1):
    number_of_lambda_weights += (layers_with_input_output[i]+1)*layers_with_input_output[i+1]

  0%|          | 0/6 [00:00<?, ?it/s]

List length: 6
Number of monomials in a polynomial with 1 variables and degree 5: 6
Sparsity: 6
['0', '1', '2', '3', '4', '5']


  0%|          | 0/6 [00:00<?, ?it/s]

List length: 6
Number of monomials in a polynomial with 1 variables and degree 5: 6
Sparsity: 6
['0', '1', '2', '3', '4', '5']


## Load Data and Generate Datasets

In [10]:
def load_lambda_nets(index):
    
    if psutil.virtual_memory().percent > 80:
        raise SystemExit("Out of RAM!")
    
    directory = './data/weights/' + 'weights_' + path_identifier_lambda_net_data + '/'
    path_weights = directory + 'weights_epoch_' + str(index).zfill(3) + '.txt'
    path_X_data = directory + 'lambda_X_test_data.txt'
    path_y_data = directory + 'lambda_y_test_data.txt'        
    
    weight_data = pd.read_csv(path_weights, sep=",", header=None)
    weight_data = weight_data.sort_values(by=0).sample(frac=1, random_state=RANDOM_SEED)
    weight_data = weight_data.sort_values(by=0).sample(n=interpretation_dataset_size, random_state=RANDOM_SEED)
    
    lambda_X_test_data = pd.read_csv(path_X_data, sep=",", header=None)
    lambda_X_test_data = lambda_X_test_data.sort_values(by=0).sample(frac=1, random_state=RANDOM_SEED)
    lambda_X_test_data = lambda_X_test_data.sort_values(by=0).sample(n=interpretation_dataset_size, random_state=RANDOM_SEED)
    
    lambda_y_test_data = pd.read_csv(path_y_data, sep=",", header=None)
    lambda_y_test_data = lambda_y_test_data.sort_values(by=0).sample(frac=1, random_state=RANDOM_SEED)
    lambda_y_test_data = lambda_y_test_data.sort_values(by=0).sample(n=interpretation_dataset_size, random_state=RANDOM_SEED)
        
    lambda_nets = [None] * weight_data.shape[0]
    for i, (row_weights, row_lambda_X_test_data, row_lambda_y_test_data) in enumerate(zip(weight_data.values, lambda_X_test_data.values, lambda_y_test_data.values)):        
        lambda_net = LambdaNet(row_weights, row_lambda_X_test_data, row_lambda_y_test_data)
        lambda_nets[i] = lambda_net
                
    lambda_net_dataset = LambdaNetDataset(lambda_nets)
        
    return lambda_net_dataset
    

In [11]:
#LOAD DATA

parallel = Parallel(n_jobs=n_jobs, verbose=3, backend='multiprocessing')
lambda_net_dataset_list = parallel(delayed(load_lambda_nets)((i+1)*each_epochs_save_lambda if each_epochs_save_lambda==1 else i*each_epochs_save_lambda if i > 1 else each_epochs_save_lambda if i==1 else 1) for i in epochs_save_range_lambda)  
del parallel

lambda_net_dataset = lambda_net_dataset_list[-1]


[Parallel(n_jobs=5)]: Using backend MultiprocessingBackend with 5 concurrent workers.
[Parallel(n_jobs=5)]: Done   1 out of   1 | elapsed:   19.5s finished


## Data Inspection

In [12]:
lambda_net_dataset.as_pandas().head()

Unnamed: 0,seed,0-target,1-target,2-target,3-target,4-target,5-target,0-lstsq_lambda,1-lstsq_lambda,2-lstsq_lambda,3-lstsq_lambda,4-lstsq_lambda,5-lstsq_lambda,0-lstsq_target,1-lstsq_target,2-lstsq_target,3-lstsq_target,4-lstsq_target,5-lstsq_target,wb_0,wb_1,wb_2,wb_3,wb_4,wb_5,wb_6,wb_7,wb_8,wb_9,wb_10,wb_11,wb_12,wb_13,wb_14,wb_15,wb_16,wb_17,wb_18,wb_19,wb_20,wb_21,wb_22,wb_23,wb_24,wb_25,wb_26,wb_27,wb_28,wb_29,wb_30,wb_31,wb_32,wb_33,wb_34,wb_35,wb_36,wb_37,wb_38,wb_39,wb_40,wb_41,wb_42,wb_43,wb_44,wb_45,wb_46,wb_47,wb_48,wb_49,wb_50,wb_51,wb_52,wb_53,wb_54,wb_55,wb_56,wb_57,wb_58,wb_59,wb_60,wb_61,wb_62,wb_63,wb_64,wb_65,wb_66,wb_67,wb_68,wb_69,wb_70,wb_71,wb_72,wb_73,wb_74,wb_75,wb_76,wb_77,wb_78,wb_79,wb_80,wb_81,wb_82,wb_83,wb_84,wb_85,wb_86,wb_87,wb_88,wb_89,wb_90
6252,1373158606,0.924,0.492,-0.142,-0.402,-0.741,0.255,0.941,0.421,-1.533,6.634,-11.5,5.498,0.941,0.421,-1.533,6.634,-11.5,5.498,-0.009,-0.227,0.414,0.34,0.169,0.103,0.419,-0.297,0.334,0.062,0.544,-0.088,-0.019,0.451,0.09,0.523,-0.042,0.202,0.265,0.229,-0.054,0.043,0.008,-0.042,-0.404,-0.09,0.333,-0.291,-0.209,-0.233,0.0,0.0,-0.212,-0.233,0.234,0.234,-0.286,0.0,-0.234,0.296,-0.243,0.0,0.0,-0.255,0.244,-0.205,0.232,0.254,0.227,0.242,0.288,-0.046,-0.023,0.0,0.0,0.0,-0.233,0.0,0.0,0.0,-0.254,-0.359,-0.566,-0.69,0.341,0.4,-0.479,-0.408,-0.789,0.12,-0.376,-0.352,-0.364,-0.443,0.341,-0.452,0.687,0.215,0.348,0.244,0.337,-0.311,-0.044,-0.261,-0.383,-0.059,-0.969,0.144,-0.258,-0.315,0.202
4684,1373158606,-0.794,0.241,-0.127,-0.188,-0.535,-0.939,-0.768,1.42,-14.454,48.213,-63.525,27.058,-0.768,1.42,-14.454,48.213,-63.525,27.058,-0.009,-0.227,0.231,0.122,0.313,0.08,0.185,-0.297,0.089,0.253,0.358,-0.088,-0.019,0.234,0.078,0.342,0.008,1.474,0.405,0.393,0.033,0.035,-0.043,-0.042,-0.404,-0.09,0.048,-0.291,-0.209,-0.233,0.0,0.0,0.172,0.175,0.001,-0.081,0.178,0.0,0.175,0.19,0.173,0.0,0.0,0.175,-0.079,0.172,-0.023,-0.949,0.001,0.0,-0.04,0.176,0.2,0.0,0.0,0.0,0.175,0.0,0.0,0.0,-0.254,-0.359,-0.339,-0.36,0.278,0.187,-0.191,-0.408,-0.396,-0.074,-0.188,-0.352,-0.364,-0.199,0.121,-0.265,0.393,-2.337,0.33,0.208,0.049,-0.458,-0.229,-0.261,-0.383,-0.059,-0.479,0.144,-0.258,-0.315,-0.168
1731,1373158606,-0.739,0.951,0.737,0.786,0.773,0.083,-0.721,0.346,4.731,-9.701,12.714,-4.814,-0.721,0.346,4.731,-9.701,12.714,-4.814,-0.009,-0.227,-0.173,0.082,0.724,0.671,0.528,-0.297,0.067,0.916,0.416,-0.088,-0.019,0.349,0.689,-0.003,0.528,0.849,0.892,0.883,0.649,0.039,0.007,-0.042,-0.404,-0.09,0.046,-0.291,-0.209,-0.233,0.0,0.0,0.437,-0.092,-0.412,-0.438,-0.438,0.0,-0.078,-0.404,0.001,0.0,0.0,0.001,-0.463,0.398,-0.438,-0.28,-0.0,-0.14,-0.525,-0.054,-0.024,0.0,0.0,0.0,-0.06,0.0,0.0,0.0,-0.254,-0.359,-0.466,-0.215,0.917,1.043,0.791,-0.408,-0.251,0.617,0.148,-0.352,-0.364,0.19,1.01,-0.24,1.515,0.684,0.819,0.7,1.093,-0.306,-0.042,-0.261,-0.383,-0.059,-0.331,0.144,-0.258,-0.315,-0.437
4742,1373158606,0.173,-0.607,-0.085,0.746,-0.828,-0.939,0.161,-0.337,-2.491,8.142,-10.118,3.139,0.161,-0.337,-2.491,8.142,-10.118,3.139,-0.009,-0.227,0.511,0.469,0.05,-0.029,0.515,-0.297,0.433,0.595,0.678,-0.088,-0.019,0.54,-0.019,0.653,0.006,0.278,0.134,0.182,0.03,0.396,0.433,-0.042,-0.404,-0.09,0.397,-0.291,-0.209,-0.233,0.0,0.0,-0.264,-0.331,0.09,0.099,-0.363,0.0,-0.339,-0.404,-0.007,0.0,0.0,-0.279,0.089,-0.006,-0.025,-0.002,0.094,0.015,-0.044,-0.335,-0.366,0.0,0.0,0.0,-0.337,0.0,0.0,0.0,-0.254,-0.359,-0.691,-0.997,0.091,0.166,-0.746,-0.408,-1.085,-0.558,-0.418,-0.352,-0.364,-0.58,0.091,-0.495,0.391,-0.077,0.106,-0.017,0.043,-1.272,-1.217,-0.261,-0.383,-0.059,-1.255,0.144,-0.258,-0.315,0.121
4521,1373158606,0.131,0.208,0.871,0.323,0.573,0.164,0.124,0.308,0.697,-0.264,2.311,-0.934,0.124,0.308,0.697,-0.264,2.311,-0.934,-0.009,-0.227,0.128,0.083,0.634,0.562,0.108,-0.297,0.067,0.951,0.4,-0.088,-0.019,0.126,0.521,0.243,0.398,0.706,0.723,0.705,0.612,0.041,0.008,-0.042,-0.404,-0.09,0.047,-0.291,-0.209,-0.233,0.0,0.0,-0.128,-0.084,-0.13,-0.215,-0.108,0.0,-0.07,-0.709,-0.337,0.0,0.0,-0.126,-0.436,-0.173,-0.186,-0.442,0.066,-0.33,-0.462,-0.047,-0.023,0.0,0.0,0.0,-0.053,0.0,0.0,0.0,-0.254,-0.359,-0.193,-0.225,0.614,0.717,-0.034,-0.408,-0.258,0.87,0.361,-0.352,-0.364,-0.008,0.835,0.164,0.961,0.674,0.627,0.624,0.986,-0.309,-0.044,-0.261,-0.383,-0.059,-0.334,0.144,-0.258,-0.315,0.078


In [13]:
lambda_net_dataset.as_pandas().describe()

Unnamed: 0,seed,0-target,1-target,2-target,3-target,4-target,5-target,0-lstsq_lambda,1-lstsq_lambda,2-lstsq_lambda,3-lstsq_lambda,4-lstsq_lambda,5-lstsq_lambda,0-lstsq_target,1-lstsq_target,2-lstsq_target,3-lstsq_target,4-lstsq_target,5-lstsq_target,wb_0,wb_1,wb_2,wb_3,wb_4,wb_5,wb_6,wb_7,wb_8,wb_9,wb_10,wb_11,wb_12,wb_13,wb_14,wb_15,wb_16,wb_17,wb_18,wb_19,wb_20,wb_21,wb_22,wb_23,wb_24,wb_25,wb_26,wb_27,wb_28,wb_29,wb_30,wb_31,wb_32,wb_33,wb_34,wb_35,wb_36,wb_37,wb_38,wb_39,wb_40,wb_41,wb_42,wb_43,wb_44,wb_45,wb_46,wb_47,wb_48,wb_49,wb_50,wb_51,wb_52,wb_53,wb_54,wb_55,wb_56,wb_57,wb_58,wb_59,wb_60,wb_61,wb_62,wb_63,wb_64,wb_65,wb_66,wb_67,wb_68,wb_69,wb_70,wb_71,wb_72,wb_73,wb_74,wb_75,wb_76,wb_77,wb_78,wb_79,wb_80,wb_81,wb_82,wb_83,wb_84,wb_85,wb_86,wb_87,wb_88,wb_89,wb_90
count,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0,10000.0
mean,1373158606.0,0.001,-0.005,-0.0,0.004,-0.003,-0.003,0.004,-0.025,0.136,-0.544,0.793,-0.364,0.004,-0.025,0.136,-0.544,0.793,-0.364,-0.009,-0.227,0.287,0.192,0.282,0.221,0.275,-0.297,0.162,0.454,0.475,-0.088,-0.019,0.322,0.222,0.404,0.079,0.415,0.37,0.393,0.149,0.112,0.058,-0.042,-0.404,-0.09,0.123,-0.291,-0.209,-0.233,0.0,0.0,-0.015,-0.027,-0.009,-0.015,-0.051,0.0,-0.026,-0.082,-0.046,0.0,0.0,-0.045,-0.023,-0.004,0.013,-0.066,0.018,-0.031,-0.016,-0.016,0.004,0.0,0.0,0.0,-0.018,0.0,0.0,0.0,-0.254,-0.359,-0.372,-0.424,0.354,0.419,-0.21,-0.408,-0.465,0.04,-0.106,-0.352,-0.364,-0.189,0.364,-0.265,0.627,0.118,0.358,0.176,0.332,-0.521,-0.258,-0.261,-0.383,-0.059,-0.544,0.144,-0.258,-0.315,0.004
std,0.0,0.574,0.577,0.577,0.577,0.575,0.58,0.542,0.65,3.538,10.664,13.939,5.786,0.542,0.65,3.538,10.664,13.939,5.786,0.0,0.0,0.142,0.143,0.158,0.156,0.171,0.0,0.142,0.277,0.201,0.0,0.0,0.168,0.163,0.145,0.148,0.243,0.162,0.187,0.202,0.14,0.182,0.0,0.0,0.0,0.139,0.0,0.0,0.0,0.0,0.0,0.147,0.158,0.17,0.179,0.189,0.0,0.158,0.293,0.218,0.0,0.0,0.185,0.192,0.145,0.155,0.249,0.137,0.214,0.212,0.152,0.176,0.0,0.0,0.0,0.152,0.0,0.0,0.0,0.0,0.0,0.177,0.234,0.212,0.248,0.342,0.0,0.257,0.446,0.45,0.0,0.0,0.347,0.263,0.236,0.311,0.502,0.179,0.439,0.346,0.287,0.334,0.0,0.0,0.0,0.278,0.0,0.0,0.0,0.139
min,1373158606.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.038,-2.548,-19.92,-73.162,-89.155,-43.752,-1.038,-2.548,-19.92,-73.162,-89.155,-43.752,-0.009,-0.227,-0.173,-0.289,-0.227,-0.278,-0.181,-0.297,-0.291,-0.132,0.099,-0.088,-0.019,-0.129,-0.29,-0.017,-0.301,-0.172,-0.094,-0.137,-0.64,-0.303,-0.633,-0.042,-0.404,-0.09,-0.291,-0.291,-0.209,-0.233,0.0,0.0,-0.469,-0.494,-0.56,-0.601,-0.958,0.0,-0.495,-1.158,-1.224,0.0,0.0,-1.01,-0.619,-0.962,-0.551,-1.123,-0.557,-1.056,-0.804,-0.496,-0.812,0.0,0.0,0.0,-0.496,0.0,0.0,0.0,-0.254,-0.359,-1.75,-1.862,-0.03,0.066,-1.36,-0.408,-2.213,-2.456,-0.918,-0.352,-0.364,-1.083,0.002,-0.982,0.26,-2.767,-1.802,-3.101,-0.695,-2.082,-3.002,-0.261,-0.383,-0.059,-2.17,0.144,-0.258,-0.315,-0.578
25%,1373158606.0,-0.491,-0.505,-0.506,-0.494,-0.497,-0.505,-0.444,-0.453,-1.84,-5.518,-6.649,-2.946,-0.444,-0.453,-1.84,-5.518,-6.649,-2.946,-0.009,-0.227,0.182,0.083,0.174,0.093,0.154,-0.297,0.067,0.275,0.355,-0.088,-0.019,0.214,0.087,0.314,0.008,0.262,0.271,0.268,0.033,0.041,0.008,-0.042,-0.404,-0.09,0.047,-0.291,-0.209,-0.233,0.0,0.0,-0.13,-0.093,-0.108,-0.084,-0.171,0.0,-0.077,-0.331,-0.11,0.0,0.0,-0.159,-0.088,-0.071,-0.023,-0.25,-0.039,-0.142,-0.044,-0.05,-0.023,0.0,0.0,0.0,-0.055,0.0,0.0,0.0,-0.254,-0.359,-0.417,-0.441,0.222,0.261,-0.272,-0.408,-0.484,-0.115,-0.264,-0.352,-0.364,-0.278,0.196,-0.344,0.394,0.094,0.253,0.135,0.05,-0.549,-0.309,-0.261,-0.383,-0.059,-0.569,0.144,-0.258,-0.315,-0.125
50%,1373158606.0,0.005,-0.005,0.002,0.003,-0.005,0.001,0.002,-0.005,0.012,-0.028,0.022,-0.008,0.002,-0.005,0.012,-0.028,0.022,-0.008,-0.009,-0.227,0.288,0.183,0.268,0.206,0.254,-0.297,0.144,0.376,0.43,-0.088,-0.019,0.302,0.203,0.397,0.012,0.355,0.352,0.354,0.081,0.055,0.008,-0.042,-0.404,-0.09,0.079,-0.291,-0.209,-0.233,0.0,0.0,-0.002,-0.006,-0.001,-0.001,-0.005,0.0,-0.015,0.012,-0.004,0.0,0.0,-0.004,-0.002,-0.002,-0.017,-0.004,-0.0,-0.002,-0.007,-0.047,-0.022,0.0,0.0,0.0,-0.051,0.0,0.0,0.0,-0.254,-0.359,-0.335,-0.358,0.31,0.356,-0.188,-0.408,-0.389,-0.024,-0.186,-0.352,-0.364,-0.194,0.293,-0.266,0.552,0.185,0.331,0.22,0.231,-0.432,-0.129,-0.261,-0.383,-0.059,-0.46,0.144,-0.258,-0.315,0.002
75%,1373158606.0,0.5,0.49,0.496,0.507,0.495,0.502,0.459,0.395,1.952,5.079,7.322,2.59,0.459,0.395,1.952,5.079,7.322,2.59,-0.009,-0.227,0.367,0.275,0.375,0.316,0.354,-0.297,0.245,0.57,0.528,-0.088,-0.019,0.395,0.322,0.476,0.149,0.5,0.456,0.481,0.265,0.184,0.086,-0.042,-0.404,-0.09,0.199,-0.291,-0.209,-0.233,0.0,0.0,0.115,0.113,0.136,0.138,0.108,0.0,0.113,0.154,0.108,0.0,0.0,0.109,0.139,0.115,0.142,0.132,0.135,0.133,0.157,0.114,0.131,0.0,0.0,0.0,0.113,0.0,0.0,0.0,-0.254,-0.359,-0.253,-0.255,0.391,0.437,-0.088,-0.408,-0.259,0.113,-0.099,-0.352,-0.364,-0.104,0.38,-0.186,0.669,0.271,0.415,0.305,0.362,-0.31,-0.044,-0.261,-0.383,-0.059,-0.335,0.144,-0.258,-0.315,0.137
max,1373158606.0,1.0,1.0,1.0,1.0,0.999,1.0,1.031,2.526,21.905,66.969,99.952,38.263,1.031,2.526,21.905,66.969,99.952,38.263,-0.009,-0.227,0.794,0.68,0.884,0.819,1.505,-0.297,0.653,1.81,1.956,-0.088,-0.019,1.502,0.851,1.179,0.667,1.779,1.017,1.308,0.934,0.684,1.111,-0.042,-0.404,-0.09,0.723,-0.291,-0.209,-0.233,0.0,0.0,0.437,0.332,0.281,0.278,0.32,0.0,0.331,0.318,0.289,0.0,0.0,0.331,0.291,0.486,0.3,0.267,0.449,0.256,0.408,0.303,0.428,0.0,0.0,0.0,0.291,0.0,0.0,0.0,-0.254,-0.359,-0.052,-0.115,1.466,1.564,3.033,-0.408,-0.174,2.744,2.887,-0.352,-0.364,3.059,1.528,2.871,2.646,1.39,1.242,1.154,2.765,-0.207,0.441,-0.261,-0.383,-0.059,-0.236,0.144,-0.258,-0.315,0.443


In [14]:
lambda_net_dataset.X_test_data_list[0][:10]

array([[0.04991591],
       [0.55931928],
       [0.41215429],
       [0.47204459],
       [0.50702493],
       [0.36030918],
       [0.48300684],
       [0.91632547],
       [0.27265516],
       [0.88145597]])

In [15]:
lambda_net_dataset.y_test_data_list[0][:10]

array([[1.0879322 ],
       [0.92128287],
       [0.94887895],
       [1.06356541],
       [1.04696906],
       [0.91303792],
       [0.84312583],
       [0.62768837],
       [1.02064446],
       [0.61265825]])

## Generate Datasets for Interpretation-Net training

In [16]:
#generate train, test and validation data for training

lambda_net_train_dataset_list = []
lambda_net_valid_dataset_list = []
lambda_net_test_dataset_list = []

for lambda_net_dataset in lambda_net_dataset_list:
    
    
    if inet_holdout_seed_evaluation:
        complete_seed_list = list(set(lambda_net_dataset.train_settings_list['seed']))#list(weight_data.iloc[:,1].unique())

        random.seed(RANDOM_SEED)
        test_seeds = random.sample(complete_seed_list, int(len(complete_seed_list)-len(complete_seed_list)/(1/0.9)))
        lambda_net_test_dataset = lambda_net_dataset.get_lambda_nets_by_seed(test_seeds)
        complete_seed_list = list(set(complete_seed_list) - set(test_seeds))#complete_seed_list.remove(test_seeds)
        
        random.seed(RANDOM_SEED)
        valid_seeds = random.sample(complete_seed_list, int(len(complete_seed_list)-len(complete_seed_list)/(1/0.9)))
        lambda_net_valid_dataset = lambda_net_dataset.get_lambda_nets_by_seed(valid_seeds)
        complete_seed_list = list(set(complete_seed_list) - set(valid_seeds))

        train_seeds = complete_seed_list
        lambda_net_train_dataset = lambda_net_dataset.get_lambda_nets_by_seed(train_seeds)       
        
        lambda_net_train_dataset_list.append(lambda_net_train_dataset)
        lambda_net_valid_dataset_list.append(lambda_net_valid_dataset)
        lambda_net_test_dataset_list.append(lambda_net_test_dataset)
        
        del lambda_net_dataset
    else:
        lambda_net_train_with_valid_dataset, lambda_net_test_dataset = split_LambdaNetDataset(lambda_net_dataset, test_split=0.1)
        lambda_net_train_dataset, lambda_net_valid_dataset = split_LambdaNetDataset(lambda_net_train_with_valid_dataset, test_split=0.1)

        lambda_net_train_dataset_list.append(lambda_net_train_dataset)
        lambda_net_valid_dataset_list.append(lambda_net_valid_dataset)
        lambda_net_test_dataset_list.append(lambda_net_test_dataset)
    
        del lambda_net_dataset, lambda_net_train_with_valid_dataset

        
del lambda_net_dataset_list

In [17]:
lambda_net_train_dataset_list[-1].as_pandas().shape

(8100, 110)

In [18]:
lambda_net_valid_dataset_list[-1].as_pandas().shape

(900, 110)

In [19]:
lambda_net_test_dataset_list[-1].as_pandas().shape

(1000, 110)

In [20]:
lambda_net_train_dataset_list[-1].as_pandas().head()

Unnamed: 0,seed,0-target,1-target,2-target,3-target,4-target,5-target,0-lstsq_lambda,1-lstsq_lambda,2-lstsq_lambda,3-lstsq_lambda,4-lstsq_lambda,5-lstsq_lambda,0-lstsq_target,1-lstsq_target,2-lstsq_target,3-lstsq_target,4-lstsq_target,5-lstsq_target,wb_0,wb_1,wb_2,wb_3,wb_4,wb_5,wb_6,wb_7,wb_8,wb_9,wb_10,wb_11,wb_12,wb_13,wb_14,wb_15,wb_16,wb_17,wb_18,wb_19,wb_20,wb_21,wb_22,wb_23,wb_24,wb_25,wb_26,wb_27,wb_28,wb_29,wb_30,wb_31,wb_32,wb_33,wb_34,wb_35,wb_36,wb_37,wb_38,wb_39,wb_40,wb_41,wb_42,wb_43,wb_44,wb_45,wb_46,wb_47,wb_48,wb_49,wb_50,wb_51,wb_52,wb_53,wb_54,wb_55,wb_56,wb_57,wb_58,wb_59,wb_60,wb_61,wb_62,wb_63,wb_64,wb_65,wb_66,wb_67,wb_68,wb_69,wb_70,wb_71,wb_72,wb_73,wb_74,wb_75,wb_76,wb_77,wb_78,wb_79,wb_80,wb_81,wb_82,wb_83,wb_84,wb_85,wb_86,wb_87,wb_88,wb_89,wb_90
3003,1373158606,-0.34,-0.876,-0.048,-0.46,0.43,0.914,-0.318,-1.219,2.333,-7.451,9.054,-2.808,-0.318,-1.219,2.333,-7.451,9.054,-2.808,-0.009,-0.227,0.365,0.259,0.453,0.402,0.318,-0.297,0.226,0.374,0.488,-0.088,-0.019,0.366,0.424,0.474,0.009,0.532,0.531,0.522,0.506,0.174,0.102,-0.042,-0.404,-0.09,0.186,-0.291,-0.209,-0.233,0.0,0.0,0.075,0.076,-0.309,-0.326,0.076,0.0,0.076,0.083,0.075,0.0,0.0,0.076,-0.357,0.075,-0.023,-0.38,-0.265,-0.338,-0.425,0.076,0.085,0.0,0.0,0.0,0.076,0.0,0.0,0.0,-0.254,-0.359,-0.374,-0.366,0.847,1.064,-0.212,-0.408,-0.389,-0.101,-0.241,-0.352,-0.364,-0.231,1.094,-0.318,0.394,0.733,0.753,0.697,1.517,-0.428,-0.152,-0.261,-0.383,-0.059,-0.456,0.144,-0.258,-0.315,-0.075
5985,1373158606,-0.496,-0.607,0.252,0.645,-0.276,-0.135,-0.514,-0.25,-2.073,6.844,-7.191,2.565,-0.514,-0.25,-2.073,6.844,-7.191,2.565,-0.009,-0.227,0.251,0.142,0.285,0.225,0.201,-0.297,0.109,0.25,0.375,-0.088,-0.019,0.251,0.224,0.361,0.008,0.341,0.366,0.35,0.198,0.055,-0.05,-0.042,-0.404,-0.09,0.068,-0.291,-0.209,-0.233,0.0,0.0,0.126,0.128,-0.103,-0.101,0.132,0.0,0.128,0.151,0.128,0.0,0.0,0.13,-0.109,0.126,-0.023,-0.121,-0.107,-0.117,-0.111,0.13,0.169,0.0,0.0,0.0,0.129,0.0,0.0,0.0,-0.254,-0.359,-0.316,-0.332,0.31,0.358,-0.165,-0.408,-0.366,-0.052,-0.171,-0.352,-0.364,-0.175,0.296,-0.248,0.393,0.186,0.336,0.222,0.298,-0.431,-0.231,-0.261,-0.383,-0.059,-0.451,0.144,-0.258,-0.315,-0.122
6925,1373158606,-0.316,0.287,0.323,-0.293,-0.713,0.105,-0.199,-0.106,0.0,-0.0,0.0,-0.0,-0.199,-0.106,0.0,-0.0,0.0,-0.0,-0.009,-0.227,0.268,0.162,0.255,0.193,0.223,-0.297,0.129,0.282,0.391,-0.088,-0.019,0.27,0.189,0.376,0.012,0.303,0.341,0.319,0.097,0.077,0.009,-0.042,-0.404,-0.09,0.089,-0.291,-0.209,-0.233,0.0,0.0,0.058,0.059,0.001,0.001,0.058,0.0,0.059,0.052,0.058,0.0,0.0,0.058,0.001,0.058,-0.018,0.001,0.001,0.001,0.002,0.059,0.065,0.0,0.0,0.0,0.059,0.0,0.0,0.0,-0.254,-0.359,-0.289,-0.286,0.237,0.266,-0.13,-0.408,-0.312,-0.017,-0.153,-0.352,-0.364,-0.147,0.2,-0.23,0.395,0.12,0.277,0.159,0.098,-0.358,-0.105,-0.261,-0.383,-0.059,-0.383,0.144,-0.258,-0.315,-0.059
9619,1373158606,0.123,-0.67,0.322,-0.94,0.676,0.843,0.126,-1.059,4.63,-15.522,19.791,-7.659,0.126,-1.059,4.63,-15.522,19.791,-7.659,-0.009,-0.227,0.328,0.222,0.356,0.392,0.282,-0.297,0.189,0.333,0.45,-0.088,-0.019,0.329,0.423,0.436,0.456,0.234,0.275,0.252,0.564,0.137,0.062,-0.042,-0.404,-0.09,0.149,-0.291,-0.209,-0.233,0.0,0.0,-0.004,-0.007,-0.19,-0.316,-0.01,0.0,-0.006,-0.013,-0.015,0.0,0.0,-0.012,-0.339,-0.014,-0.338,0.054,0.066,0.06,-0.351,-0.005,-0.002,0.0,0.0,0.0,-0.005,0.0,0.0,0.0,-0.254,-0.359,-0.333,-0.321,0.535,1.018,-0.17,-0.408,-0.341,-0.06,-0.203,-0.352,-0.364,-0.191,1.024,-0.279,1.886,0.067,0.225,0.107,1.356,-0.371,-0.052,-0.261,-0.383,-0.059,-0.402,0.144,-0.258,-0.315,0.073
6916,1373158606,-0.533,-0.333,0.694,-0.708,0.255,-0.974,-0.413,-0.67,0.032,-0.073,0.074,-0.027,-0.413,-0.67,0.032,-0.073,0.074,-0.027,-0.009,-0.227,0.336,0.23,0.198,0.133,0.291,-0.297,0.197,0.355,0.459,-0.088,-0.019,0.338,0.129,0.444,0.008,0.244,0.282,0.26,0.034,0.144,0.082,-0.042,-0.404,-0.09,0.156,-0.291,-0.209,-0.233,0.0,0.0,0.097,0.097,-0.002,-0.001,0.102,0.0,0.097,0.119,0.1,0.0,0.0,0.101,-0.001,0.098,-0.023,-0.004,-0.003,-0.004,-0.039,0.097,0.112,0.0,0.0,0.0,0.097,0.0,0.0,0.0,-0.254,-0.359,-0.363,-0.361,0.17,0.189,-0.204,-0.408,-0.387,-0.092,-0.226,-0.352,-0.364,-0.221,0.121,-0.303,0.393,0.051,0.211,0.092,0.049,-0.431,-0.166,-0.261,-0.383,-0.059,-0.457,0.144,-0.258,-0.315,-0.092


In [21]:
lambda_net_valid_dataset_list[-1].as_pandas().head()

Unnamed: 0,seed,0-target,1-target,2-target,3-target,4-target,5-target,0-lstsq_lambda,1-lstsq_lambda,2-lstsq_lambda,3-lstsq_lambda,4-lstsq_lambda,5-lstsq_lambda,0-lstsq_target,1-lstsq_target,2-lstsq_target,3-lstsq_target,4-lstsq_target,5-lstsq_target,wb_0,wb_1,wb_2,wb_3,wb_4,wb_5,wb_6,wb_7,wb_8,wb_9,wb_10,wb_11,wb_12,wb_13,wb_14,wb_15,wb_16,wb_17,wb_18,wb_19,wb_20,wb_21,wb_22,wb_23,wb_24,wb_25,wb_26,wb_27,wb_28,wb_29,wb_30,wb_31,wb_32,wb_33,wb_34,wb_35,wb_36,wb_37,wb_38,wb_39,wb_40,wb_41,wb_42,wb_43,wb_44,wb_45,wb_46,wb_47,wb_48,wb_49,wb_50,wb_51,wb_52,wb_53,wb_54,wb_55,wb_56,wb_57,wb_58,wb_59,wb_60,wb_61,wb_62,wb_63,wb_64,wb_65,wb_66,wb_67,wb_68,wb_69,wb_70,wb_71,wb_72,wb_73,wb_74,wb_75,wb_76,wb_77,wb_78,wb_79,wb_80,wb_81,wb_82,wb_83,wb_84,wb_85,wb_86,wb_87,wb_88,wb_89,wb_90
8323,1373158606,0.31,0.452,0.977,-0.064,-0.261,0.704,0.268,0.991,0.838,-6.198,11.956,-5.805,0.268,0.991,0.838,-6.198,11.956,-5.805,-0.009,-0.227,0.13,0.082,0.396,0.332,0.109,-0.297,0.067,1.235,0.735,-0.088,-0.019,0.128,0.335,0.165,0.183,0.465,0.479,0.472,0.265,0.041,0.008,-0.042,-0.404,-0.09,0.047,-0.291,-0.209,-0.233,0.0,0.0,-0.13,-0.083,0.069,0.071,-0.109,0.0,-0.069,-0.679,-0.581,0.0,0.0,-0.128,0.067,-0.165,0.076,0.06,0.071,0.064,0.058,-0.047,-0.023,0.0,0.0,0.0,-0.052,0.0,0.0,0.0,-0.254,-0.359,-0.198,-0.225,0.356,0.385,-0.04,-0.408,-0.258,1.138,1.249,-0.352,-0.364,-0.032,0.319,-0.049,0.551,0.238,0.396,0.278,0.22,-0.31,-0.045,-0.261,-0.383,-0.059,-0.334,0.144,-0.258,-0.315,0.083
5392,1373158606,0.502,-0.666,-0.803,0.182,0.483,-0.702,0.496,-0.26,-3.834,8.733,-9.909,3.801,0.496,-0.26,-3.834,8.733,-9.909,3.801,-0.009,-0.227,0.484,0.346,0.059,-0.012,0.442,-0.297,0.309,0.528,0.625,-0.088,-0.019,0.486,-0.017,0.603,-0.171,0.148,0.151,0.144,-0.158,0.281,0.271,-0.042,-0.404,-0.09,0.286,-0.291,-0.209,-0.233,0.0,0.0,-0.03,-0.198,0.162,0.171,-0.149,0.0,-0.218,-0.171,-0.002,0.0,0.0,-0.088,0.17,-0.001,0.188,0.108,0.159,0.139,0.21,-0.199,-0.205,0.0,0.0,0.0,-0.203,0.0,0.0,0.0,-0.254,-0.359,-0.495,-0.566,0.183,0.269,-0.373,-0.408,-0.632,-0.25,-0.363,-0.352,-0.364,-0.353,0.205,-0.44,0.671,0.016,0.18,0.061,0.264,-0.731,-0.549,-0.261,-0.383,-0.059,-0.743,0.144,-0.258,-0.315,0.173
4718,1373158606,0.39,0.568,-0.964,0.19,0.177,-0.243,0.407,0.144,1.337,-4.38,3.622,-0.967,0.407,0.144,1.337,-4.38,3.622,-0.967,-0.009,-0.227,0.333,0.232,0.199,0.136,0.303,-0.297,0.198,0.311,0.453,-0.088,-0.019,0.344,0.129,0.434,-0.01,0.243,0.289,0.263,0.008,0.174,0.008,-0.042,-0.404,-0.09,0.177,-0.291,-0.209,-0.233,0.0,0.0,-0.126,-0.125,0.13,0.13,-0.137,0.0,-0.129,0.016,-0.14,0.0,0.0,-0.133,0.133,-0.135,0.132,0.136,0.128,0.132,0.155,-0.113,-0.022,0.0,0.0,0.0,-0.115,0.0,0.0,0.0,-0.254,-0.359,-0.399,-0.428,0.254,0.299,-0.253,-0.408,-0.473,-0.041,-0.244,-0.352,-0.364,-0.258,0.236,-0.322,0.562,0.13,0.279,0.166,0.2,-0.581,-0.045,-0.261,-0.383,-0.059,-0.588,0.144,-0.258,-0.315,0.122
990,1373158606,-0.178,-0.531,-0.961,0.759,0.931,-0.317,-0.157,-1.101,3.285,-11.782,16.44,-7.057,-0.157,-1.101,3.285,-11.782,16.44,-7.057,-0.009,-0.227,0.338,0.232,0.39,0.323,0.292,-0.297,0.199,0.35,0.461,-0.088,-0.019,0.34,0.337,0.447,0.01,0.443,0.435,0.432,0.367,0.147,0.077,-0.042,-0.404,-0.09,0.159,-0.291,-0.209,-0.233,0.0,0.0,0.042,0.042,-0.219,-0.241,0.041,0.0,0.042,0.039,0.041,0.0,0.0,0.041,-0.253,0.042,-0.022,-0.27,-0.231,-0.263,-0.276,0.042,0.044,0.0,0.0,0.0,0.042,0.0,0.0,0.0,-0.254,-0.359,-0.344,-0.334,0.595,0.712,-0.182,-0.408,-0.357,-0.071,-0.212,-0.352,-0.364,-0.202,0.667,-0.289,0.394,0.45,0.559,0.468,0.896,-0.393,-0.109,-0.261,-0.383,-0.059,-0.421,0.144,-0.258,-0.315,-0.042
6498,1373158606,0.861,-0.387,-0.55,0.46,0.491,-0.056,0.779,-0.195,-0.003,0.007,-0.007,0.003,0.779,-0.195,-0.003,0.007,-0.007,0.003,-0.009,-0.227,0.328,0.237,0.221,0.157,0.301,-0.297,0.208,0.178,0.449,-0.088,-0.019,0.339,0.15,0.431,0.013,0.263,0.311,0.284,0.03,0.041,0.008,-0.042,-0.404,-0.09,0.047,-0.291,-0.209,-0.233,0.0,0.0,-0.0,0.0,0.205,0.206,-0.0,0.0,0.0,0.223,-0.001,0.0,0.0,-0.0,0.213,-0.001,0.207,0.219,0.199,0.21,0.254,-0.047,-0.023,0.0,0.0,0.0,-0.052,0.0,0.0,0.0,-0.254,-0.359,-0.35,-0.372,0.325,0.377,-0.198,-0.408,-0.405,0.072,-0.203,-0.352,-0.364,-0.208,0.315,-0.281,0.647,0.2,0.342,0.233,0.291,-0.31,-0.045,-0.261,-0.383,-0.059,-0.335,0.144,-0.258,-0.315,0.183


In [22]:
lambda_net_test_dataset_list[-1].as_pandas().head()

Unnamed: 0,seed,0-target,1-target,2-target,3-target,4-target,5-target,0-lstsq_lambda,1-lstsq_lambda,2-lstsq_lambda,3-lstsq_lambda,4-lstsq_lambda,5-lstsq_lambda,0-lstsq_target,1-lstsq_target,2-lstsq_target,3-lstsq_target,4-lstsq_target,5-lstsq_target,wb_0,wb_1,wb_2,wb_3,wb_4,wb_5,wb_6,wb_7,wb_8,wb_9,wb_10,wb_11,wb_12,wb_13,wb_14,wb_15,wb_16,wb_17,wb_18,wb_19,wb_20,wb_21,wb_22,wb_23,wb_24,wb_25,wb_26,wb_27,wb_28,wb_29,wb_30,wb_31,wb_32,wb_33,wb_34,wb_35,wb_36,wb_37,wb_38,wb_39,wb_40,wb_41,wb_42,wb_43,wb_44,wb_45,wb_46,wb_47,wb_48,wb_49,wb_50,wb_51,wb_52,wb_53,wb_54,wb_55,wb_56,wb_57,wb_58,wb_59,wb_60,wb_61,wb_62,wb_63,wb_64,wb_65,wb_66,wb_67,wb_68,wb_69,wb_70,wb_71,wb_72,wb_73,wb_74,wb_75,wb_76,wb_77,wb_78,wb_79,wb_80,wb_81,wb_82,wb_83,wb_84,wb_85,wb_86,wb_87,wb_88,wb_89,wb_90
7217,1373158606,0.715,0.432,0.981,-0.843,-0.74,-0.831,0.712,0.906,-3.423,12.732,-17.553,6.388,0.712,0.906,-3.423,12.732,-17.553,6.388,-0.009,-0.227,0.551,0.483,0.261,0.198,0.594,-0.297,0.456,0.193,0.669,-0.088,-0.019,0.588,0.188,0.624,0.057,0.299,0.353,0.322,0.066,0.478,0.008,-0.042,-0.404,-0.09,0.457,-0.291,-0.209,-0.233,0.0,0.0,-0.299,-0.368,0.185,0.185,-0.419,0.0,-0.398,0.207,-0.363,0.0,0.0,-0.408,0.191,-0.338,0.183,0.199,0.18,0.191,0.222,-0.388,-0.022,0.0,0.0,0.0,-0.386,0.0,0.0,0.0,-0.254,-0.359,-0.91,-1.276,0.346,0.391,-1.018,-0.408,-1.491,0.093,-0.683,-0.352,-0.364,-0.898,0.329,-0.754,0.625,0.223,0.369,0.257,0.278,-1.739,-0.045,-0.261,-0.383,-0.059,-1.706,0.144,-0.258,-0.315,0.164
8291,1373158606,-0.213,0.402,-0.761,-0.041,-0.451,-0.243,-0.18,-0.014,0.537,-0.987,-1.331,0.714,-0.18,-0.014,0.537,-0.987,-1.331,0.714,-0.009,-0.227,0.286,0.453,0.253,0.189,0.516,-0.297,0.36,0.711,0.416,-0.088,-0.019,0.547,0.188,0.397,0.008,0.307,0.337,0.323,0.032,0.35,0.471,-0.042,-0.404,-0.09,0.354,-0.291,-0.209,-0.233,0.0,0.0,0.093,-0.168,-0.002,-0.002,-0.293,0.0,-0.286,-0.405,0.085,0.0,0.0,-0.205,-0.002,0.095,-0.023,-0.003,-0.002,-0.002,-0.037,-0.281,-0.373,0.0,0.0,0.0,-0.253,0.0,0.0,0.0,-0.254,-0.359,-0.323,-0.733,0.222,0.248,-0.644,-0.408,-0.88,-0.641,-0.18,-0.352,-0.364,-0.542,0.181,-0.261,0.393,0.103,0.263,0.144,0.047,-1.077,-1.181,-0.261,-0.383,-0.059,-1.001,0.144,-0.258,-0.315,-0.108
4607,1373158606,0.514,-0.41,-0.766,-0.485,-0.225,-0.896,0.504,-0.223,-2.133,3.54,-5.265,1.293,0.504,-0.223,-2.133,3.54,-5.265,1.293,-0.009,-0.227,0.663,0.56,-0.115,-0.166,0.64,-0.297,0.53,0.752,0.788,-0.088,-0.019,0.691,-0.167,0.791,-0.254,0.468,0.018,0.278,-0.281,0.481,0.535,-0.042,-0.404,-0.09,0.485,-0.291,-0.209,-0.233,0.0,0.0,-0.327,-0.412,0.115,0.166,-0.442,0.0,-0.424,-0.484,-0.294,0.0,0.0,-0.369,0.167,-0.166,0.253,0.137,0.117,0.149,0.28,-0.411,-0.428,0.0,0.0,0.0,-0.414,0.0,0.0,0.0,-0.254,-0.359,-0.857,-1.084,0.122,0.305,-0.857,-0.408,-1.183,-0.686,-0.608,-0.352,-0.364,-0.755,0.244,-0.657,0.805,-0.222,0.051,-0.113,0.405,-1.307,-1.185,-0.261,-0.383,-0.059,-1.311,0.144,-0.258,-0.315,0.132
5114,1373158606,-0.373,0.948,0.872,0.284,-0.963,0.497,-0.337,0.233,4.836,-8.996,8.775,-3.282,-0.337,0.233,4.836,-8.996,8.775,-3.282,-0.009,-0.227,0.137,0.087,0.503,0.43,0.117,-0.297,0.07,0.643,0.305,-0.088,-0.019,0.138,0.439,0.243,0.258,0.581,0.582,0.582,0.367,0.042,0.008,-0.042,-0.404,-0.09,0.048,-0.291,-0.209,-0.233,0.0,0.0,-0.138,-0.091,-0.011,-0.058,-0.117,0.0,-0.075,-0.182,-0.05,0.0,0.0,-0.136,-0.059,-0.05,-0.155,-0.044,-0.015,-0.038,-0.158,-0.051,-0.024,0.0,0.0,0.0,-0.057,0.0,0.0,0.0,-0.254,-0.359,-0.19,-0.225,0.474,0.509,-0.032,-0.408,-0.257,0.262,0.036,-0.352,-0.364,-0.006,0.445,-0.043,0.778,0.354,0.505,0.39,0.407,-0.309,-0.044,-0.261,-0.383,-0.059,-0.334,0.144,-0.258,-0.315,-0.328
1859,1373158606,0.664,-0.38,0.886,-0.641,0.325,0.344,0.522,0.397,-0.001,0.002,-0.002,0.001,0.522,0.397,-0.001,0.002,-0.002,0.001,-0.009,-0.227,0.25,0.082,0.29,0.227,0.203,-0.297,0.067,0.285,0.367,-0.088,-0.019,0.25,0.223,0.354,0.086,0.338,0.377,0.355,0.127,0.041,0.008,-0.042,-0.404,-0.09,0.047,-0.291,-0.209,-0.233,0.0,0.0,0.0,-0.083,0.148,0.147,0.0,0.0,-0.069,0.136,-0.001,0.0,0.0,-0.0,0.151,-0.0,0.145,0.157,0.145,0.153,0.167,-0.047,-0.022,0.0,0.0,0.0,-0.052,0.0,0.0,0.0,-0.254,-0.359,-0.251,-0.225,0.319,0.356,-0.086,-0.408,-0.258,0.062,-0.116,-0.352,-0.364,-0.108,0.291,-0.194,0.561,0.198,0.351,0.236,0.216,-0.31,-0.045,-0.261,-0.383,-0.059,-0.334,0.144,-0.258,-0.315,0.134


## Interpretation Network Training

In [None]:
(history_list, 
scores_list, 

function_values_complete_list, 
function_values_valid_list, 
function_values_test_list, 

inet_preds_list, 
inet_preds_valid_list, 
inet_preds_test_list, 

distrib_dict_list,
model_list) = calculate_interpretation_net_results(lambda_net_train_dataset_list, 
                                                   lambda_net_valid_dataset_list, 
                                                   lambda_net_test_dataset_list)

[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.


[Parallel(n_jobs=15)]: Using backend LokyBackend with 15 concurrent workers.
[Parallel(n_jobs=15)]: Done   1 tasks      | elapsed:  4.0min
[Parallel(n_jobs=15)]: Done   2 tasks      | elapsed:  4.5min
[Parallel(n_jobs=15)]: Done   3 tasks      | elapsed:  5.5min
[Parallel(n_jobs=15)]: Done   4 tasks      | elapsed:  5.8min
[Parallel(n_jobs=15)]: Done   5 tasks      | elapsed:  7.4min
[Parallel(n_jobs=15)]: Done   6 tasks      | elapsed:  7.5min
[Parallel(n_jobs=15)]: Done   7 tasks      | elapsed:  7.6min
[Parallel(n_jobs=15)]: Done   8 tasks      | elapsed:  8.6min
[Parallel(n_jobs=15)]: Done   9 tasks      | elapsed:  8.7min
[Parallel(n_jobs=15)]: Done  10 tasks      | elapsed:  8.8min
[Parallel(n_jobs=15)]: Done  11 tasks      | elapsed:  9.6min
[Parallel(n_jobs=15)]: Done  12 tasks      | elapsed: 10.3min
[Parallel(n_jobs=15)]: Done  13 tasks      | elapsed: 11.4min
[Parallel(n_jobs=15)]: Done  14 tasks      | elapsed: 11.7min
[Parallel(n_jobs=15)]: Done  15 tasks      | elapsed: 1

## Evaluate Interpretation Net

In [None]:
# Done   1 out of   1 | elapsed:  1.6min remaining:    0.0s

In [None]:
if nas:
    for trial in history_list[-1]: 
        print(trial.summary())

In [None]:
if len(model_list) >= 1:
    print(model_list[-1].summary())
    print(model_list[-1].get_config())

In [None]:
scores_list[-1]

In [None]:
distrib_dict_list[-1]['MAE']

In [None]:
distrib_dict_list[-1]['R2']

In [None]:
index_min = int(np.argmin(distrib_dict_list[-1]['R2'].loc['predLambda_VS_lstsqTarget_test']))
polynomial_inet = inet_preds_test_list[-1][index_min]

print(distrib_dict_list[-1]['R2'].loc['inetPoly_VS_targetPoly_test'][index_min])

print_polynomial_from_coefficients(polynomial_inet)

In [None]:
index_min = int(np.argmin(distrib_dict_list[-1]['R2'].loc['predLambda_VS_lstsqTarget_test']))
#polynomial_target = lambda_net_test_dataset.get_lambda_net_by_lambda_index(index_min).target_polynomial

print(distrib_dict_list[-1]['R2'].loc['lstsqLambda_VS_targetPoly_test'][index_min])

polynomial_lambda = lambda_net_test_dataset.lstsq_lambda_pred_polynomial_list[index_min]
print_polynomial_from_coefficients(polynomial_lambda, force_complete_poly_representation=True)

In [None]:
index_min = int(np.argmin(distrib_dict_list[-1]['R2'].loc['predLambda_VS_lstsqTarget_test']))
#polynomial_target = lambda_net_test_dataset.get_lambda_net_by_lambda_index(index_min).target_polynomial
polynomial_target = lambda_net_test_dataset.target_polynomial_list[index_min]
print_polynomial_from_coefficients(polynomial_target, force_complete_poly_representation=True)

In [None]:
index_max = int(np.argmax(distrib_dict_list[-1]['R2'].loc['predLambda_VS_lstsqTarget_test']))
polynomial_inet = inet_preds_test_list[-1][index_max]

print(distrib_dict_list[-1]['R2'].loc['inetPoly_VS_targetPoly_test'][index_max])

print_polynomial_from_coefficients(polynomial_inet)

In [None]:
index_max = int(np.argmax(distrib_dict_list[-1]['R2'].loc['predLambda_VS_lstsqTarget_test']))
#polynomial_target = lambda_net_test_dataset.get_lambda_net_by_lambda_index(index_min).target_polynomial

print(distrib_dict_list[-1]['R2'].loc['lstsqLambda_VS_targetPoly_test'][index_max])

polynomial_lambda = lambda_net_test_dataset.lstsq_lambda_pred_polynomial_list[index_max]
print_polynomial_from_coefficients(polynomial_lambda, force_complete_poly_representation=True)

In [None]:
index_max = int(np.argmax(distrib_dict_list[-1]['R2'].loc['predLambda_VS_lstsqTarget_test']))
#polynomial_target = lambda_net_test_dataset.get_lambda_net_by_lambda_index(index_min).target_polynomial
polynomial_target = lambda_net_test_dataset.target_polynomial_list[index_max]
print_polynomial_from_coefficients(polynomial_target, force_complete_poly_representation=True)

In [None]:
r2_values_inet = distrib_dict_list[-1]['R2'].loc['inetPoly_VS_targetPoly_test']
print('Mean: ' + str(np.mean(r2_values_inet)) + ' (' + str(r2_values_inet.shape[0]) + ' Samples)')

r2_values_positive_inet = r2_values_inet[r2_values_inet>0]
print('Mean (only positive): ' + str(np.mean(r2_values_positive_inet)) + ' (' + str(r2_values_positive_inet.shape[0]) + ' Samples)')



In [None]:
r2_values_lstsq_lambda = distrib_dict_list[-1]['R2'].loc['lstsqLambda_VS_targetPoly_test']
print('Mean: ' + str(np.mean(r2_values_lstsq_lambda)) + ' (' + str(r2_values_inet.shape[0]) + ' Samples)')

r2_values_positive_lstsq_lambda = r2_values_lstsq_lambda[r2_values_lstsq_lambda>0]
print('Mean (only positive): ' + str(np.mean(r2_values_positive_lstsq_lambda)) + ' (' + str(r2_values_positive_lstsq_lambda.shape[0]) + ' Samples)')



In [None]:
sns.histplot(distrib_dict_list[-1]['MAE'].loc['inetPoly_VS_targetPoly_test'][distrib_dict_list[-1]['R2'].loc['inetPoly_VS_targetPoly_test'] > -50])


In [None]:
sns.histplot(distrib_dict_list[-1]['MAE'].loc['lstsqLambda_VS_targetPoly_test'][distrib_dict_list[-1]['R2'].loc['lstsqLambda_VS_targetPoly_test'] > -50])


In [None]:
p = sns.histplot(distrib_dict_list[-1]['R2'].loc['inetPoly_VS_targetPoly_test'][distrib_dict_list[-1]['R2'].loc['inetPoly_VS_targetPoly_test'] > -50], binwidth=0.2)
p.set(xlim=(-30, 1))

In [None]:
p = sns.histplot(distrib_dict_list[-1]['R2'].loc['inetPoly_VS_targetPoly_test'][distrib_dict_list[-1]['R2'].loc['inetPoly_VS_targetPoly_test'] > -50], binwidth=0.1)
p.set(xlim=(0, 1))

In [None]:
p = sns.histplot(distrib_dict_list[-1]['R2'].loc['lstsqLambda_VS_targetPoly_test'][distrib_dict_list[-1]['R2'].loc['lstsqLambda_VS_targetPoly_test'] > -50], binwidth=0.2)
p.set(xlim=(-10, 1))

In [None]:
p = sns.histplot(distrib_dict_list[-1]['R2'].loc['lstsqLambda_VS_targetPoly_test'][distrib_dict_list[-1]['R2'].loc['lstsqLambda_VS_targetPoly_test'] > -50], binwidth=0.1)
p.set(xlim=(0, 1))

In [None]:
if not nas:
    history = history_list[-1]

    plt.plot(history[list(history.keys())[1]])
    if consider_labels_training or evaluate_with_real_function:
        plt.plot(history[list(history.keys())[len(history.keys())//2+1]]) 
    plt.title('model ' + list(history.keys())[1])
    plt.ylabel('metric')
    plt.xlabel('epoch')
    plt.legend(['train', 'valid'], loc='upper left')
    plt.savefig('./data/results/' + path_identifier_interpretation_net_data + '/metric_' + '_epoch_' + str(epochs_lambda).zfill(3) + '.png')


In [None]:
if not nas:
    history = history_list[-1]

    plt.plot(history['loss'])
    if consider_labels_training or evaluate_with_real_function:
        plt.plot(history['val_loss'])
    plt.title('model loss')
    plt.ylabel('loss')
    plt.xlabel('epoch')
    plt.legend(['train', 'valid'], loc='upper left')
    plt.savefig('./data/results/' + path_identifier_interpretation_net_data + '/loss_' + '_epoch_' + str(epochs_lambda).zfill(3) + '.png')    


### Multi Epoch/Sampes Analysis

### Generate Comparison Plots

In [None]:
if len(scores_list) > 1:
    plot_metric_list = ['MAE FV', 'RMSE FV', 'MAPE FV', 'R2 FV', 'RAAE FV', 'RMAE FV']

    generate_inet_comparison_plot(scores_list, plot_metric_list)

In [None]:
if len(scores_list) > 1:
    plot_metric_list = ['MAE FV']

    generate_inet_comparison_plot(scores_list, plot_metric_list)

In [None]:
if len(scores_list) > 1:
    plot_metric_list = ['R2 FV']

    generate_inet_comparison_plot(scores_list, plot_metric_list, ylim=(-5, 1))

#### Generate and Analyze Predictions for Random Function

In [None]:
index = 6

polynomial_target = lambda_net_test_dataset.target_polynomial_list[index]
polynomial_lstsq_target = lambda_net_test_dataset.lstsq_target_polynomial_list[index]
polynomial_lstsq_lambda = lambda_net_test_dataset.lstsq_lambda_pred_polynomial_list[index]
polynomial_inet = inet_preds_test_list[-1][index]

print('Target Poly:')
print_polynomial_from_coefficients(polynomial_target, force_complete_poly_representation=True, round_digits=4)
print('LSTSQ Target Poly:')
print_polynomial_from_coefficients(polynomial_lstsq_target, force_complete_poly_representation=True, round_digits=4)
print('LSTSQ Lambda Poly:')
print_polynomial_from_coefficients(polynomial_lstsq_lambda, force_complete_poly_representation=True, round_digits=4)
print('I-Net Poly:')
print_polynomial_from_coefficients(polynomial_inet, round_digits=4)


In [None]:
lambda_net_test_dataset_list[0].X_test_data_list[0][:10]

In [None]:
lambda_net_test_dataset_list[0].y_test_data_list[0][:10]

In [None]:
plot_and_save_single_polynomial_prediction_evaluation(lambda_net_test_dataset_list, 
                                                      function_values_test_list, 
                                                      inet_preds_test_list,
                                                      rand_index=index, 
                                                      plot_type=1)

In [None]:
plot_and_save_single_polynomial_prediction_evaluation(lambda_net_test_dataset_list, 
                                                      function_values_test_list, 
                                                      inet_preds_test_list,
                                                      rand_index=index, 
                                                      plot_type=2)

In [None]:
plot_and_save_single_polynomial_prediction_evaluation(lambda_net_test_dataset_list, 
                                                      function_values_test_list, 
                                                      inet_preds_test_list,
                                                      rand_index=index, 
                                                      plot_type=3)

# BENCHMARK (RANDOM GUESS) EVALUATION

In [None]:
list_of_random_polynomials = np.random.uniform(low=-10, high=10, size=(len(lambda_net_test_dataset_list[-1]), sparsity))

In [None]:
true_fv_test = parallel_fv_calculation_from_polynomial(lambda_net_test_dataset_list[-1].target_polynomial_list, lambda_net_test_dataset_list[-1].X_test_data_list, force_complete_poly_representation=True)
random_fv_test = parallel_fv_calculation_from_polynomial(list_of_random_polynomials, lambda_net_test_dataset_list[-1].X_test_data_list, force_complete_poly_representation=True)

In [None]:
print('Random Guess Error Coefficients: ' + str(np.round(mean_absolute_error(lambda_net_test_dataset_list[-1].target_polynomial_list, list_of_random_polynomials), 4)))

In [None]:
print('Random Guess Error FVs: ' + str(np.round(mean_absolute_error_function_values(true_fv_test, random_fv_test), 4)))

# BENCHMARK (EDUCATED GUESS/MEAN PREDICTION) EVALUATION

In [None]:
true_fv_train = parallel_fv_calculation_from_polynomial(lambda_net_test_dataset_list[-1].target_polynomial_list, lambda_net_test_dataset_list[-1].X_test_data_list, force_complete_poly_representation=True)

mean_fv = np.mean(true_fv_train)
mean_fv_pred_test = [mean_fv for _ in range(true_fv_test.shape[0])]

In [None]:
print('Educated Guess/Mean Prediction Error FVs: ' + str(np.round(mean_absolute_error_function_values(true_fv_test, mean_fv_pred_test), 4)))

In [None]:
%%script false --no-raise-error

base_model = generate_base_model()
random_evaluation_dataset = np.random.uniform(low=x_min, high=x_max, size=(random_evaluation_dataset_size, n))
#random_evaluation_dataset = lambda_train_input_train_split[0]#lambda_train_input[0] #JUST [0] HERE BECAUSE EVALUATION ALWAYS ON THE SAME DATASET FOR ALL!!
list_of_monomial_identifiers_numbers = np.array([list(monomial_identifiers) for monomial_identifiers in list_of_monomial_identifiers]).astype(float)


loss_function = mean_absolute_error_tf_fv_lambda_extended_wrapper(random_evaluation_dataset, list_of_monomial_identifiers_numbers, base_model)      

X_train = X_train_list[-1].values[:,1:]
y_train = y_train_list[-1].values[:,2:]

#X_train = X_train[:,1:]
y_train_model = np.hstack((y_train, X_train))

print('seed_in_inet_training = ' + str(seed_in_inet_training), loss_function(y_train_model, y_train))


seed_in_inet_training = False

loss_function = mean_absolute_error_tf_fv_lambda_extended_wrapper(random_evaluation_dataset, list_of_monomial_identifiers_numbers, base_model)      

X_train = X_train_list[-1].values[:,1:]
y_train = y_train_list[-1].values[:,2:]

X_train = X_train[:,1:]
y_train_model = np.hstack((y_train, X_train))

print('seed_in_inet_training = ' + str(seed_in_inet_training), loss_function(y_train_model, y_train))

seed_in_inet_training = True



In [None]:
if use_gpu:
    from numba import cuda 
    device = cuda.get_current_device()
    device.reset()