In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torchvision.utils import save_image

# Configure device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
# Create a dir if not exist
sample_dir = 'sample'
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)
    
# Hyper-parameters
image_size = 784
h_dim = 400
z_dim = 20
num_epochs = 15
batch_size = 128
learning_rate = 1e-3

# MNIST dataset
dataset = torchvision.datasets.MNIST(root='../data',
                                     train=True,
                                     transform=transforms.ToTensor(),
                                     download=True)

# Data loader
data_loader = torch.utils.data.DataLoader(dataset=dataset,
                                          batch_size=batch_size, 
                                          shuffle=True)

In [3]:
# One can change the perceptron to convolutional blocks to get better feature extraction
class VAE(nn.Module):
    def __init__(self, image_size=784, h_dim=400, z_dim=20):
        super(VAE, self).__init__()
        self.fc1 = nn.Linear(image_size, h_dim)
        self.fc2 = nn.Linear(h_dim, z_dim)
        self.fc3 = nn.Linear(h_dim, z_dim)
        self.fc4 = nn.Linear(z_dim, h_dim)
        self.fc5 = nn.Linear(h_dim, image_size)
        
    def encode(self, x):
        h = F.relu(self.fc1(x))
        return self.fc2(h), self.fc3(h)
    
    def reparameterize(self, mu, log_var):
        std = torch.exp(log_var/2)
        eps = torch.rand_like(std)
        return mu + eps * std
    
    def decode(self, z):
        h = F.relu(self.fc4(z))
        return torch.sigmoid(self.fc5(h))
    
    def forward(self, x):
        mu, log_var = self.encode(x)
        z = self.reparameterize(mu, log_var)
        x_reconst = self.decode(z)
        return x_reconst, mu, log_var
    
model = VAE().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [4]:
# Train the model
for epoch in range(num_epochs):
    for i, (x, _) in enumerate(data_loader):
        # Forward pass
        x = x.to(device).reshape(-1, image_size)
        x_reconst, mu, log_var = model(x)
        
        # Compute reconstruction loss and KL divergence
        # For KL divergence, see Appendix B in VAE paper
        reconst_loss = F.binary_cross_entropy(x_reconst, x, size_average=False)
        kl_div = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
        
        # Backprop and optimize
        loss = reconst_loss + kl_div
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if (i+1) % 10 == 0:
            print('Epoch [{}/{}], Step [{}/{}], Reconst loss: {:.4f}, KL Div {:.4f}'
                  .format(epoch+1, num_epochs, i+1, len(data_loader), reconst_loss.item(), kl_div.item()))

    # Test the model
    with torch.no_grad():
        # Save the sampled image
        z = torch.randn(batch_size, z_dim).to(device)
        out = model.decode(z).view(-1, 1, 28, 28)
        save_image(out, os.path.join(sample_dir, 'sampled-{}.png'.format(epoch+1)))
        
        # Save the reconstructed images
        out, _, _ = model(x)
        x_concat = torch.cat([x.view(-1, 1, 28, 28), out.view(-1, 1, 28, 28)], dim=3)
        save_image(x_concat, os.path.join(sample_dir, 'reconst-{}.png'.format(epoch+1)))



Epoch [1/15], Step [10/469], Reconst loss: 33712.7227, KL Div 4422.1274
Epoch [1/15], Step [20/469], Reconst loss: 27894.6680, KL Div 1041.6532
Epoch [1/15], Step [30/469], Reconst loss: 27420.6738, KL Div 597.5938
Epoch [1/15], Step [40/469], Reconst loss: 26834.4844, KL Div 416.7742
Epoch [1/15], Step [50/469], Reconst loss: 24267.0410, KL Div 466.3390
Epoch [1/15], Step [60/469], Reconst loss: 24075.9395, KL Div 449.4094
Epoch [1/15], Step [70/469], Reconst loss: 22104.5781, KL Div 584.0891
Epoch [1/15], Step [80/469], Reconst loss: 20768.4746, KL Div 746.8918
Epoch [1/15], Step [90/469], Reconst loss: 20719.3047, KL Div 689.8594
Epoch [1/15], Step [100/469], Reconst loss: 19616.2695, KL Div 753.3087
Epoch [1/15], Step [110/469], Reconst loss: 18845.2559, KL Div 801.7220
Epoch [1/15], Step [120/469], Reconst loss: 17329.9512, KL Div 861.5785
Epoch [1/15], Step [130/469], Reconst loss: 17896.9727, KL Div 871.6507
Epoch [1/15], Step [140/469], Reconst loss: 17631.4336, KL Div 842.5442

