# APPLICATION 9 - VARIABLE AUTO ENCODER

In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torchvision.utils import save_image

In [2]:
# define hyperparameter
image_size = 784
hidden_dim = 400
latent_dim = 20
batch_size = 128
epochs     = 15

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# MNIST Dataset
train_dataset = torchvision.datasets.MNIST(root         = './data',
                                           train        = True,
                                           transform    = transforms.ToTensor(),
                                           download     = True)

test_dataset  = torchvision.datasets.MNIST(root         = './data',
                                          train         = False,
                                          transform     = transforms.ToTensor())
                                        
train_loader = torch.utils.data.DataLoader(dataset      = train_dataset,
                                           batch_size   = batch_size,
                                           shuffle      = True)

test_loader  = torch.utils.data.DataLoader(dataset      = test_dataset,
                                          batch_size    = batch_size,
                                          shuffle       = True)

In [3]:
# create directory to save the reconstructed and sampled image
sample_dir = 'results'
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)

In [4]:
# create VAE model
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()

        self.fc1        = nn.Linear(image_size, hidden_dim)
        self.fc2_mean   = nn.Linear(hidden_dim, latent_dim)
        self.fc2_logvar = nn.Linear(hidden_dim, latent_dim)
        self.fc3        = nn.Linear(latent_dim, hidden_dim)
        self.fc4        = nn.Linear(hidden_dim, image_size)

    def encode(self, x):
        h       = F.relu(self.fc1(x))
        mu      = self.fc2_mean(h)
        log_var = self.fc2_logvar(h)
        return mu, log_var

    def reparameterize(self, mu, logvar):
        std = torch.exp(logvar/2)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decoder(self, z):
        h = F.relu(self.fc3(z))
        out = torch.sigmoid(self.fc4(h))
        return out

    def forward(self, x):
        mu, logvar    = self.encode(x.view(-1, image_size))
        z             = self.reparameterize(mu, logvar)
        reconstructed = self.decoder(z)
        return reconstructed, mu, logvar

In [5]:
# define model and optimizer
model = VAE().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3)

In [6]:
# loss function
def loss_function(reconstructed_image, original_image, mu, logvar):
    bce = F.binary_cross_entropy(reconstructed_image, original_image.view(-1, image_size), reduction = 'sum')
    kld = 0.5 * torch.sum(logvar.exp() + mu.pow(2) - 1 - logvar)
    return bce + kld

In [7]:
# train function
def train(epoch):
    model.train()
    train_loss = 0
    for i, (images, _) in enumerate(train_loader):
        images = images.to(device)
        reconstructed, mu, logvar = model(images)
        loss = loss_function(reconstructed, images, mu, logvar)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

        if i % 100 == 0:
            print("Train Epoch {} [Batch {}/{}]\tLoss: {:.3f}".format(epoch, i, len(train_loader), loss.item() / len(images)))

    print('=====> Epoch {}, Average Loss: {:.3f}'.format(epoch, train_loss / len(train_loader.dataset)))

# test function
def test(epoch):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for batch_idx, (images, _) in enumerate(test_loader):
            images = images.to(device)
            reconstructed, mu, logvar = model(images)
            loss = loss_function(reconstructed, images, mu, logvar)
            test_loss += loss.item()
            if batch_idx == 0:
                comparison = torch.cat([images[:5], reconstructed.view(batch_size, 1, 28, 28)[:5]])
                save_image(comparison.cpu(), 'results/reconstructed_' + str(epoch) + '.png', nrow = 5)

    print('=====> Average Test Loss: {:.3f}'.format(test_loss / len(test_loader.dataset)))

In [8]:
# main function
for epoch in range(1, epochs + 1):
    train(epoch)
    test(epoch)
    with torch.no_grad():
        sample = torch.randn(64, 20).to(device)
        generated = model.decoder(sample).cpu()
        save_image(generated.view(64, 1, 28, 28), 'results/sample_' + str(epoch) + '.png')

Train Epoch 1 [Batch 0/469]	Loss: 552.154
Train Epoch 1 [Batch 100/469]	Loss: 182.914
Train Epoch 1 [Batch 200/469]	Loss: 156.408
Train Epoch 1 [Batch 300/469]	Loss: 140.119
Train Epoch 1 [Batch 400/469]	Loss: 132.480
=====> Epoch 1, Average Loss: 165.406
=====> Average Test Loss: 128.565
Train Epoch 2 [Batch 0/469]	Loss: 132.577
Train Epoch 2 [Batch 100/469]	Loss: 123.357
Train Epoch 2 [Batch 200/469]	Loss: 119.947
Train Epoch 2 [Batch 300/469]	Loss: 124.771
Train Epoch 2 [Batch 400/469]	Loss: 123.862
=====> Epoch 2, Average Loss: 122.272
=====> Average Test Loss: 116.179
Train Epoch 3 [Batch 0/469]	Loss: 116.523
Train Epoch 3 [Batch 100/469]	Loss: 114.146
Train Epoch 3 [Batch 200/469]	Loss: 118.713
Train Epoch 3 [Batch 300/469]	Loss: 114.624
Train Epoch 3 [Batch 400/469]	Loss: 114.380
=====> Epoch 3, Average Loss: 115.050
=====> Average Test Loss: 112.250
Train Epoch 4 [Batch 0/469]	Loss: 110.737
Train Epoch 4 [Batch 100/469]	Loss: 113.304
Train Epoch 4 [Batch 200/469]	Loss: 108.278
