# **Using Variational Autoencoder to Generate Images**

Variational Autoencoders (VAE) are powerful generative models having diverse applications ranging from generating fake faces to cool synthetic music. In this repo a VAE is trained on Fashion MNIST dataset to generate fake clothing images. We start off by importing required libraries and defining the loss function. The loss function includes the Binary Cross Entropy (BCE) loss used to measure reconstruction loss as well as Kullback–Leibler divergence (KL divergence) loss. KL divergence allows the generated encodings, to be as close as possible to each other while still being distinct, allowing smooth interpolation, and enabling the construction of new samples.



In [None]:
!pip3 install torch==1.4.0+cu100 torchvision==0.5.0+cu100 -f https://download.pytorch.org/whl/torch_stable.html

In [2]:
import torch
import torch.optim as optim
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.utils import make_grid, save_image

import imageio
import numpy as np
import matplotlib
from matplotlib import pyplot as plt
from tqdm import tqdm

matplotlib.style.use('ggplot')

to_pil_image = transforms.ToPILImage()

def image_to_vid(images):
    imgs = [np.array(to_pil_image(img)) for img in images]
    imageio.mimsave('./generated_images.gif', imgs)

def save_reconstructed_images(recon_images, epoch):
    save_image(recon_images.cpu(), f"./output{epoch}.jpg")

