In [1]:
import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
from matplotlib import pyplot as plt

ModuleNotFoundError: No module named 'torchvision'

In [None]:
# Displaying routine

def display_images(in_, out, n=1, label=None, count=False):
    for N in range(n):
        if in_ is not None:
            in_pic = in_.data.cpu().view(-1, 28, 28)
            plt.figure(figsize=(18, 4))
            plt.suptitle(label + ' – real test data / reconstructions', color='w', fontsize=16)
            for i in range(4):
                plt.subplot(1,4,i+1)
                plt.imshow(in_pic[i+4*N])
                plt.axis('off')
        out_pic = out.data.cpu().view(-1, 28, 28)
        plt.figure(figsize=(18, 6))
        for i in range(4):
            plt.subplot(1,4,i+1)
            plt.imshow(out_pic[i+4*N])
            plt.axis('off')
            if count: plt.title(str(4 * N + i), color='w')

In [None]:
# Set random seeds

torch.manual_seed(1)
torch.cuda.manual_seed(1)

In [None]:
from torchvision import datasets, transforms

datasets.MNIST.resources = [
    ('https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz', 'f68b3c2dcbeaaa9fbdd348bbdeb94873'),
    ('https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz', 'd53e105ee54ea40749a09fcbcd1e9432'),
    ('https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz', '9fb629c4189551a2d022fa330f9573f3'),
    ('https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz', 'ec29112dd5afa0611ce80d1b7f02629c')
]

batch_sz = 256
tr_ = transforms.ToTensor()
# with normalization require tuning of beta 
# tr_ = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0,), (255,))])
# MNIST
train_dataset = datasets.MNIST(
    root='./data', 
    train=True, 
    transform=tr_,  
    download=True
)

test_dataset = datasets.MNIST(
    root='./data', 
    train=False, 
    transform=tr_  
)
# DataLoader
train_loader = torch.utils.data.DataLoader(
    dataset=train_dataset,
    batch_size=batch_sz,
    shuffle=True 
)

test_loader = torch.utils.data.DataLoader(
    dataset=test_dataset,
    batch_size=batch_sz,
    shuffle=False 
)

In [None]:
# Defining the device

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
# Defining the model

# no dimensions in latent space
d = 20

class VAE(nn.Module):
    def __init__(self):
        super().__init__()

        self.encoder = nn.Sequential(
            nn.Linear(784, d ** 2),
            nn.ReLU(),
            nn.Linear(d ** 2, d * 2)
        )
        # improve generative performance with ConvLayers 
        # dont forget to add them to the decoder!
        
        # decoder requires sample as input,
        # not just the output of the encoder.
        self.decoder = nn.Sequential(
            nn.Linear(d, d ** 2),
            nn.ReLU(),
            nn.Linear(d ** 2, 784),
            nn.Sigmoid(),
        )

    def reparameterise(self, mu, logvar):
        # draw sample z in latent space for given parameters
        if self.training:
            std = logvar.mul(0.5).exp_()
            eps = std.data.new(std.size()).normal_() # random! use randn_like(.)
            return eps.mul(std).add_(mu)
        else:
            return mu

    def forward(self, x):
        # propagate image through encoder and decoder.
        # output reconstruction and parameters of distribution over
        # latent representations.
        mu_logvar = self.encoder(x.view(-1, 784)).view(-1, 2, d)
        mu = mu_logvar[:, 0, :]
        logvar = mu_logvar[:, 1, :]
        z = self.reparameterise(mu, logvar)
        return self.decoder(z), mu, logvar

model = VAE().to(device)

In [None]:
# Setting the optimiser

learning_rate = 1e-3

optimizer = torch.optim.Adam(
    model.parameters(),
    lr=learning_rate,
)

In [None]:
# Reconstruction + β * KL divergence losses summed over all elements and batch

def loss_function(x_hat, x, mu, logvar, β=1):
    # reconstruction:
    # MSE = nn.functional.mse_loss(x_hat, x.view(-1, 784), reduction='sum')
    BCE = nn.functional.binary_cross_entropy(
        x_hat, x.view(-1, 784), reduction='sum')
    
    # strucutre in latent space:
    KLD = 0.5 * torch.sum(logvar.exp() - logvar - 1 + mu.pow(2))
    
    # print('BCE: ' , BCE)
    # print('KLD: ' , KLD)
    # print('beta:', β)
    return BCE + β * KLD

In [None]:
# Training and testing the VAE

epochs = 50
codes = dict(μ=list(), logσ2=list(), y=list())
for epoch in range(0, epochs + 1):
    
    # Training
    if epoch > 0:  # test untrained net first
        model.train() # set module into training mode
        train_loss = 0
        for x, _ in train_loader:
            x = x.to(device) # use loader directly on GPU
            # clear gradient buffer
            optimizer.zero_grad()
            # ===================forward=====================
            x_hat, mu, logvar = model(x)
            # compute error for this observation
            loss = loss_function(x_hat, x, mu, logvar) # build diff graph
            train_loss += loss.item()
            # ===================backward====================
            # compute gradients and step into opposite direction
            loss.backward()
            optimizer.step()
        # ===================log========================
        print(f'====> Epoch: {epoch} Average loss: {train_loss / len(train_loader.dataset):.4f}')
    
    # Testing
    
    means, logvars, labels = list(), list(), list()
    with torch.no_grad():
        model.eval()
        test_loss = 0
        for x, y in test_loader:
            x = x.to(device)
            # ===================forward=====================
            x_hat, mu, logvar = model(x)
            test_loss += loss_function(x_hat, x, mu, logvar).item()
            # =====================log=======================
            means.append(mu.detach())
            logvars.append(logvar.detach())
            labels.append(y.detach())
    # ===================log========================
    codes['μ'].append(torch.cat(means))
    codes['logσ2'].append(torch.cat(logvars))
    codes['y'].append(torch.cat(labels))
    test_loss /= len(test_loader.dataset)
    print(f'====> Test set loss: {test_loss:.4f}')
    display_images(x, x_hat, 1, f'Epoch {epoch}')

In [None]:
# Generating a few samples

N = 16
z = torch.randn((N, d)).to(device)
sample = model.decoder(z)
display_images(None, sample, N // 4, count=True)