In [1]:
# Imports
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.utils.data as Data
import torchvision

import matplotlib.pyplot as plt

import numpy as np

from pyro.distributions.relaxed_straight_through import RelaxedBernoulliStraightThrough

import autoencoders
from torchvision.datasets import Omniglot

from skimage.measure import compare_ssim as ssim
from skimage.measure import compare_mse as mse
from skimage.measure import compare_psnr as psnr
from skimage.measure import compare_nrmse as nrmse

import pandas as pd

In [2]:
# Set seed for reproduciblity
torch.manual_seed(7212)
np.random.seed(2980)

In [3]:
# Select which dataset to use
dataset_name = "fashion-mnist"
training = False
img_size = 28

if dataset_name == "mnist":
    # Mnist digits dataset (download if necessary)
    data = torchvision.datasets.MNIST(
        root='./Datasets/mnist/',
        train=training,
        transform=torchvision.transforms.ToTensor(),                        
        download=False,
    )

elif dataset_name == "fashion-mnist":
    # Fashion Mnist clothing dataset (download if necessary)
    data = torchvision.datasets.FashionMNIST(
        root='./Datasets/fashion-mnist/',
        train=training,
        transform=torchvision.transforms.ToTensor(),                        
        download=False,
    )
    
elif dataset_name == "kmnist":
    # Kuzushiji Mnist Japanese characters dataset (download if necessary)
    data = torchvision.datasets.KMNIST(
        root='./Datasets/kmnist/',
        train=training,
        transform=torchvision.transforms.ToTensor(),                        
        download=False,
    )

else:
    print("Could not find dataset")
    
# Get the name of the dataset for saving the trained models
save_extention = data.__class__.__name__

In [4]:
# Select which device to use, use the gpu if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Currently running on:", device)

Currently running on: cuda


In [5]:
def load_autoencoder(model_name, saved_as=None, run_on="cpu"):
    
    # Set the save name to the model name if no name was specified
    if not saved_as:
        saved_as = model_name + save_extention
    
    # Select what type of autoencoder to load
    used_autoencoder = getattr(autoencoders, model_name)
    autoencoder = used_autoencoder().to(run_on)

    # Load the weights from the saved file
    autoencoder.load_state_dict(torch.load("./Trained Models/" + saved_as + ".pt", ))
    
    # Set the training flag to False
    autoencoder.eval()
    
    print("Loaded:", model_name, "| In Cuda:", next(autoencoder.parameters()).is_cuda)
    return autoencoder

In [6]:
# Load saved weights to autoencoder models
incremental_autoencoder = load_autoencoder("IncrementalAutoEncoder", run_on=device)
random_autoencoder = load_autoencoder("RandomAutoEncoder", run_on=device)
autoencoder = load_autoencoder("AutoEncoder", run_on=device)

Loaded: IncrementalAutoEncoder | In Cuda: True
Loaded: RandomAutoEncoder | In Cuda: True
Loaded: AutoEncoder | In Cuda: True


