# Import Modules and Libraries

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib
import numpy as np
import imageio
from matplotlib import pyplot as plt
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import make_grid, save_image

# MNIST Data

In [2]:
bs = 100
# Import MNIST Dataset
train_dataset = datasets.MNIST(root='./mnist_data/', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = datasets.MNIST(root='./mnist_data/', train=False, transform=transforms.ToTensor(), download=False)

# Data Loader 
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=bs, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=bs, shuffle=False)

# VAE Class

In [3]:
class VAE(nn.Module):
    def __init__(self, x_dim, h_dim1, h_dim2, z_dim):
        super(VAE, self).__init__()
        
        # encoder part
        self.fc1 = nn.Linear(x_dim, h_dim1)
        self.fc2 = nn.Linear(h_dim1, h_dim2)
        self.fc31 = nn.Linear(h_dim2, z_dim)
        self.fc32 = nn.Linear(h_dim2, z_dim)
        # decoder part
        self.fc4 = nn.Linear(z_dim, h_dim2)
        self.fc5 = nn.Linear(h_dim2, h_dim1)
        self.fc6 = nn.Linear(h_dim1, x_dim)
        
    def encoder(self, x):
        h = F.relu(self.fc1(x))
        h = F.relu(self.fc2(h))
        return self.fc31(h), self.fc32(h) # mu, log_var
    
    def sampling(self, mu, log_var):
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu) # return z sample
        
    def decoder(self, z):
        h = F.relu(self.fc4(z))
        h = F.relu(self.fc5(h))
        return F.sigmoid(self.fc6(h)) 
    
    def forward(self, x):
        mu, log_var = self.encoder(x.view(-1, 784))
        z = self.sampling(mu, log_var)
        return self.decoder(z), mu, log_var

# build model
vae = VAE(x_dim=784, h_dim1= 512, h_dim2=256, z_dim=2)
if torch.cuda.is_available():
    vae.cuda()

# Optimizer

In [4]:
optimizer = optim.Adam(vae.parameters())
# return reconstruction error + KL divergence losses
def loss_function(recon_x, x, mu, log_var):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
    KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    return BCE + KLD

# Training

In [5]:
def train(epoch):
    vae.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.cuda()
        optimizer.zero_grad()
        
        recon_batch, mu, log_var = vae(data)
        loss = loss_function(recon_batch, data, mu, log_var)
        
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item() / len(data)))
    print('====> Epoch: {} Average loss: {:.4f}'.format(epoch, train_loss / len(train_loader.dataset)))

# Evaluation

In [6]:
def test():
    vae.eval()
    test_loss= 0
    with torch.no_grad():
        for data, _ in test_loader:
            data = data.cuda()
            recon, mu, log_var = vae(data)
            
            # sum up batch loss
            test_loss += loss_function(recon, data, mu, log_var).item()
        
    test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))
    return(test_loss)

In [13]:
test_loss_array = []
images = []
z = torch.randn(64, 2).cuda()
for epoch in range(1, 301):
    train(epoch)
    test_loss = test()
    test_loss_array.append(test_loss)
    with torch.no_grad():
        sample = vae.decoder(z).cpu().detach()
        save_image(sample.view(64, 1, 28, 28), f"./outputs/gen_img_{epoch}.png")
        images.append(make_grid(sample.view(64, 1, 28, 28)))
to_pil_image = transforms.ToPILImage()
imgs = [np.array(to_pil_image(img)) for img in images]
imageio.mimsave('./outputs/generator_images.gif', imgs)

====> Epoch: 1 Average loss: 144.1354
====> Test set loss: 144.7856
====> Epoch: 2 Average loss: 143.4704
====> Test set loss: 144.3308
====> Epoch: 3 Average loss: 142.7403
====> Test set loss: 143.4863
====> Epoch: 4 Average loss: 142.1092
====> Test set loss: 143.0147
====> Epoch: 5 Average loss: 141.4834
====> Test set loss: 142.4181
====> Epoch: 6 Average loss: 141.1192
====> Test set loss: 141.7776
====> Epoch: 7 Average loss: 140.6068
====> Test set loss: 141.2196
====> Epoch: 8 Average loss: 140.2573
====> Test set loss: 141.8938
====> Epoch: 9 Average loss: 139.9168
====> Test set loss: 140.8777
====> Epoch: 10 Average loss: 139.5372
====> Test set loss: 141.3325
====> Epoch: 11 Average loss: 139.5833
====> Test set loss: 141.1212
====> Epoch: 12 Average loss: 139.1619
====> Test set loss: 140.8209
====> Epoch: 13 Average loss: 138.8994
====> Test set loss: 140.4356
====> Epoch: 14 Average loss: 138.5886
====> Test set loss: 140.4808
====> Epoch: 15 Average loss: 138.6782
====

