# deterministic autoencoder and variational autoencoder on CelebA

It also uses pre-trained checkpoints, to make the demo faster.
However, the code to train the model from scratch is also included.

In [None]:
from functools import partial
from itertools import islice
import subprocess
from typing import Sequence, Tuple

import flax
import flax.linen as nn
from flax.training import checkpoints, train_state
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import optax
import tensorflow as tf
import torch
from torch.utils.data import DataLoader
import torchvision.transforms as T
from torchvision.datasets import CelebA

tf.config.set_visible_devices([], "GPU")  # prevent TensorFlow from using the GPU


The encoder has convolutional layers followed by a fully connected layer.
The covolutional layers are shared between mean and variance, while each has its own fully connected layer.

In [None]:
class Encoder(nn.Module):
    latent_dim: int
    hidden_channels: Sequence[int]

    @nn.compact
    def __call__(self, X, training):
        for channel in self.hidden_channels:
            X = nn.Conv(channel, (3, 3), strides=2, padding=1)(X)
            X = nn.BatchNorm(use_running_average=not training)(X)
            X = jax.nn.relu(X)

        X = X.reshape((-1, np.prod(X.shape[-3:])))
        mu = nn.Dense(self.latent_dim)(X)
        logvar = nn.Dense(self.latent_dim)(X)

        return mu, logvar


class Decoder(nn.Module):
    output_dim: Tuple[int, int, int]
    hidden_channels: Sequence[int]

    @nn.compact
    def __call__(self, X, training):
        H, W, C = self.output_dim

        # TODO: relax this restriction
        factor = 2 ** len(self.hidden_channels)
        assert (
            H % factor == W % factor == 0
        ), f"output_dim must be a multiple of {factor}"
        H, W = H // factor, W // factor

        X = nn.Dense(H * W * self.hidden_channels[-1])(X)
        X = jax.nn.relu(X)
        X = X.reshape((-1, H, W, self.hidden_channels[-1]))

        for hidden_channel in reversed(self.hidden_channels[:-1]):
            X = nn.ConvTranspose(
                hidden_channel, (3, 3), strides=(2, 2), padding=((1, 2), (1, 2))
            )(X)
            X = nn.BatchNorm(use_running_average=not training)(X)
            X = jax.nn.relu(X)

        X = nn.ConvTranspose(C, (3, 3), strides=(2, 2), padding=((1, 2), (1, 2)))(X)
        X = jax.nn.sigmoid(X)

        return X


def reparameterize(key, mean, logvar):
    std = jnp.exp(0.5 * logvar)
    eps = jax.random.normal(key, logvar.shape)
    return mean + eps * std


class VAE(nn.Module):
    variational: bool
    latent_dim: int
    output_dim: Tuple[int, int, int]
    hidden_channels: Sequence[int]

    def setup(self):
        self.encoder = Encoder(self.latent_dim, self.hidden_channels)
        self.decoder = Decoder(self.output_dim, self.hidden_channels)

    def __call__(self, key, X, training):
        mean, logvar = self.encoder(X, training)
        if self.variational:
            Z = reparameterize(key, mean, logvar)
        else:
            Z = mean

        recon = self.decoder(Z, training)
        return recon, mean, logvar

    def decode(self, Z, training):
        return self.decoder(Z, training)

In [None]:

class TrainState(train_state.TrainState):
    batch_stats: flax.core.FrozenDict[str, jnp.ndarray]
    beta: float


def create_train_state(
    key, variational, beta, latent_dim, hidden_channels, learning_rate, specimen
):
    vae = VAE(variational, latent_dim, specimen.shape, hidden_channels)
    key_dummy = jax.random.PRNGKey(42)
    (recon, _, _), variables = vae.init_with_output(key, key_dummy, specimen, True)
    assert (
        recon.shape[-3:] == specimen.shape
    ), f"{recon.shape} = recon.shape != specimen.shape = {specimen.shape}"
    tx = optax.adam(learning_rate)
    state = TrainState.create(
        apply_fn=vae.apply,
        params=variables["params"],
        tx=tx,
        batch_stats=variables["batch_stats"],
        beta=beta,
    )

    return state


