In [None]:
import keras.backend as K
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

# Dimensions of the original input, the intermediate layer and latent layer
original_dim = 784
intermediate_dim = 256
latent_dim = 2

# Variational autoencoder model
input_img = keras.Input(shape=(original_dim,))
encoded = layers.Dense(intermediate_dim, activation='relu')(input_img)
x_mean = layers.Dense(latent_dim)(encoded)
x_log_var = layers.Dense(latent_dim)(encoded)

def sampling(args):
    z_mean, z_log_var = args
    epsilon = K.random_normal(shape=(K.shape(z_mean)[0], latent_dim), mean=0., stddev=1.0)
    return z_mean + K.exp(0.5 * z_log_var) * epsilon

x = layers.Lambda(sampling, output_shape=(latent_dim,))([x_mean, x_log_var])

decoded = layers.Dense(intermediate_dim, activation='relu')(x)
decoded = layers.Dense(original_dim, activation='sigmoid')(decoded)
vae = keras.Model(input_img, decoded, name='vae')

# Create the loss function and compile the model
reconstruction_loss = original_dim * keras.metrics.binary_crossentropy(input_img, decoded)

# Define the KL divergence calculation within a Lambda layer
def kl_loss_layer(x_mean, x_log_var):
    return -0.5 * tf.reduce_sum(1 + x_log_var - tf.square(x_mean) - tf.exp(x_log_var), axis=-1)

kl_loss = layers.Lambda(kl_loss_layer, output_shape=(latent_dim,))([x_mean, x_log_var])

# Combine the reconstruction loss and KL loss directly in the loss function
def vae_loss(y_true, y_pred):
    return K.mean(reconstruction_loss + kl_loss)

vae.compile(optimizer='adam', loss=vae_loss)

# Encoder model
encoder = keras.Model(input_img, [x_mean, x_log_var, x], name='encoder')

# Decoder model
encoded_input = keras.Input(shape=(latent_dim,))
decoder_layer = vae.layers[-2](encoded_input)
decoder_layer = vae.layers[-1](decoder_layer)
decoder = keras.Model(encoded_input, decoder_layer, name='decoder')


In [None]:
# Training the autoencoder

vae.fit(x_train, x_train,
                epochs=50,
                batch_size=256,
                shuffle=True,
                validation_data=(x_test, x_test))

Epoch 1/50
[1m235/235[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 22ms/step - loss: 0.3497 - val_loss: 0.2357
Epoch 2/50
[1m235/235[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 27ms/step - loss: 0.2335 - val_loss: 0.2246
Epoch 3/50
[1m235/235[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 21ms/step - loss: 0.2234 - val_loss: 0.2156
Epoch 4/50
[1m235/235[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 25ms/step - loss: 0.2157 - val_loss: 0.2117
Epoch 5/50
[1m235/235[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 20ms/step - loss: 0.2123 - val_loss: 0.2092
Epoch 6/50
[1m235/235[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 30ms/step - loss: 0.2094 - val_loss: 0.2070
Epoch 7/50
[1m235/235[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 23ms/step - loss: 0.2069 - val_loss: 0.2050
Epoch 8/50
[1m235/235[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 27ms/step - loss: 0.2045 - val_loss: 0.2033
Epoch 9/50
[1m235/235[0m [3

<keras.src.callbacks.history.History at 0x7c9bb0bc1ad0>