# Train a M²VAE with a split MNIST data set and evaluate the hyperparameter

In [3]:
import vae_tools.sanity
import vae_tools.viz
import vae_tools.callbacks
import vae_tools.loader
from vae_tools.mmvae import MmVae, ReconstructionLoss
from tensorflow.keras.optimizers import Adam, Nadam, RMSprop
vae_tools.sanity.check()
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import Input, Dense, Lambda, Layer
from tensorflow.keras.datasets import mnist
import numpy as np
from scipy.stats import norm
# Set the seed for reproducible results
import vae_tools.sampling
vae_tools.sampling.set_seed(0)
# resize the notebook if desired
#vae_tools.nb_tools.notebook_resize()
import matplotlib
import matplotlib.pyplot as plt
from itertools import product
import pandas as pd

python version:  3.5.2
keras version: 2.2.4-tf
tensorflow version: 2.0.2
matplotlib uses:  module://ipykernel.pylab.backend_inline
No GPUs available


In [4]:
def run(seed = 0, loc = '/mnt/ssd_pcie/mmvae_mnist_split/'):

    # Get the split MNIST digits
    (x_train_a, x_train_b), (x_test_a, x_test_b), y_train, y_test = vae_tools.loader.mnist_split(flatten = True, split = 'hor')

    # input image dimensions
    img_rows, img_cols, img_chns = 28, 28, 1
    original_dim = img_rows * img_cols * img_chns
    split_dim = int(original_dim / 2)

    # x_train_a = x_train_a[:10]
    # x_train_b = x_train_b[:10]
    # x_test_a = x_test_a[:10]
    # x_test_b = x_test_b[:10]
    # x_train_a = x_train_a[:10]
    # y_train = y_train[:10]
    # y_test = x_train_a[:10]

    # Show a split image

    #f, ax = plt.subplots(2,1,sharex=True)
    #ax[0].imshow(x_train_a[0,:].reshape(((int(img_rows/2), img_cols))))
    #ax[1].imshow(x_train_b[0,:].reshape(((int(img_rows/2), img_cols))))
    #plt.show()


    #%%
    # 1
    p = {'lr': [1.],
         'intermediate_dim': [500],
         'activation':['tanh'],
         'latent_intermediate_dim': [None, 125, 250, 500],
         #'latent_activation':['tanh', 'relu', 'elu'],
         'latent_activation':['tanh'],
         'batch_size': [100],
         'epochs': [100],
         'optimizer': [RMSprop],
         'beta': [1.0],
         'beta_mutual': [0.001, 0.01, 0.1, 1.0, 10.],
         'reconstruction_loss_metrics': [ReconstructionLoss.BCE],
         'z_dim': [20, 40, 80],
         'seed': [int(seed)]}
    # 2
    p = {'lr': [1.],
         'intermediate_dim': [500],
         'activation':['tanh'],
         'latent_intermediate_dim': [None],
         #'latent_activation':['tanh', 'relu', 'elu'],
         'latent_activation':['tanh'],
         'batch_size': [100],
         'epochs': [100],
         'optimizer': [RMSprop],
         'beta': [1.0],
         'beta_mutual': [10., 20., 30.],
         'reconstruction_loss_metrics': [ReconstructionLoss.BCE],
         'z_dim': [2, 5, 10, 15, 20],
         'seed': [int(seed)]}
    # 3
    p = {'lr': [1.],
         'intermediate_dim': [500],
         'activation':['relu'],
         'latent_intermediate_dim': [None],
         #'latent_activation':['tanh', 'relu', 'elu'],
         'latent_activation':['tanh'],
         'batch_size': [512],
         'epochs': [100],
         'optimizer': [Adam],
         'beta': [1.0],
         'beta_mutual': [0.001, 0.01, 0.1, 1.0, 10., 20., 30.],
         'reconstruction_loss_metrics': [ReconstructionLoss.BCE],
         'z_dim': [2, 5, 10, 15, 20, 40, 80],
         'seed': [int(seed)]}
    # 4
    p = {'lr': [1.],
         'intermediate_dim': [500],
         'activation':['relu'],
         'latent_intermediate_dim': [None],
         #'latent_activation':['tanh', 'relu', 'elu'],
         'latent_activation':['tanh'],
         'batch_size': [512],
         'epochs': [100],
         'optimizer': [Adam],
         'beta': [1.0],
         'beta_mutual': [0.001, 0.01, 0.1, 1.0, 10., 20., 30.],
         'reconstruction_loss_metrics': [ReconstructionLoss.BCE],
         'z_dim': [2, 5, 10, 15, 20, 40, 80],
         'seed': [int(seed)],
         'shared_weights': [False]}


    # Define the storage location for the networks
    dump_loc = loc + str(p['seed'][0]) + '/'

    ## Define the training loop
    def hp_process(x_train, y_train, x_val, y_val, params):
        # resetting the layer name generation counter
        tf.keras.backend.clear_session()
        # Build the model and train it
        vae_tools.sampling.set_seed(params['seed'])

        encoder = [
            [
                Input(shape=(split_dim,), name="input_a"),
                Dense(params['intermediate_dim'], activation=params['activation'], name="enc_a")
            ],
            [
                Input(shape=(split_dim,), name="input_b"),
                Dense(params['intermediate_dim'], activation=params['activation'], name="enc_b")
            ],
        ]

        decoder = [
            [
                Dense(params['intermediate_dim'], activation=params['activation'], name="dec_a"),
                Dense(split_dim, activation='sigmoid', name="output_a")
            ],
            [
                Dense(params['intermediate_dim'], activation=params['activation'], name="dec_b"),
                Dense(split_dim, activation='sigmoid', name="output_b")
            ]
        ]

        le = None
        if params['latent_intermediate_dim'] != None:
            le = vae_tools.vae.LatentEncoder(layer_dimensions=[params['latent_intermediate_dim']],
                                             is_relative=[False],
                                             activations=[params['latent_activation']])

        vae_obj = MmVae(params['z_dim'], encoder, decoder, [split_dim, split_dim], params['beta'],
                        latent_encoder = le, beta_mutual = params['beta_mutual'],
                        reconstruction_loss_metrics = [params['reconstruction_loss_metrics']],
                        shared_weights=params['shared_weights'], name='MMVAE')

        vae_model = vae_obj.get_model()
        vae_model.compile(optimizer=params['optimizer'](vae_tools.sanity.lr_normalizer(params['lr'], params['optimizer'])), loss=None)
        #vae_tools.viz.plot_model(vae, file = 'myVAE', print_svg = False, verbose = True)

        # Train
        h = vae_model.fit(x_train,
                    shuffle=True,
                    epochs=params['epochs'],
                    batch_size=params['batch_size'],
                    validation_data=(x_val, None),
                    verbose = 2
                    )
        # Store the final models
        vae_obj.store_model_powerset(dump_loc + 'enc_mean_' + str(params['index']) + '_ab_', vae_obj.encoder_inputs, vae_obj.get_encoder_mean)
        vae_obj.store_model_powerset(dump_loc + 'enc_logvar_' + str(params['index']) + '_ab_', vae_obj.encoder_inputs, vae_obj.get_encoder_logvar)
        vae_obj.get_decoder().save(dump_loc + 'dec_' + str(params['index']) + "_a.h5")

        return h.history.copy()


    ## Hyperparameter (hp) search
    # Get all combinations of hp
    hp = [dict(zip(p, v)) for v in product(*p.values())]
    # add an index to the hyperparameters
    for h, idx in zip(hp, list(range(len(hp)))):
        h.update({'index': idx})

    hp_h = [] # list of histories

    # Perform grid search
    for params in hp:
        h = hp_process([x_train_a, x_train_b], y_train, [x_test_a, x_test_b], y_test, params)
        hp_h.append(h)


    ## Create a pandas dataframe (df) and store it

    # Store just everything into a folder

    data = {}

    # Prefixes for history and for the full history as a list
    h_prefix = 'h_'
    h_list_prefix = 'list_'

    # init hp keys
    for k in hp[0].keys():
        data[k] = []
    # write hp keys
    for params in hp:
        for k in params.keys():
            data[k].append(params[k])

    # init history keys
    for k in hp_h[0].keys():
        data[h_prefix + k] = []
        data[h_prefix + h_list_prefix + k] = []
    # write history keys
    for h in hp_h:
        for k in h.keys():
            data[h_prefix + h_list_prefix + k].append(h[k])
    # write final history keys
    for h in hp_h:
        for k in h.keys():
            data[h_prefix + k].append(h[k][-1])

    # Create pandas dataframe and store it
    df = pd.DataFrame(data)
    df.to_hdf(dump_loc + 'history.h5', key='df')