@jax.vmap
def kl_divergence(mean, logvar):
    return -0.5 * jnp.sum(1 + logvar - jnp.square(mean) - jnp.exp(logvar))


@jax.jit
def train_step(state, key, image):
    @partial(jax.value_and_grad, has_aux=True)
    def loss_fn(params):
        variables = {"params": params, "batch_stats": state.batch_stats}
        (recon, mean, logvar), new_model_state = state.apply_fn(
            variables, key, image, True, mutable=["batch_stats"]
        )
        loss = jnp.sum((recon - image) ** 2) + state.beta * jnp.sum(
            kl_divergence(mean, logvar)
        )
        return loss.sum(), new_model_state

    (loss, new_model_state), grads = loss_fn(state.params)

    state = state.apply_gradients(
        grads=grads, batch_stats=new_model_state["batch_stats"]
    )

    return state, loss


@jax.jit
def test_step(state, key, image):
    variables = {"params": state.params, "batch_stats": state.batch_stats}
    recon, mean, logvar = state.apply_fn(variables, key, image, False)

    return recon, mean, logvar


@jax.jit
def decode(state, Z):
    variables = {"params": state.params, "batch_stats": state.batch_stats}
    decoded = state.apply_fn(variables, Z, False, method=VAE.decode)

    return decoded



In [None]:
from torch import Generator
from torch.utils.data import DataLoader
import torchvision.transforms as T
from torchvision.datasets import MNIST


batch_size = 256
latent_dim = 20
hidden_channels = (32, 64, 128, 256, 512)
lr = 1e-3
specimen = jnp.empty((32, 32, 1))
variational = True
beta = 1
target_epoch = 2

transform = T.Compose([T.Resize((32, 32)), T.ToTensor()])
mnist_train = MNIST("data/vae/mnist", train=True, download=True, transform=transform)
generator = Generator().manual_seed(42)
loader = DataLoader(mnist_train, batch_size, shuffle=True, generator=generator)

key = jax.random.PRNGKey(42)
state = create_train_state(key, variational, beta, latent_dim, hidden_channels, lr, specimen)

for epoch in range(target_epoch):
    loss_train = 0
    for X, _ in loader:
        image = jnp.array(X).reshape((-1, *specimen.shape))
        key, key_Z = jax.random.split(key)
        state, loss = train_step(state, key_Z, image)
        loss_train += loss

    print(f"Epoch {epoch + 1}: train loss {loss_train}")

## Create directories

In [None]:
!mkdir data/vae/celeba
!mkdir data/vae/ckpts
!mkdir data/vae/plots

## Download dataset

In [None]:
# Download the dataset from an interal mirror since the Google Drive has bandwidth limit.
#
# The files were manually downloaded from the Google Drive at https://docs.google.com/uc?export=download&id={{ FILE_ID }}
# and then uploaded to the mirror server, where FILE_ID can be found at https://github.com/pytorch/vision/blob/3db30442a52e250c569ba131485fc432c8300f16/torchvision/datasets/celeba.py#L47-L58
# Just place them in /content/celeba and unzip img_align_celeba.zip; no other preprocessing is needed.

!wget -r -np -nc -nH -q --cut-dirs=1 -R "index.html*" https://internal-use.adroits.xyz/datasets/celeba/ -P data/vae/celeba
!unzip -d data/vae/celeba -qq data/vae/celeba/img_align_celeba.zip

In [None]:
torch.manual_seed(42)

transform = T.Compose([T.RandomHorizontalFlip(), T.CenterCrop((128, 128)), T.Resize((64, 64)), np.array])

# No need to specify download=True since we have placed the dataset files in /content
celeba_train = CelebA("data/vae/celeba", split="train", transform=transform)
celeba_test = CelebA("data/vae/celeba", split="test", transform=transform)

