# Variational Autoencoder Z=30

In [None]:
from keras.layers import Input, Dense, Lambda, Layer, Add, Multiply
from keras.losses import mse

class KLDivergenceLayer(Layer):

    """ Identity transform layer that adds KL divergence
    to the final model loss.
    """

    def __init__(self, *args, **kwargs):
        self.is_placeholder = True
        super(KLDivergenceLayer, self).__init__(*args, **kwargs)

    def call(self, inputs):
        mu, log_var = inputs
        kl_batch = - .5 * K.sum(1 + log_var - K.square(mu) - K.exp(log_var), axis=-1)
        self.add_loss(K.mean(kl_batch), inputs=inputs)
        return inputs

def reconstruction_loss(y_true, y_pred):
    return mse(y_true, y_pred)*300


def VAE_Z30():
    z_dim=30
    inp_curve=Input(shape=(100,3))

    # Encoding
    x=Flatten()(inp_curve)
    x=Dense(150, activation='relu')(x)
    #x=Dense(150, activation='relu')(x)
    #x=Dense(150, activation='relu')(x)

    z_mu = Dense(z_dim)(x)
    z_log_var = Dense(z_dim)(x)

    z_mu, z_log_var = KLDivergenceLayer()([z_mu, z_log_var])
    z_sigma = Lambda(lambda t: K.exp(.5*t))(z_log_var) # normalize log variance to std dev

    eps = Input(tensor=K.random_normal(shape=(K.shape(x)[0],z_dim)))
    z_eps = Multiply()([z_sigma, eps])
    encoded = Add()([z_mu, z_eps])

    # Decoding
    x=Dense(150, activation='relu')(encoded)
    #x=Dense(150, activation='relu')(x)
    #x=Dense(150, activation='relu')(x)
    x=Dense(300)(x)
    decoded=Reshape((100, 3))(x)

    vae = Model(inputs=[inp_curve, eps], outputs=decoded)
    encoder = Model(inputs=[inp_curve, eps], outputs=encoded)
    
    return vae, encoder

#resetRNG(0)
#x_train_3D=np.expand_dims(x_train, 4)
#x_test_3D=np.expand_dims(x_test, 4)
#AE,E,train_data =AE_analysis(VAE_Z30, 10, x_train, x_test, filename='VAE_Z30')
vae,_=VAE_Z30()
vae.summary()
vae.compile(optimizer='rmsprop', loss=reconstruction_loss)
train_data=vae.fit(x_train, x_train,
        shuffle=True,
        epochs=500,
        batch_size=128,
        validation_data=(x_test, x_test), verbose=0)
visualize_AE(vae, train_data, x_test)

In [None]:
# Conditional VAE

In [None]:
# Disentangling VAE

In [None]:
# use domain knowledge and use multiple parallel input lambda layers for FDs, CAD, wavelets, etc

In [None]:
# Deep feature consistent variational auto-encoder