In [None]:
import torch
import torch.nn as nn


### How does VAE Works?
1) The output of the encoder is two vectors:
- Mean vector
- Standard Deviation

2) We sum the two vectors(mean and variation) which is mutliplied by small random number. The dimensions of the resultant vector will same to that of each vectors.

3) The resultant vector is passed to the decoder to fetch the images.

4) Loss value optimization is combination of:
- KL Divergence loss: measures the deviation of distribution of mean vector and standard variance vector from 0 and 1 resp.
-  MSE: optimization we use to reconstruct the image.

In [None]:
class VAE(nn.Module):
  def __init__(self, x_dim, h_dim1, h_dim2, latent_dim):
    super(VAE, self).__init__()
    #encoder section:
    self.d1=nn.Linear(x_dim, h_dim1) #input -> hidden layer 1
    self.d2=nn.Linear(h_dim1, h_dim2) #hidden layer 1 -> hidden layer 2

    # mean and standard variation vectors:
    self.d31=nn.Linear(h_dim2, latent_dim) # hidden layer 2 -> Mean
    self.d32=nn.Linear(h_dim2, latent_dim) # hidden layer 2 -> log variance

    #decoder section:
    self.d4=nn.Linear(latent_dim, h_dim2) # latent -> hidden layer 2
    self.d5=nn.Linear(h_dim2, h_dim1) #hidden layer 2 -> hidden layer 1
    self.d6=nn.Linear(h_dim1, x_dim)

  def encoder(self, x):
    h=F.relu(self.d1(x))
    h=F.relu(self.d2(h))
    return self.d31(h), self.d32(h) # d31:vector for mean, d32:vector for log variance

  def sampling(self, mean, log_var):
    std=torch.exp(log_var * 0.5) #standard deviation
    eps=torch.randn_like(std) # small random number
    return eps.mul(std).add_(mean) # addition of mean ans std

  def decoder(self, z):
    h=F.relu(self.d4(x))
    h=F.relu(self.d5(h))
    return F.sigmoid(self.d6(h))

  def forward(self, x):
    mean, log_var= self.encoder(x.view(-1, 784))
    z=self.sampling(mean, log_var)
    return self.decoder(z), mean, log_var

In [None]:
def train_batch(data, model, optimizer, loss_function):
  model.train()
  data=data.to(device)
  optimizer.zero_grad()
  reconstruct_batch, mean, log_var= model(data)
  loss, mse, kld=loss_function(reconstruct_batch, data, mean, log_var)
  loss.backward()
  optimizer.step()
  return loss, mse, kld, log_var.mean(), mean.mean()

@torch.no_grad()
def validate_batch(data, model, loss_function):
  model.eval()
  data=data.to(device)
  reconstruct, mean, log_var=model(data)
  loss, mse, kld=loss_function(reconstruct, data, mean, log_var)
  return loss, mse, kld, log_var.mean(), mean.mean()

def loss_function(recon_x, x, mean, log_var):
  RECON=F.mse(recon_x, x.view(-1, 784), reduction='sum')
  KLD= -0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp())
  return RECON + KLD, RECON, KLD
