In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torchvision.utils import save_image

In [2]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Create a directory if not exists
sample_dir = 'samples'
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)

# Hyper-parameters
image_size = 784
h_dim = 400
z_dim = 20
num_epochs = 15
batch_size = 128
learning_rate = 1e-3

# MNIST dataset
dataset = torchvision.datasets.MNIST(root='../../data',
                                     train=True,
                                     transform=transforms.ToTensor(),
                                     download=True)

# Data loader
data_loader = torch.utils.data.DataLoader(dataset=dataset,
                                          batch_size=batch_size, 
                                          shuffle=True)

In [3]:
# VAE model
class VAE(nn.Module):
    def __init__(self, image_size=784, h_dim=400, z_dim=20):
        super(VAE, self).__init__()
        self.fc1 = nn.Linear(image_size, h_dim)
        self.fc2 = nn.Linear(h_dim, z_dim)
        self.fc3 = nn.Linear(h_dim, z_dim)
        self.fc4 = nn.Linear(z_dim, h_dim)
        self.fc5 = nn.Linear(h_dim, image_size)
        
    def encode(self, x):
        h = F.relu(self.fc1(x))
        return self.fc2(h), self.fc3(h)
    
    def reparameterize(self, mu, log_var):
        std = torch.exp(log_var/2)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h = F.relu(self.fc4(z))
        return F.sigmoid(self.fc5(h))
    
    def forward(self, x):
        mu, log_var = self.encode(x)
        z = self.reparameterize(mu, log_var)
        x_reconst = self.decode(z)
        return x_reconst, mu, log_var

model = VAE().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [4]:
# Start training
for epoch in range(num_epochs):
    for i, (x, _) in enumerate(data_loader):
        # Forward pass
        x = x.to(device).view(-1, image_size)
        x_reconst, mu, log_var = model(x)
        
        # Compute reconstruction loss and kl divergence
        # For KL divergence, see Appendix B in VAE paper or http://yunjey47.tistory.com/43
        reconst_loss = F.binary_cross_entropy(x_reconst, x, size_average=False)
        kl_div = - 0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
        
        # Backprop and optimize
        loss = reconst_loss + kl_div
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if (i+1) % 10 == 0:
            print ("Epoch[{}/{}], Step [{}/{}], Reconst Loss: {:.4f}, KL Div: {:.4f}" 
                   .format(epoch+1, num_epochs, i+1, len(data_loader), reconst_loss.item(), kl_div.item()))
    
    with torch.no_grad():
        # Save the sampled images
        z = torch.randn(batch_size, z_dim).to(device)
        out = model.decode(z).view(-1, 1, 28, 28)
        save_image(out, os.path.join(sample_dir, 'sampled-{}.png'.format(epoch+1)))

        # Save the reconstructed images
        out, _, _ = model(x)
        x_concat = torch.cat([x.view(-1, 1, 28, 28), out.view(-1, 1, 28, 28)], dim=3)
        save_image(x_concat, os.path.join(sample_dir, 'reconst-{}.png'.format(epoch+1)))

Epoch[1/15], Step [10/469], Reconst Loss: 34476.1562, KL Div: 4142.1797
Epoch[1/15], Step [20/469], Reconst Loss: 29967.1055, KL Div: 1003.6094
Epoch[1/15], Step [30/469], Reconst Loss: 27397.9844, KL Div: 1246.0935
Epoch[1/15], Step [40/469], Reconst Loss: 26431.6250, KL Div: 666.5424
Epoch[1/15], Step [50/469], Reconst Loss: 25659.1621, KL Div: 812.8797
Epoch[1/15], Step [60/469], Reconst Loss: 24803.2910, KL Div: 892.4062
Epoch[1/15], Step [70/469], Reconst Loss: 24819.5508, KL Div: 1005.0826
Epoch[1/15], Step [80/469], Reconst Loss: 23206.4805, KL Div: 1317.9355
Epoch[1/15], Step [90/469], Reconst Loss: 22770.8203, KL Div: 1289.4001
Epoch[1/15], Step [100/469], Reconst Loss: 21283.3418, KL Div: 1376.3140
Epoch[1/15], Step [110/469], Reconst Loss: 21517.2090, KL Div: 1533.3827
Epoch[1/15], Step [120/469], Reconst Loss: 21062.2832, KL Div: 1618.5515
Epoch[1/15], Step [130/469], Reconst Loss: 19934.9707, KL Div: 1752.8223
Epoch[1/15], Step [140/469], Reconst Loss: 18701.0117, KL Div: 

