In [7]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torchvision.utils import save_image



In [8]:
# Defining hyper params
image_size = 784
hidden_dim = 400
latent_dim = 20
batch_size = 128
epochs = 10


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# MNIST dataset
train_dataset = torchvision.datasets.MNIST(root='../../data',
                                          train = True,
                                          transform = transforms.ToTensor(),
                                          download = True)

train_loader = torch.utils.data.DataLoader(dataset = train_dataset,
                                          batch_size = batch_size,
                                          shuffle = True)

test_dataset = torchvision.datasets.MNIST(root='../../data',
                                         train=False,
                                         transform = transforms.ToTensor())

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                    batch_size=batch_size,
                                    shuffle = True)

# Create directory to save the reconstructed and sampled images(if directory not present)
sample_dir = 'results'
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)

### VAE architecture followed

![68747470733a2f2f757365722d696d616765732e67697468756275736572636f6e74656e742e636f6d2f33303636313539372f37383431383130332d61323034373230302d373636622d313165612d383230352d6337653537313237313566342e706e67.png](attachment:62d401f9-de09-437b-ba9d-0ee94467d54d.png)

In [21]:
# VAE Model
class VAE(nn.Module):
    def __init__(self):
        super(VAE,self).__init__()
        
        self.fc1 = nn.Linear(image_size,hidden_dim) # 784,400
        self.fc2_mean = nn.Linear(hidden_dim,latent_dim) # 400,20 --> First
        self.fc2_logvar = nn.Linear(hidden_dim,latent_dim) # 400,20 --> second
        
        # Decoder input is 20 hidden is 40
        self.fc3 = nn.Linear(latent_dim,hidden_dim) ## 20,400
        self.fc4 = nn.Linear(hidden_dim,image_size) ## 400,784
        
    def encode(self,x):
        h = F.relu(self.fc1(x)) # Hidden
        mu = self.fc2_mean(h) # mean
        log_var = self.fc2_logvar(h) # Standard deviation
        return mu,log_var
    
    def reparametrize(self,mu,logvar):
        std = torch.exp(logvar/2)
        eps = torch.randn_like(std) # whatever the shape of std is, eps will take the same shape
        return mu + eps * std
    
    def decode(self,z):
        h = F.relu(self.fc3(z))
        out = torch.sigmoid(self.fc4(h))
        return out
    
    def forward(self,x):
        # x: (batch_size, 1, 28,28) --> (batch_size, 784)
        mu,logvar = self.encode(x.view(-1,image_size)) # Flatten
        z = self.reparametrize(mu,logvar)
        reconstructed = self.decode(z)
        
        return reconstructed,mu,logvar
    

# Define model and optimizer
model = VAE().to(device)
optimizer = torch.optim.Adam(model.parameters(),lr=1e-3)

- <math xmlns="http://www.w3.org/1998/Math/MathML">
  <mi>L</mi>
  <mi>o</mi>
  <mi>s</mi>
  <mi>s</mi>
  <mo>=</mo>
  <mo>&#x2212;</mo>
  <mi>E</mi>
  <mo stretchy="false">[</mo>
  <mi>log</mi>
  <mo data-mjx-texclass="NONE">&#x2061;</mo>
  <mi>P</mi>
  <mo stretchy="false">(</mo>
  <mi>X</mi>
  <mo data-mjx-texclass="ORD" stretchy="false">|</mo>
  <mi>z</mi>
  <mo stretchy="false">)</mo>
  <mo stretchy="false">]</mo>
  <mo>+</mo>
  <msub>
    <mi>D</mi>
    <mrow data-mjx-texclass="ORD">
      <mi>K</mi>
      <mi>L</mi>
    </mrow>
  </msub>
  <mo stretchy="false">[</mo>
  <mi>N</mi>
  <mo stretchy="false">(</mo>
  <mi>&#x3BC;</mi>
  <mo stretchy="false">(</mo>
  <mi>X</mi>
  <mo stretchy="false">)</mo>
  <mo>,</mo>
  <mi mathvariant="normal">&#x3A3;</mi>
  <mo stretchy="false">(</mo>
  <mi>X</mi>
  <mo stretchy="false">)</mo>
  <mo stretchy="false">)</mo>
  <mo data-mjx-texclass="ORD">&#x2225;</mo>
  <mi>N</mi>
  <mo stretchy="false">(</mo>
  <mn>0</mn>
  <mo>,</mo>
  <mn>1</mn>
  <mo stretchy="false">)</mo>
  <mo stretchy="false">]</mo>
</math>





