# Variable Auto Encoder

## Imports

In [1]:
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, dataset
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor, Normalize
from tqdm.notebook import tqdm

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Hyper Parameters

In [40]:
batch_train = 256
batch_test = 1
epochs = 100
lr = 0.0001

## Load dataset

In [51]:
def load_dataset(anomaly, batch_train, batch_test):
    path = './data/mnist'
    transform = Compose([ToTensor(), Normalize((0.5,), (0.5,))])
    #import the original dataset
    orig_train_dataset = MNIST(path, train=True, download=True, transform=transform)
    orig_test_dataset = MNIST(path, train=False, download=True, transform=transform)

    #remove the anomaly from the train dataset
    idx = orig_train_dataset.train_labels != anomaly
    train_dataset = dataset.Subset(orig_train_dataset, np.where(idx==True)[0])
    #split the test set into normal and anomaly dataset
    idx = orig_test_dataset.test_labels != anomaly
    test_dataset = dataset.Subset(orig_test_dataset, np.where(idx==True)[0])
    anomaly_test_dataset = dataset.Subset(orig_test_dataset, np.where(idx==False)[0])

    #create dataloaders
    train_dl = DataLoader(train_dataset, batch_size=batch_train, shuffle=True, num_workers=12)
    test_dl = DataLoader(test_dataset, batch_size=batch_test, shuffle=False, num_workers=12, drop_last=True)
    anomaly_dl = DataLoader(anomaly_test_dataset, batch_size=batch_test, shuffle=False, num_workers=12, drop_last=True)

    return train_dl, test_dl, anomaly_dl

## Loss Function

In [20]:
def VAE_Loss(x, x_hat, mu, logvariance):
    '''
    The original reconstruction loss used by Kingma and Welling was the cross-entropy log(p(x|z)). 
    As p(x|z) ~ N(mu, sigma), the cross-entropy loss is proportional to the MSE loss between the input and the output.
    https://www.deeplearningbook.org/contents/ml.html (page 130)
    '''
    mse_loss = nn.functional.mse_loss(x, x_hat)
    '''
    The KL Divergence between N(mu, sigma) and N(0, 1) is 
    0.5 * sum_over_batch(1 + log(sigma^2) - mu^2 - sigma^2)
    For sigma values close to 0, the output of log(sigma^2) can explode
    Thus, we learn log(sigma^2) instead (logvariance)
    '''
    KL_loss = -0.5 * torch.sum(1 + logvariance - mu ** 2 - logvariance.exp())

    return KL_loss + mse_loss

## VAE

### Auxiliar Modules

In [21]:
class Conv_Cell(nn.Module):
    '''
    @input:
        input_size: number of filters of the input tensor
        n_filters: number of filters of the output tensor
    @output:
        tensor with size = (batch_size, n_filters, X.size()[2]/2, X.size()[3]/2)
    '''
    def __init__(self, input_size, n_filters):
        super(Conv_Cell, self).__init__()
        conv = nn.Conv2d(input_size, n_filters, kernel_size=3, stride=2, padding=1, bias=False)
        nn.init.xavier_uniform_(conv.weight)
        self.layers = nn.Sequential(
            conv,
            nn.BatchNorm2d(n_filters),
            nn.ReLU()
        )

    def forward(self, X):
        return self.layers(X)

class Deconv_Cell(nn.Module):
    '''
    @input:
        input_size: number of filters of the input tensor
        n_filters: number of filters of the output tensor
    @output:
        tensor with size = (batch_size, n_filters, X.size()[2]*2, X.size()[3]*2)
    '''
    def __init__(self, input_size, n_filters):
        super(Deconv_Cell, self).__init__()
        deconv = nn.ConvTranspose2d(input_size, n_filters, kernel_size=4, stride=2, padding=1, bias=False)
        nn.init.xavier_uniform_(deconv.weight)
        self.layers = nn.Sequential(
            deconv,
            nn.BatchNorm2d(n_filters),
            nn.ReLU()
        )

    def forward(self, X):
        return self.layers(X)

class Flatten(nn.Module):
    def forward(self, X):
        return X.view(X.size()[0], -1)

class Deflatten(nn.Module):
    '''
    @input:
        out_size: height/width of the output tensor
        n_filters: number of filters of the output tensor
    @output:
        tensor with size = (batch_size, n_filters, out_size, out_size)
    '''
    def __init__(self, out_size, n_filters):
        super(Deflatten, self).__init__()
        self.out_size = out_size
        self.n_filters = n_filters

    def forward(self, X):
        return X.view(X.size()[0], self.n_filters, self.out_size, -1)

