# Variational Autoencoders: From Probability Theory to Practice

## 1. Probabilistic Foundations
Let $(\Omega, \sigma, \mathbb{P})$ be a probability space. Our data points $\{x_i\}_{i=1}^n$ are realizations of random variables $X_i: \Omega \to \mathbb{R}^d$. The push-forward measure $P_X = \mathbb{P} \circ X^{-1}$ defines the distribution of our data.
We aim to learn this distribution through a latent variable model. For each observation $x_i$, we posit the existence of an unobserved latent variable $z_i$ which is a realization of a random variable $Z_i: \Omega \to \mathbb{R}^k$, with push-forward measure $P_Z = \mathbb{P} \circ Z^{-1}$ and density $p_Z$ with respect to Lebesgue measure (which exists by Radon-Nikodym since we assume $P_Z \ll \lambda$ where $\lambda$ is Lebesgue measure).

## 2. The Model Class
Let $p_X$ be the true (unknown) density of $P_X$ with respect to Lebesgue measure (which exists by Radon-Nikodym since we assume $P_X \ll \lambda$ where $\lambda$ is Lebesgue measure).
We aim to approximate $p_X$ using a latent variable model. Specifically, we consider a family of joint densities $p_{X,Z}(\cdot, \cdot; \theta)$ parameterized by $\theta$, which induce marginal densities:
$$p_\theta(x_i) = \int_{\mathbb{R}^k} p_{X,Z}(x_i, z; \theta) dz$$
The maximum likelihood objective is then:
$$\hat{\theta} = \argmax_{\theta} \sum_{i=1}^n \log p_\theta(x_i)$$

## 3. The Evidence Lower Bound (ELBO)
The integral in our objective is typically intractable. We can derive a lower bound through Jensen's inequality. For any probability density $q$:
$$\log p(x_i) = \log \int_{\mathbb{R}^k} p(x_i, z) dz = \log \int_{\mathbb{R}^k} q(z) \frac{p(x_i, z)}{q(z)} dz$$
By Jensen's inequality and the concavity of log:
$$\log \mathbb{E}_{Z \sim q}\left[\frac{p(x_i, Z)}{q(Z)}\right] \geq \mathbb{E}_{Z \sim q}\left[\log \frac{p(x_i, Z)}{q(Z)}\right]$$
This gives us our evidence lower bound (ELBO):
$$\log p(x_i) \geq \int_{\mathbb{R}^k} q(z) \log \frac{p(x_i, z)}{q(z)} dz$$

## 4. Factorization and Amortization
We can decompose the joint density using the chain rule: $p(x_i, z) = p(x_i|z)p(z)$. This allows us to rewrite the ELBO:
$$\mathbb{E}_{Z \sim q}[\log p(x_i|Z)] - D_\text{KL}(q|p_Z)$$
where $p_Z$ is our prior density on the latent space.
This form is computationally tractable because:

The first term can be estimated with Monte Carlo sampling:
$$\mathbb{E}_{Z \sim q}[\log p(x_i|Z)] \approx \frac{1}{S}\sum_{s=1}^S \log p(x_i|z^{(s)})$$
where $z^{(s)} \sim q(z; g_\phi(x_i))$. The KL divergence term has a closed form when using normal distributions.

This is much more tractable than the original objective:
$\log \int p(x_i|z)p(z)dz$
which would require integrating over all $z$ just to evaluate once.
For additional computational efficiency, we make two key assumptions:

$\textbf{Mean-Field Assumption}$: The joint distribution of latent variables factors across data points:
$$q(z_1, ..., z_n) = \prod_{i=1}^n q_i(z_i)$$
$\textbf{Amortized Inference}$: Instead of learning separate variational distributions for each data point, we learn a mapping $g_\phi: \mathbb{R}^d \to \mathcal{P}(\mathbb{R}^k)$ where $\mathcal{P}(\mathbb{R}^k)$ represents the space of probability measures on $\mathbb{R}^k$:
$$q_i(z_i) = q(z_i; g_\phi(x_i))$$

## 5. Parameterization and Computation
We typically choose:

Prior: $p(z) = \mathcal{N}(0, I)$

Variational family: $q(z; g_\phi(x)) = \mathcal{N}(\mu_\phi(x), \text{diag}(\sigma^2_\phi(x)))$

Likelihood: $p(x|z) = f_\theta(z)$ for some neural network $f_\theta$

This gives us the final optimization objective:
$$\max_{\theta, \phi} \sum_{i=1}^n \left[\mathbb{E}_{z \sim q(z; g\phi(x_i))} \log p(x_i|z; f_\theta) - D_\text{KL}(q(z; g_\phi(x_i))|p(z))\right]$$
The KL divergence term has a closed form for normal distributions. The expectation term requires Monte Carlo estimation, which is made differentiable through the reparameterization trick:
Instead of sampling directly from $q(z; g_\phi(x_i))$, we sample:
$$\epsilon \sim \mathcal{N}(0, I)$$
$$z = \mu_\phi(x_i) + \sigma_\phi(x_i) \odot \epsilon$$
This makes the sampling process differentiable with respect to the parameters $\phi$.

