In [None]:
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, Model
from tensorflow.keras.datasets import mnist
import matplotlib.pyplot as plt

## 1) Load & preprocess MNIST

In [None]:
def load_and_preprocess_data():
    (x_train, _), (x_test, _) = mnist.load_data()
    x_train = x_train.astype('float32') / 255.0
    x_test = x_test.astype('float32') / 255.0
    x_train = np.reshape(x_train, (len(x_train), 28, 28, 1))
    x_test = np.reshape(x_test, (len(x_test), 28, 28, 1))
    return x_train, x_test

## 2) Build Encoder with CNN 

In [None]:
def build_encoder(latent_dim):
    encoder_inputs = layers.Input(shape=(28, 28, 1))
    x = layers.Conv2D(32, 3, activation="relu", strides=2, padding="same")(encoder_inputs)
    x = layers.Conv2D(64, 3, activation="relu", strides=2, padding="same")(x)
    x = layers.Flatten()(x)
    x = layers.Dense(128, activation="relu")(x)
    z_mean = layers.Dense(latent_dim, name="z_mean")(x)
    z_log_var = layers.Dense(latent_dim, name="z_log_var")(x)
    
    def sampling(args):
        z_mean, z_log_var = args
        eps = tf.random.normal(shape=tf.shape(z_mean))
        return z_mean + tf.exp(0.5 * z_log_var) * eps
    
    z = layers.Lambda(sampling, name="z")([z_mean, z_log_var])
    return Model(encoder_inputs, [z_mean, z_log_var, z], name="encoder")

## 3) Build Decoder with CNN

In [None]:
def build_decoder(latent_dim):
    latent_inputs = layers.Input(shape=(latent_dim,))
    x = layers.Dense(7 * 7 * 64, activation="relu")(latent_inputs)
    x = layers.Reshape((7, 7, 64))(x)
    x = layers.Conv2DTranspose(64, 3, activation="relu", strides=2, padding="same")(x)
    x = layers.Conv2DTranspose(32, 3, activation="relu", strides=2, padding="same")(x)
    decoder_outputs = layers.Conv2DTranspose(1, 3, activation="sigmoid", padding="same")(x)
    return Model(latent_inputs, decoder_outputs, name="decoder")

## 4) Custom Layer for VAE Loss

In [None]:
def vae_loss_layer(inputs, outputs, z_mean, z_log_var):
    
    class VAELossLayer(layers.Layer):
        def call(self, inputs):
            x, reconstructed, z_mean, z_log_var = inputs
            
            # Reconstruction loss
            recon_loss = tf.reduce_mean(
                tf.reduce_sum(
                    tf.keras.losses.binary_crossentropy(
                        tf.keras.layers.Flatten()(x),
                        tf.keras.layers.Flatten()(reconstructed)
                    ),
                    axis=-1
                )
            )
            
            # KL divergence
            kl_loss = -0.5 * tf.reduce_mean(
                tf.reduce_sum(
                    1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var),
                    axis=-1
                )
            )
            self.add_loss(recon_loss + kl_loss)
            return reconstructed
            
    return VAELossLayer()([inputs, outputs, z_mean, z_log_var])

## 5) Build VAE

In [None]:
def build_vae(encoder, decoder):
    encoder_inputs = encoder.inputs[0]
    z_mean, z_log_var, z = encoder(encoder_inputs)
    decoder_outputs = decoder(z)
    outputs = vae_loss_layer(encoder_inputs, decoder_outputs, z_mean, z_log_var)
    vae = Model(encoder_inputs, outputs, name="vae")
    return vae

## 6) Train VAE

In [None]:
def train_vae(vae, x_train, x_test, epochs=15, batch_size=128):
    vae.compile(optimizer="adam")
    vae.fit(x_train, None,
            epochs=epochs,
            batch_size=batch_size,
            validation_data=(x_test, None),
            verbose=1)

## 7) Generate images

In [None]:
def plot_generated(decoder, latent_dim, n=16):
    z_sample = np.random.normal(size=(n, latent_dim)).astype("float32")
    gen = decoder.predict(z_sample, verbose=0)
    cols = int(np.sqrt(n))
    rows = int(np.ceil(n / cols))
    plt.figure(figsize=(cols * 2, rows * 2))
    for i in range(n):
        plt.subplot(rows, cols, i + 1)
        plt.imshow(gen[i].squeeze(), cmap="gray")
        plt.axis("off")
    plt.tight_layout()
    plt.show()

## 8) Main execution 

In [None]:
def main():
    latent_dim = 16
    x_train, x_test = load_and_preprocess_data()
    encoder = build_encoder(latent_dim)
    decoder = build_decoder(latent_dim)
    vae = build_vae(encoder, decoder)
    train_vae(vae, x_train, x_test, epochs=30, batch_size=128)
    plot_generated(decoder, latent_dim, n=16)

if __name__ == "__main__":
    main()

## **Encoder**
| Layer / Operation | Output Shape |
|-------------------|--------------|
| **Input** | `(128, 28, 28, 1)` |
| Conv2D(32, 3, strides=2, same) | `(128, 14, 14, 32)` |
| Conv2D(64, 3, strides=2, same) | `(128, 7, 7, 64)` |
| Flatten | `(128, 3136)` |
| Dense(128) | `(128, 128)` |
| Dense(latent_dim=16) → **z_mean** | `(128, 16)` |
| Dense(latent_dim=16) → **z_log_var** | `(128, 16)` |
| Lambda(sampling) → **z** | `(128, 16)` |

---

## **Decoder**
| Layer / Operation | Output Shape |
|-------------------|--------------|
| **Input z** | `(128, 16)` |
| Dense(7×7×64) | `(128, 3136)` |
| Reshape | `(128, 7, 7, 64)` |
| Conv2DTranspose(64, 3, strides=2, same) | `(128, 14, 14, 64)` |
| Conv2DTranspose(32, 3, strides=2, same) | `(128, 28, 28, 32)` |
| Conv2DTranspose(1, 3, strides=1, same, sigmoid) | `(128, 28, 28, 1)` |

---

## **Loss Layer**
| Step | Shape |
|------|-------|
| **Inputs**: `x` | `(128, 28, 28, 1)` |
| **Inputs**: `reconstructed` | `(128, 28, 28, 1)` |
| **Inputs**: `z_mean` | `(128, 16)` |
| **Inputs**: `z_log_var` | `(128, 16)` |
| Flatten `x` | `(128, 784)` |
| Flatten `reconstructed` | `(128, 784)` |
| Binary crossentropy (per pixel) | `(128, 784)` |
| Sum over axis=-1 | `(128,)` |
| Mean over batch | `scalar` |
| KL divergence sum over axis=-1 | `(128,)` |
| KL divergence mean over batch | `scalar` |

---

## **Final Output**
| Output | Shape |
|--------|-------|
| **Reconstructed batch** | `(128, 28, 28, 1)` |
| **Loss** | `scalar` *(Recon Loss + KL Loss)* |