Train on 10 samples, validate on 10 samples
Epoch 1/2
10/10 - 3s - loss: 1112.3313 - loss_reconstruction_0_0: 274.8330 - loss_reconstruction_1_0: 280.5603 - loss_reconstruction_2_0: 276.5545 - loss_reconstruction_2_1: 273.7543 - loss_prior_0: 1.6796 - loss_prior_1: 2.5866 - loss_prior_2: 2.3533 - loss_mutual_0: 0.0048 - loss_mutual_1: 0.0050 - val_loss: 1092.6062 - val_loss_reconstruction_0_0: 272.1688 - val_loss_reconstruction_1_0: 268.4044 - val_loss_reconstruction_2_0: 266.4499 - val_loss_reconstruction_2_1: 274.4526 - val_loss_prior_0: 2.5793 - val_loss_prior_1: 2.9332 - val_loss_prior_2: 5.6025 - val_loss_mutual_0: 0.0078 - val_loss_mutual_1: 0.0077
Epoch 2/2
10/10 - 0s - loss: 1099.5500 - loss_reconstruction_0_0: 272.6115 - loss_reconstruction_1_0: 268.1393 - loss_reconstruction_2_0: 269.1646 - loss_reconstruction_2_1: 277.3934 - loss_prior_0: 2.6363 - loss_prior_1: 2.8535 - loss_prior_2: 6.7363 - loss_mutual_0: 0.0079 - loss_mutual_1: 0.0071 - val_loss: 954.4225 - val_loss_recon

your performance may suffer as PyTables will pickle object types that it cannot
map directly to c-types [inferred_type->mixed,key->block2_values] [items->['activation', 'h_list_loss', 'h_list_loss_mutual_0', 'h_list_loss_mutual_1', 'h_list_loss_prior_0', 'h_list_loss_prior_1', 'h_list_loss_prior_2', 'h_list_loss_reconstruction_0_0', 'h_list_loss_reconstruction_1_0', 'h_list_loss_reconstruction_2_0', 'h_list_loss_reconstruction_2_1', 'h_list_val_loss', 'h_list_val_loss_mutual_0', 'h_list_val_loss_mutual_1', 'h_list_val_loss_prior_0', 'h_list_val_loss_prior_1', 'h_list_val_loss_prior_2', 'h_list_val_loss_reconstruction_0_0', 'h_list_val_loss_reconstruction_1_0', 'h_list_val_loss_reconstruction_2_0', 'h_list_val_loss_reconstruction_2_1', 'latent_activation', 'latent_intermediate_dim', 'optimizer', 'reconstruction_loss_metrics']]

  return pytables.to_hdf(path_or_buf, key, self, **kwargs)
