In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

from utils import mnist, plot_graphs, plot_mnist # functions for loading and plotting MNIST
import numpy as np

import matplotlib.pyplot as plt

%matplotlib inline

In [2]:
# loads dataset

mnist_transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,)),
           ])
train_loader, valid_loader, test_loader = mnist(valid=10000, transform=mnist_transform)

In [3]:
# Encoder and decoder classes

class Encoder(nn.Module):
    def __init__(self, latent_size=10):
        super(Encoder, self).__init__()
        self.fc1 = nn.Linear(28*28, latent_size)
    
    def forward(self, x):
        x = torch.sigmoid(self.fc1(x))
        return x
    
class Decoder(nn.Module):
    def __init__(self, latent_size=10):
        super(Decoder, self).__init__()
        self.fc1 = nn.Linear(latent_size, 28*28)
    
    def forward(self, x):
        x = torch.tanh(self.fc1(x))
        return x

In [4]:
# Class for autoencoder

class Net(nn.Module):
    def __init__(self, latent_size=10, loss_fn=F.mse_loss, lr=1e-4, l2=0.):
        super(Net, self).__init__()
        self.latent_size = latent_size
        self.E = Encoder(latent_size)
        self.D = Decoder(latent_size)
        self.loss_fn = loss_fn
        self._rho_loss = None
        self._loss = None
        self.optim = optim.Adam(self.parameters(), lr=lr, weight_decay=l2)
        
    def forward(self, x):
        x = x.view(-1, 28*28)
        h = self.E(x)
        self.data_rho = h.mean(0) # calculates rho from encoder activations
        out = self.D(h)
        return out
    
    def decode(self, h):
        with torch.no_grad():
            return self.D(h)
    
    def rho_loss(self, rho, size_average=True):        
        dkl = - rho * torch.log(self.data_rho) - (1-rho)*torch.log(1-self.data_rho) # calculates KL divergence
        if size_average:
            self._rho_loss = dkl.mean()
        else:
            self._rho_loss = dkl.sum()
        return self._rho_loss
    
    def loss(self, x, target, **kwargs):
        target = target.view(-1, 28*28)
        self._loss = self.loss_fn(x, target, **kwargs)
        return self._loss

In [5]:
# Making AEs with 16, 64 and 256 neurons in latent layer

models = {"16": Net(16), "64": Net(64), "256": Net(256)}
rho = 0.05
train_log = {k: [] for k in models}
test_log = {k: [] for k in models}

In [6]:
# Train function

def train(epoch, models, log=None, add_noise=False, half_image=False):
    train_size = len(train_loader.sampler)
    for batch_idx, (data, _) in enumerate(train_loader):
        for model in models.values():
            model.optim.zero_grad()
            inputs = data.clone().detach()
            if add_noise:
                inputs = noise_batch(inputs)
            if half_image:
                inputs = half_batch(inputs)
            output = model(inputs)
            rho_loss = model.rho_loss(rho)
            loss = model.loss(output, data) + rho_loss
            loss.backward()
            model.optim.step()
            
        if batch_idx % 200 == 0:
            line = "Train Epoch: {} [{}/{} ({:.0f}%)]\tLosses ".format(
                epoch, batch_idx * len(data), train_size, 100. * batch_idx / len(train_loader))
            losses = " ".join(["{}: {:.6f}".format(k, m._loss.item()) for k, m in models.items()])
            print(line + losses)
            
    else:
        batch_idx += 1
        line = "Train Epoch: {} [{}/{} ({:.0f}%)]\tLosses ".format(
            epoch, batch_idx * len(data), train_size, 100. * batch_idx / len(train_loader))
        losses = " ".join(["{}: {:.6f}".format(k, m._loss.item()) for k, m in models.items()])
        if log is not None:
            for k in models:
                log[k].append((models[k]._loss, models[k]._rho_loss))
        print(line + losses)

In [7]:
avg_lambda = lambda l: "loss: {:.4f}".format(l)
rho_lambda = lambda p: "rho_loss: {:.4f}".format(p)
line = lambda i, l, p: "{}: ".format(i) + avg_lambda(l) + "\t" + rho_lambda(p)
  
# Test function    
    