In [7]:
def autoencoder_evaluation(autoencoder, test_data, compare_function=mse, corruption_type=None, display_img=False):
    results = []
    
    # Store the data in one large tensor
    test_images = test_data.data.view(-1, img_size*img_size).type(torch.FloatTensor).to(device)/255
    
    # Get the number of images
    n_images = len(test_images)
    
    # Encode the images in binary form
    encoded_images = autoencoder.encoder(test_images)
    binary_encoded = RelaxedBernoulliStraightThrough(1, logits=encoded_images).rsample()
    
    x,y = binary_encoded.shape
    
    # Loop to add more and more corruption
    for k in range(0, y+1, 8):
        # Add corruption
        if corruption_type in ["random", "incremental", "reverse_incremental"]:
            # Create a mask with ones at the end
            bit_map = np.concatenate((np.zeros((x,y-k)), np.ones((x, k))), axis=1)
            # Flip if the corruption should be at the front
            if corruption_type == "reverse_incremental":
                bit_map = np.flip(bit_map, 1)
            # Shuffle if the corruption should be in random positions
            if corruption_type == "random":
                for row in bit_map: np.random.shuffle(row)
            
            mask = np.where(bit_map)
            
            # Replace masked values with random zeroes and ones
            binary_encoded[mask] = torch.distributions.Bernoulli(probs=(torch.ones(x,y)/2)).sample().to(device)[mask]
        else:
            print("Failed to recognize the corruption type. Should be either incremental, reverse_incremental or random")
            return None

        # Decode the binary form back to an image
        decoded_images = autoencoder.decoder(binary_encoded)
    
        for i in range(n_images):
            # Specify the original image and the decoded image
            test_image = test_images[i].cpu().detach().numpy()
            decoded_test_image = decoded_images[i].cpu().detach().numpy()
            if display_img and i == 999:
                displayed_image = decoded_test_image
            
            # Save the results in a dict to add to the list
            results.append({"model": autoencoder.__class__.__name__, 
                            "metric": compare_function.__name__,
                            "value": compare_function(test_image, decoded_test_image), 
                            "corruption_type": corruption_type, 
                            "corruption_percentage": k/y})
        
        # Display the image that has been corrupted
        if display_img:
            print("Model used:", autoencoder.__class__.__name__)
            print("Current corruption:", k)
            print(compare_function.__name__, ":", similarity_value)
            plt.clf()
            plt.axis('off')
            plt.imshow(np.reshape(displayed_image, (img_size,img_size
                                                   )), cmap='gray')
            plt.show(); plt.pause(0.05)

    return results

In [8]:
# Initial empty results
results = []

In [9]:
%%time

# Add MSE with incremental corruption to the results
results += autoencoder_evaluation(incremental_autoencoder, data, compare_function=mse, corruption_type="incremental")
results += autoencoder_evaluation(autoencoder, data, compare_function=mse, corruption_type="incremental")
results += autoencoder_evaluation(random_autoencoder, data, compare_function=mse, corruption_type="incremental")

print(len(results))

990000
CPU times: user 2min 28s, sys: 3.74 s, total: 2min 32s
Wall time: 1min 21s


In [10]:
%%time

# Add SSIM with incremental corruption to the results
results += autoencoder_evaluation(incremental_autoencoder, data, compare_function=ssim, corruption_type="incremental")
results += autoencoder_evaluation(autoencoder, data, compare_function=ssim, corruption_type="incremental")
results += autoencoder_evaluation(random_autoencoder, data, compare_function=ssim, corruption_type="incremental")

print(len(results))

1980000
CPU times: user 5min 58s, sys: 4.07 s, total: 6min 2s
Wall time: 4min 51s


In [11]:
%%time

# Add PSNR with incremental corruption to the results
results += autoencoder_evaluation(incremental_autoencoder, data, compare_function=psnr, corruption_type="incremental")
results += autoencoder_evaluation(autoencoder, data, compare_function=psnr, corruption_type="incremental")
results += autoencoder_evaluation(random_autoencoder, data, compare_function=psnr, corruption_type="incremental")

print(len(results))

2970000
CPU times: user 3min 5s, sys: 3.66 s, total: 3min 9s
Wall time: 1min 57s


In [12]:
%%time

# Add NRMSE with incremental corruption to the results
results += autoencoder_evaluation(incremental_autoencoder, data, compare_function=nrmse, corruption_type="incremental")
results += autoencoder_evaluation(autoencoder, data, compare_function=nrmse, corruption_type="incremental")
results += autoencoder_evaluation(random_autoencoder, data, compare_function=nrmse, corruption_type="incremental")

print(len(results))

3960000
CPU times: user 2min 43s, sys: 3.48 s, total: 2min 46s
Wall time: 1min 35s


In [13]:
%%time

# Add MSE with reverse incremental corruption to the results
results += autoencoder_evaluation(incremental_autoencoder, data, compare_function=mse, corruption_type="reverse_incremental")
results += autoencoder_evaluation(autoencoder, data, compare_function=mse, corruption_type="reverse_incremental")
results += autoencoder_evaluation(random_autoencoder, data, compare_function=mse, corruption_type="reverse_incremental")

print(len(results))

