# Variational Autoencoder

Go from input dim to a small tensor of gaussians parameterized by mu and sigma and then build out to the larger output dim (equal to input dim).

Later, sample from the gaussian distribution to interpolate within the distribution learn by the VAE.

---

Example: Train on emojis and then sample novel emojis.

# MNIST

In [16]:
import os
os.environ["JAX_PLATFORMS"] = "cpu"

In [17]:
# fmnist_data_check.py
import jax
import jax.numpy as jnp
import numpy as np
from datasets import load_dataset
from einops import rearrange

print("JAX devices:", jax.devices())
print("Default device:", jax.default_backend())

JAX devices: [CpuDevice(id=0)]
Default device: cpu


In [18]:
def to_float01(example):
    # convert PIL image -> float32 [0,1], add channel dim (1, 28, 28)
    arr = np.array(example["image"], dtype=np.float32) / 255.0
    arr = rearrange(arr, "h w -> 1 h w")
    example["x"] = arr
    return example

In [19]:
# 1) load HF dataset
ds = load_dataset("fashion_mnist")

In [20]:
# 2) map -> float32, add channel dim
ds = ds.map(to_float01, remove_columns=["image"])

In [21]:
# 3) train/test numpy arrays
x_train = np.stack([ex["x"] for ex in ds["train"]], axis=0)  # (60000,1,28,28)
y_train = np.array(ds["train"]["label"])
x_test  = np.stack([ex["x"] for ex in ds["test"]], axis=0)   # (10000,1,28,28)

print("x_train:", x_train.shape, x_train.dtype, x_train.min(), x_train.max())
print("x_test :", x_test.shape,  x_test.dtype)

x_train: (60000, 1, 28, 28) float64 0.0 1.0
x_test : (10000, 1, 28, 28) float64


In [22]:
# 4) simple dataloader that yields JAX device arrays
def make_batches(x, batch_size=128, drop_last=True, shuffle=True, seed=0):
    rng = np.random.default_rng(seed)
    n = x.shape[0]
    idx = np.arange(n)
    while True:
        if shuffle:
            rng.shuffle(idx)
        for i in range(0, n, batch_size):
            j = i + batch_size
            if j > n and drop_last:
                break
            batch = x[idx[i:j]]
            # move to device (GPU) once here
            yield jax.device_put(jnp.array(batch))

In [23]:
# quick smoke test
batches = make_batches(x_train, batch_size=256)
xb = next(batches)
print("one batch:", xb.shape, xb.dtype, xb.device)

one batch: (256, 1, 28, 28) float32 TFRT_CPU_0


In [24]:
from dataclasses import dataclass
import jax, jax.numpy as jnp
import equinox as eqx
from jaxtyping import Float, Array, jaxtyped
from typing import Tuple
from beartype import beartype

In [25]:
# KeyArray = jax.Array

In [26]:
def flatten_images(x: Float[Array, "B H W C"]) -> Float[Array, "B (H W C)"]:
    b, h, w, c = x.shape
    return x.reshape(b, h*w*c)

In [27]:
x = jnp.zeros((100, 28, 28, 1))

In [34]:
y = flatten_images(x)
y.shape

(100, 784)

In [29]:
def unflatten_images(x: Float[Array, "B (H W C)"]) -> Float[Array, "B H W C"]:
    b, hwc = x.shape
    return x.reshape(b, 28, 28, 1)

In [33]:
unflatten_images(y).shape

(100, 28, 28, 1)

In [36]:
28*28

784

In [95]:
class VAE(eqx.Module):
    encoder: eqx.nn.MLP
    mu_head: eqx.nn.Linear
    logvar_head: eqx.nn.Linear
    decoder: eqx.nn.MLP
    latent_dim: int

    def __init__(self, key: jax.Array, input_dim: int = 28*28, latent_dim: int = 16, width: int = 256, depth: int = 2):
        # input_dim -> width -> width (depth times) -> latent_dim -> 
        #  -> width -> width (depth times) -> input_dim
        k1, k2, k3, k4 = jax.random.split(key, 4)
        self.encoder = eqx.nn.MLP(in_size=input_dim, out_size=width, width_size=width, depth=depth, activation=jax.nn.gelu, key=k1)
        self.mu_head = eqx.nn.Linear(width, latent_dim, use_bias=True, key=k2)
        self.logvar_head = eqx.nn.Linear(width, latent_dim, use_bias=True, key=k3)
        self.decoder = eqx.nn.MLP(in_size=latent_dim, out_size=input_dim, width_size=width, depth=depth, activation=jax.nn.gelu, key=k4)
        self.latent_dim = latent_dim

    def encode(self, x: Float[Array, "B H W C"]) -> Tuple[
        Float[Array, "B latent"], Float[Array, "B latent"]]:
        # B W H C -> (B Ldim, B Ldim)
        h = self.encoder(flatten_images(x))
        mu = self.mu_head(h)
        logvar = self.logvar_head(h)
        return mu, logvar

    def sample_z(self, mu: Array, logvar: Array, key: Array) -> Array:
        # reparameterization trick: z = mu + std * eps
        eps = jax.random.normal(key, shape=mu.shape, dtype=mu.dtype)
        std = jnp.exp(0.5 * logvar)
        return mu + std * eps

    def decode(self, z: Float[Array, "B latent"]) -> Float[Array, "B (H W C)"]:
        return self.decode(z)

    def __call__(self, x: Float[Array, "B H W C"], key: Array) -> Tuple[Float[Array, "B (W H C)"], Tuple[Array, Array, Array]]:
        mu, logvar = self.encode(x)
        key, key_z = jnp.random.split(key, 1)
        z = self.sample_z(mu, logvar, key_z)
        logits = self.decode(z)
        return logits, (mu, logvar, z)

In [100]:
SEED = 42
key = jax.random.PRNGKey(SEED)
model = VAE(key, latent_dim=8, width=128, depth=2)

In [101]:
import jax
import equinox as eqx

def count_params(tree) -> int:
    leaves = jax.tree_util.tree_leaves(eqx.filter(tree, eqx.is_inexact_array))
    return sum(int(x.size) for x in leaves)

def print_param_summary(model):
    parts = {
        "encoder": model.encoder,
        "mu_head": model.mu_head,
        "logvar_head": model.logvar_head,
        "decoder": model.decoder,
    }
    for name, part in parts.items():
        print(f"{name:12s}: {count_params(part):8d} params")
    print(f"{'TOTAL':12s}: {count_params(model):8d} params")

print_param_summary(model)


encoder     :   133504 params
mu_head     :     1032 params
logvar_head :     1032 params
decoder     :   118800 params
TOTAL       :   254368 params


In [103]:
def bce_with_logits(logits: Array, targets: Array) -> Array:
    # per-element BCE: max(l,0) - l*x + lop1p(exp(-|l|))
    l = logits
    x = targets
    return jnp.maximum(l, 0.0) - l * x + jnp.log1p(jnp.exp(-jnp.abs(l)))

In [105]:
def kl_standard_normal(mu: Array, logvar: Array) -> Array:
    # KL[q(z|x) || p(z)] for diagonal Gaussians where p(z)=N(0,I), q has (mu, logvar)
    return 0.5 * (jnp.exp(logvar) + jnp.square(mu) - 1.0 - logvar)

In [None]:
def elbo_loss(model, x: 