In [1]:
from __future__ import print_function
from urllib.request import urlretrieve
from torch.utils.data import Dataset
from torch import optim

import os
import pickle as pkl
import numpy as np
import torch

import torch
import torch.utils.data

from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
import os
import torch.nn as nn

import matplotlib.pyplot as plt
from IPython.display import clear_output

In [2]:
cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if cuda else "cpu")
batch_size = 64
log_interval = 100
epochs = 10
root = os.getcwd()
print(f"Current working directory: {root}")

Current working directory: C:\Users\Zuzka\git\02456_DL-sketch_features_VAE\vae


In [3]:
def load_mnist_binarized(root):
    datapath = os.path.join(root, 'bin-mnist')
    if not os.path.exists(datapath):
        os.makedirs(datapath)
    dataset = os.path.join(datapath, "mnist.pkl.gz")

    if not os.path.isfile(dataset):

        datafiles = {
            "train": "http://www.cs.toronto.edu/~larocheh/public/"
                     "datasets/binarized_mnist/binarized_mnist_train.amat",
            "valid": "http://www.cs.toronto.edu/~larocheh/public/datasets/"
                     "binarized_mnist/binarized_mnist_valid.amat",
            "test": "http://www.cs.toronto.edu/~larocheh/public/datasets/"
                    "binarized_mnist/binarized_mnist_test.amat"
        }
        datasplits = {}
        for split in datafiles.keys():
            print("Downloading %s data..." % (split))
            datasplits[split] = np.loadtxt(urlretrieve(datafiles[split])[0])

        pkl.dump([datasplits['train'], datasplits['valid'], datasplits['test']], open(dataset, "wb"))

    x_train, x_valid, x_test = pkl.load(open(dataset, "rb"))
    return x_train, x_valid, x_test


class BinMNIST(Dataset):
    """Binary MNIST dataset"""

    def __init__(self, data, device='cpu', transform=None):
        h, w, c = 28, 28, 1
        self.device = device
        self.data = torch.tensor(data, dtype=torch.float).view(-1, c, h, w)
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        if self.transform:
            sample = self.transform(sample)
        return sample.to(self.device)


def get_binmnist_datasets(root, device='cpu'):
    x_train, x_valid, x_test = load_mnist_binarized(root)
    x_train = np.append(x_train, x_valid, axis=0)  # https://github.com/casperkaae/LVAE/blob/master/run_models.py (line 401)
    return BinMNIST(x_train, device=device), BinMNIST(x_test, device=device), BinMNIST(x_test, device=device)

In [4]:
ngf = 32#64
ndf = 32#64
nc = 1
h_dim= 256#1024


class conv_VAE(nn.Module):
    def __init__(self, nz=32):
        super(conv_VAE, self).__init__()
        
        self.have_cuda = True
        self.nz = nz
        
        self.encoder = nn.Sequential(
            # input is (nc) x 28 x 28
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 14 x 14
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 7 x 7
            nn.Conv2d(ndf * 2, ndf * 4, 3, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 4 x 4
            nn.Conv2d(ndf * 4, h_dim, 4, 1, 0, bias=False),
            # nn.BatchNorm2d(1024),
            nn.LeakyReLU(0.2, inplace=True),
            # nn.Sigmoid()
        )
        
        
        self.decoder = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d( h_dim, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 3, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d(ngf * 2,     nc, 4, 2, 1, bias=False),
            # nn.BatchNorm2d(ngf),
            # nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            # nn.ConvTranspose2d(    ngf,      nc, 4, 2, 1, bias=False),
            # nn.Tanh()
            nn.Sigmoid()
            # state size. (nc) x 64 x 64
        )
        
        self.fc1 = nn.Linear(h_dim, 512)
        self.fc21 = nn.Linear(512, nz)
        self.fc22 = nn.Linear(512, nz)

        self.fc3 = nn.Linear(nz, 512)
        self.fc4 = nn.Linear(512, h_dim)

        self.lrelu = nn.LeakyReLU()
        self.relu = nn.ReLU()
        # self.sigmoid = nn.Sigmoid()
        
    def encode(self, x):
        conv = self.encoder(x);
        # print("encode conv", conv.size())
        h1 = self.fc1(conv.view(-1, h_dim))
        # print("encode h1", h1.size())
        return self.fc21(h1), self.fc22(h1)

    def decode(self, z):
        h3 = self.relu(self.fc3(z))
        deconv_input = self.fc4(h3)
        # print("deconv_input", deconv_input.size())
        deconv_input = deconv_input.view(-1, h_dim, 1, 1)
        # print("deconv_input", deconv_input.size())
        return self.decoder(deconv_input)

    def reparametrize(self, mu, logvar):
#         std = logvar.mul(0.5).exp_()
#         if self.have_cuda:
#             eps = torch.cuda.FloatTensor(std.size()).normal_()
#         else:
#             eps = torch.FloatTensor(std.size()).normal_()
#         eps = Variable(eps)
#         return eps.mul(std).add_(mu)
    
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x):
        # print("x", x.size())
        mu, logvar = self.encode(x)
        # print("mu, logvar", mu.size(), logvar.size())
        z = self.reparametrize(mu, logvar)
        # print("z", z.size())
        decoded = self.decode(z)
        # print("decoded", decoded.size())
        return decoded, mu, logvar

