# Inerpretation-Net Training

# Experiment 1: I-Net Performance for Different Algebras and Complexities
# Experiment 2: I-Net Performance Comparison for λ-Nets with Different Training Levels
# Experiment 3: I-Net Performance Comparison Different Training Data Sizes

## Specitication of Experiment Settings

In [1]:
#######################################################################################################################################
###################################################### CONFIG FILE ####################################################################
#######################################################################################################################################
sleep_time = 0 #minutes


config = {
    'data': {
        'd': 5, #degree
        'n': 1, #number of variables
        'sparsity': None,
        'sample_sparsity': None,
        'x_max': 1,
        'x_min': 0,
        'x_distrib': 'uniform', #'normal', 'uniform', 'beta', 'Gamma', 'laplace'
        'a_max': 1,
        'a_min': -1,
        'lambda_nets_total': 10000,
        'noise': 0.1,
        'noise_distrib': 'normal', #'normal', 'uniform', 'beta', 'Gamma', 'laplace'
        
        'same_training_all_lambda_nets': False,

        'fixed_seed_lambda_training': True,
        'fixed_initialization_lambda_training': False,
        'number_different_lambda_trainings': 1,
    },
    'lambda_net': {
        'epochs_lambda': 1000,
        'early_stopping_lambda': True,  #if early stopping is used, multi_epoch_analysis is deactivated
        'early_stopping_min_delta_lambda': 1e-4,
        'batch_lambda': 64,
        'dropout': 0,
        'lambda_network_layers': [5*'sparsity'],
        'optimizer_lambda': 'adam',
        'loss_lambda': 'mae',
        'number_of_lambda_weights': None,
        'lambda_dataset_size': 5000,
    },
    'i_net': {
        'optimizer': 'custom',#adam
        'inet_loss': 'mae',
        'inet_metrics': ['r2'],
        'dropout': 0.25,
        'dropout_output': 0,
        'epochs': 500,
        'early_stopping': True,
        'batch_size': 512,
        'dense_layers': [512, 1024],
        'convolution_layers': None,
        'lstm_layers': None,
        'interpretation_dataset_size': 1000,
                
        'interpretation_net_output_monomials': 3, #(None, int)
        'interpretation_net_output_shape': None, #calculated automatically later
        
        'evaluate_with_real_function': False,
        'consider_labels_training': False,
                      
        'data_reshape_version': None, #default to 2 options:(None, 0,1 2)
        'nas': True,
        'nas_type': 'SEQUENTIAL', #options:(None, 'SEQUENTIAL', 'CNN', 'LSTM', 'CNN-LSTM', 'CNN-LSTM-parallel')      
        'nas_trials': 1,
    },
    'evaluation': {   
        'inet_holdout_seed_evaluation': False,
        
        #set if multi_epoch_analysis should be performed
        'multi_epoch_analysis': True,
        'each_epochs_save_lambda': 20,
        'epoch_start': 0, #use to skip first epochs in multi_epoch_analysis
        
        #set if samples analysis should be performed
        'samples_list': None,#[100, 500, 750, 1000, 2500, 5000, 7500, 10000, 15000, 20000, 25000, 28125] 
       
        'random_evaluation_dataset_size': 500,
    },
    'computation':{
        'n_jobs': 11,
        'use_gpu': False,
        'gpu_numbers': '0',
        'RANDOM_SEED': 42,   
    }
}

## Imports

In [2]:
#######################################################################################################################################
########################################### IMPORT GLOBAL VARIABLES FROM CONFIG #######################################################
#######################################################################################################################################
globals().update(config['data'])
globals().update(config['lambda_net'])
globals().update(config['i_net'])
globals().update(config['evaluation'])
globals().update(config['computation'])

In [3]:
#######################################################################################################################################
##################################################### IMPORT LIBRARIES ################################################################
#######################################################################################################################################
from itertools import product       
from tqdm import tqdm_notebook as tqdm
import pickle
import numpy as np
import pandas as pd
import scipy as sp
import timeit
import psutil

from functools import reduce
from more_itertools import random_product 
from sklearn.preprocessing import Normalizer

import sys
import os
import shutil

import logging

from prettytable import PrettyTable
import colored
import math

import time
from datetime import datetime
from collections.abc import Iterable


from joblib import Parallel, delayed

from scipy.integrate import quad

from sklearn.model_selection import cross_val_score, train_test_split, StratifiedKFold, KFold
from sklearn.metrics import accuracy_score, log_loss, roc_auc_score, f1_score, mean_absolute_error, r2_score
from similaritymeasures import frechet_dist, area_between_two_curves, dtw
import keras
from keras.models import Sequential
from keras.layers.core import Dense, Dropout
from keras.utils import plot_model
from IPython.display import Image

import keras.backend as K
from livelossplot import PlotLossesKerasTF
from keras_tqdm import TQDMNotebookCallback
from keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau

from matplotlib import pyplot as plt
import seaborn as sns


import tensorflow as tf
import random 


import warnings


from IPython.display import display, Math, Latex, clear_output


In [4]:
sys.path.append('..')

from utilities.InterpretationNet import *
from utilities.LambdaNet import *
from utilities.metrics import *
from utilities.utility_functions import *

In [5]:
#######################################################################################################################################
################################################### VARIABLE ADJUSTMENTS ##############################################################
#######################################################################################################################################
variables = 'abcdefghijklmnopqrstuvwxyz'[:n]