def test(models, loader, log=None, add_noise=False, half_image=False):
    test_size = len(loader.sampler)

    test_loss = {k: 0. for k in models}
    rho_loss = {k: 0. for k in models}
    with torch.no_grad():
        for data, _ in loader:
            inputs = data.clone().detach()
            if add_noise:
                inputs = noise_batch(inputs)
            if half_image:
                inputs = half_batch(inputs)
            output = {k: m(inputs) for k, m in models.items()}
            for k, m in models.items():
                test_loss[k] += m.loss(output[k], data, reduction="sum").item()
                rho_loss[k] += m.rho_loss(rho, size_average=False).item()
    
    for k in models:
        test_loss[k] /= (test_size * 784)
        rho_loss[k] /= (test_size * models[k].latent_size)
        if log is not None:
            log[k].append((test_loss[k], rho_loss[k]))
    
    lines = "\n".join([line(k, test_loss[k], rho_loss[k]) for k in models]) + "\n"
    report = "Test set:\n" + lines        
    print(report)

In [8]:
for epoch in range(1, 51):
    for model in models.values():
        model.train()
    train(epoch, models, train_log)
    for model in models.values():
        model.eval()
    test(models, valid_loader, test_log)

Test set:
16: loss: 0.8889	rho_loss: 0.0042
64: loss: 0.6831	rho_loss: 0.0050
256: loss: 0.5741	rho_loss: 0.0046

Test set:
16: loss: 0.7382	rho_loss: 0.0048
64: loss: 0.6012	rho_loss: 0.0047
256: loss: 0.5034	rho_loss: 0.0045

Test set:
16: loss: 0.6791	rho_loss: 0.0047
64: loss: 0.5574	rho_loss: 0.0046
256: loss: 0.4659	rho_loss: 0.0044

Test set:
16: loss: 0.6468	rho_loss: 0.0046
64: loss: 0.5290	rho_loss: 0.0045
256: loss: 0.4418	rho_loss: 0.0043

Test set:
16: loss: 0.6243	rho_loss: 0.0046
64: loss: 0.5098	rho_loss: 0.0045
256: loss: 0.4253	rho_loss: 0.0043

Test set:
16: loss: 0.6076	rho_loss: 0.0045
64: loss: 0.4961	rho_loss: 0.0044
256: loss: 0.4142	rho_loss: 0.0043

Test set:
16: loss: 0.5928	rho_loss: 0.0045
64: loss: 0.4847	rho_loss: 0.0044
256: loss: 0.4060	rho_loss: 0.0042

Test set:
16: loss: 0.5826	rho_loss: 0.0044
64: loss: 0.4757	rho_loss: 0.0044
256: loss: 0.3994	rho_loss: 0.0042



KeyboardInterrupt: 

Plots of loss on validation set show smooth decrease of reconstruction and rho loss (for 256 latent dimensions and 50 epochs training). The model did not overfit on train data.

In [None]:
fig, ax = plt.subplots(2, figsize=(12,9))
ax[0].plot(np.array(test_log["256"])[:,0])
ax[0].set_title("Test reconstruction loss")
ax[1].plot(np.array(test_log["256"])[:,1])
ax[1].set_title("Test rho loss")
plt.show()

Checking the model outputs from different inputs. 

1. Reconstruction of batch of images from test set;
2. Decoding identity matrix, where only one latent neuron is active;
3. Counting number of hidden neurons with activations > 0.5. The model with 256 neurons has only few of them active at same time. Setting activations of those with < 0.5 to 0 to check what decoder can generate from those few.

In [None]:
data, _ = next(iter(test_loader))

#(1.)
output = models["256"](data)
to_plot = output.view(-1, 1, 28, 28).clamp(0, 1).data.numpy()

#(2.)
decoded = models["256"].decode(torch.eye(256))
dec_to_plot = ((decoded.view(-1, 1, 28, 28)+1)*0.5).clamp(0, 1).data.numpy()
with torch.no_grad():
    encoded = models["256"].E(data.view(-1, 28*28))
    
    #(3.)
    print("Number of neurons with activation > 0.5:\n", (encoded > 0.5).sum(1))
    encoded[encoded < 0.5] = 0.    
    decoded_f = models["256"].decode(encoded)
    f_to_plot = ((decoded_f.view(-1, 1, 28, 28)+1)*0.5).clamp(0, 1).data.numpy()

Plotting original and reconstructed images. The model makes a fairly decent reconstruction of input.

In [None]:
plot_mnist(data.data.numpy(), (5, 10))
plot_mnist(to_plot, (5, 10))

Plot of images reconstructed from latent space with zeroed "inactive" neurons (activation < 0.5). Many of the numbers a recognizable, so sparse latent representation did learn useful features. The deactivated neurons mostly were responsible for background noise.

In [None]:
plot_mnist(f_to_plot, (5, 10))

Plot of decoded images from identity matrix. The images look very blurry, but even with only one neuron active in latent layer there can be seen some structure.

In [None]:
plot_mnist(dec_to_plot, (16, 16))

### Reconstruction of corrupted images.

Training of SAE to remove noise.