In [5]:
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()

        self.fc1 = nn.Linear(784, 400)
        self.fc21 = nn.Linear(400, 20)
        self.fc22 = nn.Linear(400, 20)
        self.fc3 = nn.Linear(20, 400)
        self.fc4 = nn.Linear(400, 784)

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 784))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

In [6]:
x_train, x_valid, x_test = get_binmnist_datasets(root)

train_loader = torch.utils.data.DataLoader(x_train, batch_size=batch_size, shuffle=True, pin_memory=cuda)
test_loader  = torch.utils.data.DataLoader(x_test, batch_size=batch_size, shuffle=True, pin_memory=cuda)

In [7]:
# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    kl = -0.5 * torch.sum(1 + logvar - mu**2 - torch.exp(logvar), dim=1)

    return BCE + KLD, kl.mean()

In [8]:
model = conv_VAE().to(device)
# model = VAE().to(device)

optimizer = optim.Adam(model.parameters(), lr=1e-3)
train_loss_all = []
test_loss_all = []
kl_loss_train = []
kl_loss_test = []
test_loss_all = []

def train(epoch):
    model.train()
    train_loss = 0
    batch_idx = 0
    batch_kl = []
    batch_elbo = []

    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss, kld = loss_function(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()

        batch_kl.append(kld.item())

        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss.item() / len(data)))
            
        batch_idx = batch_idx + 1

    kl_loss_train.append(np.mean(batch_kl))

    train_mean = train_loss / len(train_loader.dataset)
    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_mean))
    train_loss_all.append(train_mean)


def test(epoch):
    show_img = False
    datapath = os.path.join(root, 'results')
    if not os.path.exists(datapath):
        os.makedirs(datapath)
    model.eval()
    test_loss = 0
    batch_kl = []
    
    with torch.no_grad():
        for data in test_loader: 
            data = data.to(device)
            recon_batch, mu, logvar = model(data)
            loss, kld = loss_function(recon_batch, data, mu, logvar)
            test_loss += loss.item()

            batch_kl.append(kld.item())

            # if i == 0:
            #     n = min(data.size(0), 8)
            #     comparison = torch.cat([data[:n],
            #                           recon_batch.view(batch_size, 1, 28, 28)[:n]])
            #     save_image(comparison.cpu(),
            #              'results/reconstruction_' + str(epoch) + '.png', nrow=n)
     
            recon_batch = recon_batch.to("cpu")

            if show_img:
                # Show input digits
                f, axarr = plt.subplots(8, 8, figsize=(8, 8))
                for i, ax in enumerate(axarr.flat):
                    ax.imshow(data[i].view(28, 28), cmap="binary_r")
                    ax.axis('off')
                plt.suptitle('Inputs')
                plt.show()

                # Show reconstructions
                f, axarr = plt.subplots(8, 8, figsize=(8, 8))
                for i, ax in enumerate(axarr.flat):
                    ax.imshow(recon_batch[i].view(28, 28), cmap="binary_r")
                    ax.axis('off')
                plt.suptitle('Reconstructions')
                plt.show()

                # Show latent space samples        
                with torch.no_grad():
                    sample = torch.randn(64, 20).to(device)
                    sample = model.decode(sample).cpu()
                    f, axarr = plt.subplots(8, 8, figsize=(8, 8))
                    for i, ax in enumerate(axarr.flat):
                        ax.imshow(sample[i].view(28, 28), cmap="binary_r")
                        ax.axis('off')
                    plt.suptitle('Latent space')
                    plt.show()
                show_img = False

    kl_loss_test.append(np.mean(batch_kl))

    test_mean = test_loss / len(test_loader.dataset)
    test_loss_all.append(test_mean)
    print('====> Test set loss: {:.4f}'.format(test_mean))

In [9]:
def plot_data(epoch_list, train_loss_all, test_loss_all, kl_loss_train, kl_loss_test):
        # Overall loss (ELBO)
        plt.plot(epoch_list, train_loss_all, color="blue")
        plt.plot(epoch_list, test_loss_all, color="green", linestyle="--")
        plt.legend(['Training', 'Testing'])
        plt.xlabel('epochs')
        plt.ylabel('loss')
        plt.show()

        # KL loss
        plt.plot(epoch_list, kl_loss_train, color="blue")
        plt.plot(epoch_list, kl_loss_test, color="green", linestyle="--")
        plt.legend(['Training', 'Testing'])
        plt.xlabel('epochs')
        plt.ylabel('KL loss')
        plt.show()

In [None]:
if __name__ == "__main__":

    epoch_list = []
    for epoch in range(1, epochs + 1):
        train(epoch)
        test(epoch)
        epoch_list.append(epoch)
        if epoch == 1:
            continue

        # plot_data(epoch_list, train_loss_all, test_loss_all, kl_loss_train, kl_loss_test)
        print("################################################################")
            
            # save_image(sample.view(64, 1, 28, 28),
            #            'results/sample_' + str(epoch) + '.png')

        #clear_output(wait=True)


  This is separate from the ipykernel package so we can avoid doing imports until




  This is separate from the ipykernel package so we can avoid doing imports until


====> Epoch: 1 Average loss: 149.4989


  This is separate from the ipykernel package so we can avoid doing imports until


====> Test set loss: 115.5637
====> Epoch: 2 Average loss: 109.2394
====> Test set loss: 104.2716
################################################################