### VAE Module

In [36]:
class VAE(nn.Module):
    def __init__(self, input_size=1):
        super(VAE, self).__init__()
        self.encoder = nn.Sequential( #input 28x28x1
            Conv_Cell(input_size, 16), #outsize 14x14x16
            Conv_Cell(16, 32), #outsize 7x7x32
            Flatten() #outsize 1568
        )

        self.fc_mu = nn.Linear(1568, 100)
        self.fc_logvariance = nn.Linear(1568, 100)


        deconv = nn.ConvTranspose2d(16, 1, kernel_size=4, stride=2, padding=1, bias=False)
        nn.init.xavier_uniform_(deconv.weight)
        self.decoder = nn.Sequential(
            nn.Linear(100, 1568), #outsize 1568
            Deflatten(7, 32), #outsize 7x7x32
            Deconv_Cell(32, 16), #outsize 14x14x16
            deconv #outsize 28x28x1
        )

    def sample_latent_vector(self, mu, logvariance):
        samples = torch.randn_like(mu)
        samples.to(device)
        '''
        logvariance = log(sigma^2) -> sigma = exp(logvariance/2)
        '''
        return mu + samples * (torch.exp(0.5*logvariance))

    def forward(self, X):
        X = self.encoder(X)
        mu = self.fc_mu(X)
        logvariance = self.fc_logvariance(X)
        X = self.sample_latent_vector(mu, logvariance)
        X = self.decoder(X)
        return X, mu, logvariance

## Training

In [37]:
def train(train_dl, anomaly_dl):
    model = VAE()
    model.train()
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    _, (anomaly_fixed_test, labels) = enumerate(anomaly_dl).__next__()
    anomaly_fixed_test = anomaly_fixed_test.to(device)

    training_losses = []
    anomaly_losses = []
    for epoch in tqdm(range(epochs), desc='Epochs'):
        total_loss=0.0
        anomaly_loss=0.0
        for i, data in enumerate(train_dl):
            x, labels = data
            x = x.to(device)
            #initialize gradients
            optimizer.zero_grad()
            #forward pass
            x_hat, mu, logvariance = model(x)
            #calculate loss
            loss = VAE_Loss(x, x_hat, mu, logvariance)
            #gradient descent
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        #check how the anomaly loss change during training
        with torch.no_grad():
            x_hat, mu, logvariance = model(anomaly_fixed_test)
            loss = VAE_Loss(anomaly_fixed_test, x_hat, mu, logvariance)
            anomaly_loss += loss.item()
        
        training_losses.append(total_loss/len(train_dl))
        anomaly_losses.append(anomaly_loss)

        tqdm.write(f'Loss: {total_loss/len(train_dl)} \t Anomaly Loss: {anomaly_loss}')
    return model, training_losses, anomaly_losses

In [38]:
#One class as anomaly, the others as normal classes
anomaly = 0
train_dl, test_dl, anomaly_dl = load_dataset(anomaly, batch_train, batch_test)
model0, training_losses0, anomaly_losses0 = train(train_dl, anomaly_dl)