def save_loss_plot(train_loss, valid_loss):
    # loss plots
    plt.figure(figsize=(10, 7))
    plt.plot(train_loss, color='orange', label='train loss')
    plt.plot(valid_loss, color='red', label='validataion loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.savefig('./loss.jpg')
    plt.show()

def final_loss(bce_loss, mu, logvar):
    """
    This function will add the reconstruction loss (BCELoss) and the 
    KL-Divergence.
    KL-Divergence = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    :param bce_loss: recontruction loss
    :param mu: the mean from the latent vector
    :param logvar: log variance from the latent vector
    """
    BCE = bce_loss 
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

## **Creating a Variational Autoencoder class**

 The required components for the class:

    1. Encoder and Decoder components for VAE class.  
    2. Parameters of the distribution and their activations in the last layer.
    3. The sampling function, which provides the sample from the distribution with given parameters
 

In [87]:
class Encoder(nn.Module):
  def __init__(self,in_channels, hidden_channels, out_channels,  latent_dim, kernel_size):
      super(Encoder,self).__init__()
      self.enc1 = nn.Conv2d(
              in_channels=in_channels, out_channels=hidden_channels, kernel_size=kernel_size, 
              stride=2, padding=1
          )
      self.enc2 = nn.Conv2d(
              in_channels=hidden_channels, out_channels=hidden_channels*2, kernel_size=kernel_size, 
              stride=2, padding=1
          )
      self.enc3 = nn.Conv2d(
              in_channels=hidden_channels*2, out_channels=hidden_channels*4, kernel_size=kernel_size, 
              stride=2, padding=1
          )
      self.enc4 = nn.Conv2d(
              in_channels=hidden_channels*4, out_channels=out_channels, kernel_size=kernel_size, 
              stride=1, padding=1
          )
      # fully connected layers for learning representations
      self.fc1 = nn.Linear(out_channels, 128)
      self.fc_mu = nn.Linear(128, latent_dim)
      self.fc_log_var = nn.Linear(128, latent_dim)
      self.fc2 = nn.Linear(latent_dim, out_channels)

  def forward(self,x):
      x = F.relu(self.enc1(x))
      x = F.relu(self.enc2(x))
      x = F.relu(self.enc3(x))
      x = F.relu(self.enc4(x))
      batch, _, _, _ = x.shape
      x = F.adaptive_avg_pool2d(x, 1).reshape(batch, -1)
      hidden = self.fc1(x)
      # get `mu` and `log_var`
      mu = self.fc_mu(hidden)
      log_var = self.fc_log_var(hidden)
      # get the latent vector through reparameterization
      z = self.reparameterize(mu, log_var)
      z = self.fc2(z)
      z = z.view(-1, 64, 1, 1)

      return mu, log_var, z

  def reparameterize(self, mu, log_var):
        """
        :param mu: mean from the encoder's latent space
        :param log_var: log variance from the encoder's latent space
        """
        std = torch.exp(0.5*log_var) # standard deviation
        eps = torch.randn_like(std) # `randn_like` as we need the same size
        sample = mu + (eps * std) # sampling
        return sample

class Decoder(nn.Module):
  def __init__(self,in_channels, hidden_channels, out_channels, kernel_size):
      super(Decoder,self).__init__()

      self.dec1 = nn.ConvTranspose2d(
            in_channels=64, out_channels=hidden_channels*8, kernel_size=kernel_size, 
            stride=1, padding=0
        )
      self.dec2 = nn.ConvTranspose2d(
            in_channels=hidden_channels*8, out_channels=hidden_channels*4, kernel_size=kernel_size, 
            stride=2, padding=1
        )
      self.dec3 = nn.ConvTranspose2d(
            in_channels=hidden_channels*4, out_channels=hidden_channels*2, kernel_size=kernel_size, 
            stride=2, padding=1
        )
      self.dec4 = nn.ConvTranspose2d(
            in_channels=hidden_channels*2, out_channels=in_channels, kernel_size=kernel_size, 
            stride=2, padding=1
        )

  def forward(self,z):
        x = F.relu(self.dec1(z))
        x = F.relu(self.dec2(x))
        x = F.relu(self.dec3(x))
        reconstruction = torch.sigmoid(self.dec4(x))
        return reconstruction
       
        
class ConvVAE(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, latent_dim, kernel_size):
        super(ConvVAE, self).__init__()
        self.in_channels = in_channels
        self.hidden_channels = hidden_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.latent_dim = latent_dim

        self.encoder = Encoder(self.in_channels,self.hidden_channels, self.out_channels, self.latent_dim, self.kernel_size)
        self.decoder = Decoder(self.in_channels,self.hidden_channels, self.out_channels, self.kernel_size)
     
 
    def forward(self, x):
        mu, log_var, z = self.encoder(x)
        reconstruction = self.decoder(z)
                       
        return reconstruction, mu, log_var

## **Training and Validation**


In [4]:
def train(model, dataloader, dataset, device, optimizer, criterion):
    model.train()
    running_loss = 0.0
    counter = 0
    for i, data in tqdm(enumerate(dataloader), total=int(len(dataset)/dataloader.batch_size)):
        counter += 1
        data = data[0]
        data = data.to(device)
        optimizer.zero_grad()
        reconstruction, mu, logvar = model(data)
        bce_loss = criterion(reconstruction, data)
        loss = final_loss(bce_loss, mu, logvar)
        loss.backward()
        running_loss += loss.item()
        optimizer.step()
    train_loss = running_loss / counter 
    return train_loss

def validate(model, dataloader, dataset, device, criterion):
    model.eval()
    running_loss = 0.0
    counter = 0
    with torch.no_grad():
        for i, data in tqdm(enumerate(dataloader), total=int(len(dataset)/dataloader.batch_size)):
            counter += 1
            data= data[0]
            data = data.to(device)
            reconstruction, mu, logvar = model(data)
            bce_loss = criterion(reconstruction, data)
            loss = final_loss(bce_loss, mu, logvar)
            running_loss += loss.item()
        
            # save the last batch input and output of every epoch
            if i == int(len(dataset)/dataloader.batch_size) - 1:
                recon_images = reconstruction
    val_loss = running_loss / counter
    return val_loss, recon_images

## **MNIST Data downloading and Parameter setting**
Fashion-MNIST dataset comprises of 60,000 small square 28×28 pixel grayscale images of items of 10 types of clothing, such as shoes, t-shirts, dresses, and more. Both the train and validation set is available from torchvision.

In [91]:
def main(learning_rate, batch_size, epochs):
  
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),])

    trainset = torchvision.datasets.FashionMNIST(root='./', train=True, download=True, transform=transform)
    train_loader = DataLoader( trainset, batch_size=batch_size, shuffle=True )

    # validation set and validation data loader
    testset = torchvision.datasets.FashionMNIST(root='./', train=False, download=True, transform=transform)
    
    test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False)

    # set the learning parameters
    hparams = {
        "in_channels": 1,
        "hidden_channels": 8,
        "out_channels": 64,
        "latent_dim" : 16,
        "kernel_size": 4,
        "learning_rate": learning_rate,
        "batch_size": batch_size,
        "epochs": epochs,
    }

     # initialize the model
    model = ConvVAE(hparams['in_channels'], hparams['hidden_channels'], hparams['out_channels'], hparams['latent_dim'], hparams['kernel_size']).to(device)

    optimizer = optim.Adam(model.parameters(), lr=hparams['learning_rate'])
    criterion = nn.BCELoss(reduction='sum')

    # a list to save all the reconstructed images in PyTorch grid format
    grid_images = []
    train_loss = []
    valid_loss = []

    for epoch in range(epochs):
        print(f"Epoch {epoch+1} of {epochs}")
        train_epoch_loss = train(model, train_loader, trainset, device, optimizer, criterion)
        valid_epoch_loss, recon_images = validate(model, test_loader, testset, device,  criterion)
        train_loss.append(train_epoch_loss)
        valid_loss.append(valid_epoch_loss)

        # save the reconstructed images from the validation loop
        save_reconstructed_images(recon_images, epoch+1)
        # convert the reconstructed images to PyTorch image grid format
        image_grid = make_grid(recon_images.detach().cpu())
        grid_images.append(image_grid)
        print(f"Train Loss: {train_epoch_loss:.4f}")
        print(f"Val Loss: {valid_epoch_loss:.4f}")

    # save the reconstructions as a .gif file
    image_to_vid(grid_images)
    # save the loss plots to disk
    save_loss_plot(train_loss, valid_loss)
    print('TRAINING COMPLETE')

    sample = Variable(torch.randn(64, hparams['out_channels'],1,1))
    sample = sample.to(device)
    sample = model.decoder(sample).cpu()

    # save out as an 8x8 matrix of MNIST digits
    # this will give you a visual idea of how well latent space can generate things
    # that look like digits
    save_image(sample.data.view(64, 1, 32, 32),'./reconstruction' + str(epoch) + '.png')

In [None]:
learning_rate = 0.001
batch_size = 64
epochs = 30


main(learning_rate, batch_size, epochs)