====> Epoch: 22 Average loss: 137.1684
====> Test set loss: 139.3342
====> Epoch: 23 Average loss: 137.1659
====> Test set loss: 139.5767
====> Epoch: 24 Average loss: 137.0491
====> Test set loss: 139.7014
====> Epoch: 25 Average loss: 137.0252
====> Test set loss: 139.4172
====> Epoch: 26 Average loss: 136.6156
====> Test set loss: 139.1442
====> Epoch: 27 Average loss: 136.4574
====> Test set loss: 139.2008
====> Epoch: 28 Average loss: 136.6023
====> Test set loss: 139.0635
====> Epoch: 29 Average loss: 136.2632
====> Test set loss: 138.8421
====> Epoch: 30 Average loss: 136.0486
====> Test set loss: 138.5980
====> Epoch: 31 Average loss: 135.8402
====> Test set loss: 138.5825
====> Epoch: 32 Average loss: 135.7657
====> Test set loss: 138.3948
====> Epoch: 33 Average loss: 135.6819
====> Test set loss: 138.8234
====> Epoch: 34 Average loss: 135.8921
====> Test set loss: 138.6725
====> Epoch: 35 Average loss: 135.7387
====> Test set loss: 138.8002
====> Epoch: 36 Average loss: 135.

====> Epoch: 44 Average loss: 135.0337
====> Test set loss: 137.9329
====> Epoch: 45 Average loss: 135.2689
====> Test set loss: 138.7123
====> Epoch: 46 Average loss: 134.9942
====> Test set loss: 138.3189
====> Epoch: 47 Average loss: 135.0246
====> Test set loss: 138.1083
====> Epoch: 48 Average loss: 134.8569
====> Test set loss: 138.1991
====> Epoch: 49 Average loss: 134.7966
====> Test set loss: 138.0646
====> Epoch: 50 Average loss: 134.7680
====> Test set loss: 138.2108
====> Epoch: 51 Average loss: 134.7811
====> Test set loss: 138.1453
====> Epoch: 52 Average loss: 134.4069
====> Test set loss: 138.0201
====> Epoch: 53 Average loss: 134.2314
====> Test set loss: 137.8835
====> Epoch: 54 Average loss: 134.2168
====> Test set loss: 138.0851
====> Epoch: 55 Average loss: 134.2454
====> Test set loss: 138.0105
====> Epoch: 56 Average loss: 134.2742
====> Test set loss: 137.7804
====> Epoch: 57 Average loss: 134.0523
====> Test set loss: 138.0289
====> Epoch: 58 Average loss: 134.

====> Epoch: 65 Average loss: 134.3928
====> Test set loss: 137.4362
====> Epoch: 66 Average loss: 133.5712
====> Test set loss: 138.0400
====> Epoch: 67 Average loss: 133.4467
====> Test set loss: 137.8300
====> Epoch: 68 Average loss: 133.4778
====> Test set loss: 137.7764
====> Epoch: 69 Average loss: 133.5370
====> Test set loss: 137.9009
====> Epoch: 70 Average loss: 133.6592
====> Test set loss: 138.2768
====> Epoch: 71 Average loss: 133.3118
====> Test set loss: 138.1934
====> Epoch: 72 Average loss: 133.4669
====> Test set loss: 137.6915
====> Epoch: 73 Average loss: 133.2310
====> Test set loss: 137.7251
====> Epoch: 74 Average loss: 133.2822
====> Test set loss: 138.0213
====> Epoch: 75 Average loss: 133.3285
====> Test set loss: 138.3209
====> Epoch: 76 Average loss: 133.3084
====> Test set loss: 138.3067
====> Epoch: 77 Average loss: 133.2196
====> Test set loss: 137.7971
====> Epoch: 78 Average loss: 133.4675
====> Test set loss: 137.7203
====> Epoch: 79 Average loss: 133.

