In [0]:
import os
import pickle as pkl
from urllib.request import urlretrieve

import numpy as np
import torch
from torch.utils.data import Dataset

from __future__ import print_function
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
import os
import shutil

import matplotlib.pyplot as plt
from IPython.display import clear_output

In [124]:
cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if cuda else "cpu")
batch_size = 64
log_interval = 100
epochs = 10
root = os.getcwd()
print(f"Current working directory: {root}")

Current working directory: /content


In [0]:
def load_mnist_binarized(root):
    datapath = os.path.join(root, 'bin-mnist')
    if not os.path.exists(datapath):
        os.makedirs(datapath)
    dataset = os.path.join(datapath, "mnist.pkl.gz")

    if not os.path.isfile(dataset):

        datafiles = {
            "train": "http://www.cs.toronto.edu/~larocheh/public/"
                     "datasets/binarized_mnist/binarized_mnist_train.amat",
            "valid": "http://www.cs.toronto.edu/~larocheh/public/datasets/"
                     "binarized_mnist/binarized_mnist_valid.amat",
            "test": "http://www.cs.toronto.edu/~larocheh/public/datasets/"
                    "binarized_mnist/binarized_mnist_test.amat"
        }
        datasplits = {}
        for split in datafiles.keys():
            print("Downloading %s data..." % (split))
            datasplits[split] = np.loadtxt(urlretrieve(datafiles[split])[0])

        pkl.dump([datasplits['train'], datasplits['valid'], datasplits['test']], open(dataset, "wb"))

    x_train, x_valid, x_test = pkl.load(open(dataset, "rb"))
    return x_train, x_valid, x_test


class BinMNIST(Dataset):
    """Binary MNIST dataset"""

    def __init__(self, data, device='cpu', transform=None):
        h, w, c = 28, 28, 1
        self.device = device
        self.data = torch.tensor(data, dtype=torch.float).view(-1, c, h, w)
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        if self.transform:
            sample = self.transform(sample)
        return sample.to(self.device)


def get_binmnist_datasets(root, device='cpu'):
    x_train, x_valid, x_test = load_mnist_binarized(root)
    x_train = np.append(x_train, x_valid, axis=0)  # https://github.com/casperkaae/LVAE/blob/master/run_models.py (line 401)
    return BinMNIST(x_train, device=device), BinMNIST(x_test, device=device), BinMNIST(x_test, device=device)

In [0]:
x_train, x_valid, x_test = get_binmnist_datasets(root)

train_loader = torch.utils.data.DataLoader(x_train, batch_size=batch_size, shuffle=True, pin_memory=cuda)
test_loader  = torch.utils.data.DataLoader(x_test, batch_size=batch_size, shuffle=True, pin_memory=cuda)

Downloading train data...
Downloading valid data...
Downloading test data...


In [0]:
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()

        self.fc1 = nn.Linear(784, 400)
        self.fc21 = nn.Linear(400, 20)
        self.fc22 = nn.Linear(400, 20)
        self.fc3 = nn.Linear(20, 400)
        self.fc4 = nn.Linear(400, 784)

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 784))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar


model = VAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
train_loss_all = []
test_loss_all = []
kl_loss_train = []
kl_loss_test = []
test_loss_all = []


# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    kl = -0.5 * torch.sum(1 + logvar - mu**2 - torch.exp(logvar), dim=1)

    return BCE + KLD, kl.mean()


def train(epoch):
    model.train()
    train_loss = 0
    batch_idx = 0
    batch_kl = []
    batch_elbo = []

    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss, kld = loss_function(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()

        batch_kl.append(kld.item())

        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss.item() / len(data)))
            
        batch_idx = batch_idx + 1

    kl_loss_train.append(np.mean(batch_kl))

    train_mean = train_loss / len(train_loader.dataset)
    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_mean))
    train_loss_all.append(train_mean)


