In [1]:
# Imports
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.utils.data as Data
import torchvision

import matplotlib.pyplot as plt

import numpy as np

from pyro.distributions.relaxed_straight_through import RelaxedBernoulliStraightThrough

import autoencoders

from skimage.measure import compare_ssim as ssim
from skimage.measure import compare_mse as mse
from skimage.measure import compare_psnr as psnr
from skimage.measure import compare_nrmse as nrmse

In [2]:
# Set seed for reproduciblity
torch.manual_seed(7212)
np.random.seed(2980)

In [3]:
# Mnist digits dataset (download if necessary)
test_data = torchvision.datasets.FashionMNIST(
    root='./Datasets/fashion-mnist/',
    train=False,
    transform=torchvision.transforms.ToTensor(),                        
    download=False,
)

In [4]:
# Select which device to use, use the gpu if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Currently running on:", device)

Currently running on: cuda


In [5]:
def load_autoencoder(model_name, saved_as=None, run_on="cpu"):
    
    # Set the save name to the model name if no name was specified
    if not saved_as:
        saved_as = model_name
    
    # Select what type of autoencoder to load
    used_autoencoder = getattr(autoencoders, model_name)
    autoencoder = used_autoencoder().to(run_on)

    # Load the weights from the saved file
    autoencoder.load_state_dict(torch.load("./Trained Models/" + saved_as + ".pt"))
    
    # Set the training flag to False
    autoencoder.eval()
    
    print("Loaded:", model_name, "| In Cuda:", next(autoencoder.parameters()).is_cuda)
    return autoencoder

In [6]:
incremental_autoencoder = load_autoencoder("IncrementalAutoEncoder", run_on=device)
random_autoencoder = load_autoencoder("RandomAutoEncoder", run_on=device)
autoencoder = load_autoencoder("AutoEncoder", run_on=device)

Loaded: IncrementalAutoEncoder | In Cuda: True
Loaded: RandomAutoEncoder | In Cuda: True
Loaded: AutoEncoder | In Cuda: True


In [7]:
TEST_IMAGES = 10

In [8]:
def autoencoder_corruption(autoencoder):
    
    test_images = test_data.data.view(-1, 28*28).type(torch.FloatTensor).to(device)/255
    
    n_images = len(test_images)
    total_mse, total_ssim, total_nrmse, total_psnr = 0, 0, 0, 0
    
#     encoded_images, decoded_images = autoencoder(test_images)

    encoded_images = autoencoder.encoder(test_images)
    binary_encoded = RelaxedBernoulliStraightThrough(1, logits=encoded_images).rsample()
    
    x,y = binary_encoded.shape
    
    for k in range(0, y+1, 8):
        mask = np.where(np.concatenate((np.zeros((x,y-k)), np.ones((x, k))), axis=1))
        binary_encoded[mask] = torch.distributions.Bernoulli(probs=(torch.ones(x,y)/2)).sample().to(device)[mask]

        decoded_images = autoencoder.decoder(binary_encoded)
        
        for i in range(n_images):
            test_image = test_images[i].cpu().detach().numpy()
            decoded_test_image = decoded_images[i].cpu().detach().numpy()
            total_mse += mse(test_image, decoded_test_image)
            total_ssim += ssim(test_image, decoded_test_image)
            total_nrmse += nrmse(test_image, decoded_test_image)
            total_psnr += psnr(test_image, decoded_test_image)

        value_mse = total_mse/n_images
        value_ssim = total_ssim/n_images
        value_nrmse = total_nrmse/n_images
        value_psnr = total_psnr/n_images

        print("Current corruption:", k)
        print("MSE:", value_mse, "SSIM:", value_ssim, "NRMSE:", value_nrmse, "PSNR:", value_psnr)

In [9]:
autoencoder_corruption(incremental_autoencoder)

Current corruption: 0
MSE: 0.020172986028902842 SSIM: 0.6827972592986662 NRMSE: 0.3445529666404924 PSNR: 17.594020590353413
Current corruption: 8
MSE: 0.04034581035156237 SSIM: 1.365571889117742 NRMSE: 0.6890275819459023 PSNR: 35.18875379780395
Current corruption: 16
MSE: 0.06052258064582815 SSIM: 2.048313042977949 NRMSE: 1.0335298718322325 PSNR: 52.78293174053138
Current corruption: 24
MSE: 0.0806888171164809 SSIM: 2.7309823151522736 NRMSE: 1.3780090992960083 PSNR: 70.37958222784656
Current corruption: 32
MSE: 0.10086977763824481 SSIM: 3.4134614357153734 NRMSE: 1.7225696029222075 PSNR: 87.97297860692653
Current corruption: 40
MSE: 0.12106494315056823 SSIM: 4.095860477774638 NRMSE: 2.0672126635691916 PSNR: 105.56317460274228
Current corruption: 48
MSE: 0.14127248513681687 SSIM: 4.778063132928884 NRMSE: 2.4119864823992208 PSNR: 123.1503976905014
Current corruption: 56
MSE: 0.16149556774550677 SSIM: 5.45987606894193 NRMSE: 2.75690438250935 PSNR: 140.7340622203826
Current corruption: 64
M