In [14]:
import os
os.environ['CUDA_DEVICE_ORDER']='PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES']='0'

In [15]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
import random
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision

import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image


In [16]:
bs = 128

train_dataset = datasets.MNIST(root='./mnist_data/', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = datasets.MNIST(root='./mnist_data/', train=False, transform=transforms.ToTensor(), download=False)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=bs, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=bs, shuffle=False)

In [17]:
def train(epoch, device, weight):
    vae.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        
        recon_batch, mu, log_var = vae(data)
        loss = loss_function(recon_batch, data, mu, log_var, weight)
        
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item() / len(data)))
    print('====> Epoch: {} Average loss: {:.4f}'.format(epoch, train_loss / len(train_loader.dataset)))

In [18]:
def test(device, weight):
    vae.eval()
    test_loss= 0
    with torch.no_grad():
        for data, _ in test_loader:
            data = data.to(device)
            recon, mu, log_var = vae(data)
            
            # sum up batch loss
            test_loss += loss_function(recon, data, mu, log_var, weight).item()
        
    test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))

# Task 1: Design the autoencoder structured network for MNIST

In [19]:
# YOUR CODE!!

class VAE(nn.Module):
    def __init__(self, x_dim, h_dim1, h_dim2, z_dim):
        super(VAE, self).__init__()
        
        # encoder part
        
        #############################################################
    
        # YOUR CODE!!
        self.fc1 = nn.Linear(x_dim, h_dim1)
        self.fc2 = nn.Linear(h_dim1, h_dim2)
        self.fc3_1 = nn.Linear(h_dim2, z_dim)
        self.fc3_2 = nn.Linear(h_dim2, z_dim)

        #############################################################
        
        
        # decoder part
        
        #############################################################
    
        # YOUR CODE!!
        self.fc4 = nn.Linear(z_dim, h_dim2)
        self.fc5 = nn.Linear(h_dim2, h_dim1)
        self.fc6 = nn.Linear(h_dim1, x_dim)
        #############################################################
        
        
    def encoder(self, x):
        # return mu, log_var
        
        #############################################################
    
        # YOUR CODE!!
        h = F.relu(self.fc1(x))
        h = F.relu(self.fc2(h))
        return self.fc3_1(h), self.fc3_2(h) # mu, log_var
        #############################################################
    
    def sampling(self, mu, log_var):
        # return z sample
        
        #############################################################
    
        # YOUR CODE!!
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu) # return z sample
        #############################################################
        
    def decoder(self, z):
        # return generated img
        
        #############################################################
    
        # YOUR CODE!!
        h = F.relu(self.fc4(z))
        h = F.relu(self.fc5(h))
        return F.sigmoid(self.fc6(h))
        #############################################################
    
    def forward(self, x):
        mu, log_var = self.encoder(x.view(-1, 784))
        z = self.sampling(mu, log_var)
        return self.decoder(z), mu, log_var

In [20]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [21]:
vae = VAE(x_dim=784, h_dim1= 512, h_dim2=256, z_dim=50)
if torch.cuda.is_available():
    vae.to(device)

In [22]:
vae

VAE(
  (fc1): Linear(in_features=784, out_features=512, bias=True)
  (fc2): Linear(in_features=512, out_features=256, bias=True)
  (fc3_1): Linear(in_features=256, out_features=50, bias=True)
  (fc3_2): Linear(in_features=256, out_features=50, bias=True)
  (fc4): Linear(in_features=50, out_features=256, bias=True)
  (fc5): Linear(in_features=256, out_features=512, bias=True)
  (fc6): Linear(in_features=512, out_features=784, bias=True)
)

# Task 2: Design the loss function for autoencoder with weight of KLD term

In [23]:
optimizer = optim.Adam(vae.parameters())

In [24]:

def loss_function(recon_x, x, mu, log_var, weight):
    # return reconstruction error + KL divergence losses
    
    #############################################################
    
    # YOUR CODE!!
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
    KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    return BCE + weight*KLD
    #############################################################
    
    pass

In [25]:
from tqdm.autonotebook import tqdm

weight = 1.0

for epoch in tqdm(range(1, 51)):
    train(epoch, device, weight)
    test(device, weight)

HBox(children=(FloatProgress(value=0.0, max=50.0), HTML(value='')))

====> Epoch: 1 Average loss: 178.2026
====> Test set loss: 142.8234
====> Epoch: 2 Average loss: 132.3210
====> Test set loss: 123.7309
====> Epoch: 3 Average loss: 119.9631
====> Test set loss: 115.4444
====> Epoch: 4 Average loss: 114.2481
====> Test set loss: 111.7911
====> Epoch: 5 Average loss: 111.2220
====> Test set loss: 109.1984
====> Epoch: 6 Average loss: 108.8526
====> Test set loss: 107.4432
====> Epoch: 7 Average loss: 107.1075
====> Test set loss: 105.9531
====> Epoch: 8 Average loss: 105.8969
====> Test set loss: 105.1582
====> Epoch: 9 Average loss: 104.9341
====> Test set loss: 104.4805
====> Epoch: 10 Average loss: 104.1572
====> Test set loss: 103.7627
====> Epoch: 11 Average loss: 103.5508
====> Test set loss: 103.5600
====> Epoch: 12 Average loss: 103.0151
====> Test set loss: 103.1518
====> Epoch: 13 Average loss: 102.5538
====> Test set loss: 102.8196
====> Epoch: 14 Average loss: 102.0851
====> Test set loss: 102.4738
====> Epoch: 15 Average loss: 101.7770
====

====> Epoch: 26 Average loss: 99.1719
====> Test set loss: 100.3470
====> Epoch: 27 Average loss: 99.0315
====> Test set loss: 100.2639
====> Epoch: 28 Average loss: 98.8812
====> Test set loss: 100.3864
====> Epoch: 29 Average loss: 98.7591
====> Test set loss: 100.2273
====> Epoch: 30 Average loss: 98.5875
====> Test set loss: 99.9982
====> Epoch: 31 Average loss: 98.4487
====> Test set loss: 99.9561
====> Epoch: 32 Average loss: 98.4141
====> Test set loss: 99.4911
====> Epoch: 33 Average loss: 98.2799
====> Test set loss: 99.8044
====> Epoch: 34 Average loss: 98.1827
====> Test set loss: 99.7127
====> Epoch: 35 Average loss: 98.1179
====> Test set loss: 99.6649
====> Epoch: 36 Average loss: 97.9904
====> Test set loss: 99.3304
====> Epoch: 37 Average loss: 97.9021
====> Test set loss: 99.5451
====> Epoch: 38 Average loss: 97.8004
====> Test set loss: 99.6752
====> Epoch: 39 Average loss: 97.7582
====> Test set loss: 99.3283
====> Epoch: 40 Average loss: 97.6346
====> Test set loss:

# Task 3

In [26]:
with torch.no_grad():
    
    #########################################################
    
    # YOUR CODE!!
    z = torch.randn(64, 50).to(device)
    sample = vae.decoder(z).to(device)
    
    
    #########################################################
    
            
    #z = torch.tensor(z).to(device)
    #sample = vae.decoder(z.float())
    
    if not os.path.exists('./samples'):
        os.makedirs('./samples')
    
    save_image(sample.view(64, 1, 28, 28), './samples/problem 2/sample_50D' + '.png')