In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd 
import random 
import torch
import torchvision
import torch.optim as optim

from torchvision import transforms
from torch.utils.data import DataLoader,random_split
from tqdm import tqdm

from vae import VariationalAutoencoder
from modules import train_epoch, test_epoch, plot_ae_outputs

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f'Selected device: {device}')

In [None]:
mnist_trainset = torchvision.datasets.MNIST(root="..", train=True, download=True, transform=None)
mnist_testset = torchvision.datasets.MNIST(root="..", train=False, download=True, transform=None)

# mnist_trainset = torchvision.datasets.CIFAR10(root="..", train=True, download=True, transform=None)
# mnist_testset = torchvision.datasets.CIFAR10(root="..", train=False, download=True, transform=None)

batch_size=128
m=len(mnist_trainset)

train_transform = transforms.Compose([
    transforms.ToTensor(),
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
])

mnist_trainset.transform = train_transform
mnist_testset.transform = test_transform

mnist_trainset, val_data = random_split(mnist_trainset, [int(m-m*0.2), int(m*0.2)])

train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=batch_size)
valid_loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size)
test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size=batch_size,shuffle=True)

In [None]:
d = 4
lr = 1e-3 

vae = VariationalAutoencoder(len(mnist_trainset[0][0]), latent_dims=d, device=device)
vae.to(device)

optim = torch.optim.Adam(vae.parameters(), lr=lr, weight_decay=1e-5)

In [None]:
num_epochs = 10
train_losses = []
val_losses = []

for epoch in range(num_epochs):
   train_loss = train_epoch(vae,device,train_loader,optim)
   val_loss = test_epoch(vae,device,valid_loader)
   print('\n EPOCH {}/{} \t train loss {:.3f} \t val loss {:.3f}'.format(epoch + 1, num_epochs,train_loss,val_loss))
   train_losses.append(train_loss)
   val_losses.append(val_loss)

In [None]:
plt.plot(range(num_epochs), train_losses, label="train")
plt.plot(range(num_epochs), val_losses, label="validation")
plt.legend()
plt.show()

In [None]:
def plot_ae_outputs(encoder,decoder,testset,device,n=10):
    plt.figure(figsize=(16,4.5))
    for i in range(n):
        ax = plt.subplot(2,n,i+1)
        img = testset[i][0].unsqueeze(0).to(device)
        encoder.eval()
        decoder.eval()
        with torch.no_grad():
            rec_img  = decoder(encoder(img), encoder.conv_shape, img.shape[2:])
        plt.imshow(img.cpu().squeeze().numpy())
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)  
        if i == n//2:
            ax.set_title('Original images')
        ax = plt.subplot(2, n, i + 1 + n)
        plt.imshow(rec_img.cpu().squeeze().numpy())  
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)  
        if i == n//2:
            ax.set_title('Reconstructed images')
    plt.show()

In [None]:
test_loss = test_epoch(vae, device, test_loader)
print("TEST loss: ", test_loss)
plot_ae_outputs(vae.encoder,vae.decoder,mnist_testset,device,n=20)

In [None]:
torch.save(vae.state_dict(), "model.pt")