# Variational Autoencoders (VAE)

Import the libraries.

In [None]:
import torch
from torchvision import datasets
from torchvision import transforms
import matplotlib.pyplot as plt
import os
from torchvision.utils import save_image
from torch.autograd import Variable
from os.path import exists
import torch.nn.functional as F

Utility functions, to store the temporary images and to convert matrix shape.

In [None]:
if not os.path.exists('./vae_img'):
    os.mkdir('./vae_img')
    
def to_img(x):
    x = 0.5 * (x + 1)
    x = x.clamp(0, 1)
    x = x.view(x.size(0), 1, 28, 28)
    return x

## Dataset

In [None]:
num_epochs = 100
batch_size = 128
learning_rate = 1e-3

# Transforms images to a PyTorch Tensor
img_transform = transforms.Compose([
    transforms.ToTensor()
])

# Download the MNIST Dataset
dataset = datasets.MNIST(root = "./data", train = True,download = True,transform = img_transform)

# DataLoader is used to load the dataset
# for training
loader = torch.utils.data.DataLoader(dataset = dataset,batch_size = batch_size,shuffle = True)

## Model

In [None]:
class VAE(torch.nn.Module):
    def __init__(self):
        super(VAE, self).__init__()

        self.fc1 = torch.nn.Linear(784, 400)
        self.fc21 = torch.nn.Linear(400, 2)
        self.fc22 = torch.nn.Linear(400, 2)
        self.fc3 = torch.nn.Linear(2, 400)
        self.fc4 = torch.nn.Linear(400, 784)

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)

    def reparametrize(self, mu, logvar):
        std = logvar.mul(0.5).exp_()
        if torch.cuda.is_available():
            eps = torch.cuda.FloatTensor(std.size()).normal_()
        else:
            eps = torch.FloatTensor(std.size()).normal_()
        eps = Variable(eps)
        return eps.mul(std).add_(mu)

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return F.sigmoid(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparametrize(mu, logvar)
        return self.decode(z), mu, logvar

In [None]:
# Model Initialization
model = VAE()
#moving to gpu
if torch.cuda.is_available():
    model.cuda()
    
reconstruction_function = torch.nn.MSELoss(size_average=False)


def loss_function(recon_x, x, mu, logvar):
    """
    recon_x: generating images
    x: origin images
    mu: latent mean
    logvar: latent log variance
    """
    BCE = reconstruction_function(recon_x, x)  # mse loss
    # loss = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
    KLD = torch.sum(KLD_element).mul_(-0.5)
    # KL divergence
    return BCE + KLD


optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

## Training / Loading

In [None]:
model_file = "vae.pth"

In [None]:
if exists(model_file):
    model.load_state_dict(torch.load(model_file))
    model.eval()
else:
    losses = []
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0
        for batch_idx, data in enumerate(loader):
            img, _ = data
            img = img.view(img.size(0), -1)
            img = Variable(img)
            if torch.cuda.is_available():
                img = img.cuda()
            optimizer.zero_grad()
            recon_batch, mu, logvar = model(img)
            loss = loss_function(recon_batch, img, mu, logvar)
            loss.backward()
            train_loss += loss.data
            optimizer.step()
            if batch_idx % 100 == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch,
                    batch_idx * len(img),
                    len(loader.dataset), 100. * batch_idx / len(loader),
                    loss.data / len(img)))

        print('====> Epoch: {} Average loss: {:.4f}'.format(
            epoch, train_loss / len(loader.dataset)))
        if epoch % 10 == 0:
            save = to_img(recon_batch.cpu().data)
            save_image(save, './vae_img/image_{}.png'.format(epoch))

        torch.save(model.state_dict(), model_file)
        
    model.eval()
    # Defining the Plot Style
    plt.style.use('fivethirtyeight')
    plt.xlabel('Iterations')
    plt.ylabel('Loss')

    # Plotting the last 100 values
    plt.plot(losses[-100:])

## Display sample images

In [None]:
import random
import numpy as np
from mpl_toolkits.axes_grid1 import ImageGrid
model.eval()

fig = plt.figure(figsize=(20., 5.))
grid = ImageGrid(fig, 111,  # similar to subplot(111)
                 nrows_ncols=(2, 10),  # creates 2x2 grid of axes
                 axes_pad=0.1,  # pad between axes in inch.
                 )


for i in range(10):
    image,_ = dataset.__getitem__(random.randint(1,37000))

    img = image.reshape(-1, 28*28)
    img = Variable(img)
    if torch.cuda.is_available():
        img = img.cuda()    

    reconstructed, _, _ = model(img)
    item = reconstructed.reshape(-1, 28, 28)
    item = item.cpu().detach().numpy()

    # Iterating over the grid returns the Axes.
    grid[i].imshow(image[0])
    grid[i].grid(False)
    grid[10+i].imshow(item[0])
    grid[10+i].grid(False)


plt.savefig("nums_vae.png",transparent=True)
plt.show()


## Display latent space

In [None]:
import seaborn as sns
import pandas as pd

latents = []
for i in range(10000):
    sample_idx = torch.randint(len(dataset), size=(1,)).item()
    img, label = dataset[sample_idx]
    img = Variable(img)
    if torch.cuda.is_available():
        img = img.cuda()    

    output,mu,logvar = model(img.reshape(-1,28*28))
    latent = model.reparametrize(mu, logvar)

    latents.append(np.append(latent.cpu().detach().numpy(),label))

latarray = np.stack(latents, axis=0 )
df = pd.DataFrame(data=latarray, columns=("x1", "x2", "label"))
sns.FacetGrid(df, hue="label", height=6).map(plt.scatter, 'x1', 'x2').add_legend()
plt.savefig("latent_space.png",transparent=True)
plt.show()


In [None]:
data_types_dict = {'label': int}
df = df.astype(data_types_dict)
print(df.dtypes)

In [None]:
f, ax = plt.subplots(figsize=(12, 12))
#sns.despine(f, left=True, bottom=True)
sns.scatterplot(x="x1", y="x2",
                hue="label",
                legend="full",
                palette="tab10",
                linewidth=0,
                data=df, ax=ax)
plt.savefig("latent_space.png",transparent=True)

## Testing on different space

Visualize some MNIST samples

In [None]:
model.eval()

fig = plt.figure(figsize=(20., 5.))
grid = ImageGrid(fig, 111,  # similar to subplot(111)
                 nrows_ncols=(2, 10),  # creates 2x2 grid of axes
                 axes_pad=0.1,  # pad between axes in inch.
                 )


for i in range(10):
    image,_ = dataset.__getitem__(random.randint(1,37000))

    img = image.reshape(-1, 28*28)
    img = Variable(img)
    if torch.cuda.is_available():
        img = img.cuda()    

    reconstructed, _, _ = model(img)
    item = reconstructed.reshape(-1, 28, 28)
    item = item.cpu().detach().numpy()

    # Iterating over the grid returns the Axes.
    grid[i].imshow(image[0])
    grid[i].grid(False)
    grid[10+i].imshow(item[0])
    grid[10+i].grid(False)


plt.show()

Download the FashionMNIST dataset from kaggle

In [None]:
# Download the Fashion MNIST Dataset
fashionDataset = ...

Visualize some Fashion MNIST samples

In [None]:
model.eval()

...