## Notes:
This is the standard VI setup with local latent variables, not a fully Bayesian treatment with global latents

In [1]:
import jax
import jax.numpy as jnp
from jax import random, grad, jit, vmap
import tensorflow_datasets as tfds
import optax

def get_mnist():
    """Load and normalize MNIST data."""
    ds = tfds.load('mnist', split='train', as_supervised=True)
    ds = ds.map(lambda x, y: (tf.cast(x, tf.float32) / 255., y))
    return ds.batch(32).prefetch(1)

def init_network_params(key, input_dim, hidden_dim, latent_dim):
    """Initialize neural network parameters."""
    # Encoder: input_dim -> hidden_dim -> latent_dim*2 (mean and logvar)
    key1, key2, key3, key4 = random.split(key, 4)
    
    encoder_params = {
        'h1': {
            'w': random.normal(key1, (input_dim, hidden_dim)) / jnp.sqrt(input_dim),
            'b': jnp.zeros(hidden_dim)
        },
        'h2': {
            'w': random.normal(key2, (hidden_dim, latent_dim * 2)) / jnp.sqrt(hidden_dim),
            'b': jnp.zeros(latent_dim * 2)
        }
    }
    
    # Decoder: latent_dim -> hidden_dim -> input_dim
    decoder_params = {
        'h1': {
            'w': random.normal(key3, (latent_dim, hidden_dim)) / jnp.sqrt(latent_dim),
            'b': jnp.zeros(hidden_dim)
        },
        'h2': {
            'w': random.normal(key4, (hidden_dim, input_dim)) / jnp.sqrt(hidden_dim),
            'b': jnp.zeros(input_dim)
        }
    }
    
    return {'encoder': encoder_params, 'decoder': decoder_params}

def encoder(params, x):
    """Encoder network mapping x to mean and logvar of q(z|x)."""
    # First layer with tanh activation
    h = jnp.tanh(x @ params['h1']['w'] + params['h1']['b'])
    # Output layer with no activation (mean and logvar)
    h = h @ params['h2']['w'] + params['h2']['b']
    # Split into mean and logvar
    mean, logvar = jnp.split(h, 2, axis=-1)
    return mean, logvar

def reparameterize(key, mean, logvar):
    """Reparameterization trick: z = mean + std * epsilon."""
    eps = random.normal(key, mean.shape)
    return mean + jnp.exp(0.5 * logvar) * eps

def decoder(params, z):
    """Decoder network mapping z to reconstruction."""
    # First layer with tanh activation
    h = jnp.tanh(z @ params['h1']['w'] + params['h1']['b'])
    # Output layer with sigmoid for pixel values
    x_recon = jax.nn.sigmoid(h @ params['h2']['w'] + params['h2']['b'])
    return x_recon

def vae_loss(params, key, batch):
    """Compute ELBO loss for a batch."""
    # Encode
    mean, logvar = encoder(params['encoder'], batch)
    
    # Sample z using reparameterization trick
    z = reparameterize(key, mean, logvar)
    
    # Decode
    x_recon = decoder(params['decoder'], z)
    
    # Reconstruction loss (binary cross entropy)
    recon_loss = -jnp.sum(
        batch * jnp.log(x_recon + 1e-8) + 
        (1 - batch) * jnp.log(1 - x_recon + 1e-8)
    )
    
    # KL divergence (analytical for normal distributions)
    kl_loss = -0.5 * jnp.sum(
        1 + logvar - jnp.square(mean) - jnp.exp(logvar)
    )
    
    return recon_loss + kl_loss

# JIT compile for faster execution
@jit
def train_step(params, opt_state, key, batch):
    """Single training step."""
    loss_val, grads = jax.value_and_grad(vae_loss)(params, key, batch)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss_val

# Training setup
batch_size = 32
input_dim = 784  # flattened MNIST
hidden_dim = 512
latent_dim = 2
learning_rate = 1e-3

# Initialize parameters and optimizer
key = random.PRNGKey(0)
params = init_network_params(key, input_dim, hidden_dim, latent_dim)
optimizer = optax.adam(learning_rate)
opt_state = optimizer.init(params)

# Training loop
def train_epoch(params, opt_state, key, train_ds):
    """Train for one epoch."""
    for batch, _ in train_ds:
        key, subkey = random.split(key)
        batch = jnp.reshape(batch, (batch_size, -1))
        params, opt_state, loss = train_step(params, opt_state, subkey, batch)
    return params, opt_state

# To use:
# train_ds = get_mnist()
# for epoch in range(num_epochs):
#     key, subkey = random.split(key)
#     params, opt_state = train_epoch(params, opt_state, subkey, train_ds)

ModuleNotFoundError: No module named 'jax'