## Hyperparameters

In [None]:
ckpt_dir = "data/vae/ckpts"
epochs = 5
batch_size = 256
latent_dim = 256
hidden_channels = (32, 64, 128, 256, 512)
lr = 1e-3
specimen = jnp.empty((64, 64, 3))

configs = {
    "ae": (False, 0),
    "vae_0.1": (True, 0.1),
    "vae_0.5": (True, 0.5),
    "vae_1.0": (True, 1),
}

## Prepare checkpoints

Either by downloading or training

In [None]:
def train(name, variational, beta, loader_train, target_epoch):
    print(f"=== Training {name} ===")
    key = jax.random.PRNGKey(42)
    state = create_train_state(key, variational, beta, latent_dim, hidden_channels, lr, specimen)

    for epoch in range(target_epoch):
        loss_train = 0
        for X, _ in loader_train:
            image = jnp.array(X).reshape((-1, *specimen.shape)) / 255.0
            key, key_Z = jax.random.split(key)
            state, loss = train_step(state, key_Z, image)
            loss_train += loss

        print(f"Epoch {epoch + 1}: train loss {loss_train}")

        # Keep the model with lowest loss_train
        checkpoints.save_checkpoint(
            ckpt_dir, state, epoch + 1, prefix=f"{name}_celeba_", keep=target_epoch, overwrite=True
        )


def prepare_checkpoints(target_epoch, download=True):
    # Download pre-trained checkpoints by default to save time. At the time
    # of writing, we have prepared the checkpoints for the first 5 epochs,
    # so target_epoch must be less than or equal to 5.
    #
    # NOTE: You may want to deletes all existing checkpoints before downloading.
    if download:
        # print("Deleting existing checkpoints")
        # subprocess.run(["rm", "-rf", ckpt_dir])
        # subprocess.run(["mkdir", ckpt_dir])
        for name in configs:
            url = f"https://internal-use.adroits.xyz/ckpts/celeba_vae_ae_comparison/{name}_celeba_{target_epoch}"
            print("Downloading", url)
            subprocess.run(["wget", "-q", url, "-P", ckpt_dir], check=True)
        return

    # Alternatively, you can train the model from scratch if you want to tune
    # the hyperparameters. It takes ~4 minutes per epoch per model on a GPU
    # runtime with a Tesla P100 GPU.
    loader_train = DataLoader(celeba_train, batch_size, shuffle=True, generator=torch.Generator().manual_seed(42))

    for name, (variational, beta) in configs.items():
        train(name, variational, beta, loader_train, target_epoch)


prepare_checkpoints(epochs, download=True)

## Load checkpoints

In [None]:
states = {}

for name, (variational, beta) in configs.items():
    key = jax.random.PRNGKey(42)
    state = create_train_state(key, variational, beta, latent_dim, hidden_channels, lr, specimen)
    restored = checkpoints.restore_checkpoint(ckpt_dir, state, prefix=f"{name}_celeba_")
    if state is restored:
        raise FileNotFoundError(f"Cannot load checkpoint from {ckpt_dir}/{name}_celeba_X")
    states[name] = restored

## Visualization of reconstructed images

In [None]:
loader_test = DataLoader(celeba_test, batch_size, shuffle=True, generator=torch.Generator().manual_seed(42))
X, y = next(iter(loader_test))
image = jnp.array(X).reshape((-1, *specimen.shape)) / 255.0

recons = {
    "original": image,
}
key = jax.random.PRNGKey(42)
key, *key_Z = jax.random.split(key, 5)
for i, name in enumerate(configs):
    recon, _, _ = test_step(states[name], key_Z[i], image)
    recons[name] = recon

Show a single figure montage

In [None]:
fig, axes = plt.subplots(5, 6, constrained_layout=True, figsize=plt.figaspect(1))
for row, (name, recon) in enumerate(recons.items()):
    for col in range(6):
        axes[row, col].imshow(recon[col], aspect=218 / 178)
        axes[row, col].axis("off")