====> Epoch: 87 Average loss: 132.7432
====> Test set loss: 138.3582
====> Epoch: 88 Average loss: 132.4969
====> Test set loss: 137.6677
====> Epoch: 89 Average loss: 132.5842
====> Test set loss: 138.1423
====> Epoch: 90 Average loss: 133.0326
====> Test set loss: 138.4818
====> Epoch: 91 Average loss: 132.8289
====> Test set loss: 138.1801
====> Epoch: 92 Average loss: 132.4552
====> Test set loss: 138.2913
====> Epoch: 93 Average loss: 132.6277
====> Test set loss: 138.0687
====> Epoch: 94 Average loss: 132.5510
====> Test set loss: 137.9658
====> Epoch: 95 Average loss: 132.5669
====> Test set loss: 138.0894
====> Epoch: 96 Average loss: 132.4419
====> Test set loss: 137.8304
====> Epoch: 97 Average loss: 132.7145
====> Test set loss: 137.9470
====> Epoch: 98 Average loss: 132.3929
====> Test set loss: 137.9308
====> Epoch: 99 Average loss: 132.3338
====> Test set loss: 138.0985
====> Epoch: 100 Average loss: 132.4394
====> Test set loss: 137.9277
====> Epoch: 101 Average loss: 13

====> Epoch: 108 Average loss: 132.2973
====> Test set loss: 138.0799
====> Epoch: 109 Average loss: 132.2323
====> Test set loss: 137.7339
====> Epoch: 110 Average loss: 131.9890
====> Test set loss: 138.3188
====> Epoch: 111 Average loss: 132.1137
====> Test set loss: 137.8764
====> Epoch: 112 Average loss: 132.1965
====> Test set loss: 138.3827
====> Epoch: 113 Average loss: 132.1969
====> Test set loss: 138.2935
====> Epoch: 114 Average loss: 132.1679
====> Test set loss: 138.4215
====> Epoch: 115 Average loss: 132.4583
====> Test set loss: 138.3160
====> Epoch: 116 Average loss: 132.4475
====> Test set loss: 138.5337
====> Epoch: 117 Average loss: 132.3280
====> Test set loss: 137.7877
====> Epoch: 118 Average loss: 132.3620
====> Test set loss: 138.2223
====> Epoch: 119 Average loss: 132.7727
====> Test set loss: 139.2714
====> Epoch: 120 Average loss: 132.7619
====> Test set loss: 138.1194
====> Epoch: 121 Average loss: 132.1548
====> Test set loss: 137.7797
====> Epoch: 122 Ave

====> Test set loss: 138.1307
====> Epoch: 130 Average loss: 131.5866
====> Test set loss: 137.7514
====> Epoch: 131 Average loss: 131.5452
====> Test set loss: 137.9238
====> Epoch: 132 Average loss: 131.8417
====> Test set loss: 137.6447
====> Epoch: 133 Average loss: 131.7273
====> Test set loss: 138.1391
====> Epoch: 134 Average loss: 131.5776
====> Test set loss: 137.8894
====> Epoch: 135 Average loss: 131.5562
====> Test set loss: 137.6659
====> Epoch: 136 Average loss: 131.3215
====> Test set loss: 137.6024
====> Epoch: 137 Average loss: 131.2835
====> Test set loss: 137.8555
====> Epoch: 138 Average loss: 131.4306
====> Test set loss: 137.8125
====> Epoch: 139 Average loss: 131.1708
====> Test set loss: 137.6402
====> Epoch: 140 Average loss: 131.2577
====> Test set loss: 138.0522
====> Epoch: 141 Average loss: 131.7024
====> Test set loss: 138.3987
====> Epoch: 142 Average loss: 132.2325
====> Test set loss: 138.4693
====> Epoch: 143 Average loss: 131.5676
====> Test set loss:

====> Epoch: 151 Average loss: 131.0490
====> Test set loss: 137.8497
====> Epoch: 152 Average loss: 131.2748
====> Test set loss: 137.6719
====> Epoch: 153 Average loss: 131.2546
====> Test set loss: 137.9682
====> Epoch: 154 Average loss: 130.9724
====> Test set loss: 138.2380
====> Epoch: 155 Average loss: 130.9016
====> Test set loss: 137.6245
====> Epoch: 156 Average loss: 130.8682
====> Test set loss: 138.1454
====> Epoch: 157 Average loss: 131.0232
====> Test set loss: 137.8804
====> Epoch: 158 Average loss: 131.0883
====> Test set loss: 138.0251
====> Epoch: 159 Average loss: 130.9396
====> Test set loss: 137.9237
====> Epoch: 160 Average loss: 131.3245
====> Test set loss: 138.6091
====> Epoch: 161 Average loss: 131.3201
====> Test set loss: 138.0523
====> Epoch: 162 Average loss: 131.1661
====> Test set loss: 138.1454
====> Epoch: 163 Average loss: 131.1436
====> Test set loss: 137.9404
====> Epoch: 164 Average loss: 130.8729
====> Test set loss: 137.8927
====> Epoch: 165 Ave