4950000
CPU times: user 2min 32s, sys: 3.49 s, total: 2min 35s
Wall time: 1min 23s


In [14]:
%%time

# Add SSIM with reverse incremental corruption to the results
results += autoencoder_evaluation(incremental_autoencoder, data, compare_function=ssim, corruption_type="reverse_incremental")
results += autoencoder_evaluation(autoencoder, data, compare_function=ssim, corruption_type="reverse_incremental")
results += autoencoder_evaluation(random_autoencoder, data, compare_function=ssim, corruption_type="reverse_incremental")

print(len(results))

5940000
CPU times: user 5min 37s, sys: 3.46 s, total: 5min 40s
Wall time: 4min 28s


In [15]:
%%time

# Add PSNR with reverse incremental corruption to the results
results += autoencoder_evaluation(incremental_autoencoder, data, compare_function=psnr, corruption_type="reverse_incremental")
results += autoencoder_evaluation(autoencoder, data, compare_function=psnr, corruption_type="reverse_incremental")
results += autoencoder_evaluation(random_autoencoder, data, compare_function=psnr, corruption_type="reverse_incremental")

print(len(results))

6930000
CPU times: user 2min 51s, sys: 3.51 s, total: 2min 54s
Wall time: 1min 43s


In [16]:
%%time

# Add NRMSE with reverse incremental corruption to the results
results += autoencoder_evaluation(incremental_autoencoder, data, compare_function=nrmse, corruption_type="reverse_incremental")
results += autoencoder_evaluation(autoencoder, data, compare_function=nrmse, corruption_type="reverse_incremental")
results += autoencoder_evaluation(random_autoencoder, data, compare_function=nrmse, corruption_type="reverse_incremental")

print(len(results))

7920000
CPU times: user 2min 46s, sys: 3.47 s, total: 2min 49s
Wall time: 1min 37s


In [17]:
%%time

# Add MSE with random corruption to the results
results += autoencoder_evaluation(incremental_autoencoder, data, compare_function=mse, corruption_type="random")
results += autoencoder_evaluation(autoencoder, data, compare_function=mse, corruption_type="random")
results += autoencoder_evaluation(random_autoencoder, data, compare_function=mse, corruption_type="random")

print(len(results))

8910000
CPU times: user 3min 12s, sys: 3.54 s, total: 3min 15s
Wall time: 2min


In [18]:
%%time

# Add SSIM with random corruption to the results
results += autoencoder_evaluation(incremental_autoencoder, data, compare_function=ssim, corruption_type="random")
results += autoencoder_evaluation(autoencoder, data, compare_function=ssim, corruption_type="random")
results += autoencoder_evaluation(random_autoencoder, data, compare_function=ssim, corruption_type="random")

print(len(results))

9900000
CPU times: user 6min 11s, sys: 3.71 s, total: 6min 15s
Wall time: 5min


In [19]:
%%time

# Add PSNR with random corruption to the results
results += autoencoder_evaluation(incremental_autoencoder, data, compare_function=psnr, corruption_type="random")
results += autoencoder_evaluation(autoencoder, data, compare_function=psnr, corruption_type="random")
results += autoencoder_evaluation(random_autoencoder, data, compare_function=psnr, corruption_type="random")

print(len(results))

10890000
CPU times: user 3min 33s, sys: 3.27 s, total: 3min 36s
Wall time: 2min 22s


In [20]:
%%time

# Add NRMSE with random corruption to the results
results += autoencoder_evaluation(incremental_autoencoder, data, compare_function=nrmse, corruption_type="random")
results += autoencoder_evaluation(autoencoder, data, compare_function=nrmse, corruption_type="random")
results += autoencoder_evaluation(random_autoencoder, data, compare_function=nrmse, corruption_type="random")

print(len(results))

11880000
CPU times: user 3min 27s, sys: 3.59 s, total: 3min 30s
Wall time: 2min 16s


In [21]:
# Convert the list of results to a pandas dataframe
df = pd.DataFrame(results)

# Save the dataframe to a pickle file
df.to_pickle("./Dataframes/" + save_extention + "_DF.pkl")