In [2]:
import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt

print(tf.__version__)

2.15.1


In [3]:
from tensorflow.keras import layers

class Sampling(layers.Layer):
    """Uses (z_mean, z_log_var) to sample z, the vector encoding a digit."""

    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon
    
class Encoder(keras.Model):
    """Maps MNIST digits to a triplet (z_mean, z_log_var, z)."""
    
    def __init__(self, latent_dim=32, intermediate_dim=64, name="encoder", **kwargs):
        super(Encoder, self).__init__(name=name, **kwargs)
        self.dense_proj = layers.Dense(intermediate_dim, activation="relu")
        self.dense_mean = layers.Dense(latent_dim)
        self.dense_log_var = layers.Dense(latent_dim)
        self.sampling = Sampling()

    def call(self, inputs):
        x = self.dense_proj(inputs)
        z_mean = self.dense_mean(x)
        z_log_var = self.dense_log_var(x)
        z = self.sampling((z_mean, z_log_var))
        return z_mean, z_log_var, z
    
class Decoder(keras.Model):
    """Converts z, the encoded digit vector, back into a readable digit."""
    
    def __init__(self, original_dim=784, intermediate_dim=64, name="decoder", **kwargs):
        super(Decoder, self).__init__(name=name, **kwargs)
        self.dense_proj = layers.Dense(intermediate_dim, activation="relu")
        self.dense_output = layers.Dense(original_dim, activation="sigmoid")

    def call(self, inputs):
        x = self.dense_proj(inputs)
        return self.dense_output(x)
    
class VariationalAutoEncoder(keras.Model):
    """Combines the encoder and decoder into an end-to-end model for training."""
    
    def __init__(self, original_dim=784, intermediate_dim=64, latent_dim=32, name="autoencoder", **kwargs):
        super(VariationalAutoEncoder, self).__init__(name=name, **kwargs)
        self.encoder = Encoder(latent_dim=latent_dim, intermediate_dim=intermediate_dim)
        self.decoder = Decoder(original_dim=original_dim, intermediate_dim=intermediate_dim)

    def call(self, inputs):
        z_mean, z_log_var, z = self.encoder(inputs)
        reconstructed = self.decoder(z)
        kl_loss = -0.5 * tf.reduce_mean(
            z_log_var - tf.square(z_mean) - tf.exp(z_log_var) + 1
        )
        self.add_loss(kl_loss)
        return reconstructed

In [4]:
# Simple training loop on MNIST digits

original_dim = 784
vae = VariationalAutoEncoder(original_dim, intermediate_dim=64, latent_dim=32)

optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
mse_loss_fn = tf.keras.losses.MeanSquaredError()
loss_metric = tf.keras.metrics.Mean()

(x_train, _), _ = keras.datasets.mnist.load_data()
x_train = x_train.reshape(60000, original_dim).astype("float32") / 255

train_dataset = tf.data.Dataset.from_tensor_slices(x_train)
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(64)

epochs = 2

# Iterate over epochs.
for epoch in range(epochs):
    print("Start of epoch %d" % (epoch,))
    
    # Iterate over the batches of the dataset.
    for step, x_batch_train in enumerate(train_dataset):
        with tf.GradientTape() as tape:
            reconstructed = vae(x_batch_train)
            mse_loss = mse_loss_fn(x_batch_train, reconstructed)
            loss = mse_loss + sum(vae.losses)  # Add KLD regularization loss

        grads = tape.gradient(loss, vae.trainable_weights)
        optimizer.apply_gradients(zip(grads, vae.trainable_weights))

        loss_metric(loss)

        if step % 100 == 0:
            print("step %d: mean loss = %.4f" % (step, loss_metric.result()))

2026-02-24 14:16:13.350593: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2026-02-24 14:16:13.477472: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2026-02-24 14:16:13.477751: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-

Start of epoch 0


2026-02-24 14:16:15.258337: I external/local_xla/xla/service/service.cc:168] XLA service 0x114c0520 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2026-02-24 14:16:15.258376: I external/local_xla/xla/service/service.cc:176]   StreamExecutor device (0): NVIDIA GeForce GTX 1050 Ti, Compute Capability 6.1
2026-02-24 14:16:15.268931: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:269] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
2026-02-24 14:16:15.292543: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:454] Loaded cuDNN version 8904
I0000 00:00:1771960575.382808   22087 device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


step 0: mean loss = 0.3300
step 100: mean loss = 0.1251
step 200: mean loss = 0.0990
step 300: mean loss = 0.0890
step 400: mean loss = 0.0841
step 500: mean loss = 0.0808
step 600: mean loss = 0.0787
step 700: mean loss = 0.0771
step 800: mean loss = 0.0759
step 900: mean loss = 0.0749
Start of epoch 1
step 0: mean loss = 0.0746
step 100: mean loss = 0.0740
step 200: mean loss = 0.0735
step 300: mean loss = 0.0730
step 400: mean loss = 0.0727
step 500: mean loss = 0.0723
step 600: mean loss = 0.0720
step 700: mean loss = 0.0717
step 800: mean loss = 0.0715
step 900: mean loss = 0.0712


In [5]:
vae = VariationalAutoEncoder(original_dim, intermediate_dim=64, latent_dim=32)

optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)

vae.compile(optimizer, loss=tf.keras.losses.MeanSquaredError())
vae.fit(x_train, x_train, epochs=2, batch_size=64)

Epoch 1/2
Epoch 2/2


<keras.src.callbacks.History at 0x76a64f27c160>

In [8]:
original_dim = 784
intermediate_dim = 64
latent_dim = 32

# Define the Encoder Model
orignal_inputs = tf.keras.Input(shape=(original_dim,), name="encoder_input")
x = layers.Dense(intermediate_dim, activation="relu")(orignal_inputs)
z_mean = layers.Dense(latent_dim, name="z_mean")(x)
z_log_var = layers.Dense(latent_dim, name="z_log_var")(x)
z = Sampling()((z_mean, z_log_var))
encoder = keras.Model(orignal_inputs, [z_mean, z_log_var, z], name="encoder")
encoder.summary()

# Define the Decoder Model
latent_inputs = tf.keras.Input(shape=(latent_dim,), name="z_sampling")
x = layers.Dense(intermediate_dim, activation="relu")(latent_inputs)
outputs = layers.Dense(original_dim, activation="sigmoid")(x)
decoder = keras.Model(latent_inputs, outputs, name="decoder")
decoder.summary()

# Define the VAE Model
outputs = decoder(encoder(orignal_inputs)[2])
vae = keras.Model(orignal_inputs, outputs, name="vae")
vae.summary()

# Add KL divergence regularization loss
kl_loss = -0.5 * tf.reduce_mean(
    z_log_var - tf.square(z_mean) - tf.exp(z_log_var) + 1
)
vae.add_loss(kl_loss)

# Train the VAE
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
vae.compile(optimizer, loss=tf.keras.losses.MeanSquaredError())
vae.fit(x_train, x_train, epochs=5, batch_size=64)

Model: "encoder"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 encoder_input (InputLayer)  [(None, 784)]                0         []                            
                                                                                                  
 dense_16 (Dense)            (None, 64)                   50240     ['encoder_input[0][0]']       
                                                                                                  
 z_mean (Dense)              (None, 32)                   2080      ['dense_16[0][0]']            
                                                                                                  
 z_log_var (Dense)           (None, 32)                   2080      ['dense_16[0][0]']            
                                                                                            

<keras.src.callbacks.History at 0x76a647fbd7b0>