In [1]:
import pandas as pd
import numpy as np
from scipy.stats import gaussian_kde, entropy

In [2]:
encoder_dense_layers_trial = [[10, 8], [12, 10], [14, 12], [16, 14], [18, 16], [20, 18],[22, 20]]
decoder_dense_layers_trial = [[6, 8, 10, 12], [8, 10, 12, 14], [10, 12, 14, 16], [12, 14, 16, 18], [14, 16, 18, 20], 
                              [16, 18, 20, 22], [18, 20, 22, 24]]
bottle_neck_trial = [8, 10, 12, 14, 16, 18]

In [3]:
def calculate(original_data, synthetic_data):

    num_columns = len(synthetic_data.columns)
    highlighted_areas = {}  
    kl_divergences = {}  

    for i, column in enumerate(original_data.columns):
        x = np.linspace(0, 1, 1000)  
        kde_original = gaussian_kde(original_data[column])
        kde_synthetic = gaussian_kde(synthetic_data[column])
        y1 = kde_original(x)
        y2 = kde_synthetic(x)
       
        # Area
        highlighted_area = np.sum(np.maximum(y1 - y2, 0) * np.diff(x)[0])
        highlighted_areas[column] = highlighted_area

        # KL divergence using entropy
        kl_divergence = entropy(y1, y2) 
        kl_divergences[column] = kl_divergence

        total_highlighted_area = np.sum(list(highlighted_areas.values()))
        total_kl_divergence = np.sum(list(kl_divergences.values()))

    return total_highlighted_area, total_kl_divergence / num_columns

In [5]:
result = []

for bn in bottle_neck_trial:
    for enc_layers in encoder_dense_layers_trial:
        for dec_layers in decoder_dense_layers_trial:
            model_name = f"L27_E{enc_layers[0]}_{enc_layers[1]}_B{bn}_D{dec_layers[0]}_{dec_layers[1]}_{dec_layers[2]}_{dec_layers[3]}"
            original_df = pd.read_csv(f"{model_name}_Original_minority_data.csv")
            synthetic_df = pd.read_csv(f"{model_name}_Synthetic_minority_data.csv")
            original_df.drop('class', axis=1, inplace=True)
            synthetic_df.drop('class', axis=1, inplace=True)

            total_highlighted_area, average_kl_divergence = calculate(original_df, synthetic_df)

            result.append([model_name, total_highlighted_area, average_kl_divergence])
            print(model_name, "," ,total_highlighted_area, ",", average_kl_divergence)

L27_E10_8_B8_D6_8_10_12 , 3.8272325565403387 , 1.7083816431775392
L27_E10_8_B8_D8_10_12_14 , 3.673100926856023 , 2.1358453017970827
L27_E10_8_B8_D10_12_14_16 , 3.1693662481310936 , 0.41302605281925764
L27_E10_8_B8_D12_14_16_18 , 3.537099125004338 , 0.49387790819414806
L27_E10_8_B8_D14_16_18_20 , 3.6093443422521325 , 1.047685698008539
L27_E10_8_B8_D16_18_20_22 , 3.5500563750745244 , 0.7825460042440991
L27_E10_8_B8_D18_20_22_24 , 3.5139272304935067 , 0.5767074781626624
L27_E12_10_B8_D6_8_10_12 , 4.5476381412038 , 3.137564134509851
L27_E12_10_B8_D8_10_12_14 , 3.4979666061808277 , 1.9000476136324234
L27_E12_10_B8_D10_12_14_16 , 4.246057036509662 , 1.1329527635612777
L27_E12_10_B8_D12_14_16_18 , 3.546430947661488 , 1.2679494680953938
L27_E12_10_B8_D14_16_18_20 , 2.5288565203214954 , 0.1883197125079781
L27_E12_10_B8_D16_18_20_22 , 3.0090978201343175 , 0.3942589633603109
L27_E12_10_B8_D18_20_22_24 , 3.306720832305252 , 0.5025389253423005
L27_E14_12_B8_D6_8_10_12 , 3.5862304429092524 , 1.48275

In [6]:
result

[['L27_E10_8_B8_D6_8_10_12', 3.8272325565403387, 1.7083816431775392],
 ['L27_E10_8_B8_D8_10_12_14', 3.673100926856023, 2.1358453017970827],
 ['L27_E10_8_B8_D10_12_14_16', 3.1693662481310936, 0.41302605281925764],
 ['L27_E10_8_B8_D12_14_16_18', 3.537099125004338, 0.49387790819414806],
 ['L27_E10_8_B8_D14_16_18_20', 3.6093443422521325, 1.047685698008539],
 ['L27_E10_8_B8_D16_18_20_22', 3.5500563750745244, 0.7825460042440991],
 ['L27_E10_8_B8_D18_20_22_24', 3.5139272304935067, 0.5767074781626624],
 ['L27_E12_10_B8_D6_8_10_12', 4.5476381412038, 3.137564134509851],
 ['L27_E12_10_B8_D8_10_12_14', 3.4979666061808277, 1.9000476136324234],
 ['L27_E12_10_B8_D10_12_14_16', 4.246057036509662, 1.1329527635612777],
 ['L27_E12_10_B8_D12_14_16_18', 3.546430947661488, 1.2679494680953938],
 ['L27_E12_10_B8_D14_16_18_20', 2.5288565203214954, 0.1883197125079781],
 ['L27_E12_10_B8_D16_18_20_22', 3.0090978201343175, 0.3942589633603109],
 ['L27_E12_10_B8_D18_20_22_24', 3.306720832305252, 0.5025389253423005],