fig.suptitle("Original (upper) vs Reconstructed (lower four)")
fig.show()
plt.savefig("data/vae/plots/celeba_recon_montage.pdf")

Plot each row as a separate figure and then save them with meaningful filenames

In [None]:
for name, recon in recons.items():
    fig, axes = plt.subplots(1, 6, constrained_layout=True, figsize=plt.figaspect(0.2))
    for col in range(6):
        axes[col].imshow(recon[col], aspect=218 / 178)
        axes[col].axis("off")

    fig.show()
    plt.savefig(f"data/vae/plots/celeba_recon_{name}.pdf")

## Sampling

In [None]:
key = jax.random.PRNGKey(42)
key, *key_Z = jax.random.split(key, 5)

generated_images = {}
for i, name in enumerate(configs):
    Z = jax.random.normal(key_Z[i], (8, latent_dim))
    generated_image = decode(states[name], Z)
    generated_images[name] = generated_image

Show a single figure montage

In [None]:
fig, axes = plt.subplots(4, 6, constrained_layout=True, figsize=plt.figaspect(1))
for row, (name, generated_image) in enumerate(generated_images.items()):
    for col in range(6):
        axes[row, col].imshow(generated_image[col], aspect=218 / 178)
        axes[row, col].axis("off")

fig.suptitle("Generated", fontsize="xx-large")
fig.show()
plt.savefig("data/vae/plots/celeba_gen_montage.pdf")

Plot each row as a separate figure and then save them with meaningful filenames

In [None]:
for name, generated_image in generated_images.items():
    fig, axes = plt.subplots(1, 6, constrained_layout=True, figsize=plt.figaspect(0.2))
    for col in range(6):
        axes[col].imshow(generated_image[col], aspect=218 / 178)
        axes[col].axis("off")

    fig.show()
    plt.savefig(f"data/vae/plots/celeba_gen_{name}.pdf")

## Interpolation

In [None]:
X, y = next(iter(loader_test))
image = jnp.array(X).reshape((-1, *specimen.shape)) / 255.0

means = {}
key = jax.random.PRNGKey(42)
key, *key_Z = jax.random.split(key, 5)
for i, name in enumerate(configs):
    _, mean, _ = test_step(states[name], key_Z[i], image)
    means[name] = mean, image

Show interpolation results

In [None]:
def slerp(val, low, high):
    """Spherical interpolation. val has a range of 0 to 1."""
    if val <= 0:
        return low
    elif val >= 1:
        return high
    elif jnp.allclose(low, high):
        return low
    omega = jnp.arccos(jnp.dot(low / jnp.linalg.norm(low), high / jnp.linalg.norm(high)))
    so = jnp.sin(omega)
    return jnp.sin((1.0 - val) * omega) / so * low + jnp.sin(val * omega) / so * high


def interp(start_index, end_index):
    for name, (mean, image) in means.items():
        fig, axes = plt.subplots(1, 6, constrained_layout=True, figsize=plt.figaspect(0.2))

        # Anchors
        axes[0].imshow(image[start_index], aspect=218 / 178)
        axes[0].set_title("Start Image")
        axes[0].axis("off")
        axes[5].imshow(image[end_index], aspect=218 / 178)
        axes[5].set_title("End Image")
        axes[5].axis("off")

        # Interpolated images
        for col in range(1, 5):
            Z = slerp(col / 5, mean[start_index], mean[end_index])
            recon = decode(states[name], Z)
            axes[col].imshow(recon[0], aspect=218 / 178)
            axes[col].set_title(f"{col/5}")
            axes[col].axis("off")

        fig.show()
        plt.savefig(f"data/vae/plots/celeba_interp_start{start_index}_end{end_index}_{name}.pdf")

In [None]:
interp(0, 1)

In [None]:
interp(1, 2)

In [None]:
interp(2, 3)

## Vector Arithmetic

Calculate the latent vector corresponding to whether the subject is wearing eyeglasses.

