In [1]:
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

In [2]:
args = {}
args['dim_h'] = 40            # factor controlling size of hidden layers
args['n_channel'] = 1         # number of channels in the input data (MNIST is 1, aka greyscale)
args['n_z'] = 20              # number of dimensions in latent space. 
args['sigma'] = 1.0           # variance in n_z
args['lambda'] = 0.01           # hyper param for weight of discriminator loss
args['lr'] = 0.0002           # learning rate for Adam optimizer
args['epochs'] = 50            # how many epochs to run for
args['batch_size'] = 256      # batch size for SGD
args['save'] = False          # save weights at each epoch of training if True
args['train'] = False         # train networks if True, else load networks from saved weights

## load all encoder and decoder classes

In [3]:
## create encoder model and decoder model
class EncoderD(nn.Module):
    def __init__(self, args):
        super(EncoderD, self).__init__()

        self.n_channel = args['n_channel']
        self.dim_h = args['dim_h']
        self.n_z = args['n_z']
        
        # convolutional filters, work excellent with image data
        self.conv = nn.Sequential(
            nn.Conv2d(self.n_channel, self.dim_h, 4, 2, 1, bias=False),
            nn.ReLU(True),
            nn.Conv2d(self.dim_h, self.dim_h * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.dim_h * 2),
            nn.ReLU(True),
            nn.Conv2d(self.dim_h * 2, self.dim_h * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.dim_h * 4),
            nn.ReLU(True),
            nn.Conv2d(self.dim_h * 4, self.dim_h * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.dim_h * 8),
            nn.ReLU(True),
        )
        
        # final layer is fully connected
        self.fc = nn.Linear(self.dim_h * (2 ** 3), self.n_z)

    def forward(self, x):
        x = self.conv(x)
        x = x.squeeze()
        x = self.fc(x)
        return x
    
## create encoder model and decoder model
class EncoderS(nn.Module):
    def __init__(self, args):
        super(EncoderS, self).__init__()

        self.n_channel = args['n_channel']
        self.dim_h = args['dim_h']
        self.n_z = args['n_z']
        
        # convolutional filters, work excellent with image data
        self.conv = nn.Sequential(
            nn.Conv2d(self.n_channel, self.dim_h, 4, 2, 1, bias=False),
            nn.ReLU(True),
            nn.Conv2d(self.dim_h, self.dim_h * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.dim_h * 2),
            nn.ReLU(True),
            nn.Conv2d(self.dim_h * 2, self.dim_h * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.dim_h * 4),
            nn.ReLU(True),
            nn.Conv2d(self.dim_h * 4, self.dim_h * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.dim_h * 8),
            nn.ReLU(True),
        )
        
        # final layer is fully connected
        self.fc1 = nn.Linear(self.dim_h * (2 ** 3), self.n_z)
        self.fc2 = nn.Linear(self.dim_h * (2 ** 3), self.n_z)

    def reparameterize(self, mu, logvar):
        std = logvar.mul(0.5).exp_()
        # return torch.normal(mu, std)
        esp = torch.randn(*mu.size())
        z = mu + std * esp
        return z
        
    def bottleneck(self, h):
        mu, logvar = self.fc1(h), self.fc2(h)
        z = self.reparameterize(mu, logvar)
        return z, mu, logvar
        
    
    def forward(self, x):
        h = self.conv(x)
        h = h.squeeze()
        z, mu, logvar = self.bottleneck(h)
        
        return z, mu, logvar

