### Import Packages

In [1]:
import numpy as np
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from dataloaders import AVTDataset
from models.model import VAE, vae_loss


import matplotlib.pyplot as plt
from PIL import Image

### Hyperparameters

In [2]:
# Constants
use_gpu = True
learning_rate = 1e-3
num_epochs = 1000

### Setup the Dataset and DataLoader for the custom dataset

In [3]:
# Read full set of images to calculate the mean and variance of channels
interpolation_mode = Image.BICUBIC
transfs = transforms.Compose([transforms.ToTensor(),
                                     transforms.Resize((128, 128), interpolation=interpolation_mode)])
dataset_normal = AVTDataset(path='Images', subsample=1, transforms=transfs)

In [None]:
loader = DataLoader(dataset_normal, batch_size=32, shuffle=True)

### Build the VAE Model

In [None]:
vae = VAE()

device = torch.device("cuda:0" if use_gpu and torch.cuda.is_available() else "cpu")
vae = vae.to(device)

num_params = sum(p.numel() for p in vae.parameters() if p.requires_grad)
print('Number of parameters: %d' % num_params)

### Train the Network

In [None]:
optimizer = torch.optim.Adam(params=vae.parameters(), lr=learning_rate, weight_decay=1e-5)

# set to training mode
vae.train()

train_loss_avg = []

print('Training ...')
for epoch in range(num_epochs):
    train_loss_avg.append(0)
    num_batches = 0
    
    for image_batch, _ in loader:
        
        image_batch = image_batch.to(device)

        # Get the latent vectors and the reconstructed images.
        image_batch_recon, latent_mu, latent_logvar = vae(image_batch)

        # reconstruction error
        loss = vae_loss(image_batch_recon, image_batch, latent_mu, latent_logvar)
        
        # backpropagation
        optimizer.zero_grad()
        loss.backward()
        
        # one step of the optmizer (using the gradients from backpropagation)
        optimizer.step()
        
        train_loss_avg[-1] += loss.item()
        num_batches += 1
        
    train_loss_avg[-1] /= num_batches
    print('Epoch [%d / %d] average reconstruction error: %f' % (epoch+1, num_epochs, train_loss_avg[-1]))

### Save the trained weights

In [29]:
# Save model weights
state = {'epoch': num_epochs,
            'latent_dimension':vae.encoder.latent_dim,
            'model_state_dict': vae.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': train_loss_avg[-1]}
torch.save(state, 'VAE_Interpolation_'+str(num_epochs)+"_"+str(interpolation_mode)+'.pth')

In [41]:
# load the network from disk
saved_net_state = torch.load('VAE_Interpolation_1000_3.pth')

Number of epochs we trained this network for = 1000


In [42]:
print(f'Number of epochs we trained this network for = {saved_net_state["epoch"]}')

Number of epochs we trained this network for = 1000


### Reload the saved model

In [38]:
# Load the model from the disk.
# Test with the model loaded from the disk
model = VAE().to(device)
model.load_state_dict(saved_net_state['model_state_dict'])

<All keys matched successfully>

### Test folder with the images

In [39]:
test_dataset = AVTDataset(path='Test', subsample=1, transforms=transfs)
test_loader = DataLoader(test_dataset, batch_size=32)

### Test model against new images

In [None]:
# Evaluate the model
import numpy as np
import matplotlib.pyplot as plt
plt.ion()

import torchvision.utils

model.training = False
model.eval()

# This function takes as an input the images to reconstruct
# and the name of the model with which the reconstructions
# are performed
def to_img(x):
    x = x.clamp(0, 1)
    return x

def show_image(img):
    img = to_img(img)
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))

def visualise_output(images, model):

    with torch.no_grad():
    
        images = images.to(device)
        re_images, _, _ = model(images)
        re_images = re_images.cpu()
        re_images = to_img(re_images)
        np_imagegrid = torchvision.utils.make_grid(re_images[1:50], 10, 5).numpy()
        plt.imshow(np.transpose(np_imagegrid, (1, 2, 0)))
        plt.show()

images, labels = iter(test_loader).next()

# First visualise the original images
print('Original images')
show_image(torchvision.utils.make_grid(images[1:50],10,5))
plt.show()

# Reconstruct and visualise the images using the vae
print('VAE reconstruction:')
visualise_output(images, model)