In [None]:
from torchsummary import summary # pip only

import itertools
import numpy as np

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import torchvision 
import torchvision.datasets as datasets

%load_ext autoreload
%autoreload 2
    
import models
import data

In [None]:
device = torch.device('cuda:0')

torch.manual_seed(0)
if torch.cuda.is_available():
    torch.cuda.manual_seed(0)

latent_size = 2
enc_sizes = [784,500,latent_size]
dec_sizes = [latent_size,500,784]
batch_size = 100
train = False

n_epochs = 200
n_batches_print = 100

conditional_size = 10

cvae_enc_sizes = enc_sizes.copy()
cvae_enc_sizes[0] += conditional_size

cvae_dec_sizes = dec_sizes.copy()
cvae_dec_sizes[0] += conditional_size

cvae_prior_network_sizes = [conditional_size, 5, latent_size]

cvae_enc_sizes, cvae_dec_sizes, cvae_prior_network_sizes

In [None]:
transforms = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Lambda(lambda x: torch.round(x)),
    torchvision.transforms.Lambda(lambda x: x.view(-1)),
])

trainset = datasets.MNIST(root='../data', train=True, download=True, transform=transforms)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=8)

testset = datasets.MNIST(root='../data', train=False, download=True, transform=transforms)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=True, num_workers=8)

In [None]:
x,y = next(iter(trainloader))
print(x.shape, y.shape, torch.nn.functional.one_hot(y,10)[0])

plt.imshow(x[0].view(28,28))

In [None]:
vae = models.VAE(enc_sizes,dec_sizes)

if not train:
    vae.load_state_dict(torch.load('vae.pt'))

vae.to(device)
print(vae)

In [None]:
if train:
        
    optimizer = torch.optim.Adam(vae.parameters())

    n_epochs = 200
    n_batches_print = 100

    for epoch in range(n_epochs):

        running_loss = 0

        for it, (images, labels) in enumerate(trainloader):

            images, labels = images.to(device), labels.to(device)

            mu, logvariance, z, y = vae(images)
            loss = -models.VAE.variational_objective(images, mu, logvariance, z, y)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss
            if it % n_batches_print == n_batches_print-1:
                print(f'[{epoch+1} {it+1}] loss: {running_loss/n_batches_print}')
                running_loss = 0.0

    print('Finished Training')
    
    torch.save(vae.state_dict(), "./vae.pt")

In [None]:
def test_model(vae, dataloader, device, conditional=False):
    elbo_avg = 0
    for x, y in dataloader:
        x, y = x.to(device), y.to(device)
        if conditional:
            out = vae(x, nn.functional.one_hot(y, 10).float())
            elbo_avg += models.CVAE.variational_objective(x, *out)
        else:
            out = vae(x)
            elbo_avg += models.VAE.variational_objective(x, *out)
        
    return elbo_avg/len(dataloader)


print(f"Training set ELBO = {test_model(vae, trainloader, device)}")
print(f"Test set ELBO     = {test_model(vae, testloader, device)}")