In [None]:
# Add noise. 

def noise_pixels(x):
    f = x + torch.randn_like(x)    
    return f
    
def noise_batch(batch):
    batch_z = batch.clone().detach()
    for i in range(batch_z.shape[0]):
        batch_z[i] = noise_pixels(batch_z[i])
    return batch_z

In [None]:
models = {"256": Net(256)}
rho = 0.05
train_log = {k: [] for k in models}
test_log = {k: [] for k in models}

In [None]:
for epoch in range(1, 51):
    for model in models.values():
        model.train()
    train(epoch, models, train_log, add_noise=True)
    for model in models.values():
        model.eval()
    test(models, valid_loader, test_log, add_noise=True)

In [None]:
fig, ax = plt.subplots(2, figsize=(12,9))
ax[0].plot(np.array(test_log["256"])[:,0])
ax[0].set_title("Test reconstruction loss")
ax[1].plot(np.array(test_log["256"])[:,1])
ax[1].set_title("Test rho loss")
plt.show()

In [None]:
data, _ = next(iter(test_loader))
inputs = noise_batch(data)
output = models["256"](inputs)
to_plot = output.view(-1, 1, 28, 28).clamp(0, 1).data.numpy()

decoded = models["256"].decode(torch.eye(256))
dec_to_plot = ((decoded.view(-1, 1, 28, 28)+1)*0.5).clamp(0, 1).data.numpy()

with torch.no_grad():
    encoded = models["256"].E(inputs.view(-1, 28*28))    
    print("Number of neurons with activation > 0.5:\n", (encoded > 0.5).sum(1))
    encoded[encoded < 0.5] = 0.    
    decoded_f = models["256"].decode(encoded)
    f_to_plot = ((decoded_f.view(-1, 1, 28, 28)+1)*0.5).clamp(0, 1).data.numpy()

After 50 epochs of training model did quite a good job of reconstructing heavily noised images.

In [None]:
plot_mnist(inputs.data.numpy(), (5, 10))
plot_mnist(to_plot, (5, 10))

And it performed even better with deactevated neurons (<0.5 set to 0) than when it was trained on clean images.

In [None]:
plot_mnist(f_to_plot, (5, 10))

Reconstructions from one active neuron also look more clear.

In [None]:
plot_mnist(dec_to_plot[:64], (8, 8))

Training on images cut in half.

In [None]:
# Cuts left part.

def half_pixels(x):
    f = x    
    f[0,:,:14] = 0    
    return f
    
def half_batch(batch):
    batch_z = batch.clone().detach() 
    for i in range(batch_z.shape[0]):
        batch_z[i] = half_pixels(batch_z[i])
    return batch_z

In [None]:
models = {"256": Net(256)}
rho = 0.05
train_log = {k: [] for k in models}
test_log = {k: [] for k in models}

In [None]:
for epoch in range(1, 51):
    for model in models.values():
        model.train()
    train(epoch, models, train_log, half_image=True)
    for model in models.values():
        model.eval()
    test(models, valid_loader, test_log, half_image=True)

In [None]:
fig, ax = plt.subplots(2, figsize=(12,9))
ax[0].plot(np.array(test_log["256"])[:,0])
ax[0].set_title("Test reconstruction loss")
ax[1].plot(np.array(test_log["256"])[:,1])
ax[1].set_title("Test rho loss")
plt.show()

In [None]:
data, _ = next(iter(test_loader))
inputs = half_batch(data)
output = models["256"](inputs)
to_plot = output.view(-1, 1, 28, 28).clamp(0, 1).data.numpy()

decoded = models["256"].decode(torch.eye(256))
dec_to_plot = ((decoded.view(-1, 1, 28, 28)+1)*0.5).clamp(0, 1).data.numpy()

with torch.no_grad():
    encoded = models["256"].E(inputs.view(-1, 28*28))    
    print("Number of neurons with activation > 0.5:\n", (encoded > 0.5).sum(1))
    encoded[encoded < 0.5] = 0.    
    decoded_f = models["256"].decode(encoded)
    f_to_plot = ((decoded_f.view(-1, 1, 28, 28)+1)*0.5).clamp(0, 1).data.numpy()

Plot of reconstructed half images from test set.

In [None]:
plot_mnist(inputs.data.numpy(), (5, 10))
plot_mnist(to_plot, (5, 10))

Again, looks like making task more difficult for SAE forces it to extract better features (than training on original images).

In [None]:
plot_mnist(f_to_plot, (5, 10))

Although, reconstructions from identity matrix look worse in case of halved images.

In [None]:
plot_mnist(dec_to_plot[:64], (8, 8))