Epoch [3/15], Step [230/469], Reconst loss: 11066.5713, KL Div 1214.3119
Epoch [3/15], Step [240/469], Reconst loss: 11001.5078, KL Div 1246.1538
Epoch [3/15], Step [250/469], Reconst loss: 10750.4072, KL Div 1255.2090
Epoch [3/15], Step [260/469], Reconst loss: 10979.9355, KL Div 1265.2710
Epoch [3/15], Step [270/469], Reconst loss: 11275.7764, KL Div 1308.3527
Epoch [3/15], Step [280/469], Reconst loss: 11304.4746, KL Div 1220.7709
Epoch [3/15], Step [290/469], Reconst loss: 11002.0039, KL Div 1298.2063
Epoch [3/15], Step [300/469], Reconst loss: 11002.4883, KL Div 1218.9883
Epoch [3/15], Step [310/469], Reconst loss: 11114.6670, KL Div 1268.0280
Epoch [3/15], Step [320/469], Reconst loss: 10825.9854, KL Div 1269.0654
Epoch [3/15], Step [330/469], Reconst loss: 11060.1338, KL Div 1241.2736
Epoch [3/15], Step [340/469], Reconst loss: 10862.3115, KL Div 1278.0352
Epoch [3/15], Step [350/469], Reconst loss: 11108.5049, KL Div 1275.9846
Epoch [3/15], Step [360/469], Reconst loss: 10590.7

Epoch [5/15], Step [450/469], Reconst loss: 10157.6406, KL Div 1224.2507
Epoch [5/15], Step [460/469], Reconst loss: 10448.3418, KL Div 1296.4609
Epoch [6/15], Step [10/469], Reconst loss: 10489.5527, KL Div 1302.7861
Epoch [6/15], Step [20/469], Reconst loss: 10061.7012, KL Div 1249.8083
Epoch [6/15], Step [30/469], Reconst loss: 10560.0742, KL Div 1327.0112
Epoch [6/15], Step [40/469], Reconst loss: 10707.6416, KL Div 1338.1689
Epoch [6/15], Step [50/469], Reconst loss: 10070.0977, KL Div 1243.3440
Epoch [6/15], Step [60/469], Reconst loss: 10075.8906, KL Div 1298.6704
Epoch [6/15], Step [70/469], Reconst loss: 9996.4473, KL Div 1286.9905
Epoch [6/15], Step [80/469], Reconst loss: 10506.2871, KL Div 1303.5341
Epoch [6/15], Step [90/469], Reconst loss: 10989.9150, KL Div 1273.8584
Epoch [6/15], Step [100/469], Reconst loss: 9900.9805, KL Div 1291.3372
Epoch [6/15], Step [110/469], Reconst loss: 10262.5957, KL Div 1330.7041
Epoch [6/15], Step [120/469], Reconst loss: 10627.1270, KL Div

Epoch [8/15], Step [210/469], Reconst loss: 9744.3340, KL Div 1242.2181
Epoch [8/15], Step [220/469], Reconst loss: 10422.3477, KL Div 1253.7498
Epoch [8/15], Step [230/469], Reconst loss: 10057.3584, KL Div 1332.6431
Epoch [8/15], Step [240/469], Reconst loss: 10256.9199, KL Div 1288.8892
Epoch [8/15], Step [250/469], Reconst loss: 10490.9160, KL Div 1321.7722
Epoch [8/15], Step [260/469], Reconst loss: 9987.5947, KL Div 1296.3562
Epoch [8/15], Step [270/469], Reconst loss: 10158.2920, KL Div 1250.3708
Epoch [8/15], Step [280/469], Reconst loss: 10221.9668, KL Div 1299.5718
Epoch [8/15], Step [290/469], Reconst loss: 9622.6152, KL Div 1315.6542
Epoch [8/15], Step [300/469], Reconst loss: 9808.3213, KL Div 1233.5071
Epoch [8/15], Step [310/469], Reconst loss: 10403.0410, KL Div 1286.2997
Epoch [8/15], Step [320/469], Reconst loss: 9486.7480, KL Div 1220.0457
Epoch [8/15], Step [330/469], Reconst loss: 9808.9033, KL Div 1302.8806
Epoch [8/15], Step [340/469], Reconst loss: 10101.1777, K