In [None]:
eyeglasses_delta = {}
for name in configs:
    eyeglasses_delta[name] = (jnp.zeros((latent_dim,)), jnp.zeros((latent_dim,)), 0, 0)

eyeglasses_attr_idx = 15
key = jax.random.PRNGKey(42)
for X, y in loader_test:
    image = jnp.array(X).reshape((-1, *specimen.shape)) / 255.0
    key, *key_Z = jax.random.split(key, 5)
    for i, name in enumerate(configs):
        _, mean, _ = test_step(states[name], key_Z[i], image)
        sum_pos, sum_neg, count_pos, count_neg = eyeglasses_delta[name]
        eyeglass_mask = jnp.array(y[:, eyeglasses_attr_idx] == 1)
        sum_pos += mean[eyeglass_mask, :].sum(axis=0)
        sum_neg += mean[~eyeglass_mask, :].sum(axis=0)
        count_pos += eyeglass_mask.sum()
        count_neg += (~eyeglass_mask).sum()
        eyeglasses_delta[name] = sum_pos, sum_neg, count_pos, count_neg

In [None]:
X, y = next(iter(loader_test))
image = jnp.array(X).reshape((-1, *specimen.shape)) / 255.0

va_means = {}
key = jax.random.PRNGKey(42)
key, *key_Z = jax.random.split(key, 5)
for i, name in enumerate(configs):
    _, mean, _ = test_step(states[name], key_Z[i], image)
    va_means[name] = mean, image

Show a single figure montage

In [None]:
def arith_montage(image_index, multiplier):
    fig, axes = plt.subplots(4, 6, constrained_layout=True, figsize=plt.figaspect(1))
    for row, (name, (sum_pos, sum_neg, count_pos, count_neg)) in enumerate(eyeglasses_delta.items()):
        delta = sum_pos / count_pos - sum_neg / count_neg
        mean, image = va_means[name]
        axes[row, 0].set_title("Original")
        axes[row, 0].axis("off")
        axes[row, 0].imshow(image[image_index], aspect=218 / 178)
        for col in range(1, 6):
            coef = multiplier * (col - 3)
            Z = mean[image_index] + coef * delta
            recon = decode(states[name], Z)
            axes[row, col].set_title(f"{coef}")
            axes[row, col].axis("off")
            axes[row, col].imshow(recon[0], aspect=218 / 178)

    fig.show()
    plt.savefig(f"data/vae/plots/celeba_arith_img{image_index}_mul{multiplier}_montage.pdf")

In [None]:
arith_montage(0, 1)

In [None]:
arith_montage(1, 1)

In [None]:
arith_montage(2, 1)

In [None]:
arith_montage(0, 2)

In [None]:
arith_montage(1, 2)

In [None]:
arith_montage(2, 2)

Plot each row as a separate figure and then save them with meaningful filenames

In [None]:
def arith(image_index, multiplier):
    for name, (sum_pos, sum_neg, count_pos, count_neg) in eyeglasses_delta.items():
        fig, axes = plt.subplots(1, 6, constrained_layout=True, figsize=plt.figaspect(0.2))
        delta = sum_pos / count_pos - sum_neg / count_neg
        mean, image = va_means[name]
        axes[0].set_title("Original")
        axes[0].axis("off")
        axes[0].imshow(image[image_index], aspect=218 / 178)
        for col in range(1, 6):
            coef = multiplier * (col - 3)
            Z = mean[image_index] + coef * delta
            recon = decode(states[name], Z)
            axes[col].set_title(f"{coef}")
            axes[col].axis("off")
            axes[col].imshow(recon[0], aspect=218 / 178)

        fig.show()
        plt.savefig(f"data/vae/plots/celeba_arith_img{image_index}_mul{multiplier}_{name}.pdf")

In [None]:
arith(0, 1)

In [None]:
arith(1, 1)

In [None]:
arith(2, 1)

In [None]:
arith(0, 2)

In [None]:
arith(1, 2)

In [None]:
arith(2, 2)