class Decoder(nn.Module):
    def __init__(self, args):
        super(Decoder, self).__init__()

        self.n_channel = args['n_channel']
        self.dim_h = args['dim_h']
        self.n_z = args['n_z']

        # first layer is fully connected
        self.fc = nn.Sequential(
            nn.Linear(self.n_z, self.dim_h * 8 * 7 * 7),
            nn.ReLU()
        )

        # deconvolutional filters, essentially the inverse of convolutional filters
        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(self.dim_h * 8, self.dim_h * 4, 4),
            nn.BatchNorm2d(self.dim_h * 4),
            nn.ReLU(True),
            nn.ConvTranspose2d(self.dim_h * 4, self.dim_h * 2, 4),
            nn.BatchNorm2d(self.dim_h * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(self.dim_h * 2, 1, 4, stride=2),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.fc(x)
        x = x.view(-1, self.dim_h * 8, 7, 7)
        x = self.deconv(x)
        return x

## Get final test loss for FMNIST

In [4]:
# set FMNIST test set
dataset = 'FashionMNIST'
testset = datasets.FashionMNIST(
    root='./FMNIST/',
    train=False,
    transform=transforms.ToTensor(),
    download=True
)

# load the full validation set as batch size
test_loader = DataLoader(
    dataset=testset,
    batch_size=10000,
    shuffle=False
)

In [5]:
# loss for all models
criterion = nn.MSELoss()

In [6]:
# standard AE
encoder, decoder = EncoderD(args), Decoder(args)
encoder.eval()
decoder.eval()

# load encoder and decoder weights from checkpoint
enc_checkpoint = torch.load('save/AE_encoder-best_fmnist.pth')
encoder.load_state_dict(enc_checkpoint)

dec_checkpoint = torch.load('save/AE_decoder-best_fmnist.pth')
decoder.load_state_dict(dec_checkpoint)

for images, _ in test_loader:
    z_hat = encoder(images)
    x_hat = decoder(z_hat)
    test_recon_loss = criterion(x_hat, images)
    print('AE final reconstruction loss {} on {}'.format(test_recon_loss.data.item(), dataset))

AE final reconstruction loss 0.008279001340270042 on FashionMNIST


In [7]:
# VAE
encoder, decoder = EncoderS(args), Decoder(args)
encoder.eval()
decoder.eval()

# load encoder and decoder weights from checkpoint
enc_checkpoint = torch.load('save/VAE_encoder-best_fmnist.pth')
encoder.load_state_dict(enc_checkpoint)

dec_checkpoint = torch.load('save/VAE_decoder-best_fmnist.pth')
decoder.load_state_dict(dec_checkpoint)

for images, _ in test_loader:
    z_hat, mu, logvar = encoder(images)
    x_hat = decoder(z_hat)
    test_recon_loss = criterion(x_hat, images)
    print('VAE final reconstruction loss {} on {}'.format(test_recon_loss.data.item(), dataset))

VAE final reconstruction loss 0.013607991859316826 on FashionMNIST


In [8]:
#WAE
encoder, decoder = EncoderD(args), Decoder(args)
encoder.eval()
decoder.eval()

# load encoder and decoder weights from checkpoint
enc_checkpoint = torch.load('save/WAEgan_encoder-best_fmnist.pth')
encoder.load_state_dict(enc_checkpoint)

dec_checkpoint = torch.load('save/WAEgan_decoder-best_fmnist.pth')
decoder.load_state_dict(dec_checkpoint)

for images, _ in test_loader:
    z_hat = encoder(images)
    x_hat = decoder(z_hat)
    test_recon_loss = criterion(x_hat, images)
    print('WAE final reconstruction loss {} on {}'.format(test_recon_loss.data.item(), dataset))

WAE final reconstruction loss 0.009220210835337639 on FashionMNIST


## MNIST

In [9]:
# set FMNIST test set
dataset = 'MNIST'
testset = datasets.MNIST(
    root='./data/',
    train=False,
    transform=transforms.ToTensor(),
    download=True
)

# load the full validation set as batch size
test_loader = DataLoader(
    dataset=testset,
    batch_size=10000,
    shuffle=False
)

In [10]:
# standard AE
encoder, decoder = EncoderD(args), Decoder(args)
encoder.eval()
decoder.eval()

# load encoder and decoder weights from checkpoint
enc_checkpoint = torch.load('save/AE_encoder-best.pth')
encoder.load_state_dict(enc_checkpoint)

dec_checkpoint = torch.load('save/AE_decoder-best.pth')
decoder.load_state_dict(dec_checkpoint)

for images, _ in test_loader:
    z_hat = encoder(images)
    x_hat = decoder(z_hat)
    test_recon_loss = criterion(x_hat, images)
    print('AE final reconstruction loss {} on {}'.format(test_recon_loss.data.item(), dataset))

AE final reconstruction loss 0.004790022969245911 on MNIST


In [11]:
# VAE
encoder, decoder = EncoderS(args), Decoder(args)
encoder.eval()
decoder.eval()

# load encoder and decoder weights from checkpoint
enc_checkpoint = torch.load('save/VAE_encoder-best.pth')
encoder.load_state_dict(enc_checkpoint)

dec_checkpoint = torch.load('save/VAE_decoder-best.pth')
decoder.load_state_dict(dec_checkpoint)

for images, _ in test_loader:
    z_hat, mu, logvar = encoder(images)
    x_hat = decoder(z_hat)
    test_recon_loss = criterion(x_hat, images)
    print('VAE final reconstruction loss {} on {}'.format(test_recon_loss.data.item(), dataset))

VAE final reconstruction loss 0.009450603276491165 on MNIST


In [12]:
#WAE
encoder, decoder = EncoderD(args), Decoder(args)
encoder.eval()
decoder.eval()

# load encoder and decoder weights from checkpoint
enc_checkpoint = torch.load('save/WAEgan_encoder-best.pth')
encoder.load_state_dict(enc_checkpoint)

dec_checkpoint = torch.load('save/WAEgan_decoder-best.pth')
decoder.load_state_dict(dec_checkpoint)

for images, _ in test_loader:
    z_hat = encoder(images)
    x_hat = decoder(z_hat)
    test_recon_loss = criterion(x_hat, images)
    print('WAE final reconstruction loss {} on {}'.format(test_recon_loss.data.item(), dataset))

WAE final reconstruction loss 0.007124471943825483 on MNIST
