# Calculate the Jensen-Shannon Divergence

In [3]:
import vae_tools
import vae_tools.sanity
import vae_tools.viz
import vae_tools.callbacks
vae_tools.sanity.check()
from tensorflow.keras.datasets import mnist
import tensorflow as tf
import numpy as np
import glob
import pickle


python version:  3.5.2
keras version: 2.2.4-tf
tensorflow version: 2.0.2
matplotlib uses:  module://ipykernel.pylab.backend_inline
No GPUs available


In [4]:
def js_calc(z_ab, z_a, z_b, z_logvar_ab, z_logvar_a, z_logvar_b, n_z, n_samples):
    z_ab = z_ab[:n_z]
    z_a = z_a[:n_z]
    z_b = z_b[:n_z]
    z_logvar_ab = z_logvar_ab[:n_z]
    z_logvar_a = z_logvar_a[:n_z]
    z_logvar_b = z_logvar_b[:n_z]
    js_d_ab = vae_tools.metrics.js_loss(z_ab, z_ab, z_logvar_ab, z_logvar_ab, n_samples)
    js_d_ab_a = vae_tools.metrics.js_loss(z_ab, z_a, z_logvar_ab, z_logvar_a, n_samples)
    js_d_ab_b = vae_tools.metrics.js_loss(z_ab, z_b, z_logvar_ab, z_logvar_b, n_samples)
    js_d_b_a = vae_tools.metrics.js_loss(z_a, z_b, z_logvar_a, z_logvar_b, n_samples)
    return js_d_ab, js_d_ab_a, js_d_ab_b, js_d_b_a


def run(seed = '0', loc = '/mnt/ssd_pcie/mmvae_mnist_split/'):
    dump_loc = loc + seed
    n_z = 10000 # Samples in latent space
    n_samples = 100 # Monte Carlo Samples

    # Get the MNIST digits
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    x_train = x_train.astype('float32') / 255.
    x_test = x_test.astype('float32') / 255.
    x_train = x_train.reshape((len(x_train), np.prod(x_train.shape[1:])))
    x_test = x_test.reshape((len(x_test), np.prod(x_test.shape[1:])))

    # Cut down data set for testing
    #x_train = x_train[:10,:]
    #y_train = y_train[:10]
    #x_test = x_test[:10,:]
    #y_test = y_test[:10]

    # input image dimensions
    img_rows, img_cols, img_chns = 28, 28, 1
    original_dim = img_rows * img_cols * img_chns
    split_dim = int(original_dim / 2)

    # Split it horizontally
    x_train_a = x_train[:,:split_dim]
    x_train_b = x_train[:,split_dim:]
    x_test_a = x_test[:,:split_dim]
    x_test_b = x_test[:,split_dim:]


    for net_str_m in glob.glob(dump_loc + '/enc_mean*11.h5'):
        net_str_lv = net_str_m.replace('mean', 'logvar')

        # Load the models

        print("load models")
        # Remove the suffix (like '11.h5')
        model_enc_mean, _ = vae_tools.vae.GenericVae.load_model_powerset(net_str_m[:-5], 2)
        model_enc_logvar, _ = vae_tools.vae.GenericVae.load_model_powerset(net_str_lv[:-5], 2)


        model_enc_mean_a = model_enc_mean[0]
        model_enc_mean_b = model_enc_mean[1]
        model_enc_mean_ab = model_enc_mean[2]

        model_enc_logvar_a = model_enc_logvar[0]
        model_enc_logvar_b = model_enc_logvar[1]
        model_enc_logvar_ab = model_enc_logvar[2]

        # Predict the test data
        print("predict data")
        z_train_ab = model_enc_mean_ab.predict([x_train_a, x_train_b])
        z_train_a = model_enc_mean_a.predict(x_train_a)
        z_train_b = model_enc_mean_b.predict(x_train_b)

        z_test_ab = model_enc_mean_ab.predict([x_test_a, x_test_b])
        z_test_a = model_enc_mean_a.predict(x_test_a)
        z_test_b = model_enc_mean_b.predict(x_test_b)

        z_train_logvar_ab = model_enc_logvar_ab.predict([x_train_a, x_train_b])
        z_train_logvar_a = model_enc_logvar_a.predict(x_train_a)
        z_train_logvar_b = model_enc_logvar_b.predict(x_train_b)

        z_test_logvar_ab = model_enc_logvar_ab.predict([x_test_a, x_test_b])
        z_test_logvar_a = model_enc_logvar_a.predict(x_test_a)
        z_test_logvar_b = model_enc_logvar_b.predict(x_test_b)

        # Cleanup tghe models
        for m_m, m_lv in zip(model_enc_mean, model_enc_logvar):
            del m_m, m_lv
        tf.keras.backend.clear_session()


        # Get the JS divergence
        print("calc js")
        import time

        start = time.time()
        js_d_ab, js_d_ab_a, js_d_ab_b, js_d_b_a = js_calc(z_test_ab, z_test_a, z_test_b, z_test_logvar_ab, z_test_logvar_a, z_test_logvar_b, n_z, n_samples)
        js_d_ab_train, js_d_ab_a_train, js_d_ab_b_train, js_d_b_a_train = js_calc(z_train_ab, z_train_a, z_train_b, z_train_logvar_ab, z_train_logvar_a, z_train_logvar_b, n_z, n_samples)
        end = time.time()
        print(end - start)
        js = {'test_ab_vs_ab': js_d_ab,
              'test_ab_vs_a': js_d_ab_a,
              'test_ab_vs_b': js_d_ab_b,
              'test_b_vs_a': js_d_b_a,
              'train_ab_vs_ab': js_d_ab_train,
              'train_ab_vs_a': js_d_ab_a_train,
              'train_ab_vs_b': js_d_ab_b_train,
              'train_b_vs_a': js_d_b_a_train}

        fn = net_str_m.replace('enc_mean', 'jsd').replace('_11.h5', '.p')
        print("write " + fn)
        with open(fn, 'wb') as handle:
            pickle.dump(js, handle, protocol=pickle.HIGHEST_PROTOCOL)