Epoch [10/15], Step [430/469], Reconst loss: 9894.5117, KL Div 1268.4297
Epoch [10/15], Step [440/469], Reconst loss: 10483.9160, KL Div 1334.4348
Epoch [10/15], Step [450/469], Reconst loss: 9809.6631, KL Div 1284.7487
Epoch [10/15], Step [460/469], Reconst loss: 9873.0664, KL Div 1296.5020
Epoch [11/15], Step [10/469], Reconst loss: 9905.4238, KL Div 1310.5449
Epoch [11/15], Step [20/469], Reconst loss: 10167.7705, KL Div 1304.3949
Epoch [11/15], Step [30/469], Reconst loss: 9727.7148, KL Div 1249.3481
Epoch [11/15], Step [40/469], Reconst loss: 9443.5732, KL Div 1258.9019
Epoch [11/15], Step [50/469], Reconst loss: 9882.3223, KL Div 1264.1727
Epoch [11/15], Step [60/469], Reconst loss: 9440.2949, KL Div 1277.7983
Epoch [11/15], Step [70/469], Reconst loss: 9815.6875, KL Div 1341.6930
Epoch [11/15], Step [80/469], Reconst loss: 10039.7070, KL Div 1296.2133
Epoch [11/15], Step [90/469], Reconst loss: 10164.0459, KL Div 1309.6594
Epoch [11/15], Step [100/469], Reconst loss: 9687.7285, 

Epoch [13/15], Step [190/469], Reconst loss: 10079.6543, KL Div 1308.5422
Epoch [13/15], Step [200/469], Reconst loss: 9828.4424, KL Div 1299.3047
Epoch [13/15], Step [210/469], Reconst loss: 9526.3223, KL Div 1258.7966
Epoch [13/15], Step [220/469], Reconst loss: 9879.8848, KL Div 1277.6836
Epoch [13/15], Step [230/469], Reconst loss: 9661.9629, KL Div 1259.8477
Epoch [13/15], Step [240/469], Reconst loss: 9626.6758, KL Div 1333.7888
Epoch [13/15], Step [250/469], Reconst loss: 9664.6875, KL Div 1301.2827
Epoch [13/15], Step [260/469], Reconst loss: 9195.1816, KL Div 1301.8851
Epoch [13/15], Step [270/469], Reconst loss: 9540.7207, KL Div 1294.4614
Epoch [13/15], Step [280/469], Reconst loss: 9994.4268, KL Div 1247.7679
Epoch [13/15], Step [290/469], Reconst loss: 9257.4502, KL Div 1270.3582
Epoch [13/15], Step [300/469], Reconst loss: 9428.9785, KL Div 1238.9316
Epoch [13/15], Step [310/469], Reconst loss: 9621.9121, KL Div 1240.5278
Epoch [13/15], Step [320/469], Reconst loss: 9152.

Epoch [15/15], Step [410/469], Reconst loss: 9551.4541, KL Div 1294.7408
Epoch [15/15], Step [420/469], Reconst loss: 9647.0000, KL Div 1310.4636
Epoch [15/15], Step [430/469], Reconst loss: 9178.6006, KL Div 1284.2456
Epoch [15/15], Step [440/469], Reconst loss: 9488.0664, KL Div 1243.7935
Epoch [15/15], Step [450/469], Reconst loss: 9507.0059, KL Div 1313.9934
Epoch [15/15], Step [460/469], Reconst loss: 9767.1982, KL Div 1294.7338
