# $\beta$-VAE

In this notebook, we're going to implement a $\beta$-VAE from the paper [beta-VAE: Learning Basic Visual Concepts with a Constrained Variational Framework](https://openreview.net/forum?id=Sy2fzU9gl). This will require a basic understanding of the VAE, so make sure to check out that notebook first.

This is basically an improvement of the VAE, which results in a more disentangled latent space. This means that each dimension of the latent space should impact its own feature of the output. We'll demonstrate this with some experiments on the [CelebA](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) dataset.

All right, let's go!

In [None]:
import glob
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn, optim
from torch.nn import functional as F
from torchvision import transforms, datasets
from torchsummary import summary # TODO: Remove

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## Encoder

For the encoder, we'll use a convolutional neural network (which often is a good idea to use for image related stuff).

In [None]:
class Encoder(nn.Module):
    def __init__(self, latent_size):
        super(Encoder, self).__init__()
        
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=2)
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=2)
        self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=2)
        self.fc_mu = nn.Linear(7*7*64, latent_size)
        self.fc_logvar = nn.Linear(7*7*64, latent_size)
    
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = x.flatten(start_dim=1)
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar

## Sampler

We sample using the reparameterization trick.

In [None]:
def sample(mu, logvar):
    std = torch.exp(logvar / 2)
    eps = torch.rand_like(std)
    return mu + std * eps

## Decoder

For the decoder, we'll use deconvolutions.

In [None]:
class Decoder(nn.Module):
    def __init__(self, latent_size):
        super(Decoder, self).__init__()
        
        self.fc = nn.Linear(latent_size, 7*7*64)
        self.deconv1 = nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=3, stride=2)
        self.deconv2 = nn.ConvTranspose2d(in_channels=32, out_channels=16, kernel_size=3, stride=2)
        self.deconv_recon = nn.ConvTranspose2d(in_channels=16, out_channels=3, kernel_size=3, stride=2, output_padding=1)
    
    def forward(self, z):
        z = F.relu(self.fc(z))
        z = z.reshape(-1, 64, 7, 7)
        z = F.relu(self.deconv1(z))
        z = F.relu(self.deconv2(z))
        x_recon = torch.sigmoid(self.deconv_recon(z))
        return x_recon

## VAE

All right, let's put everything together and build a sweet $\beta$-VAE!

In [None]:
class VAE(nn.Module):
    def __init__(self, latent_size):
        super(VAE, self).__init__()
        
        self.encoder = Encoder(latent_size)
        self.decoder = Decoder(latent_size)
    
    def forward(self, x):
        mu, logvar = self.encoder(x)
        z = sample(mu, logvar)
        x_recon = self.decoder(z)
        return x_recon, mu, logvar

## Loss

The only thing that's different from the normal VAE loss is that we introduce a weighting factor $\beta$ for the Kullback-Leibler divergence.

In [None]:
def vae_loss(x_recon, x, mu, logvar, beta=3):
    bce = F.binary_cross_entropy(x_recon, x, reduction='sum')
    kld = -1/2 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return bce + beta * kld

## Data

We'll use the [CelebA](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) dataset, which contains a bunch of celebrity images. Let's assume that you've extracted this dataset under the directory `data/img_align_celeba/`.

In [None]:
batch_size = 4
test_size = 100

class Dataset(torch.utils.data.Dataset):
    def __init__(self, paths):
        self.paths = paths
        self.transform = transforms.Compose([
            transforms.CenterCrop(178),
            transforms.Resize(64),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor()
        ])

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index):
        x = datasets.folder.default_loader(self.paths[index])
        x = self.transform(x)
        return x

paths = glob.glob('data/img_align_celeba/*.jpg')

train_data = Dataset(paths[:-test_size])
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)

test_data = Dataset(paths[-test_size:])
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size)

Let's plot some images from the dataset.

In [None]:
def plot_grid(images, rows, cols):
    fig = plt.figure(figsize=(cols*2, rows*2))
    for i, image in enumerate(images, start=1):
        ax = fig.add_subplot(rows, cols, i)
        plt.imshow(image.permute(1, 2, 0), cmap='gray')
        plt.axis('off')

rows = cols = 5
indices = np.random.choice(np.arange(len(train_data)), rows * cols, replace=False)
images = [train_data[index] for index in indices]
plot_grid(images, rows, cols)

## Training

Now we're ready to train our VAE. We'll use the Adagrad optimization algorithm since that's what's used in the paper.

In [None]:
latent_size = 16

vae = VAE(latent_size=latent_size).to(device)
optimizer = optim.Adam(vae.parameters())

epochs = 10
for epoch in range(1, epochs + 1):
    running_loss = 0
    for batch, x in enumerate(train_loader, start=1):
        x = x.to(device)
        
        optimizer.zero_grad()
        x_recon, mu, logvar = vae(x)
        loss = vae_loss(x_recon, x, mu, logvar)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
        if batch % (len(train_loader) // 10) == 0:
            print('=', end='')
    print(f'> Epoch: {epoch}, Loss: {running_loss / len(train_data)}')

torch.save(vae.state_dict(), 'beta-vae.pth')

In [None]:
# Uncomment to load the saved model
#vae.load_state_dict(torch.load('vae.pth', map_location=device))

In [None]:
with torch.no_grad():
    for x in train_loader:
        x_recon, mu, logvar = vae(x)
        #plot_grid(sum([[x[i], x_recon[i]] for i, _ in enumerate(x)], []), rows=len(x), cols=2)
        
        z = sample(mu, logvar)[0]
        ls = torch.linspace(-3, 3, steps=5)
        images = []
        for i, _ in enumerate(z):
            for a in ls:
                z_ = z.clone()
                z_[i] = a
                images.append(vae.decoder(z_)[0])
        plot_grid(images, rows=len(z), cols=len(images) // len(z))
        break

All right, let's reconstruct some random images and see what they look like.

In [None]:
cols = 10
indices = np.random.choice(np.arange(len(test_data)), cols, replace=False)
images = [test_data[index][0].squeeze() for index in indices]
with torch.no_grad():
    recon = vae(torch.stack(images).reshape(-1, input_size))[0].reshape(-1, *train_data[0][0].shape[1:])
    recon = [r for r in recon]
plot_grid(images + recon, 2, cols)

Since we're using a 2D latent space, we can make a grid to visualize what the latent space looks like!

In [None]:
rows = cols = 10
with torch.no_grad():
    a = torch.linspace(-3, 3, rows)
    b = torch.linspace(-3, 3, cols)
    z = torch.stack(torch.meshgrid(a, b), dim=2).reshape(-1, latent_size)
    samples = vae.decoder(z.to(device))
    samples = [sample.reshape(train_data[0][0].shape).squeeze() for sample in samples]
    plot_grid(samples, rows, cols)

Here, we can see that the encoder puts different digits in different regions of the latent space. Pretty cool!

Well, that's it! If you want to explore the VAE further, you could experiment with using more dimensions for the latent space, and think about how it could be visualized. You could also experiment with using other activation functions, and another optimization algorithm.