Epoch[3/15], Step [220/469], Reconst Loss: 11353.8965, KL Div: 3013.3428
Epoch[3/15], Step [230/469], Reconst Loss: 12272.0156, KL Div: 3215.0293
Epoch[3/15], Step [240/469], Reconst Loss: 11686.2646, KL Div: 3037.6924
Epoch[3/15], Step [250/469], Reconst Loss: 11744.1641, KL Div: 3254.1033
Epoch[3/15], Step [260/469], Reconst Loss: 11040.2861, KL Div: 2954.5415
Epoch[3/15], Step [270/469], Reconst Loss: 11478.8867, KL Div: 3067.7881
Epoch[3/15], Step [280/469], Reconst Loss: 11804.4658, KL Div: 3041.0732
Epoch[3/15], Step [290/469], Reconst Loss: 11956.3584, KL Div: 3019.6196
Epoch[3/15], Step [300/469], Reconst Loss: 11839.9922, KL Div: 3172.8726
Epoch[3/15], Step [310/469], Reconst Loss: 11692.7793, KL Div: 3024.5417
Epoch[3/15], Step [320/469], Reconst Loss: 12010.2285, KL Div: 3192.1401
Epoch[3/15], Step [330/469], Reconst Loss: 11411.6504, KL Div: 3052.6565
Epoch[3/15], Step [340/469], Reconst Loss: 11539.7617, KL Div: 3104.0874
Epoch[3/15], Step [350/469], Reconst Loss: 11446.29

Epoch[5/15], Step [430/469], Reconst Loss: 11071.7031, KL Div: 3133.7319
Epoch[5/15], Step [440/469], Reconst Loss: 11175.2900, KL Div: 3245.7324
Epoch[5/15], Step [450/469], Reconst Loss: 11178.5010, KL Div: 3208.4116
Epoch[5/15], Step [460/469], Reconst Loss: 11034.5264, KL Div: 3196.6199
Epoch[6/15], Step [10/469], Reconst Loss: 10532.8252, KL Div: 3198.3989
Epoch[6/15], Step [20/469], Reconst Loss: 10869.7695, KL Div: 3250.6885
Epoch[6/15], Step [30/469], Reconst Loss: 11274.4404, KL Div: 3192.4487
Epoch[6/15], Step [40/469], Reconst Loss: 10863.2500, KL Div: 3236.5938
Epoch[6/15], Step [50/469], Reconst Loss: 10624.8916, KL Div: 3154.1694
Epoch[6/15], Step [60/469], Reconst Loss: 11004.0654, KL Div: 3140.5132
Epoch[6/15], Step [70/469], Reconst Loss: 10636.6221, KL Div: 3173.3687
Epoch[6/15], Step [80/469], Reconst Loss: 10462.3877, KL Div: 3201.3540
Epoch[6/15], Step [90/469], Reconst Loss: 10910.7754, KL Div: 3266.3635
Epoch[6/15], Step [100/469], Reconst Loss: 10920.7598, KL Di

Epoch[8/15], Step [190/469], Reconst Loss: 10151.4141, KL Div: 3078.5195
Epoch[8/15], Step [200/469], Reconst Loss: 10667.8379, KL Div: 3320.9587
Epoch[8/15], Step [210/469], Reconst Loss: 10772.2744, KL Div: 3241.6133
Epoch[8/15], Step [220/469], Reconst Loss: 10243.8154, KL Div: 3229.2324
Epoch[8/15], Step [230/469], Reconst Loss: 10641.3008, KL Div: 3211.5371
Epoch[8/15], Step [240/469], Reconst Loss: 10622.8047, KL Div: 3338.4800
Epoch[8/15], Step [250/469], Reconst Loss: 10016.8438, KL Div: 3066.4009
Epoch[8/15], Step [260/469], Reconst Loss: 10640.7188, KL Div: 3223.8586
Epoch[8/15], Step [270/469], Reconst Loss: 10922.4980, KL Div: 3275.2017
Epoch[8/15], Step [280/469], Reconst Loss: 9800.3223, KL Div: 3105.6597
Epoch[8/15], Step [290/469], Reconst Loss: 10654.2939, KL Div: 3207.1748
Epoch[8/15], Step [300/469], Reconst Loss: 10559.1426, KL Div: 3243.5662
Epoch[8/15], Step [310/469], Reconst Loss: 10352.0547, KL Div: 3241.9854
Epoch[8/15], Step [320/469], Reconst Loss: 11074.792

