# (Vanilla) Variational Autoencoder in Jax

### Dataset Stuff
- Load dataset
- Create dataloader

In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# load datasets
(ds_train, ds_test), ds_info = tfds.load(
    'mnist',
    data_dir='data/',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)

In [3]:
import numpy as np

# scale images from 0 to 255 to 0 to 1 (so that reconstruction job is easier)
ds_train_norm = ds_train.map(
    lambda x, y: (tf.cast(x, tf.float32) / 255., y)
)

# check if min and max are between 0 and 1
first_image = next(iter(ds_train_norm.take(1)))[0]
# Notice the pixel values are now in `[0,1]`.
print(np.min(first_image), np.max(first_image))

Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089


Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089


0.0 1.0


In [4]:
batch_size = 32
ds_trainloader = ds_train_norm.batch(batch_size)
ds_testloader = ds_test.batch(batch_size)

### Build VAE Architecture

Instead of flattening the images and constructing the VAEs using linear layers, let's build a convolutional VAE. The Variational Autoencoder has the following components,
1. Encoder: predicts $\mu$ and $\log\sigma^2$ from images using conv layers and then linear layers.
2. Sampling: use the reparametrization trick, we sample from a gaussian of `latent_dim` size and then reparametrize using the $\mu$ and $\sigma$.
3. Decoder: Samples new images given the latent variable from the sampling process.

In [5]:
from flax import linen as nn
from jax import random
import jax.numpy as jnp
import jax.nn as jnn
import numpy as np

In [37]:
# using the setup way to create our network
class VAE(nn.Module):
    latent_dim: int
        
    @nn.compact
    def __call__(self, x, rng):
        '''Forward pass of the VAE
        
        The following things will be done in order,
        1. Encoder: Converts images to \mu and log (\sigma^2) (or log variance)
        2. Reparametrized Sampling: Samples latents using reparametrization trick
        3. Decoder: Samples images using latents
        '''        
        # -------- Encoder -----------
        mu, logvar = self.encoder(x)
        
        # -------- Reparametrized Sampling -----------
        z = self.reparametrize(mu, logvar, rng)
        
        # -------- Decode Images ---------
        gen_x = self.decoder(rng, z=z)
        
        return mu, logvar, gen_x
    
    def encoder(self, x):
        '''Encodes an image into \mu and \logvar with self.latent size'''        
        # use conv filters
        # since mnist images, input image size 28 x 28 x 1
        x = nn.Conv(32, kernel_size=(3, 3), strides=2, name='enc_conv_1')(x) # 28 x 28 -> 14 x 14
        x = nn.gelu(x)
        x = nn.Conv(32, kernel_size=(3, 3), strides=2, name='enc_conv_2')(x) # 14 x 14 -> 7 x 7
        x = x.reshape(x.shape[0], -1) # (batch_size, 7 x 7 x 32)
        
        # get \mu and \logvar of latent space
        mu = nn.Dense(self.latent_dim, name='enc_dense_1')(x) # 
        logvar = nn.Dense(self.latent_dim, name='enc_dense_2')(x)
        
        return mu, logvar
    
    def reparametrize(self, mu, logvar, rng):
        '''Samples from a Gaussian Distribution and Reparametrize using \mu and \logvar'''
        # sample from gaussian
        e = random.normal(rng, shape=(self.latent_dim,))
        
        # convert log-variance to standard deviation, std = \root(\exp log-variance)
        std = jnp.exp(0.5 * logvar)
        
        # reparametrization trick
        return mu + e * std
        
    def decoder(self, rng, z):
        '''Decodes from latent representation, generate z from gaussian and decode images
        '''
        # exactly similar to encoder but in reverse, Conv -> ConvTranspose
        gen_x = nn.Dense(7 * 7 * 32, name='dec_dense')(z).reshape(z.shape[0], 7, 7, 32) # (batch_size, 7, 7, 32)
        gen_x = nn.ConvTranspose(32, kernel_size=(3, 3), strides=(2, 2), name='dec_conv_t_1')(gen_x) # 7 x 7 -> 14 x 14
        gen_x = nn.gelu(gen_x)
        gen_x = nn.ConvTranspose(1, kernel_size=(3, 3), strides=(2, 2), name='dec_conv_t_2')(gen_x) # 14 x 14 -> 28 x 28
        gen_x = nn.sigmoid(gen_x)
        return gen_x

    def generate(self, rng, num_samples=10):
        '''Generates num_samples images'''
        z = random.normal(rng, shape=(num_samples, self.latent_dim))        
        return self.decoder(rng, z)

### State Function

In [21]:
from flax.training import train_state
import optax as opt

def create_state(model, key, rng, learning_rate=1e-3):
    x = jnp.array(np.random.randn(1, 28, 28, 1))
    params = model.init(key, x, rng)['params']
    return train_state.TrainState.create(
        apply_fn=model.apply,
        params=params,
        tx=opt.adam(learning_rate=learning_rate)
    )

### Training Step

In the training step, we will do the forward pass and compute losses. Variational Autoencoders have two kinds of losses,
1. Negative Log Likelihood
2. KL-Divergence Loss between true prior $p(z)$ and approximate posterior $q(z|x)$ 

In [27]:
import jax

def training_step(state, imgs, rng):
    imgs = jnp.array(imgs)
    def loss_fn(params):
        mu, logvar, recon_imgs = model.apply({'params': params}, imgs, rng)
        # reconstruction loss: Using mean squared error
        recon_loss = ((recon_imgs - imgs) ** 2).mean(axis=0).sum()  # Mean over batch, sum over pixels
        # kl-divergence loss
        kl_loss = -0.5 * jnp.sum(1 + logvar - jnp.square(mu) - jnp.exp(logvar))
        loss = recon_loss + kl_loss
        log = {
            "loss": loss,
            "recon_loss": recon_loss,
            "kl_loss": kl_loss
        }
        return loss, log
    
    # get gradient function using jax.grad()
    (loss, log), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)
    
    # apply gradients to state
    state = state.apply_gradients(grads=grads)
    return state, log

### Start Training

In [33]:
from tqdm import tqdm

epoch_num = 1
epochs = tqdm(range(epoch_num), desc="Epochs", leave=True)
training_progress = tqdm(total=len(ds_trainloader), desc="Training progress", position=0, leave=True)

model = VAE(latent_dim = 10)
rng = random.PRNGKey(0)
rng, key = random.split(rng)
state = create_state(model, key, rng, learning_rate=1e-3)
history = []

for epoch in epochs:
    # reset training_progress
    training_progress.reset()
    
    # loop over batches
    for batch_id, batch in enumerate(ds_trainloader):
        imgs = batch[0]
        # train
        state, log = training_step(state, imgs, rng)
        training_progress.update()
        
        history.append(jax.tree_map(np.asarray, log))
        training_progress.set_postfix(loss=log['loss'], kl_loss=log['kl_loss'], reconstruction_loss=log['recon_loss'])


Training progress:   9%|▉         | 177/1875 [00:40<06:24,  4.42it/s, kl_loss=1.6461719, loss=48.06882]
Training progress: 100%|██████████| 1875/1875 [06:48<00:00,  4.63it/s, kl_loss=1.2289428, loss=24.680344, reconstruction_loss=23.4514]   
Epochs: 100%|██████████| 1/1 [06:48<00:00, 408.38s/it][A


### Plot Reconstructed Images Post Training

In [35]:
model.apply({'params': state.params}, , 'rng')

NameError: name 'x' is not defined