In [18]:
import os
os.environ['CUDA_DEVICE_ORDER']='PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES']='1'

In [19]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
import random
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision

import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image


In [20]:
class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)


class UnFlatten(nn.Module):
    def forward(self, input, size=100):
        return input.view(input.size(0), size, 1, 1)

In [21]:
class VAE(nn.Module):
    def __init__(self, x_dim, h_dim1, h_dim2, z_dim):
        super(VAE, self).__init__()
        
        # encoder part
        self.fc1 = nn.Linear(x_dim, h_dim1)
        self.fc2 = nn.Linear(h_dim1, h_dim2)
        self.fc31 = nn.Linear(h_dim2, z_dim)
        self.fc32 = nn.Linear(h_dim2, z_dim)
        # decoder part
        self.fc4 = nn.Linear(z_dim, h_dim2)
        self.fc5 = nn.Linear(h_dim2, h_dim1)
        self.fc6 = nn.Linear(h_dim1, x_dim)
        
    def encoder(self, x):
        h = F.relu(self.fc1(x))
        h = F.relu(self.fc2(h))
        return self.fc31(h), self.fc32(h) # mu, log_var
    
    def sampling(self, mu, log_var):
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu) # return z sample
        
    def decoder(self, z):
        h = F.relu(self.fc4(z))
        h = F.relu(self.fc5(h))
        return F.sigmoid(self.fc6(h)) 
    
    def forward(self, x):
        mu, log_var = self.encoder(x.view(-1, 784))
        z = self.sampling(mu, log_var)
        return self.decoder(z), mu, log_var

In [22]:
train_dataset = datasets.MNIST(root='./mnist_data/', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = datasets.MNIST(root='./mnist_data/', train=False, transform=transforms.ToTensor(), download=False)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=128, shuffle=False)

In [23]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [24]:
vae = VAE(x_dim=784, h_dim1= 512, h_dim2=256, z_dim=2)
if torch.cuda.is_available():
    vae.cuda()

In [25]:
optimizer = optim.Adam(vae.parameters())
def loss_function(recon_x, x, mu, log_var):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
    KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    return BCE + 10 * KLD

In [26]:
def train(epoch):
    vae.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.cuda()
        optimizer.zero_grad()
        
        recon_batch, mu, log_var = vae(data)
        loss = loss_function(recon_batch, data, mu, log_var)
        
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item() / len(data)))
    print('====> Epoch: {} Average loss: {:.4f}'.format(epoch, train_loss / len(train_loader.dataset)))

In [27]:
def test():
    vae.eval()
    test_loss= 0
    with torch.no_grad():
        for data, _ in test_loader:
            data = data.cuda()
            recon, mu, log_var = vae(data)
            
            # sum up batch loss
            test_loss += loss_function(recon, data, mu, log_var).item()
        
    test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))

In [28]:
from tqdm.autonotebook import tqdm

for epoch in tqdm(range(1, 51)):
    train(epoch)
    test()

HBox(children=(FloatProgress(value=0.0, max=50.0), HTML(value='')))

====> Epoch: 1 Average loss: 205.1307
====> Test set loss: 192.7945
====> Epoch: 2 Average loss: 191.4079
====> Test set loss: 190.6421
====> Epoch: 3 Average loss: 189.5318
====> Test set loss: 188.7500
====> Epoch: 4 Average loss: 188.2595
====> Test set loss: 187.7227
====> Epoch: 5 Average loss: 187.5780
====> Test set loss: 187.1112
====> Epoch: 6 Average loss: 187.0155
====> Test set loss: 186.1965
====> Epoch: 7 Average loss: 186.6446
====> Test set loss: 185.8432
====> Epoch: 8 Average loss: 186.1861
====> Test set loss: 185.8456
====> Epoch: 9 Average loss: 185.6800
====> Test set loss: 185.3806
====> Epoch: 10 Average loss: 185.2661
====> Test set loss: 185.1349
====> Epoch: 11 Average loss: 185.0022
====> Test set loss: 184.6427
====> Epoch: 12 Average loss: 184.5710
====> Test set loss: 184.1479
====> Epoch: 13 Average loss: 184.3184
====> Test set loss: 184.1066
====> Epoch: 14 Average loss: 184.0733
====> Test set loss: 183.8774
====> Epoch: 15 Average loss: 183.7597
====

In [12]:
with torch.no_grad():
#     z = torch.randn(64, 2).cuda()
    x = np.linspace(-1, 1, 21)
    z = []
    for i in x:
        for j in x:
            z.append([j, -i])
            
    z = torch.tensor(z).cuda()
    
#     print (z.shape)
    sample = vae.decoder(z).cuda()
    
    save_image(sample.view(21*21, 1, 28, 28), './samples/sample_1' + '.png', nrow=21)

torch.Size([441, 2])