n_jobs = min((epochs_lambda//each_epochs_save_lambda+1, n_jobs)) if multi_epoch_analysis else min(len(samples_list), n_jobs) if samples_list!=None else 1

multi_epoch_analysis = False if early_stopping_lambda else multi_epoch_analysis #deactivate multi_epoch_analysis if early stopping is used

each_epochs_save_lambda = each_epochs_save_lambda if multi_epoch_analysis else epochs_lambda
epochs_save_range_lambda = range(epoch_start//each_epochs_save_lambda, epochs_lambda//each_epochs_save_lambda) if each_epochs_save_lambda == 1 else range(epoch_start//each_epochs_save_lambda, epochs_lambda//each_epochs_save_lambda+1) if multi_epoch_analysis else range(1,2)

data_reshape_version = 2 if data_reshape_version == None and (convolution_layers != None or lstm_layers != None or (nas and nas_type != 'SEQUENTIAL')) else data_reshape_version
#######################################################################################################################################
###################################################### SET VARIABLES + DESIGN #########################################################
#######################################################################################################################################

#os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_numbers if use_gpu else ''

logging.getLogger('tensorflow').disabled = True

sns.set_style("darkgrid")
#np.set_printoptions(suppress=True)

random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
if int(tf.__version__[0]) >= 2:
    tf.random.set_seed(RANDOM_SEED)
else:
    tf.set_random_seed(RANDOM_SEED)
    
    
pd.set_option('display.float_format', lambda x: '%.3f' % x)
pd.set_option('display.max_columns', 500)

warnings.filterwarnings('ignore')

In [6]:
#######################################################################################################################################
####################################################### CONFIG ADJUSTMENTS ############################################################
#######################################################################################################################################
config['evaluation']['multi_epoch_analysis'] = multi_epoch_analysis
config['evaluation']['each_epochs_save_lambda'] = each_epochs_save_lambda
config['i_net']['data_reshape_version'] = data_reshape_version

config['data']['sparsity'] = nCr(config['data']['n']+config['data']['d'], config['data']['d'])
config['data']['sample_sparsity'] = config['data']['sparsity'] if config['data']['sample_sparsity'] == None else config['data']['sample_sparsity']

config['i_net']['interpretation_net_output_shape'] = config['data']['sparsity'] if config['i_net']['interpretation_net_output_monomials'] is None else config['data']['sparsity']*config['i_net']['interpretation_net_output_monomials']+config['i_net']['interpretation_net_output_monomials']


transformed_layers = []
for layer in config['lambda_net']['lambda_network_layers']:
    if type(layer) == str:
        transformed_layers.append(layer.count('sparsity')*config['data']['sparsity'])
    else:
        transformed_layers.append(layer)
config['lambda_net']['lambda_network_layers'] = transformed_layers

layers_with_input_output = list(flatten([[config['data']['n']], config['lambda_net']['lambda_network_layers'], [1]]))
number_of_lambda_weights = 0
for i in range(len(layers_with_input_output)-1):
    number_of_lambda_weights += (layers_with_input_output[i]+1)*layers_with_input_output[i+1]  
config['lambda_net']['number_of_lambda_weights'] = number_of_lambda_weights
    
#######################################################################################################################################
################################################## UPDATE VARIABLES ###################################################################
#######################################################################################################################################
globals().update(config['data'])
globals().update(config['lambda_net'])
globals().update(config['i_net'])
globals().update(config['evaluation'])
globals().update(config['computation'])



initialize_LambdaNet_config_from_curent_notebook(config)
initialize_metrics_config_from_curent_notebook(config)
initialize_utility_functions_config_from_curent_notebook(config)
initialize_InterpretationNet_config_from_curent_notebook(config)


#######################################################################################################################################
###################################################### PATH + FOLDER CREATION #########################################################
#######################################################################################################################################
globals().update(generate_paths(path_type='interpretation_net'))
create_folders_inet()

#######################################################################################################################################
############################################################ SLEEP TIMER ##############################################################
#######################################################################################################################################
sleep_minutes(sleep_time)

In [7]:
print(path_identifier_interpretation_net_data)

print(path_identifier_lambda_net_data)


inet_dense512-1024-output_21_drop0.25e500b512_custom/lnets_1000_30-1000e_ES0.0001_64b_adam_mae_train_5000_diffX_1-FixSeed_42/var_1_d_5_sparsity_6_amin_-1_amax_1_xdist_uniform_noise_normal_0.1
lnets_10000_30-1000e_ES0.0001_64b_adam_mae_train_5000_diffX_1-FixSeed_42/var_1_d_5_sparsity_6_amin_-1_amax_1_xmin_0_xmax_1_xdist_uniform_noise_normal_0.1


In [8]:
print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU')))
print("Num XLA-GPUs Available: ", len(tf.config.experimental.list_physical_devices('XLA_GPU')))

Num GPUs Available:  0
Num XLA-GPUs Available:  0


## Utility functions

### Generate List of Monomial Identifiers

In [9]:
list_of_monomial_identifiers_extended = []
for i in tqdm(range((d+1)**n)):    
    monomial_identifier = dec_to_base(i, base = (d+1)).zfill(n) 
    list_of_monomial_identifiers_extended.append(monomial_identifier)

print('List length: ' + str(len(list_of_monomial_identifiers_extended)))
print('Number of monomials in a polynomial with ' + str(n) + ' variables and degree ' + str(d) + ': ' + str(nCr(n+d, d)))
print('Sparsity: ' + str(sparsity))
print(list_of_monomial_identifiers_extended)

list_of_monomial_identifiers = []
for monomial_identifier in tqdm(list_of_monomial_identifiers_extended):
    monomial_identifier_values = list(map(int, list(monomial_identifier)))
    if sum(monomial_identifier_values) <= d:
        list_of_monomial_identifiers.append(monomial_identifier)

print('List length: ' + str(len(list_of_monomial_identifiers)))
print('Number of monomials in a polynomial with ' + str(n) + ' variables and degree ' + str(d) + ': ' + str(nCr(n+d, d)))
print('Sparsity: ' + str(sparsity))
print(list_of_monomial_identifiers)


layers_with_input_output = list(flatten([[n], lambda_network_layers, [1]]))
number_of_lambda_weights = 0
for i in range(len(layers_with_input_output)-1):
    number_of_lambda_weights += (layers_with_input_output[i]+1)*layers_with_input_output[i+1]

  0%|          | 0/6 [00:00<?, ?it/s]

List length: 6
Number of monomials in a polynomial with 1 variables and degree 5: 6
Sparsity: 6
['0', '1', '2', '3', '4', '5']


  0%|          | 0/6 [00:00<?, ?it/s]

List length: 6
Number of monomials in a polynomial with 1 variables and degree 5: 6
Sparsity: 6
['0', '1', '2', '3', '4', '5']


## Load Data and Generate Datasets

In [10]:
def load_lambda_nets(index):
    
    if psutil.virtual_memory().percent > 80:
        raise SystemExit("Out of RAM!")
    
    directory = './data/weights/' + 'weights_' + path_identifier_lambda_net_data + '/'
    path_weights = directory + 'weights_epoch_' + str(index).zfill(3) + '.txt'
    path_X_data = directory + 'lambda_X_test_data.txt'
    path_y_data = directory + 'lambda_y_test_data.txt'        
    
    weight_data = pd.read_csv(path_weights, sep=",", header=None)
    weight_data = weight_data.sort_values(by=0).sample(frac=1, random_state=RANDOM_SEED)
    weight_data = weight_data.sort_values(by=0).sample(n=interpretation_dataset_size, random_state=RANDOM_SEED)
    
    lambda_X_test_data = pd.read_csv(path_X_data, sep=",", header=None)
    lambda_X_test_data = lambda_X_test_data.sort_values(by=0).sample(frac=1, random_state=RANDOM_SEED)
    lambda_X_test_data = lambda_X_test_data.sort_values(by=0).sample(n=interpretation_dataset_size, random_state=RANDOM_SEED)
    
    lambda_y_test_data = pd.read_csv(path_y_data, sep=",", header=None)
    lambda_y_test_data = lambda_y_test_data.sort_values(by=0).sample(frac=1, random_state=RANDOM_SEED)
    lambda_y_test_data = lambda_y_test_data.sort_values(by=0).sample(n=interpretation_dataset_size, random_state=RANDOM_SEED)
        
    lambda_nets = [None] * weight_data.shape[0]
    for i, (row_weights, row_lambda_X_test_data, row_lambda_y_test_data) in enumerate(zip(weight_data.values, lambda_X_test_data.values, lambda_y_test_data.values)):        
        lambda_net = LambdaNet(row_weights, row_lambda_X_test_data, row_lambda_y_test_data)
        lambda_nets[i] = lambda_net
                
    lambda_net_dataset = LambdaNetDataset(lambda_nets)
        
    return lambda_net_dataset
    

In [11]:
#LOAD DATA

parallel = Parallel(n_jobs=n_jobs, verbose=3, backend='multiprocessing')
lambda_net_dataset_list = parallel(delayed(load_lambda_nets)((i+1)*each_epochs_save_lambda if each_epochs_save_lambda==1 else i*each_epochs_save_lambda if i > 1 else each_epochs_save_lambda if i==1 else 1) for i in epochs_save_range_lambda)  
del parallel

lambda_net_dataset = lambda_net_dataset_list[-1]


[Parallel(n_jobs=11)]: Using backend MultiprocessingBackend with 11 concurrent workers.
[Parallel(n_jobs=11)]: Done   1 out of   1 | elapsed:    7.9s finished


## Data Inspection

In [12]:
lambda_net_dataset.as_pandas().head()

Unnamed: 0,seed,0-target,1-target,2-target,3-target,4-target,5-target,0-lstsq_lambda,1-lstsq_lambda,2-lstsq_lambda,3-lstsq_lambda,4-lstsq_lambda,5-lstsq_lambda,0-lstsq_target,1-lstsq_target,2-lstsq_target,3-lstsq_target,4-lstsq_target,5-lstsq_target,wb_0,wb_1,wb_2,wb_3,wb_4,wb_5,wb_6,wb_7,wb_8,wb_9,wb_10,wb_11,wb_12,wb_13,wb_14,wb_15,wb_16,wb_17,wb_18,wb_19,wb_20,wb_21,wb_22,wb_23,wb_24,wb_25,wb_26,wb_27,wb_28,wb_29,wb_30,wb_31,wb_32,wb_33,wb_34,wb_35,wb_36,wb_37,wb_38,wb_39,wb_40,wb_41,wb_42,wb_43,wb_44,wb_45,wb_46,wb_47,wb_48,wb_49,wb_50,wb_51,wb_52,wb_53,wb_54,wb_55,wb_56,wb_57,wb_58,wb_59,wb_60,wb_61,wb_62,wb_63,wb_64,wb_65,wb_66,wb_67,wb_68,wb_69,wb_70,wb_71,wb_72,wb_73,wb_74,wb_75,wb_76,wb_77,wb_78,wb_79,wb_80,wb_81,wb_82,wb_83,wb_84,wb_85,wb_86,wb_87,wb_88,wb_89,wb_90
6252,1373158606,0.924,0.492,-0.142,-0.402,-0.741,0.255,0.941,0.421,-1.533,6.634,-11.5,5.498,0.911,0.605,-0.683,0.915,-2.125,0.759,-0.009,-0.227,0.414,0.34,0.169,0.103,0.419,-0.297,0.334,0.062,0.544,-0.088,-0.019,0.451,0.09,0.523,-0.042,0.202,0.265,0.229,-0.054,0.043,0.008,-0.042,-0.404,-0.09,0.333,-0.291,-0.209,-0.233,0.0,0.0,-0.212,-0.233,0.234,0.234,-0.286,0.0,-0.234,0.296,-0.243,0.0,0.0,-0.255,0.244,-0.205,0.232,0.254,0.227,0.242,0.288,-0.046,-0.023,0.0,0.0,0.0,-0.233,0.0,0.0,0.0,-0.254,-0.359,-0.566,-0.69,0.341,0.4,-0.479,-0.408,-0.789,0.12,-0.376,-0.352,-0.364,-0.443,0.341,-0.452,0.687,0.215,0.348,0.244,0.337,-0.311,-0.044,-0.261,-0.383,-0.059,-0.969,0.144,-0.258,-0.315,0.202
4684,1373158606,-0.794,0.241,-0.127,-0.188,-0.535,-0.939,-0.768,1.42,-14.454,48.213,-63.525,27.058,-0.794,0.232,0.112,-0.997,0.376,-1.277,-0.009,-0.227,0.231,0.122,0.313,0.08,0.185,-0.297,0.089,0.253,0.358,-0.088,-0.019,0.234,0.078,0.342,0.008,1.474,0.405,0.393,0.033,0.035,-0.043,-0.042,-0.404,-0.09,0.048,-0.291,-0.209,-0.233,0.0,0.0,0.172,0.175,0.001,-0.081,0.178,0.0,0.175,0.19,0.173,0.0,0.0,0.175,-0.079,0.172,-0.023,-0.949,0.001,0.0,-0.04,0.176,0.2,0.0,0.0,0.0,0.175,0.0,0.0,0.0,-0.254,-0.359,-0.339,-0.36,0.278,0.187,-0.191,-0.408,-0.396,-0.074,-0.188,-0.352,-0.364,-0.199,0.121,-0.265,0.393,-2.337,0.33,0.208,0.049,-0.458,-0.229,-0.261,-0.383,-0.059,-0.479,0.144,-0.258,-0.315,-0.168
1731,1373158606,-0.739,0.951,0.737,0.786,0.773,0.083,-0.721,0.346,4.731,-9.701,12.714,-4.814,-0.744,0.972,0.672,0.758,0.98,-0.046,-0.009,-0.227,-0.173,0.082,0.724,0.671,0.528,-0.297,0.067,0.916,0.416,-0.088,-0.019,0.349,0.689,-0.003,0.528,0.849,0.892,0.883,0.649,0.039,0.007,-0.042,-0.404,-0.09,0.046,-0.291,-0.209,-0.233,0.0,0.0,0.437,-0.092,-0.412,-0.438,-0.438,0.0,-0.078,-0.404,0.001,0.0,0.0,0.001,-0.463,0.398,-0.438,-0.28,-0.0,-0.14,-0.525,-0.054,-0.024,0.0,0.0,0.0,-0.06,0.0,0.0,0.0,-0.254,-0.359,-0.466,-0.215,0.917,1.043,0.791,-0.408,-0.251,0.617,0.148,-0.352,-0.364,0.19,1.01,-0.24,1.515,0.684,0.819,0.7,1.093,-0.306,-0.042,-0.261,-0.383,-0.059,-0.331,0.144,-0.258,-0.315,-0.437
4742,1373158606,0.173,-0.607,-0.085,0.746,-0.828,-0.939,0.161,-0.337,-2.491,8.142,-10.118,3.139,0.181,-0.787,0.885,-1.338,1.146,-1.629,-0.009,-0.227,0.511,0.469,0.05,-0.029,0.515,-0.297,0.433,0.595,0.678,-0.088,-0.019,0.54,-0.019,0.653,0.006,0.278,0.134,0.182,0.03,0.396,0.433,-0.042,-0.404,-0.09,0.397,-0.291,-0.209,-0.233,0.0,0.0,-0.264,-0.331,0.09,0.099,-0.363,0.0,-0.339,-0.404,-0.007,0.0,0.0,-0.279,0.089,-0.006,-0.025,-0.002,0.094,0.015,-0.044,-0.335,-0.366,0.0,0.0,0.0,-0.337,0.0,0.0,0.0,-0.254,-0.359,-0.691,-0.997,0.091,0.166,-0.746,-0.408,-1.085,-0.558,-0.418,-0.352,-0.364,-0.58,0.091,-0.495,0.391,-0.077,0.106,-0.017,0.043,-1.272,-1.217,-0.261,-0.383,-0.059,-1.255,0.144,-0.258,-0.315,0.121
4521,1373158606,0.131,0.208,0.871,0.323,0.573,0.164,0.124,0.308,0.697,-0.264,2.311,-0.934,0.132,0.241,0.293,2.685,-2.779,1.703,-0.009,-0.227,0.128,0.083,0.634,0.562,0.108,-0.297,0.067,0.951,0.4,-0.088,-0.019,0.126,0.521,0.243,0.398,0.706,0.723,0.705,0.612,0.041,0.008,-0.042,-0.404,-0.09,0.047,-0.291,-0.209,-0.233,0.0,0.0,-0.128,-0.084,-0.13,-0.215,-0.108,0.0,-0.07,-0.709,-0.337,0.0,0.0,-0.126,-0.436,-0.173,-0.186,-0.442,0.066,-0.33,-0.462,-0.047,-0.023,0.0,0.0,0.0,-0.053,0.0,0.0,0.0,-0.254,-0.359,-0.193,-0.225,0.614,0.717,-0.034,-0.408,-0.258,0.87,0.361,-0.352,-0.364,-0.008,0.835,0.164,0.961,0.674,0.627,0.624,0.986,-0.309,-0.044,-0.261,-0.383,-0.059,-0.334,0.144,-0.258,-0.315,0.078


In [13]:
lambda_net_dataset.as_pandas().describe()

Unnamed: 0,seed,0-target,1-target,2-target,3-target,4-target,5-target,0-lstsq_lambda,1-lstsq_lambda,2-lstsq_lambda,3-lstsq_lambda,4-lstsq_lambda,5-lstsq_lambda,0-lstsq_target,1-lstsq_target,2-lstsq_target,3-lstsq_target,4-lstsq_target,5-lstsq_target,wb_0,wb_1,wb_2,wb_3,wb_4,wb_5,wb_6,wb_7,wb_8,wb_9,wb_10,wb_11,wb_12,wb_13,wb_14,wb_15,wb_16,wb_17,wb_18,wb_19,wb_20,wb_21,wb_22,wb_23,wb_24,wb_25,wb_26,wb_27,wb_28,wb_29,wb_30,wb_31,wb_32,wb_33,wb_34,wb_35,wb_36,wb_37,wb_38,wb_39,wb_40,wb_41,wb_42,wb_43,wb_44,wb_45,wb_46,wb_47,wb_48,wb_49,wb_50,wb_51,wb_52,wb_53,wb_54,wb_55,wb_56,wb_57,wb_58,wb_59,wb_60,wb_61,wb_62,wb_63,wb_64,wb_65,wb_66,wb_67,wb_68,wb_69,wb_70,wb_71,wb_72,wb_73,wb_74,wb_75,wb_76,wb_77,wb_78,wb_79,wb_80,wb_81,wb_82,wb_83,wb_84,wb_85,wb_86,wb_87,wb_88,wb_89,wb_90
count,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0,1000.0
mean,1373158606.0,0.002,-0.013,0.004,-0.006,0.004,-0.015,0.006,-0.031,0.087,-0.396,0.611,-0.292,0.002,-0.018,0.028,-0.055,0.046,-0.028,-0.009,-0.227,0.289,0.196,0.282,0.221,0.278,-0.297,0.166,0.449,0.476,-0.088,-0.019,0.324,0.221,0.405,0.08,0.415,0.371,0.393,0.147,0.115,0.061,-0.042,-0.404,-0.09,0.127,-0.291,-0.209,-0.233,0.0,0.0,-0.017,-0.028,-0.011,-0.017,-0.054,0.0,-0.029,-0.077,-0.045,0.0,0.0,-0.049,-0.025,-0.005,0.011,-0.069,0.018,-0.033,-0.018,-0.017,0.003,0.0,0.0,0.0,-0.021,0.0,0.0,0.0,-0.254,-0.359,-0.373,-0.425,0.354,0.418,-0.216,-0.408,-0.469,0.037,-0.115,-0.352,-0.364,-0.197,0.363,-0.269,0.627,0.117,0.359,0.177,0.327,-0.522,-0.251,-0.261,-0.383,-0.059,-0.549,0.144,-0.258,-0.315,0.004
std,0.0,0.569,0.591,0.58,0.574,0.563,0.585,0.536,0.653,3.592,10.852,14.163,5.881,0.569,0.633,1.564,3.645,3.985,1.704,0.0,0.0,0.142,0.142,0.162,0.162,0.16,0.0,0.141,0.263,0.204,0.0,0.0,0.161,0.167,0.144,0.152,0.246,0.166,0.193,0.205,0.141,0.178,0.0,0.0,0.0,0.14,0.0,0.0,0.0,0.0,0.0,0.146,0.157,0.171,0.18,0.186,0.0,0.158,0.287,0.212,0.0,0.0,0.182,0.194,0.143,0.157,0.249,0.135,0.214,0.214,0.152,0.174,0.0,0.0,0.0,0.153,0.0,0.0,0.0,0.0,0.0,0.174,0.231,0.218,0.254,0.323,0.0,0.258,0.427,0.438,0.0,0.0,0.33,0.27,0.218,0.315,0.503,0.179,0.437,0.35,0.29,0.333,0.0,0.0,0.0,0.284,0.0,0.0,0.0,0.138
min,1373158606.0,-1.0,-1.0,-0.998,-0.997,-1.0,-0.998,-1.038,-1.997,-14.454,-73.162,-64.644,-41.14,-1.002,-1.431,-4.763,-10.661,-11.225,-4.626,-0.009,-0.227,-0.173,-0.289,-0.134,-0.199,-0.158,-0.297,-0.254,-0.132,0.099,-0.088,-0.019,-0.12,-0.209,-0.003,-0.28,-0.098,-0.022,-0.063,-0.317,-0.286,-0.411,-0.042,-0.404,-0.09,-0.273,-0.291,-0.209,-0.233,0.0,0.0,-0.401,-0.467,-0.545,-0.601,-0.705,0.0,-0.491,-1.001,-1.18,0.0,0.0,-0.88,-0.587,-0.818,-0.551,-0.959,-0.329,-0.907,-0.704,-0.486,-0.761,0.0,0.0,0.0,-0.486,0.0,0.0,0.0,-0.254,-0.359,-1.131,-1.352,0.003,0.092,-1.185,-0.408,-1.452,-1.35,-0.915,-0.352,-0.364,-1.058,0.025,-0.982,0.329,-2.407,-0.001,-2.723,-0.308,-1.758,-2.052,-0.261,-0.383,-0.059,-1.72,0.144,-0.258,-0.315,-0.476
25%,1373158606.0,-0.474,-0.557,-0.5,-0.501,-0.469,-0.527,-0.423,-0.487,-1.948,-4.897,-6.805,-2.656,-0.481,-0.551,-1.06,-2.48,-2.638,-1.197,-0.009,-0.227,0.191,0.083,0.174,0.087,0.168,-0.297,0.067,0.284,0.36,-0.088,-0.019,0.225,0.084,0.317,0.008,0.258,0.269,0.264,0.033,0.041,0.008,-0.042,-0.404,-0.09,0.047,-0.291,-0.209,-0.233,0.0,0.0,-0.13,-0.093,-0.108,-0.082,-0.185,0.0,-0.079,-0.325,-0.114,0.0,0.0,-0.167,-0.083,-0.077,-0.023,-0.269,-0.036,-0.138,-0.041,-0.049,-0.023,0.0,0.0,0.0,-0.056,0.0,0.0,0.0,-0.254,-0.359,-0.417,-0.441,0.209,0.252,-0.278,-0.408,-0.496,-0.113,-0.264,-0.352,-0.364,-0.281,0.187,-0.346,0.394,0.079,0.246,0.124,0.049,-0.543,-0.275,-0.261,-0.383,-0.059,-0.573,0.144,-0.258,-0.315,-0.118
50%,1373158606.0,0.021,-0.008,0.009,-0.03,-0.003,-0.038,0.01,-0.019,0.002,-0.004,0.005,-0.001,0.024,0.002,0.045,0.014,0.018,-0.081,-0.009,-0.227,0.294,0.191,0.261,0.202,0.263,-0.297,0.155,0.375,0.433,-0.088,-0.019,0.308,0.199,0.4,0.013,0.353,0.347,0.351,0.079,0.071,0.008,-0.042,-0.404,-0.09,0.094,-0.291,-0.209,-0.233,0.0,0.0,-0.003,-0.008,-0.001,-0.001,-0.008,0.0,-0.016,0.011,-0.004,0.0,0.0,-0.006,-0.001,-0.002,-0.018,-0.004,-0.0,-0.002,-0.008,-0.046,-0.022,0.0,0.0,0.0,-0.051,0.0,0.0,0.0,-0.254,-0.359,-0.339,-0.356,0.311,0.354,-0.189,-0.408,-0.386,-0.023,-0.193,-0.352,-0.364,-0.199,0.291,-0.271,0.544,0.182,0.329,0.218,0.216,-0.425,-0.122,-0.261,-0.383,-0.059,-0.456,0.144,-0.258,-0.315,0.011
75%,1373158606.0,0.488,0.508,0.481,0.5,0.457,0.495,0.435,0.392,1.864,5.383,6.678,2.523,0.484,0.505,1.099,2.27,2.749,1.068,-0.009,-0.227,0.364,0.272,0.371,0.312,0.356,-0.297,0.245,0.563,0.525,-0.088,-0.019,0.39,0.314,0.473,0.145,0.501,0.454,0.477,0.246,0.179,0.084,-0.042,-0.404,-0.09,0.2,-0.291,-0.209,-0.233,0.0,0.0,0.11,0.109,0.128,0.131,0.105,0.0,0.109,0.152,0.102,0.0,0.0,0.104,0.132,0.108,0.135,0.127,0.128,0.129,0.153,0.111,0.123,0.0,0.0,0.0,0.111,0.0,0.0,0.0,-0.254,-0.359,-0.26,-0.26,0.381,0.431,-0.098,-0.408,-0.27,0.112,-0.102,-0.352,-0.364,-0.108,0.373,-0.186,0.675,0.263,0.41,0.296,0.354,-0.31,-0.044,-0.261,-0.383,-0.059,-0.335,0.144,-0.258,-0.315,0.128
max,1373158606.0,1.0,0.992,0.998,0.999,0.999,0.999,0.99,1.983,21.905,48.213,96.505,28.522,1.005,1.415,4.774,10.3,10.71,4.676,-0.009,-0.227,0.761,0.66,0.847,0.784,1.057,-0.297,0.621,1.611,1.834,-0.088,-0.019,1.442,0.833,1.07,0.649,1.544,1.017,1.133,0.848,0.568,0.931,-0.042,-0.404,-0.09,0.568,-0.291,-0.209,-0.233,0.0,0.0,0.437,0.332,0.241,0.242,0.245,0.0,0.331,0.306,0.221,0.0,0.0,0.235,0.252,0.398,0.283,0.26,0.233,0.249,0.32,0.287,0.411,0.0,0.0,0.0,0.275,0.0,0.0,0.0,-0.254,-0.359,-0.057,-0.145,1.217,1.479,2.119,-0.408,-0.18,2.272,2.854,-0.352,-0.364,2.402,1.467,2.148,1.856,1.101,1.096,1.035,1.72,-0.254,0.029,-0.261,-0.383,-0.059,-0.247,0.144,-0.258,-0.315,0.326


In [14]:
lambda_net_dataset.X_test_data_list[0][:10]

array([[0.04991591],
       [0.55931928],
       [0.41215429],
       [0.47204459],
       [0.50702493],
       [0.36030918],
       [0.48300684],
       [0.91632547],
       [0.27265516],
       [0.88145597]])

In [15]:
lambda_net_dataset.y_test_data_list[0][:10]

array([[1.0879322 ],
       [0.92128287],
       [0.94887895],
       [1.06356541],
       [1.04696906],
       [0.91303792],
       [0.84312583],
       [0.62768837],
       [1.02064446],
       [0.61265825]])

## Generate Datasets for Interpretation-Net training

In [16]:
#generate train, test and validation data for training

lambda_net_train_dataset_list = []
lambda_net_valid_dataset_list = []
lambda_net_test_dataset_list = []

for lambda_net_dataset in lambda_net_dataset_list:
    
    
    if inet_holdout_seed_evaluation:
        complete_seed_list = list(set(lambda_net_dataset.train_settings_list['seed']))#list(weight_data.iloc[:,1].unique())

        random.seed(RANDOM_SEED)
        test_seeds = random.sample(complete_seed_list, int(len(complete_seed_list)-len(complete_seed_list)/(1/0.9)))
        lambda_net_test_dataset = lambda_net_dataset.get_lambda_nets_by_seed(test_seeds)
        complete_seed_list = list(set(complete_seed_list) - set(test_seeds))#complete_seed_list.remove(test_seeds)
        
        random.seed(RANDOM_SEED)
        valid_seeds = random.sample(complete_seed_list, int(len(complete_seed_list)-len(complete_seed_list)/(1/0.9)))
        lambda_net_valid_dataset = lambda_net_dataset.get_lambda_nets_by_seed(valid_seeds)
        complete_seed_list = list(set(complete_seed_list) - set(valid_seeds))

        train_seeds = complete_seed_list
        lambda_net_train_dataset = lambda_net_dataset.get_lambda_nets_by_seed(train_seeds)       
        
        lambda_net_train_dataset_list.append(lambda_net_train_dataset)
        lambda_net_valid_dataset_list.append(lambda_net_valid_dataset)
        lambda_net_test_dataset_list.append(lambda_net_test_dataset)
        
        del lambda_net_dataset
    else:
        lambda_net_train_with_valid_dataset, lambda_net_test_dataset = split_LambdaNetDataset(lambda_net_dataset, test_split=0.1)
        lambda_net_train_dataset, lambda_net_valid_dataset = split_LambdaNetDataset(lambda_net_train_with_valid_dataset, test_split=0.1)

        lambda_net_train_dataset_list.append(lambda_net_train_dataset)
        lambda_net_valid_dataset_list.append(lambda_net_valid_dataset)
        lambda_net_test_dataset_list.append(lambda_net_test_dataset)
    
        del lambda_net_dataset, lambda_net_train_with_valid_dataset

        
del lambda_net_dataset_list

In [17]:
lambda_net_train_dataset_list[-1].as_pandas().shape

(810, 110)

In [18]:
lambda_net_valid_dataset_list[-1].as_pandas().shape

(90, 110)

In [19]:
lambda_net_test_dataset_list[-1].as_pandas().shape

(100, 110)

In [20]:
lambda_net_train_dataset_list[-1].as_pandas().head()

Unnamed: 0,seed,0-target,1-target,2-target,3-target,4-target,5-target,0-lstsq_lambda,1-lstsq_lambda,2-lstsq_lambda,3-lstsq_lambda,4-lstsq_lambda,5-lstsq_lambda,0-lstsq_target,1-lstsq_target,2-lstsq_target,3-lstsq_target,4-lstsq_target,5-lstsq_target,wb_0,wb_1,wb_2,wb_3,wb_4,wb_5,wb_6,wb_7,wb_8,wb_9,wb_10,wb_11,wb_12,wb_13,wb_14,wb_15,wb_16,wb_17,wb_18,wb_19,wb_20,wb_21,wb_22,wb_23,wb_24,wb_25,wb_26,wb_27,wb_28,wb_29,wb_30,wb_31,wb_32,wb_33,wb_34,wb_35,wb_36,wb_37,wb_38,wb_39,wb_40,wb_41,wb_42,wb_43,wb_44,wb_45,wb_46,wb_47,wb_48,wb_49,wb_50,wb_51,wb_52,wb_53,wb_54,wb_55,wb_56,wb_57,wb_58,wb_59,wb_60,wb_61,wb_62,wb_63,wb_64,wb_65,wb_66,wb_67,wb_68,wb_69,wb_70,wb_71,wb_72,wb_73,wb_74,wb_75,wb_76,wb_77,wb_78,wb_79,wb_80,wb_81,wb_82,wb_83,wb_84,wb_85,wb_86,wb_87,wb_88,wb_89,wb_90
3105,1373158606,0.83,-0.392,0.484,0.428,-0.411,0.102,0.724,0.24,-0.119,0.268,-0.268,0.098,0.839,-0.515,1.365,-2.339,3.144,-1.464,-0.009,-0.227,0.284,0.19,0.268,0.205,0.247,-0.297,0.068,0.251,0.399,-0.088,-0.019,0.288,0.199,0.384,0.063,0.314,0.356,0.332,0.095,0.041,0.009,-0.042,-0.404,-0.09,0.047,-0.291,-0.209,-0.233,0.0,0.0,-0.003,-0.001,0.191,0.191,-0.002,0.0,-0.069,0.204,-0.008,0.0,0.0,-0.003,0.197,-0.008,0.188,0.205,0.187,0.198,0.224,-0.047,-0.022,0.0,0.0,0.0,-0.052,0.0,0.0,0.0,-0.254,-0.359,-0.3,-0.307,0.336,0.38,-0.144,-0.408,-0.258,0.083,-0.154,-0.352,-0.364,-0.156,0.316,-0.232,0.61,0.214,0.361,0.249,0.261,-0.31,-0.045,-0.261,-0.383,-0.059,-0.335,0.144,-0.258,-0.315,0.17
483,1373158606,-0.15,-0.925,-0.744,0.531,-1.0,-0.167,-0.13,-0.901,-2.645,8.657,-12.88,5.51,-0.153,-0.868,-1.191,1.878,-2.575,0.45,-0.009,-0.227,0.509,0.404,0.108,0.08,0.586,-0.297,0.369,0.814,0.682,-0.088,-0.019,0.638,0.079,0.622,0.008,0.553,0.139,0.391,0.033,0.315,0.571,-0.042,-0.404,-0.09,0.325,-0.291,-0.209,-0.233,0.0,0.0,0.03,0.024,-0.108,-0.081,-0.372,0.0,0.03,-0.522,-0.288,0.0,0.0,-0.292,-0.079,-0.045,-0.023,-0.437,-0.139,-0.307,-0.039,0.036,-0.439,0.0,0.0,0.0,0.039,0.0,0.0,0.0,-0.254,-0.359,-0.5,-0.496,0.135,0.187,-0.646,-0.408,-0.521,-0.654,-0.499,-0.352,-0.364,-0.595,0.121,-0.442,0.393,-0.661,0.13,-0.444,0.048,-0.562,-1.013,-0.261,-0.383,-0.059,-0.589,0.144,-0.258,-0.315,-0.044
381,1373158606,0.943,-0.337,-0.036,-0.608,0.222,-0.439,0.897,0.57,-5.309,12.721,-14.935,5.854,0.946,-0.451,1.099,-4.081,4.435,-2.219,-0.009,-0.227,0.454,0.358,0.065,-0.004,0.437,-0.297,0.334,-0.017,0.605,-0.088,-0.019,0.464,-0.02,0.576,-0.153,0.096,0.167,0.127,-0.177,0.317,0.011,-0.042,-0.404,-0.09,0.31,-0.291,-0.209,-0.233,0.0,0.0,-0.174,-0.227,0.223,0.226,-0.275,0.0,-0.233,0.201,-0.106,0.0,0.0,-0.252,0.236,-0.095,0.23,0.239,0.214,0.227,0.291,-0.246,-0.021,0.0,0.0,0.0,-0.241,0.0,0.0,0.0,-0.254,-0.359,-0.561,-0.693,0.331,0.411,-0.51,-0.408,-0.772,0.049,-0.366,-0.352,-0.364,-0.46,0.357,-0.443,0.757,0.196,0.315,0.218,0.402,-0.947,-0.046,-0.261,-0.383,-0.059,-0.931,0.144,-0.258,-0.315,0.198
6765,1373158606,-0.901,-0.598,-0.416,0.189,0.62,-0.879,-0.81,-1.05,0.361,-0.788,0.772,-0.279,-0.898,-0.803,1.508,-5.479,7.188,-3.511,-0.009,-0.227,0.361,0.254,0.192,0.081,0.317,-0.297,0.221,0.387,0.485,-0.088,-0.019,0.364,0.079,0.47,0.008,0.236,0.276,0.254,0.033,0.169,0.109,-0.042,-0.404,-0.09,0.181,-0.291,-0.209,-0.233,0.0,0.0,0.165,0.166,-0.007,-0.081,0.176,0.0,0.165,0.206,0.172,0.0,0.0,0.174,-0.079,0.167,-0.023,-0.021,-0.014,-0.015,-0.039,0.165,0.196,0.0,0.0,0.0,0.164,0.0,0.0,0.0,-0.254,-0.359,-0.415,-0.421,0.16,0.187,-0.26,-0.408,-0.449,-0.148,-0.274,-0.352,-0.364,-0.274,0.121,-0.351,0.393,0.042,0.204,0.085,0.048,-0.498,-0.241,-0.261,-0.383,-0.059,-0.523,0.144,-0.258,-0.315,-0.151
3999,1373158606,-0.691,0.827,-0.143,-0.244,-0.537,0.272,-0.56,0.233,0.002,-0.005,0.006,-0.002,-0.676,0.26,3.634,-9.601,9.302,-3.448,-0.009,-0.227,0.227,0.118,0.318,0.256,0.181,-0.297,0.085,0.244,0.353,-0.088,-0.019,0.229,0.256,0.337,0.009,0.377,0.4,0.387,0.191,0.03,-0.08,-0.042,-0.404,-0.09,0.043,-0.291,-0.209,-0.233,0.0,0.0,0.134,0.138,-0.0,-0.0,0.136,0.0,0.14,0.128,0.13,0.0,0.0,0.134,-0.0,0.132,-0.022,-0.0,-0.0,-0.0,0.0,0.143,0.208,0.0,0.0,0.0,0.142,0.0,0.0,0.0,-0.254,-0.359,-0.302,-0.327,0.304,0.339,-0.152,-0.408,-0.367,-0.03,-0.151,-0.352,-0.364,-0.16,0.274,-0.229,0.394,0.184,0.338,0.222,0.208,-0.448,-0.297,-0.261,-0.383,-0.059,-0.463,0.144,-0.258,-0.315,-0.135


In [21]:
lambda_net_valid_dataset_list[-1].as_pandas().head()

Unnamed: 0,seed,0-target,1-target,2-target,3-target,4-target,5-target,0-lstsq_lambda,1-lstsq_lambda,2-lstsq_lambda,3-lstsq_lambda,4-lstsq_lambda,5-lstsq_lambda,0-lstsq_target,1-lstsq_target,2-lstsq_target,3-lstsq_target,4-lstsq_target,5-lstsq_target,wb_0,wb_1,wb_2,wb_3,wb_4,wb_5,wb_6,wb_7,wb_8,wb_9,wb_10,wb_11,wb_12,wb_13,wb_14,wb_15,wb_16,wb_17,wb_18,wb_19,wb_20,wb_21,wb_22,wb_23,wb_24,wb_25,wb_26,wb_27,wb_28,wb_29,wb_30,wb_31,wb_32,wb_33,wb_34,wb_35,wb_36,wb_37,wb_38,wb_39,wb_40,wb_41,wb_42,wb_43,wb_44,wb_45,wb_46,wb_47,wb_48,wb_49,wb_50,wb_51,wb_52,wb_53,wb_54,wb_55,wb_56,wb_57,wb_58,wb_59,wb_60,wb_61,wb_62,wb_63,wb_64,wb_65,wb_66,wb_67,wb_68,wb_69,wb_70,wb_71,wb_72,wb_73,wb_74,wb_75,wb_76,wb_77,wb_78,wb_79,wb_80,wb_81,wb_82,wb_83,wb_84,wb_85,wb_86,wb_87,wb_88,wb_89,wb_90
5459,1373158606,0.697,-0.796,0.848,0.249,0.738,0.369,0.625,-1.026,10.588,-39.312,56.247,-25.293,0.692,-0.775,0.962,-0.417,1.734,-0.1,-0.009,-0.227,0.119,0.083,0.181,0.115,0.105,-0.297,0.067,0.333,0.517,-0.088,-0.019,1.442,0.113,0.508,-0.036,0.243,0.268,0.255,0.02,0.041,0.008,-0.042,-0.404,-0.09,0.047,-0.291,-0.209,-0.233,0.0,0.0,-0.119,-0.084,0.182,0.184,-0.105,0.0,-0.07,0.05,0.001,0.0,0.0,-0.88,0.185,0.001,0.189,0.179,0.18,0.179,0.194,-0.047,-0.023,0.0,0.0,0.0,-0.052,0.0,0.0,0.0,-0.254,-0.359,-0.176,-0.225,0.279,0.328,-0.018,-0.408,-0.258,-0.025,-0.228,-0.352,-0.364,2.354,0.264,-0.324,0.579,0.149,0.299,0.185,0.209,-0.309,-0.044,-0.261,-0.383,-0.059,-0.334,0.144,-0.258,-0.315,0.183
7942,1373158606,0.404,-0.399,-0.503,0.509,0.467,-0.702,0.396,-0.056,-2.466,5.243,-5.05,1.805,0.412,-0.497,0.169,-1.526,2.993,-1.764,-0.009,-0.227,0.334,0.234,0.194,0.131,0.298,-0.297,0.203,0.384,0.461,-0.088,-0.019,0.342,0.126,0.444,-0.014,0.242,0.281,0.259,0.02,0.154,0.121,-0.042,-0.404,-0.09,0.165,-0.291,-0.209,-0.233,0.0,0.0,-0.037,-0.009,0.122,0.124,-0.014,0.0,-0.003,-0.017,-0.028,0.0,0.0,-0.024,0.125,-0.029,0.133,0.119,0.12,0.119,0.149,-0.002,-0.002,0.0,0.0,0.0,-0.002,0.0,0.0,0.0,-0.254,-0.359,-0.348,-0.346,0.242,0.29,-0.189,-0.408,-0.372,-0.076,-0.213,-0.352,-0.364,-0.206,0.226,-0.29,0.586,0.115,0.266,0.152,0.2,-0.419,-0.168,-0.261,-0.383,-0.059,-0.445,0.144,-0.258,-0.315,0.122
8142,1373158606,0.086,0.024,-0.819,-0.19,-0.591,0.067,0.087,0.313,-3.919,10.0,-13.687,5.853,0.111,-0.44,1.298,-3.998,2.36,-0.754,-0.009,-0.227,0.509,0.408,0.079,0.002,0.507,-0.297,0.374,0.617,0.656,-0.088,-0.019,0.533,0.017,0.63,0.004,0.265,0.16,0.2,0.028,0.337,0.366,-0.042,-0.404,-0.09,0.343,-0.291,-0.209,-0.233,0.0,0.0,-0.228,-0.268,0.052,0.054,-0.249,0.0,-0.273,-0.303,-0.085,0.0,0.0,-0.251,0.045,-0.08,-0.023,0.014,0.055,0.027,-0.04,-0.255,-0.276,0.0,0.0,0.0,-0.256,0.0,0.0,0.0,-0.254,-0.359,-0.612,-0.711,0.071,0.128,-0.505,-0.408,-0.769,-0.39,-0.397,-0.352,-0.364,-0.488,0.055,-0.473,0.39,-0.075,0.096,-0.025,0.041,-0.872,-0.73,-0.261,-0.383,-0.059,-0.881,0.144,-0.258,-0.315,0.074
6599,1373158606,0.957,0.365,0.795,0.692,0.164,0.787,0.88,0.031,12.196,-44.799,63.61,-28.474,0.939,0.713,-1.299,5.794,-5.286,2.919,-0.009,-0.227,0.131,0.083,0.395,0.329,0.11,-0.297,0.067,0.519,1.652,-0.088,-0.019,0.131,0.334,0.178,0.177,0.468,0.476,0.473,0.268,0.041,0.009,-0.042,-0.404,-0.09,0.047,-0.291,-0.209,-0.233,0.0,0.0,-0.131,-0.083,0.198,0.198,-0.11,0.0,-0.069,0.156,-1.012,0.0,0.0,-0.131,0.2,-0.178,0.197,0.203,0.196,0.2,0.211,-0.047,-0.022,0.0,0.0,0.0,-0.052,0.0,0.0,0.0,-0.254,-0.359,-0.201,-0.226,0.406,0.443,-0.043,-0.408,-0.258,0.16,2.296,-0.352,-0.364,-0.043,0.377,-0.095,0.638,0.284,0.438,0.322,0.295,-0.31,-0.045,-0.261,-0.383,-0.059,-0.335,0.144,-0.258,-0.315,0.19
3570,1373158606,0.581,-0.597,-0.03,0.215,-0.68,0.557,0.549,-0.204,-1.913,4.173,-4.09,1.48,0.574,-0.457,-0.615,1.096,-1.191,0.657,-0.009,-0.227,0.345,0.244,0.193,0.13,0.311,-0.297,0.213,0.411,0.469,-0.088,-0.019,0.353,0.124,0.451,-0.017,0.24,0.282,0.258,0.007,0.165,0.009,-0.042,-0.404,-0.09,0.175,-0.291,-0.209,-0.233,0.0,0.0,-0.019,-0.006,0.154,0.157,-0.008,0.0,-0.005,-0.009,-0.03,0.0,0.0,-0.014,0.161,-0.033,0.167,0.157,0.15,0.153,0.201,-0.003,-0.022,0.0,0.0,0.0,-0.004,0.0,0.0,0.0,-0.254,-0.359,-0.359,-0.362,0.273,0.328,-0.202,-0.408,-0.391,-0.086,-0.221,-0.352,-0.364,-0.217,0.266,-0.298,0.641,0.144,0.291,0.179,0.261,-0.446,-0.045,-0.261,-0.383,-0.059,-0.468,0.144,-0.258,-0.315,0.147


In [22]:
lambda_net_test_dataset_list[-1].as_pandas().head()

Unnamed: 0,seed,0-target,1-target,2-target,3-target,4-target,5-target,0-lstsq_lambda,1-lstsq_lambda,2-lstsq_lambda,3-lstsq_lambda,4-lstsq_lambda,5-lstsq_lambda,0-lstsq_target,1-lstsq_target,2-lstsq_target,3-lstsq_target,4-lstsq_target,5-lstsq_target,wb_0,wb_1,wb_2,wb_3,wb_4,wb_5,wb_6,wb_7,wb_8,wb_9,wb_10,wb_11,wb_12,wb_13,wb_14,wb_15,wb_16,wb_17,wb_18,wb_19,wb_20,wb_21,wb_22,wb_23,wb_24,wb_25,wb_26,wb_27,wb_28,wb_29,wb_30,wb_31,wb_32,wb_33,wb_34,wb_35,wb_36,wb_37,wb_38,wb_39,wb_40,wb_41,wb_42,wb_43,wb_44,wb_45,wb_46,wb_47,wb_48,wb_49,wb_50,wb_51,wb_52,wb_53,wb_54,wb_55,wb_56,wb_57,wb_58,wb_59,wb_60,wb_61,wb_62,wb_63,wb_64,wb_65,wb_66,wb_67,wb_68,wb_69,wb_70,wb_71,wb_72,wb_73,wb_74,wb_75,wb_76,wb_77,wb_78,wb_79,wb_80,wb_81,wb_82,wb_83,wb_84,wb_85,wb_86,wb_87,wb_88,wb_89,wb_90
5206,1373158606,-0.399,-0.329,-0.815,0.24,-0.306,0.258,-0.345,-1.045,2.35,-6.134,5.577,-1.735,-0.393,-0.347,-0.912,0.58,-0.667,0.388,-0.009,-0.227,0.336,0.23,0.1,0.08,0.294,-0.297,0.197,0.367,0.461,-0.088,-0.019,0.34,0.078,0.445,0.008,0.623,0.277,0.247,0.033,0.145,0.091,-0.042,-0.404,-0.09,0.157,-0.291,-0.209,-0.233,0.0,0.0,0.086,0.086,-0.1,-0.08,0.086,0.0,0.086,0.087,0.086,0.0,0.0,0.086,-0.079,0.086,-0.023,-0.279,-0.002,-0.002,-0.04,0.086,0.087,0.0,0.0,0.0,0.086,0.0,0.0,0.0,-0.254,-0.359,-0.357,-0.353,0.121,0.187,-0.197,-0.408,-0.377,-0.085,-0.221,-0.352,-0.364,-0.215,0.121,-0.298,0.394,-0.686,0.194,0.066,0.049,-0.418,-0.142,-0.261,-0.383,-0.059,-0.445,0.144,-0.258,-0.315,-0.086
2771,1373158606,-0.016,0.467,0.133,-0.457,-0.709,0.212,0.006,0.324,-0.163,1.912,-4.323,1.916,-0.005,0.303,1.063,-3.034,2.487,-1.201,-0.009,-0.227,0.362,0.277,0.253,0.191,0.351,-0.297,0.261,0.414,0.458,-0.088,-0.019,0.376,0.186,0.446,0.05,0.298,0.34,0.316,0.085,0.232,0.25,-0.042,-0.404,-0.09,0.236,-0.291,-0.209,-0.233,0.0,0.0,-0.199,-0.214,0.001,0.001,-0.203,0.0,-0.202,-0.304,-0.209,0.0,0.0,-0.207,0.001,-0.2,0.002,-0.0,0.001,0.0,-0.001,-0.184,-0.198,0.0,0.0,0.0,-0.187,0.0,0.0,0.0,-0.254,-0.359,-0.55,-0.679,0.229,0.254,-0.437,-0.408,-0.748,-0.351,-0.332,-0.352,-0.364,-0.414,0.188,-0.416,0.377,0.113,0.272,0.153,0.072,-0.876,-0.791,-0.261,-0.383,-0.059,-0.88,0.144,-0.258,-0.315,0.004
5928,1373158606,0.298,0.78,0.572,-0.944,-0.869,0.051,0.308,0.727,0.111,1.926,-5.599,2.46,0.302,0.663,1.273,-2.535,0.573,-0.376,-0.009,-0.227,0.48,0.388,0.35,0.286,0.492,-0.297,0.372,0.355,0.602,-0.088,-0.019,0.499,0.282,0.569,0.144,0.399,0.437,0.415,0.181,0.367,0.009,-0.042,-0.404,-0.09,0.361,-0.291,-0.209,-0.233,0.0,0.0,-0.246,-0.293,0.085,0.085,-0.317,0.0,-0.298,0.018,-0.277,0.0,0.0,-0.321,0.087,-0.262,0.086,0.089,0.084,0.087,0.098,-0.299,-0.022,0.0,0.0,0.0,-0.29,0.0,0.0,0.0,-0.254,-0.359,-0.746,-0.957,0.331,0.361,-0.717,-0.408,-1.113,0.076,-0.537,-0.352,-0.364,-0.664,0.295,-0.615,0.544,0.212,0.37,0.252,0.202,-1.396,-0.045,-0.261,-0.383,-0.059,-1.351,0.144,-0.258,-0.315,0.081
103,1373158606,0.612,0.497,-0.631,-0.581,-0.259,-0.031,0.652,0.132,-0.66,3.376,-7.821,3.986,0.627,0.329,0.048,-1.766,0.678,-0.316,-0.009,-0.227,0.455,0.367,0.107,0.039,0.448,-0.297,0.336,0.516,0.602,-0.088,-0.019,0.478,0.026,0.568,-0.109,0.144,0.204,0.17,-0.118,0.316,0.012,-0.042,-0.404,-0.09,0.314,-0.291,-0.209,-0.233,0.0,0.0,-0.229,-0.24,0.18,0.182,-0.293,0.0,-0.246,-0.16,-0.24,0.0,0.0,-0.279,0.188,-0.226,0.184,0.191,0.174,0.183,0.227,-0.235,-0.02,0.0,0.0,0.0,-0.233,0.0,0.0,0.0,-0.254,-0.359,-0.62,-0.716,0.295,0.365,-0.532,-0.408,-0.783,-0.204,-0.444,-0.352,-0.364,-0.504,0.306,-0.519,0.679,0.162,0.293,0.19,0.322,-0.913,-0.046,-0.261,-0.383,-0.059,-0.912,0.144,-0.258,-0.315,0.163
4367,1373158606,0.502,-0.106,0.268,-0.095,0.221,0.645,0.496,-0.917,9.58,-30.508,38.591,-15.914,0.503,0.003,-0.605,2.472,-2.97,2.052,-0.009,-0.227,0.285,0.081,0.246,0.183,1.057,-0.297,0.067,0.261,0.41,-0.088,-0.019,0.283,0.179,0.396,0.04,0.296,0.333,0.312,0.085,0.04,0.008,-0.042,-0.404,-0.09,0.046,-0.291,-0.209,-0.233,0.0,0.0,0.001,-0.083,0.143,0.143,-0.705,0.0,-0.069,0.072,0.001,0.0,0.0,0.001,0.144,0.001,0.143,0.146,0.142,0.144,0.152,-0.047,-0.023,0.0,0.0,0.0,-0.052,0.0,0.0,0.0,-0.254,-0.359,-0.275,-0.224,0.287,0.325,1.983,-0.408,-0.257,0.021,-0.147,-0.352,-0.364,-0.126,0.26,-0.226,0.533,0.165,0.317,0.202,0.186,-0.309,-0.044,-0.261,-0.383,-0.059,-0.334,0.144,-0.258,-0.315,0.139


## Interpretation Network Training

In [None]:
(history_list, 
scores_list, 

function_values_complete_list, 
function_values_valid_list, 
function_values_test_list, 

inet_preds_list, 
inet_preds_valid_list, 
inet_preds_test_list, 
 
per_network_preds_list,

distrib_dict_list,
model_list) = calculate_interpretation_net_results(lambda_net_train_dataset_list, 
                                                   lambda_net_valid_dataset_list, 
                                                   lambda_net_test_dataset_list)

Trial 1 Complete [00h 00m 18s]
val_loss: 0.1308407187461853

Best val_loss So Far: 0.1308407187461853
Total elapsed time: 00h 00m 18s
Epoch 1/500
Epoch 2/500
Epoch 3/500
Epoch 4/500
Epoch 5/500
Epoch 6/500
Epoch 7/500
Epoch 8/500
Epoch 9/500
Epoch 10/500
Epoch 11/500
Epoch 12/500
Epoch 13/500
Epoch 14/500
Epoch 15/500
Epoch 16/500
Epoch 17/500
Epoch 18/500
Epoch 19/500
Epoch 20/500
Epoch 21/500
Epoch 22/500
Epoch 23/500
Epoch 24/500
Epoch 25/500
Epoch 26/500
Epoch 27/500
Epoch 28/500
Epoch 29/500
Epoch 30/500
Epoch 31/500
Epoch 32/500
Epoch 33/500
Epoch 34/500
Epoch 35/500
Epoch 36/500
Epoch 37/500
Epoch 38/500
Epoch 39/500
Epoch 40/500
Epoch 41/500
Epoch 42/500
Epoch 43/500
Epoch 44/500
Epoch 45/500
Epoch 46/500
Epoch 47/500
Epoch 48/500
Epoch 49/500
Epoch 50/500
Epoch 51/500
Epoch 52/500
Epoch 53/500
Epoch 54/500
Epoch 55/500
Epoch 56/500
Epoch 57/500
Epoch 58/500
Epoch 59/500
Epoch 60/500
Epoch 61/500
Epoch 62/500
Epoch 63/500
Epoch 64/500
Epoch 65/500
Epoch 66/500
Epoch 67/500
Epoc

## Evaluate Interpretation Net

In [None]:
poly_optimize = tf.constant([float(i) for i in range(interpretation_net_output_shape)])

if interpretation_net_output_monomials != None:
    poly_optimize_coeffs = poly_optimize[:interpretation_net_output_monomials]

    poly_optimize_identifiers_list = []
    for i in range(interpretation_net_output_monomials):
        poly_optimize_identifiers = tf.math.softmax(poly_optimize[sparsity*i+interpretation_net_output_monomials:sparsity*(i+1)+interpretation_net_output_monomials])
        poly_optimize_identifiers_list.append(poly_optimize_identifiers)
    poly_optimize_identifiers_list = tf.keras.backend.flatten(poly_optimize_identifiers_list)
    poly_optimize = tf.concat([poly_optimize_coeffs, poly_optimize_identifiers_list], axis=0)


In [None]:
poly_optimize

In [None]:
if nas:
    for trial in history_list[-1]: 
        print(trial.summary())

In [None]:
if len(model_list) >= 1:
    print(model_list[-1].summary())
    print(model_list[-1].get_config())

In [None]:
if evaluate_with_real_function:
    keys = ['inetPoly_VS_targetPoly_test', 'perNetworkPoly_VS_targetPoly_test', 'predLambda_VS_targetPoly_test', 'lstsqLambda_VS_targetPoly_test', 'lstsqTarget_VS_targetPoly_test']
else:
    keys = ['inetPoly_VS_predLambda_test', 'inetPoly_VS_lstsqLambda_test', 'perNetworkPoly_VS_predLambda_test', 'perNetworkPoly_VS_lstsqLambda_test', 'lstsqLambda_VS_predLambda_test', 'predLambda_VS_targetPoly_test']

In [None]:
scores_list[-1]

In [None]:
distrib_dict_list[-1]['MAE']

In [None]:
distrib_dict_list[-1]['R2']

In [None]:
index_min = int(np.argmin(distrib_dict_list[-1]['R2'].loc['predLambda_VS_lstsqTarget_test']))
polynomial_inet = inet_preds_test_list[-1][index_min]

print(distrib_dict_list[-1]['R2'].loc['inetPoly_VS_targetPoly_test'][index_min])

print_polynomial_from_coefficients(polynomial_inet)

In [None]:
index_min = int(np.argmin(distrib_dict_list[-1]['R2'].loc['predLambda_VS_lstsqTarget_test']))
#polynomial_target = lambda_net_test_dataset.get_lambda_net_by_lambda_index(index_min).target_polynomial

print(distrib_dict_list[-1]['R2'].loc['lstsqLambda_VS_targetPoly_test'][index_min])

polynomial_lambda = lambda_net_test_dataset.lstsq_lambda_pred_polynomial_list[index_min]
print_polynomial_from_coefficients(polynomial_lambda, force_complete_poly_representation=True)

In [None]:
index_min = int(np.argmin(distrib_dict_list[-1]['R2'].loc['predLambda_VS_lstsqTarget_test']))
#polynomial_target = lambda_net_test_dataset.get_lambda_net_by_lambda_index(index_min).target_polynomial
polynomial_target = lambda_net_test_dataset.target_polynomial_list[index_min]
print_polynomial_from_coefficients(polynomial_target, force_complete_poly_representation=True)

In [None]:
index_max = int(np.argmax(distrib_dict_list[-1]['R2'].loc['predLambda_VS_lstsqTarget_test']))
polynomial_inet = inet_preds_test_list[-1][index_max]

print(distrib_dict_list[-1]['R2'].loc['inetPoly_VS_targetPoly_test'][index_max])

print_polynomial_from_coefficients(polynomial_inet)

In [None]:
index_max = int(np.argmax(distrib_dict_list[-1]['R2'].loc['predLambda_VS_lstsqTarget_test']))
#polynomial_target = lambda_net_test_dataset.get_lambda_net_by_lambda_index(index_min).target_polynomial

print(distrib_dict_list[-1]['R2'].loc['lstsqLambda_VS_targetPoly_test'][index_max])

polynomial_lambda = lambda_net_test_dataset.lstsq_lambda_pred_polynomial_list[index_max]
print_polynomial_from_coefficients(polynomial_lambda, force_complete_poly_representation=True)

In [None]:
index_max = int(np.argmax(distrib_dict_list[-1]['R2'].loc['predLambda_VS_lstsqTarget_test']))
#polynomial_target = lambda_net_test_dataset.get_lambda_net_by_lambda_index(index_min).target_polynomial
polynomial_target = lambda_net_test_dataset.target_polynomial_list[index_max]
print_polynomial_from_coefficients(polynomial_target, force_complete_poly_representation=True)

In [None]:
r2_values_inet = distrib_dict_list[-1]['R2'].loc['inetPoly_VS_targetPoly_test']
print('Mean: ' + str(np.mean(r2_values_inet)) + ' (' + str(r2_values_inet.shape[0]) + ' Samples)')

r2_values_positive_inet = r2_values_inet[r2_values_inet>0]
print('Mean (only positive): ' + str(np.mean(r2_values_positive_inet)) + ' (' + str(r2_values_positive_inet.shape[0]) + ' Samples)')



In [None]:
r2_values_lstsq_lambda = distrib_dict_list[-1]['R2'].loc['lstsqLambda_VS_targetPoly_test']
print('Mean: ' + str(np.mean(r2_values_lstsq_lambda)) + ' (' + str(r2_values_inet.shape[0]) + ' Samples)')

r2_values_positive_lstsq_lambda = r2_values_lstsq_lambda[r2_values_lstsq_lambda>0]
print('Mean (only positive): ' + str(np.mean(r2_values_positive_lstsq_lambda)) + ' (' + str(r2_values_positive_lstsq_lambda.shape[0]) + ' Samples)')



In [None]:
sns.histplot(distrib_dict_list[-1]['MAE'].loc['inetPoly_VS_targetPoly_test'][distrib_dict_list[-1]['R2'].loc['inetPoly_VS_targetPoly_test'] > -50])


In [None]:
sns.histplot(distrib_dict_list[-1]['MAE'].loc['lstsqLambda_VS_targetPoly_test'][distrib_dict_list[-1]['R2'].loc['lstsqLambda_VS_targetPoly_test'] > -50])


In [None]:
p = sns.histplot(distrib_dict_list[-1]['R2'].loc['inetPoly_VS_targetPoly_test'][distrib_dict_list[-1]['R2'].loc['inetPoly_VS_targetPoly_test'] > -50], binwidth=0.2)
p.set(xlim=(-30, 1))

In [None]:
p = sns.histplot(distrib_dict_list[-1]['R2'].loc['inetPoly_VS_targetPoly_test'][distrib_dict_list[-1]['R2'].loc['inetPoly_VS_targetPoly_test'] > -50], binwidth=0.1)
p.set(xlim=(0, 1))

In [None]:
p = sns.histplot(distrib_dict_list[-1]['R2'].loc['lstsqLambda_VS_targetPoly_test'][distrib_dict_list[-1]['R2'].loc['lstsqLambda_VS_targetPoly_test'] > -50], binwidth=0.2)
p.set(xlim=(-10, 1))

In [None]:
p = sns.histplot(distrib_dict_list[-1]['R2'].loc['lstsqLambda_VS_targetPoly_test'][distrib_dict_list[-1]['R2'].loc['lstsqLambda_VS_targetPoly_test'] > -50], binwidth=0.1)
p.set(xlim=(0, 1))

In [None]:
if not nas:
    history = history_list[-1]

    plt.plot(history[list(history.keys())[1]])
    if consider_labels_training or evaluate_with_real_function:
        plt.plot(history[list(history.keys())[len(history.keys())//2+1]]) 
    plt.title('model ' + list(history.keys())[1])
    plt.ylabel('metric')
    plt.xlabel('epoch')
    plt.legend(['train', 'valid'], loc='upper left')
    plt.savefig('./data/results/' + path_identifier_interpretation_net_data + '/metric_' + '_epoch_' + str(epochs_lambda).zfill(3) + '.png')


In [None]:
if not nas:
    history = history_list[-1]

    plt.plot(history['loss'])
    if consider_labels_training or evaluate_with_real_function:
        plt.plot(history['val_loss'])
    plt.title('model loss')
    plt.ylabel('loss')
    plt.xlabel('epoch')
    plt.legend(['train', 'valid'], loc='upper left')
    plt.savefig('./data/results/' + path_identifier_interpretation_net_data + '/loss_' + '_epoch_' + str(epochs_lambda).zfill(3) + '.png')    


### Multi Epoch/Sampes Analysis

### Generate Comparison Plots

In [None]:
if len(scores_list) > 1:
    plot_metric_list = ['MAE FV', 'RMSE FV', 'MAPE FV', 'R2 FV', 'RAAE FV', 'RMAE FV']

    generate_inet_comparison_plot(scores_list, plot_metric_list)

In [None]:
if len(scores_list) > 1:
    plot_metric_list = ['MAE FV']

    generate_inet_comparison_plot(scores_list, plot_metric_list)

In [None]:
if len(scores_list) > 1:
    plot_metric_list = ['R2 FV']

    generate_inet_comparison_plot(scores_list, plot_metric_list, ylim=(-5, 1))

#### Generate and Analyze Predictions for Random Function

In [None]:
index = 6

polynomial_target = lambda_net_test_dataset.target_polynomial_list[index]
polynomial_lstsq_target = lambda_net_test_dataset.lstsq_target_polynomial_list[index]
polynomial_lstsq_lambda = lambda_net_test_dataset.lstsq_lambda_pred_polynomial_list[index]
polynomial_inet = inet_preds_test_list[-1][index]
per_network_poly = per_network_preds_list[-1][index]

print('Target Poly:')
print_polynomial_from_coefficients(polynomial_target, force_complete_poly_representation=True, round_digits=4)
print('LSTSQ Target Poly:')
print_polynomial_from_coefficients(polynomial_lstsq_target, force_complete_poly_representation=True, round_digits=4)
print('LSTSQ Lambda Poly:')
print_polynomial_from_coefficients(polynomial_lstsq_lambda, force_complete_poly_representation=True, round_digits=4)
print('I-Net Poly:')
print_polynomial_from_coefficients(polynomial_inet, round_digits=4)
print('Per Network Optimization Poly:')
print_polynomial_from_coefficients(per_network_poly, round_digits=4)



In [None]:
lambda_net_test_dataset_list[0].X_test_data_list[0][:10]

In [None]:
lambda_net_test_dataset_list[0].y_test_data_list[0][:10]

In [None]:
plot_and_save_single_polynomial_prediction_evaluation(lambda_net_test_dataset_list, 
                                                      function_values_test_list, 
                                                      inet_preds_test_list,
                                                      per_network_preds_list,
                                                      rand_index=index, 
                                                      plot_type=1)

In [None]:
plot_and_save_single_polynomial_prediction_evaluation(lambda_net_test_dataset_list, 
                                                      function_values_test_list, 
                                                      inet_preds_test_list,
                                                      per_network_preds_list,
                                                      rand_index=index, 
                                                      plot_type=2)

In [None]:
plot_and_save_single_polynomial_prediction_evaluation(lambda_net_test_dataset_list, 
                                                      function_values_test_list, 
                                                      inet_preds_test_list,
                                                      per_network_preds_list,
                                                      rand_index=index, 
                                                      plot_type=3)

# BENCHMARK (RANDOM GUESS) EVALUATION

In [None]:
list_of_random_polynomials = np.random.uniform(low=-10, high=10, size=(len(lambda_net_test_dataset_list[-1]), sparsity))

In [None]:
true_fv_test = parallel_fv_calculation_from_polynomial(lambda_net_test_dataset_list[-1].target_polynomial_list, lambda_net_test_dataset_list[-1].X_test_data_list, force_complete_poly_representation=True)
random_fv_test = parallel_fv_calculation_from_polynomial(list_of_random_polynomials, lambda_net_test_dataset_list[-1].X_test_data_list, force_complete_poly_representation=True)

In [None]:
print('Random Guess Error Coefficients: ' + str(np.round(mean_absolute_error(lambda_net_test_dataset_list[-1].target_polynomial_list, list_of_random_polynomials), 4)))

In [None]:
print('Random Guess Error FVs: ' + str(np.round(mean_absolute_error_function_values(true_fv_test, random_fv_test), 4)))

# BENCHMARK (EDUCATED GUESS/MEAN PREDICTION) EVALUATION

In [None]:
true_fv_train = parallel_fv_calculation_from_polynomial(lambda_net_test_dataset_list[-1].target_polynomial_list, lambda_net_test_dataset_list[-1].X_test_data_list, force_complete_poly_representation=True)

mean_fv = np.mean(true_fv_train)
mean_fv_pred_test = [mean_fv for _ in range(true_fv_test.shape[0])]

In [None]:
print('Educated Guess/Mean Prediction Error FVs: ' + str(np.round(mean_absolute_error_function_values(true_fv_test, mean_fv_pred_test), 4)))

In [None]:
%%script false --no-raise-error

base_model = generate_base_model()
random_evaluation_dataset = np.random.uniform(low=x_min, high=x_max, size=(random_evaluation_dataset_size, n))
#random_evaluation_dataset = lambda_train_input_train_split[0]#lambda_train_input[0] #JUST [0] HERE BECAUSE EVALUATION ALWAYS ON THE SAME DATASET FOR ALL!!
list_of_monomial_identifiers_numbers = np.array([list(monomial_identifiers) for monomial_identifiers in list_of_monomial_identifiers]).astype(tf.float32)


loss_function = mean_absolute_error_tf_fv_lambda_extended_wrapper(random_evaluation_dataset, list_of_monomial_identifiers_numbers, base_model)      

X_train = X_train_list[-1].values[:,1:]
y_train = y_train_list[-1].values[:,2:]

#X_train = X_train[:,1:]
y_train_model = np.hstack((y_train, X_train))

print('seed_in_inet_training = ' + str(seed_in_inet_training), loss_function(y_train_model, y_train))


seed_in_inet_training = False

loss_function = mean_absolute_error_tf_fv_lambda_extended_wrapper(random_evaluation_dataset, list_of_monomial_identifiers_numbers, base_model)      

X_train = X_train_list[-1].values[:,1:]
y_train = y_train_list[-1].values[:,2:]

X_train = X_train[:,1:]
y_train_model = np.hstack((y_train, X_train))

print('seed_in_inet_training = ' + str(seed_in_inet_training), loss_function(y_train_model, y_train))

seed_in_inet_training = True



In [None]:
if use_gpu:
    from numba import cuda 
    device = cuda.get_current_device()
    device.reset()

In [None]:
random_index = RANDOM_SEED

per_network_hyperparams = {
    'optimizer':  'Powell',
    'max_steps': 5000,#100,
    'restarts': 3,
    'per_network_dataset_size': 500,
}

lambda_net_test_dataset = lambda_net_test_dataset_list[-1]
lambda_network_weights_list = np.array(lambda_net_test_dataset.weight_list)
lambda_network_weights = lambda_network_weights_list[random_index]

list_of_monomial_identifiers_numbers = np.array([list(monomial_identifiers) for monomial_identifiers in list_of_monomial_identifiers]).astype(float)  

printing = True

config = {
         'n': n,
         'inet_loss': inet_loss,
         'sparsity': sparsity,
         'lambda_network_layers': lambda_network_layers,
         'interpretation_net_output_shape': interpretation_net_output_shape,
         'RANDOM_SEED': RANDOM_SEED,
         'nas': nas,
         'number_of_lambda_weights': number_of_lambda_weights,
         'interpretation_net_output_monomials': interpretation_net_output_monomials,
         'x_min': x_min,
         'x_max': x_max,
         }


per_network_optimization_error, per_network_optimization_polynomial = per_network_poly_optimization_scipy(per_network_hyperparams['per_network_dataset_size'], 
                                                                                                      lambda_network_weights, 
                                                                                                      list_of_monomial_identifiers_numbers, 
                                                                                                      config,
                                                                                                      optimizer = per_network_hyperparams['optimizer'],
                                                                                                      max_steps = per_network_hyperparams['max_steps'], 
                                                                                                      restarts = per_network_hyperparams['restarts'],
                                                                                                      printing = True,
                                                                                                      return_error = True)

print('\n\nError: ' + str(per_network_optimization_error))
print_polynomial_from_coefficients(per_network_optimization_polynomial)



In [None]:
from sklearn.model_selection import ParameterGrid, ParameterSampler
pd.set_option('max_colwidth', 100)

evaluation_size = 10

per_network_hyperparams = {
    'optimizer':  ['Nelder-Mead', 'Powell', 'CG', 'BFGS', 'Newton-CG', 'L-BFGS-B', 'TNC', 'COBYLA', 'SLSQP', 'trust-constr', 'dogleg', 'trust-ncg', 'trust-exact', 'trust-krylov'],
    'max_steps': [5000],#100,
    'restarts': [3],
    'per_network_dataset_size': [500],
}

#param_iterator = ParameterSampler(per_network_hyperparams, n_iter=60, random_state=RANDOM_SEED)
param_iterator = ParameterGrid(per_network_hyperparams)


lambda_net_test_dataset = lambda_net_test_dataset_list[-1]
lambda_network_weights_list = np.array(lambda_net_test_dataset.weight_list)

list_of_monomial_identifiers_numbers = np.array([list(monomial_identifiers) for monomial_identifiers in list_of_monomial_identifiers]).astype(float)  
printing = True if n_jobs == 1 else False

config = {
         'n': n,
         'inet_loss': inet_loss,
         'sparsity': sparsity,
         'lambda_network_layers': lambda_network_layers,
         'interpretation_net_output_shape': interpretation_net_output_shape,
         'RANDOM_SEED': RANDOM_SEED,
         'nas': nas,
         'number_of_lambda_weights': number_of_lambda_weights,
         'interpretation_net_output_monomials': interpretation_net_output_monomials,
         'x_min': x_min,
         'x_max': x_max,
         }

In [None]:
params_error_list = []
for params in tqdm(param_iterator):
    parallel_per_network = Parallel(n_jobs=n_jobs, verbose=0, backend='loky')

    result_list = parallel_per_network(delayed(per_network_poly_optimization_scipy)(params['per_network_dataset_size'], 
                                                                                  lambda_network_weights, 
                                                                                  list_of_monomial_identifiers_numbers, 
                                                                                  config,
                                                                                  optimizer = params['optimizer'],
                                                                                  max_steps = params['max_steps'], 
                                                                                  restarts = params['restarts'],
                                                                                  printing = printing,
                                                                                  return_error = True) for lambda_network_weights in lambda_network_weights_list[:evaluation_size])  
    
    
    per_network_optimization_errors = [result[0] for result in result_list]
    per_network_optimization_polynomials = [result[1] for result in result_list]
        
    params_score = np.mean(per_network_optimization_errors)
    
    evaluation_result = list(params.values())
    evaluation_result.append(params_score)
    
    params_error_list.append(evaluation_result)
        
    del parallel_per_network

columns = list(params.keys())
columns.append('score')
params_error_df = pd.DataFrame(data=params_error_list, columns=columns).sort_values(by='score')
params_error_df.head(10) 

In [None]:
random_index = RANDOM_SEED

per_network_hyperparams = {
    'optimizer': tf.keras.optimizers.RMSprop,
    'lr': 0.02,
    'max_steps': 500,
    'early_stopping': 10,
    'restarts': 3,
    'per_network_dataset_size': 5000,
}

lambda_net_test_dataset = lambda_net_test_dataset_list[-1]
lambda_network_weights_list = np.array(lambda_net_test_dataset.weight_list)
lambda_network_weights = lambda_network_weights_list[random_index]

list_of_monomial_identifiers_numbers = np.array([list(monomial_identifiers) for monomial_identifiers in list_of_monomial_identifiers]).astype(float)  

printing = True

config = {
         'n': n,
         'inet_loss': inet_loss,
         'sparsity': sparsity,
         'lambda_network_layers': lambda_network_layers,
         'interpretation_net_output_shape': interpretation_net_output_shape,
         'RANDOM_SEED': RANDOM_SEED,
         'nas': nas,
         'number_of_lambda_weights': number_of_lambda_weights,
         'interpretation_net_output_monomials': interpretation_net_output_monomials,
         'x_min': x_min,
         'x_max': x_max,
         }


per_network_optimization_error, per_network_optimization_polynomial = per_network_poly_optimization_tf(per_network_hyperparams['per_network_dataset_size'], 
                                                                                                      lambda_network_weights, 
                                                                                                      list_of_monomial_identifiers_numbers, 
                                                                                                      config,
                                                                                                      optimizer = per_network_hyperparams['optimizer'],
                                                                                                      lr = per_network_hyperparams['lr'], 
                                                                                                      max_steps = per_network_hyperparams['max_steps'], 
                                                                                                      early_stopping = per_network_hyperparams['early_stopping'], 
                                                                                                      restarts = per_network_hyperparams['restarts'],
                                                                                                      printing = True,
                                                                                                      return_error = True)

print('\n\nError: ' + str(per_network_optimization_error.numpy()))
print_polynomial_from_coefficients(per_network_optimization_polynomial)



In [None]:
from sklearn.model_selection import ParameterGrid, ParameterSampler
pd.set_option('max_colwidth', 100)

evaluation_size = 100

per_network_hyperparams = {
    'optimizer': [tf.keras.optimizers.RMSprop], #[tf.keras.optimizers.SGD, tf.optimizers.Adam, tf.keras.optimizers.RMSprop, tf.keras.optimizers.Adadelta]
    'lr': [0.02], #[0.5, 0.25, 0.1, 0.05, 0.025]
    'max_steps': [500],#100,
    'early_stopping': [10],
    'restarts': [3],
    'per_network_dataset_size': [5000, 10000, 20000],
}

#param_iterator = ParameterSampler(per_network_hyperparams, n_iter=60, random_state=RANDOM_SEED)
param_iterator = ParameterGrid(per_network_hyperparams)


lambda_net_test_dataset = lambda_net_test_dataset_list[-1]
lambda_network_weights_list = np.array(lambda_net_test_dataset.weight_list)

list_of_monomial_identifiers_numbers = np.array([list(monomial_identifiers) for monomial_identifiers in list_of_monomial_identifiers]).astype(float)  
printing = True if n_jobs == 1 else False

config = {
         'n': n,
         'inet_loss': inet_loss,
         'sparsity': sparsity,
         'lambda_network_layers': lambda_network_layers,
         'interpretation_net_output_shape': interpretation_net_output_shape,
         'RANDOM_SEED': RANDOM_SEED,
         'nas': nas,
         'number_of_lambda_weights': number_of_lambda_weights,
         'interpretation_net_output_monomials': interpretation_net_output_monomials,
         'x_min': x_min,
         'x_max': x_max,
         }

In [None]:
params_error_list = []
for params in tqdm(param_iterator):
    parallel_per_network = Parallel(n_jobs=n_jobs, verbose=0, backend='loky')

    result_list = parallel_per_network(delayed(per_network_poly_optimization_tf)(params['per_network_dataset_size'], 
                                                                                  lambda_network_weights, 
                                                                                  list_of_monomial_identifiers_numbers, 
                                                                                  config,
                                                                                  optimizer = params['optimizer'],
                                                                                  lr = params['lr'], 
                                                                                  max_steps = params['max_steps'], 
                                                                                  early_stopping = params['early_stopping'], 
                                                                                  restarts = params['restarts'],
                                                                                  printing = printing,
                                                                                  return_error = True) for lambda_network_weights in lambda_network_weights_list[:evaluation_size])  
    
    
    per_network_optimization_errors = [result[0] for result in result_list]
    per_network_optimization_polynomials = [result[1] for result in result_list]
        
    params_score = np.mean(per_network_optimization_errors)
    
    evaluation_result = list(params.values())
    evaluation_result.append(params_score)
    
    params_error_list.append(evaluation_result)
        
    del parallel_per_network

columns = list(params.keys())
columns.append('score')
params_error_df = pd.DataFrame(data=params_error_list, columns=columns).sort_values(by='score')
params_error_df.head(10) 