<a href="https://colab.research.google.com/github/rishabt20/GenerativeDL/blob/main/VAE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import numpy as np
import torchvision
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt

from torch import nn
from torch import distributions
from torchsummary import summary
from torchvision import transforms
from torch.autograd import Variable

In [2]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

In [13]:
D_DIM=784
BATCH_SIZE=128
EPOCHS=1000
LEARNING_RATE=.003
H_DIM=256
Z_DIM=16

In [14]:
class Encoder(nn.Module):
  def __init__(self,D_dim,H_dim,Z_dim):
    super(Encoder,self).__init__()
    self.linear1=nn.Linear(D_dim,H_dim)
    self.linear2=nn.Linear(H_dim,H_dim)
    self.mean=nn.Linear(H_dim,Z_dim)
    self.log_covariance=nn.Linear(H_dim,Z_dim)
  def forward(self,x):
    x=self.linear1(x)
    x=F.relu(x)
    x=self.linear2(x)
    x=F.relu(x)
    mean=self.mean(x)
    log_covariance=self.log_covariance(x)
    covariance=torch.exp(log_covariance)
    return torch.distributions.Normal(loc=mean,scale=(covariance))

In [15]:
class Decoder(nn.Module):
  def __init__(self,Z_dim,H_dim,D_out):
    super(Decoder,self).__init__()
    self.linear1=nn.Linear(Z_dim,H_dim)
    self.linear2=nn.Linear(H_dim,D_out)
  def forward(self,x):
    x=self.linear1(x)
    x=F.relu(x)
    x=self.linear2(x)
    mean=torch.tanh(x)
    return torch.distributions.Normal(loc=mean,scale=torch.ones_like(mean))

In [16]:
class VAE(nn.Module):
  def __init__(self,encoder,decoder):
    super(VAE,self).__init__()
    self.encoder=encoder
    self.decoder=decoder
  def forward(self,inp):
    hidden_distn=self.encoder(inp)
    hidden_sample=hidden_distn.rsample()
    output_distn=self.decoder(hidden_sample)
    return hidden_distn,output_distn

In [17]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(.5,1)
])
mnist=torchvision.datasets.MNIST('./content',download=True,transform=transform)

In [18]:
dataloader=torch.utils.data.DataLoader(
    mnist,batch_size=BATCH_SIZE,
    shuffle=True,
    pin_memory=torch.cuda.is_available()
)

In [19]:
encoder=Encoder(D_DIM,H_DIM,Z_DIM)
decoder=Decoder(Z_DIM,H_DIM,D_DIM)
vae=VAE(encoder,decoder).to(device)

In [20]:
print(vae)

VAE(
  (encoder): Encoder(
    (linear1): Linear(in_features=784, out_features=256, bias=True)
    (linear2): Linear(in_features=256, out_features=256, bias=True)
    (mean): Linear(in_features=256, out_features=16, bias=True)
    (log_covariance): Linear(in_features=256, out_features=16, bias=True)
  )
  (decoder): Decoder(
    (linear1): Linear(in_features=16, out_features=256, bias=True)
    (linear2): Linear(in_features=256, out_features=784, bias=True)
  )
)


In [21]:
optimizer=optim.Adam(vae.parameters(),lr=LEARNING_RATE)

In [None]:
from torch.distributions import MultivariateNormal

mean = torch.zeros(Z_DIM)
covariance = torch.eye(Z_DIM)

distribution = MultivariateNormal(mean, covariance)
vae=vae.to(device)
for epoch in range(EPOCHS):
  for data in dataloader:
    inputs,_=data
    inputs=inputs.view(-1,D_DIM).to(device)
    optimizer.zero_grad()
    hidden_distn,output_distn=vae(inputs)
    min_squared_error=output_distn.log_prob(inputs).sum(-1).mean()
    kl_divergence_error=torch.distributions.kl_divergence(
        hidden_distn,torch.distributions.Normal(0,1.)

    ).sum(-1).mean()
    loss=-(min_squared_error-kl_divergence_error)
    loss.backward()
    optimizer.step()
    l=loss.item()
  print(epoch,l)
  decoder=decoder.to("cpu")
  out=decoder.forward(distribution.rsample())
  out=out.sample().view(28,28)
  vae=vae.to(device)
  plt.imshow(out.numpy())
  plt.show()