====> Epoch: 172 Average loss: 130.7132
====> Test set loss: 137.9793
====> Epoch: 173 Average loss: 130.4952
====> Test set loss: 138.4130
====> Epoch: 174 Average loss: 130.7723
====> Test set loss: 138.1829
====> Epoch: 175 Average loss: 130.3818
====> Test set loss: 138.1142
====> Epoch: 176 Average loss: 130.5652
====> Test set loss: 138.0066
====> Epoch: 177 Average loss: 130.6557
====> Test set loss: 138.7145
====> Epoch: 178 Average loss: 130.9835
====> Test set loss: 138.3385
====> Epoch: 179 Average loss: 130.8109
====> Test set loss: 138.0338
====> Epoch: 180 Average loss: 130.7563
====> Test set loss: 138.2623
====> Epoch: 181 Average loss: 130.5384
====> Test set loss: 138.1714
====> Epoch: 182 Average loss: 130.4332
====> Test set loss: 138.2834
====> Epoch: 183 Average loss: 130.4042
====> Test set loss: 138.2416
====> Epoch: 184 Average loss: 130.4112
====> Test set loss: 138.1421
====> Epoch: 185 Average loss: 130.3800
====> Test set loss: 138.0399
====> Epoch: 186 Ave

====> Epoch: 193 Average loss: 130.3614
====> Test set loss: 138.5624
====> Epoch: 194 Average loss: 130.3228
====> Test set loss: 138.2419
====> Epoch: 195 Average loss: 130.4132
====> Test set loss: 138.3754
====> Epoch: 196 Average loss: 130.3525
====> Test set loss: 138.6287
====> Epoch: 197 Average loss: 130.5051
====> Test set loss: 138.4451
====> Epoch: 198 Average loss: 130.1490
====> Test set loss: 138.7622
====> Epoch: 199 Average loss: 130.2782
====> Test set loss: 138.3823
====> Epoch: 200 Average loss: 130.5280
====> Test set loss: 138.0468
====> Epoch: 201 Average loss: 130.0728
====> Test set loss: 138.1659
====> Epoch: 202 Average loss: 130.3381
====> Test set loss: 138.7104
====> Epoch: 203 Average loss: 130.1712
====> Test set loss: 138.6847
====> Epoch: 204 Average loss: 130.0182
====> Test set loss: 138.4783
====> Epoch: 205 Average loss: 130.1966
====> Test set loss: 137.9868
====> Epoch: 206 Average loss: 130.1467
====> Test set loss: 138.6126
====> Epoch: 207 Ave

====> Epoch: 214 Average loss: 130.0714
====> Test set loss: 138.5886
====> Epoch: 215 Average loss: 129.9415
====> Test set loss: 138.7458
====> Epoch: 216 Average loss: 129.9529
====> Test set loss: 138.7886
====> Epoch: 217 Average loss: 130.0653
====> Test set loss: 138.7086
====> Epoch: 218 Average loss: 130.1845
====> Test set loss: 138.9287
====> Epoch: 219 Average loss: 130.3601
====> Test set loss: 138.8264
====> Epoch: 220 Average loss: 130.0420
====> Test set loss: 138.7402
====> Epoch: 221 Average loss: 130.1945
====> Test set loss: 138.8863
====> Epoch: 222 Average loss: 129.9811
====> Test set loss: 138.6814
====> Epoch: 223 Average loss: 129.9517
====> Test set loss: 139.3324
====> Epoch: 224 Average loss: 131.1506
====> Test set loss: 139.1039
====> Epoch: 225 Average loss: 130.6694
====> Test set loss: 138.6720
====> Epoch: 226 Average loss: 129.9029
====> Test set loss: 138.5253
====> Epoch: 227 Average loss: 129.8244
====> Test set loss: 139.1543
====> Epoch: 228 Ave