Epoch[10/15], Step [400/469], Reconst Loss: 10365.5488, KL Div: 3214.6853
Epoch[10/15], Step [410/469], Reconst Loss: 10123.2188, KL Div: 3210.8850
Epoch[10/15], Step [420/469], Reconst Loss: 10328.2969, KL Div: 3289.1870
Epoch[10/15], Step [430/469], Reconst Loss: 10005.1592, KL Div: 3220.5161
Epoch[10/15], Step [440/469], Reconst Loss: 9825.9482, KL Div: 3294.0273
Epoch[10/15], Step [450/469], Reconst Loss: 10330.3262, KL Div: 3305.0376
Epoch[10/15], Step [460/469], Reconst Loss: 10496.8770, KL Div: 3268.0764
Epoch[11/15], Step [10/469], Reconst Loss: 10258.6045, KL Div: 3142.7021
Epoch[11/15], Step [20/469], Reconst Loss: 10347.6465, KL Div: 3314.4019
Epoch[11/15], Step [30/469], Reconst Loss: 10187.2295, KL Div: 3208.0815
Epoch[11/15], Step [40/469], Reconst Loss: 10312.1699, KL Div: 3331.2324
Epoch[11/15], Step [50/469], Reconst Loss: 10456.3145, KL Div: 3286.3359
Epoch[11/15], Step [60/469], Reconst Loss: 10594.3730, KL Div: 3238.3462
Epoch[11/15], Step [70/469], Reconst Loss: 10

Epoch[13/15], Step [160/469], Reconst Loss: 10135.2217, KL Div: 3254.0334
Epoch[13/15], Step [170/469], Reconst Loss: 9997.2461, KL Div: 3200.6309
Epoch[13/15], Step [180/469], Reconst Loss: 10565.5742, KL Div: 3293.3523
Epoch[13/15], Step [190/469], Reconst Loss: 10244.9512, KL Div: 3205.3086
Epoch[13/15], Step [200/469], Reconst Loss: 10102.4189, KL Div: 3173.6504
Epoch[13/15], Step [210/469], Reconst Loss: 10047.5303, KL Div: 3246.3687
Epoch[13/15], Step [220/469], Reconst Loss: 10200.0166, KL Div: 3287.1147
Epoch[13/15], Step [230/469], Reconst Loss: 10216.1113, KL Div: 3282.0989
Epoch[13/15], Step [240/469], Reconst Loss: 9681.5361, KL Div: 3120.8838
Epoch[13/15], Step [250/469], Reconst Loss: 10530.1191, KL Div: 3351.5181
Epoch[13/15], Step [260/469], Reconst Loss: 10096.7012, KL Div: 3247.9810
Epoch[13/15], Step [270/469], Reconst Loss: 10190.6113, KL Div: 3282.5884
Epoch[13/15], Step [280/469], Reconst Loss: 10299.8789, KL Div: 3389.4844
Epoch[13/15], Step [290/469], Reconst Lo

Epoch[15/15], Step [370/469], Reconst Loss: 9794.1025, KL Div: 3137.1436
Epoch[15/15], Step [380/469], Reconst Loss: 10211.7988, KL Div: 3261.4084
Epoch[15/15], Step [390/469], Reconst Loss: 9901.6895, KL Div: 3242.2046
Epoch[15/15], Step [400/469], Reconst Loss: 10011.6865, KL Div: 3191.9458
Epoch[15/15], Step [410/469], Reconst Loss: 10282.7373, KL Div: 3267.4224
Epoch[15/15], Step [420/469], Reconst Loss: 10347.0586, KL Div: 3232.7456
Epoch[15/15], Step [430/469], Reconst Loss: 10169.0664, KL Div: 3353.2227
Epoch[15/15], Step [440/469], Reconst Loss: 10259.9092, KL Div: 3341.6943
Epoch[15/15], Step [450/469], Reconst Loss: 9802.7861, KL Div: 3277.4980
Epoch[15/15], Step [460/469], Reconst Loss: 10417.8652, KL Div: 3253.4810
