In [15]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torchvision.utils import save_image

In [35]:
# Definee hyperparameters
image_size = 784
h_dim = 400
latent_dim = 20
batch_size = 128
num_epochs = 10

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# MNIST dataset
train_dataset = torchvision.datasets.MNIST(root='./data/',
                                             train=True,
                                                transform=transforms.ToTensor(),
                                                    download=True)

test_dataset = torchvision.datasets.MNIST(root='./data/',
                                            train=False,
                                                transform=transforms.ToTensor())

# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                                  batch_size=batch_size,
                                                        shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                                    batch_size=batch_size,
                                                            shuffle=False)

# Create a directory to savee the reconstructed and sampled images (if not exists)
sample_dir = 'results'
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)


![vae](https://user-images.githubusercontent.com/30661597/78418103-a2047200-766b-11ea-8205-c7e5712715f4.png)

In [36]:
# Define the model
# VAE model
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        self.fc1 = nn.Linear(image_size, h_dim)
        self.fc2_mean = nn.Linear(h_dim, latent_dim) # mu
        self.fc2_logvar = nn.Linear(h_dim, latent_dim) # logvar
        
        self.fc3 = nn.Linear(latent_dim, h_dim)
        self.fc4 = nn.Linear(h_dim, image_size)
        
    def encode(self, x):
        h = F.relu(self.fc1(x))
        return self.fc2_mean(h), self.fc2_logvar(h)
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(logvar/2) # standard deviation
        eps = torch.randn_like(std) # epsilon
        return mu + eps*std
    
    def decode(self, z):
        h = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h))
    
    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, image_size))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar
    
# Define model and optimizer
model = VAE().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

$Loss = -E[\log P(X | z)]+D_{K L}[N(\mu(X), \Sigma(X)) \| N(0,1)]$

#### $D_{K L}[N(\mu(X), \Sigma(X)) \| N(0,1)]=\frac{1}{2} \sum_{k}\left(\exp (\Sigma(X))+\mu^{2}(X)-1-\Sigma(X)\right)$

In [37]:
# Define loss function
def loss_function(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, image_size), reduction='sum')
    
    # KL divergence
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    
    return BCE + KLD

In [38]:
# Define train function
def train(epoch):
    model.train()
    train_loss = 0
    for i, (images, _) in enumerate(train_loader):
        images = images.to(device)
        
        recon_batch, mu, logvar = model(images)
        loss = loss_function(recon_batch, images, mu, logvar)
        
        optimizer.zero_grad()
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        
        if i % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\t Average loss: {:.6f}'.format(
                epoch, i*len(images), len(train_loader.dataset),
                100.*i/len(train_loader),
                loss.item()/len(images)))
            
    print('====> Epoch: {} Average loss: {:.4f}'.format(
        epoch, train_loss/len(train_loader.dataset)))
    


In [40]:
# Define test function
def test(epoch):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for batch_idx, (images, _) in enumerate(test_loader):
            images = images.to(device)
            
            recon_batch, mu, logvar = model(images)
            test_loss += loss_function(recon_batch, images, mu, logvar).item()
            
            if batch_idx == 0:
                n = min(images.size(0), 8)
                comparison = torch.cat([images[:n],
                                       recon_batch.view(batch_size, 1, 28, 28)[:n]])
            
            # Save the reconstructed images
            save_image(comparison.cpu(), 'results/reconstruction_' + str(epoch) + '.png')
            
    test_loss /= len(test_loader.dataset)
    print('====> Test loss: {:.4f}'.format(test_loss))

In [41]:
# Train the model
for epoch in range(1, num_epochs+1):
    train(epoch)
    test(epoch)
    
    # Save the sampled images
    with torch.no_grad():
        z = torch.randn(batch_size, latent_dim).to(device)
        sample = model.decode(z).cpu()
        save_image(sample.view(batch_size, 1, 28, 28), 'results/sample_' + str(epoch) + '.png')

====> Epoch: 1 Average loss: 165.0513
====> Test loss: 127.6641
====> Epoch: 2 Average loss: 121.3749
====> Test loss: 115.6545
====> Epoch: 3 Average loss: 114.4816
====> Test loss: 111.8110
====> Epoch: 4 Average loss: 111.5546
====> Test loss: 109.9430
====> Epoch: 5 Average loss: 109.9045
====> Test loss: 108.4127
====> Epoch: 6 Average loss: 108.7364
====> Test loss: 107.7080
====> Epoch: 7 Average loss: 107.8712
====> Test loss: 106.9747
====> Epoch: 8 Average loss: 107.2442
====> Test loss: 106.4121
====> Epoch: 9 Average loss: 106.7085
====> Test loss: 106.1986
====> Epoch: 10 Average loss: 106.2655
====> Test loss: 105.6931