- <math xmlns="http://www.w3.org/1998/Math/MathML">
  <msub>
    <mi>D</mi>
    <mrow data-mjx-texclass="ORD">
      <mi>K</mi>
      <mi>L</mi>
    </mrow>
  </msub>
  <mo stretchy="false">[</mo>
  <mi>N</mi>
  <mo stretchy="false">(</mo>
  <mi>&#x3BC;</mi>
  <mo stretchy="false">(</mo>
  <mi>X</mi>
  <mo stretchy="false">)</mo>
  <mo>,</mo>
  <mi mathvariant="normal">&#x3A3;</mi>
  <mo stretchy="false">(</mo>
  <mi>X</mi>
  <mo stretchy="false">)</mo>
  <mo stretchy="false">)</mo>
  <mo data-mjx-texclass="ORD">&#x2225;</mo>
  <mi>N</mi>
  <mo stretchy="false">(</mo>
  <mn>0</mn>
  <mo>,</mo>
  <mn>1</mn>
  <mo stretchy="false">)</mo>
  <mo stretchy="false">]</mo>
  <mo>=</mo>
  <mfrac>
    <mn>1</mn>
    <mn>2</mn>
  </mfrac>
  <munder>
    <mo data-mjx-texclass="OP">&#x2211;</mo>
    <mrow data-mjx-texclass="ORD">
      <mi>k</mi>
    </mrow>
  </munder>
  <mrow data-mjx-texclass="INNER">
    <mo data-mjx-texclass="OPEN">(</mo>
    <mi>exp</mi>
    <mo data-mjx-texclass="NONE">&#x2061;</mo>
    <mo stretchy="false">(</mo>
    <mi mathvariant="normal">&#x3A3;</mi>
    <mo stretchy="false">(</mo>
    <mi>X</mi>
    <mo stretchy="false">)</mo>
    <mo stretchy="false">)</mo>
    <mo>+</mo>
    <msup>
      <mi>&#x3BC;</mi>
      <mrow data-mjx-texclass="ORD">
        <mn>2</mn>
      </mrow>
    </msup>
    <mo stretchy="false">(</mo>
    <mi>X</mi>
    <mo stretchy="false">)</mo>
    <mo>&#x2212;</mo>
    <mn>1</mn>
    <mo>&#x2212;</mo>
    <mi mathvariant="normal">&#x3A3;</mi>
    <mo stretchy="false">(</mo>
    <mi>X</mi>
    <mo stretchy="false">)</mo>
    <mo data-mjx-texclass="CLOSE">)</mo>
  </mrow>
</math>

In [22]:
# Define loss
def loss_function(reconstructed_image,original_image,mu,logvar):
    bce = F.binary_cross_entropy(reconstructed_image,original_image.view(-1,784),reduction = 'sum')
    kld = 0.5 * torch.sum(logvar.exp() + mu.pow(2) -1 - logvar)
    
    return bce+kld

# Train Function

def train(epoch):
    model.train()
    train_loss = 0
    for i,(images,_) in enumerate(train_loader):
        images = images.to(device)
        reconstructed,mu,logvar = model(images)
        loss = loss_function(reconstructed,images,mu,logvar)
        # Backpropogation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        
        if i % 100 == 0:
            print("Train Epoch {} [Batch {}/{}]\tLoss: {:.3f}".format(epoch, i, len(train_loader), loss.item()/len(images)))
            
    print('=====> Epoch {}, Average Loss: {:.3f}'.format(epoch, train_loss/len(train_loader.dataset)))
        
        
# Test Fxn
def test(epoch):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for batch_idx,(images,_) in enumerate(test_loader):
            images = images.to(device)

            reconstructed,mu,logvar = model(images)
            test_loss += loss_function(reconstructed,images,mu,logvar).item()
            if batch_idx == 0:
                comparison = torch.cat([images[:5],reconstructed.view(batch_size,1,28,28)[:5]])
                save_image(comparison.cpu(), 'results/reconstruction_' + str(epoch) + '.png', nrow = 5)
    print('=====> Average Test Loss: {:.3f}'.format(test_loss/len(test_loader.dataset)))

In [23]:
# Main function
for epoch in range(1, epochs + 1):
    train(epoch)
    test(epoch)
    with torch.no_grad():
        # Get rid of the encoder and sample z from the gaussian ditribution and feed it to the decoder to generate samples
        sample = torch.randn(64,20).to(device)
        generated = model.decode(sample).cpu()
        save_image(generated.view(64,1,28,28), 'results/sample_' + str(epoch) + '.png')

Train Epoch 1 [Batch 0/469]	Loss: 549.875
Train Epoch 1 [Batch 100/469]	Loss: 186.105
Train Epoch 1 [Batch 200/469]	Loss: 157.023
Train Epoch 1 [Batch 300/469]	Loss: 138.640
Train Epoch 1 [Batch 400/469]	Loss: 134.077
=====> Epoch 1, Average Loss: 165.400
=====> Average Test Loss: 128.136
Train Epoch 2 [Batch 0/469]	Loss: 127.534
Train Epoch 2 [Batch 100/469]	Loss: 121.143
Train Epoch 2 [Batch 200/469]	Loss: 125.326
Train Epoch 2 [Batch 300/469]	Loss: 116.087
Train Epoch 2 [Batch 400/469]	Loss: 117.952
=====> Epoch 2, Average Loss: 121.782
=====> Average Test Loss: 116.284
Train Epoch 3 [Batch 0/469]	Loss: 119.493
Train Epoch 3 [Batch 100/469]	Loss: 117.486
Train Epoch 3 [Batch 200/469]	Loss: 114.048
Train Epoch 3 [Batch 300/469]	Loss: 112.653
Train Epoch 3 [Batch 400/469]	Loss: 118.941
=====> Epoch 3, Average Loss: 114.698
=====> Average Test Loss: 111.828
Train Epoch 4 [Batch 0/469]	Loss: 117.989
Train Epoch 4 [Batch 100/469]	Loss: 111.440
Train Epoch 4 [Batch 200/469]	Loss: 109.636


In [None]:
### First Row contains original images
### Second row contains reconstructed images