def test(epoch):
    show_img = False
    save_img = True

    model.eval()
    test_loss = 0
    batch_kl = []
    
    with torch.no_grad():
        for data in test_loader: 
            data = data.to(device)
            recon_batch, mu, logvar = model(data)
            loss, kld = loss_function(recon_batch, data, mu, logvar)
            test_loss += loss.item()

            batch_kl.append(kld.item())

            if data.shape[0] == batch_size:
                out_batch_comparison(data, recon_batch, epoch)

            # if epoch == 1:
            #     n = min(data.size(0), 8)
            #     comparison = torch.cat([data[:n],
            #                           recon_batch.view(batch_size, 1, 28, 28)[:n]])
            #     save_image(comparison.cpu(),
            #              'results/reconstruction_' + str(epoch) + '.png', nrow=n)
     
            recon_batch = recon_batch.to("cpu")
            sample = torch.randn(64, 20).to(device)
            sample = model.decode(sample).cpu()

            if show_img:
                # Show input digits
                f, axarr = plt.subplots(8, 8, figsize=(8, 8))
                for i, ax in enumerate(axarr.flat):
                    ax.imshow(data[i].cpu().view(28, 28), cmap="binary_r")
                    ax.axis('off')
                plt.suptitle('Inputs')
                plt.show()

                # Show reconstructions
                f, axarr = plt.subplots(8, 8, figsize=(8, 8))
                for i, ax in enumerate(axarr.flat):
                    ax.imshow(recon_batch[i].view(28, 28), cmap="binary_r")
                    ax.axis('off')
                plt.suptitle('Reconstructions')
                plt.show()

                # Show latent space samples        
                with torch.no_grad():
                    # sample = torch.randn(64, 20).to(device)
                    # sample = model.decode(sample).cpu()
                    f, axarr = plt.subplots(8, 8, figsize=(8, 8))
                    for i, ax in enumerate(axarr.flat):
                        ax.imshow(sample[i].view(28, 28), cmap="binary_r")
                        ax.axis('off')
                    plt.suptitle('Latent space')
                    plt.show()
                show_img = False

            if save_img:
                # out_batch_image(data, "test_input", epoch)
                # out_batch_image(recon_batch, "test_output", epoch)
                out_batch_image(sample, "latent_sample", epoch)

    kl_loss_test.append(np.mean(batch_kl))

    test_mean = test_loss / len(test_loader.dataset)
    test_loss_all.append(test_mean)
    print('====> Test set loss: {:.4f}'.format(test_mean))

In [0]:
def plot_data(epoch_list, train_loss_all, test_loss_all, kl_loss_train, kl_loss_test):
        # Overall loss (ELBO)
        plt.plot(epoch_list, train_loss_all, color="blue")
        plt.plot(epoch_list, test_loss_all, color="green", linestyle="--")
        plt.legend(['Training', 'Testing'])
        plt.xlabel('epochs')
        plt.ylabel('loss')
        plt.show()

        # KL loss
        plt.plot(epoch_list, kl_loss_train, color="blue")
        plt.plot(epoch_list, kl_loss_test, color="green", linestyle="--")
        plt.legend(['Training', 'Testing'])
        plt.xlabel('epochs')
        plt.ylabel('KL loss')
        plt.show()

def out_batch_image(data, prefix, epoch):
    save_image(data.view(-1, 1, 28, 28),
                       f'results/{prefix}_{epoch}.png')

def out_batch_comparison(input, recon, epoch):
    comparison = torch.cat([input, recon.view(-1, 1, 28, 28)])
    out_batch_image(comparison, "comparison", epoch)



In [218]:
if __name__ == "__main__":

    datapath = os.path.join(root, 'results')
    if os.path.exists(datapath):
        shutil.rmtree('results')
    if not os.path.exists(datapath):
        os.makedirs(datapath)

    epoch_list = []
    for epoch in range(1, epochs + 1):
        train(epoch)
        test(epoch)
        epoch_list.append(epoch)
        if epoch == 1:
            continue

        plot_data(epoch_list, train_loss_all, test_loss_all, kl_loss_train, kl_loss_test)
        print("################################################################")
            
            # save_image(sample.view(64, 1, 28, 28),
            #            'results/sample_' + str(epoch) + '.png')

        #clear_output(wait=True)


====> Epoch: 1 Average loss: 146.3072
====> Test set loss: 117.0949
====> Epoch: 2 Average loss: 113.4724
====> Test set loss: 109.7257
################################################################
====> Epoch: 3 Average loss: 108.8341
====> Test set loss: 106.9986
################################################################
====> Epoch: 4 Average loss: 106.6213
====> Test set loss: 105.5940
################################################################
====> Epoch: 5 Average loss: 105.3213
====> Test set loss: 104.5152
################################################################
====> Epoch: 6 Average loss: 104.3426
====> Test set loss: 103.7401
################################################################
====> Epoch: 7 Average loss: 103.6449
====> Test set loss: 103.4695
################################################################
====> Epoch: 8 Average loss: 103.0298
====> Test set loss: 102.8133
################################################################
=

In [0]:
# parser = argparse.ArgumentParser(description='VAE MNIST Example')
# parser.add_argument('--batch-size', type=int, default=128, metavar='N',
#                     help='input batch size for training (default: 128)')
# parser.add_argument('--epochs', type=int, default=10, metavar='N',
#                     help='number of epochs to train (default: 10)')
# parser.add_argument('--no-cuda', action='store_true', default=False,
#                     help='enables CUDA training')
# parser.add_argument('--seed', type=int, default=1, metavar='S',
#                     help='random seed (default: 1)')
# parser.add_argument('--log-interval', type=int, default=10, metavar='N',
#                     help='how many batches to wait before logging training status')
# args = parser.parse_args()
# args.cuda = not args.no_cuda and torch.cuda.is_available()

# torch.manual_seed(args.seed)


# kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}

# train_loader = torch.utils.data.DataLoader(
#     datasets.MNIST('../data', train=True, download=True,
#                    transform=transforms.ToTensor()),
#     batch_size=batch_size, shuffle=True)
# test_loader = torch.utils.data.DataLoader(
#     datasets.MNIST('../data', train=False, transform=transforms.ToTensor()),
#     batch_size=batch_size, shuffle=True)