load models
ffffffffffffffffffffffffffffffffaaaaaaaaaaaaaaaaaaaaaaaaaaaake
calc js
1.4921205043792725
write /mnt/ssd_pcie/mmvae_mnist_split_1/jsd_45_a.p
load models
ffffffffffffffffffffffffffffffffaaaaaaaaaaaaaaaaaaaaaaaaaaaake
calc js
1.4639711380004883
write /mnt/ssd_pcie/mmvae_mnist_split_1/jsd_57_a.p
load models
ffffffffffffffffffffffffffffffffaaaaaaaaaaaaaaaaaaaaaaaaaaaake
calc js
1.4385802745819092
write /mnt/ssd_pcie/mmvae_mnist_split_1/jsd_22_a.p
load models
ffffffffffffffffffffffffffffffffaaaaaaaaaaaaaaaaaaaaaaaaaaaake
calc js
1.4484925270080566
write /mnt/ssd_pcie/mmvae_mnist_split_1/jsd_24_a.p
load models
ffffffffffffffffffffffffffffffffaaaaaaaaaaaaaaaaaaaaaaaaaaaake
calc js
1.4449970722198486
write /mnt/ssd_pcie/mmvae_mnist_split_1/jsd_39_a.p
load models
ffffffffffffffffffffffffffffffffaaaaaaaaaaaaaaaaaaaaaaaaaaaake
calc js
1.4356253147125244
write /mnt/ssd_pcie/mmvae_mnist_split_1/jsd_2_a.p
load models
ffffffffffffffffffffffffffffffffaaaaaaaaaaaaaaaaaaaaaaaaaaaake
calc js