====> Epoch: 235 Average loss: 130.1590
====> Test set loss: 138.7983
====> Epoch: 236 Average loss: 130.0242
====> Test set loss: 139.1281
====> Epoch: 237 Average loss: 130.0689
====> Test set loss: 138.6701
====> Epoch: 238 Average loss: 129.9326
====> Test set loss: 139.0190
====> Epoch: 239 Average loss: 129.6161
====> Test set loss: 138.6885
====> Epoch: 240 Average loss: 129.4727
====> Test set loss: 138.5324
====> Epoch: 241 Average loss: 129.3967
====> Test set loss: 138.6660
====> Epoch: 242 Average loss: 129.5320
====> Test set loss: 138.9383
====> Epoch: 243 Average loss: 129.4844
====> Test set loss: 138.8263
====> Epoch: 244 Average loss: 129.5775
====> Test set loss: 138.9434
====> Epoch: 245 Average loss: 129.6149
====> Test set loss: 138.6782
====> Epoch: 246 Average loss: 129.5352
====> Test set loss: 139.0649
====> Epoch: 247 Average loss: 129.8054
====> Test set loss: 139.1066
====> Epoch: 248 Average loss: 130.2732
====> Test set loss: 139.1932
====> Epoch: 249 Ave

====> Epoch: 256 Average loss: 129.3535
====> Test set loss: 139.3266
====> Epoch: 257 Average loss: 129.4988
====> Test set loss: 139.1232
====> Epoch: 258 Average loss: 129.9021
====> Test set loss: 139.1560
====> Epoch: 259 Average loss: 129.7351
====> Test set loss: 138.9571
====> Epoch: 260 Average loss: 129.6492
====> Test set loss: 138.9302
====> Epoch: 261 Average loss: 129.7550
====> Test set loss: 139.0253
====> Epoch: 262 Average loss: 129.7368
====> Test set loss: 139.6358
====> Epoch: 263 Average loss: 129.8098
====> Test set loss: 139.6406
====> Epoch: 264 Average loss: 129.8649
====> Test set loss: 139.5403
====> Epoch: 265 Average loss: 129.7692
====> Test set loss: 139.5935
====> Epoch: 266 Average loss: 129.7459
====> Test set loss: 139.2441
====> Epoch: 267 Average loss: 129.6379
====> Test set loss: 139.3390
====> Epoch: 268 Average loss: 129.8441
====> Test set loss: 139.7183
====> Epoch: 269 Average loss: 129.7479
====> Test set loss: 139.3571
====> Epoch: 270 Ave

====> Test set loss: 139.9408
====> Epoch: 278 Average loss: 130.3080
====> Test set loss: 139.1269
====> Epoch: 279 Average loss: 129.6780
====> Test set loss: 139.2284
====> Epoch: 280 Average loss: 129.4337
====> Test set loss: 139.4064
====> Epoch: 281 Average loss: 129.3137
====> Test set loss: 139.1243
====> Epoch: 282 Average loss: 129.4891
====> Test set loss: 138.9370
====> Epoch: 283 Average loss: 129.6301
====> Test set loss: 139.1803
====> Epoch: 284 Average loss: 129.5729
====> Test set loss: 139.6132
====> Epoch: 285 Average loss: 129.9409
====> Test set loss: 139.5146
====> Epoch: 286 Average loss: 129.5906
====> Test set loss: 139.1752
====> Epoch: 287 Average loss: 129.2008
====> Test set loss: 138.9114
====> Epoch: 288 Average loss: 129.1351
====> Test set loss: 139.2483
====> Epoch: 289 Average loss: 129.3405
====> Test set loss: 139.1222
====> Epoch: 290 Average loss: 129.5062
====> Test set loss: 139.4726
====> Epoch: 291 Average loss: 129.8593
====> Test set loss:

====> Epoch: 299 Average loss: 129.2232
====> Test set loss: 140.0319
====> Epoch: 300 Average loss: 129.3788
====> Test set loss: 139.8180


In [None]:
plt.figure()
plt.plot(test_loss_array, label='Reconstruction loss')
plt.legend()

# Generate

In [None]:
with torch.no_grad():
    z = torch.randn(64, 2).cuda()
    sample = vae.decoder(z).cuda()
    
    save_image(sample.view(64, 1, 28, 28), './outputs/gen_img_{}.png')