HBox(children=(FloatProgress(value=0.0, description='Epochs', style=ProgressStyle(description_width='initial')…

Loss: 19.68285060436768 	 Anomaly Loss: 38.02560806274414
Loss: 2.3900828946976973 	 Anomaly Loss: 16.114274978637695
Loss: 1.2035951784376562 	 Anomaly Loss: 8.485692024230957
Loss: 0.7964525114855118 	 Anomaly Loss: 5.166813373565674
Loss: 0.6064379479405443 	 Anomaly Loss: 3.4739675521850586
Loss: 0.4975477078609918 	 Anomaly Loss: 2.272233486175537
Loss: 0.43076125214085775 	 Anomaly Loss: 1.8001391887664795
Loss: 0.38760357005356333 	 Anomaly Loss: 1.521540880203247
Loss: 0.35775885733627005 	 Anomaly Loss: 1.0550117492675781
Loss: 0.3346467215986647 	 Anomaly Loss: 0.9051533937454224
Loss: 0.3201422388031638 	 Anomaly Loss: 0.8537915945053101
Loss: 0.3072014921515651 	 Anomaly Loss: 0.6584595441818237
Loss: 0.2994239777150239 	 Anomaly Loss: 0.6771160364151001
Loss: 0.29357749442377035 	 Anomaly Loss: 0.6062570810317993
Loss: 0.2895015122975118 	 Anomaly Loss: 0.8280441761016846
Loss: 0.2860493386285545 	 Anomaly Loss: 0.6662386655807495
Loss: 0.285025578014244 	 Anomaly Loss: 0.

In [39]:
#One class as anomaly, the others as normal classes
anomaly = 1
train_dl, test_dl, anomaly_dl = load_dataset(anomaly, batch_train, batch_test)
model1, training_losses1, anomaly_losses1 = train(train_dl, anomaly_dl)

HBox(children=(FloatProgress(value=0.0, description='Epochs', style=ProgressStyle(description_width='initial')…

Loss: 23.195323292089967 	 Anomaly Loss: 45.687416076660156
Loss: 3.032850387145062 	 Anomaly Loss: 15.357585906982422
Loss: 1.4974287167674496 	 Anomaly Loss: 7.545223236083984
Loss: 0.96065728908398 	 Anomaly Loss: 4.646596908569336
Loss: 0.710412855516104 	 Anomaly Loss: 3.571070671081543
Loss: 0.5769590957611263 	 Anomaly Loss: 3.0875260829925537
Loss: 0.4952808292973943 	 Anomaly Loss: 1.9687163829803467
Loss: 0.43875918676014564 	 Anomaly Loss: 1.7380006313323975
Loss: 0.4021780289450184 	 Anomaly Loss: 1.5403180122375488
Loss: 0.3764853941506985 	 Anomaly Loss: 1.3360602855682373
Loss: 0.3550921660296771 	 Anomaly Loss: 1.2199890613555908
Loss: 0.3403446174540869 	 Anomaly Loss: 0.8661714792251587
Loss: 0.3279375338396963 	 Anomaly Loss: 0.8733000755310059
Loss: 0.32006992827228853 	 Anomaly Loss: 1.1052120923995972
Loss: 0.31323402637050074 	 Anomaly Loss: 0.7423142790794373
Loss: 0.3079343496822939 	 Anomaly Loss: 0.9276999235153198
Loss: 0.3053715771057454 	 Anomaly Loss: 1.2

In [86]:
anomaly = 0
train_dl = load_dataset(anomaly, batch_train, 512)[0]
mse_normal = 0.0
with torch.no_grad():
    for i, (x, label) in enumerate(train_dl):
        x = x.to(device)
        x_hat = model0(x)[0]
        mse = nn.functional.mse_loss(x, x_hat)
        mse_normal += mse.item()
    mse_normal /= len(train_dl)
print(mse_normal)

0.25726127814290656


In [44]:
anomaly = 1
train_dl = load_dataset(anomaly, batch_train, batch_test)[0]
mse_normal = 0.0
with torch.no_grad():
    for i, (x, label) in enumerate(train_dl):
        x = x.to(device)
        x_hat = model1(x)[0]
        mse = nn.functional.mse_loss(x, x_hat)
        mse_normal += mse.item()
    mse_normal /= len(train_dl)
print(mse_normal)

0.27666353109920994


## Evaluation

Using the reconstruction error as a classifier for now.

Might try implementing the reconstruction probability later. 
http://dm.snu.ac.kr/static/docs/TR/SNUDM-TR-2015-03.pdf

In [90]:
from sklearn.metrics import roc_auc_score
def calculate_auroc(model, anomaly, thresh):
    model.eval()
    model.to("cpu")
    _, normal_dl, anomaly_dl = load_dataset(anomaly, 1, 512)
    with torch.no_grad():
        y = np.empty(shape=(len(normal_dl)+len(anomaly_dl))*512)
        labels = np.empty(shape=(len(normal_dl)+len(anomaly_dl))*512)
        for i, (x, label) in enumerate(normal_dl):
            x_hat = model(x)[0]
            mse = np.square(x.numpy() - x_hat.numpy())
            mse = mse.reshape((mse.shape[0], -1)).mean(axis=1)
            y[i*512 : (i+1)*512] = (mse > thresh)
            labels[i*512 : (i+1)*512] = False
        offset = len(normal_dl)*512
        for i, (x, label) in enumerate(anomaly_dl):
            x_hat = model(x)[0]
            mse = np.square(x.numpy() - x_hat.numpy())
            mse = mse.reshape((mse.shape[0], -1)).mean(axis=1)
            y[offset+i*512 : offset+(i+1)*512] = (mse > thresh)
            labels[offset+i*512 : offset+(i+1)*512] = True
    return roc_auc_score(np.array(labels), np.array(y))


In [98]:
print(calculate_auroc(model0, 0, 0.3))

0.7402918198529411


In [99]:
print(calculate_auroc(model1, 1, 0.3))

0.3531709558823529


As expected, the digit 1 has a smaller auroc value. (similiar result was shown by An and Cho)