<h1> Autoencoder </h1>
Autoencoders are a fairly straightforward network structure, characterised by a "bottleneck" where the input is "compressed" before being upsampled again. This network can be used to create compressed representations of images by training the model to reconstruct the input on the output. It could also be used for our segmentation problem! However in segmentation, we don't really want our network to compress our image, we want it to do some "work" and then give us a segmented version of the input!
<img src="https://miro.medium.com/max/3148/1*44eDEuZBEsmG_TCAKRI3Kw@2x.png" width="750" align="center">

[Autoencoders](https://towardsdatascience.com/applied-deep-learning-part-3-autoencoders-1c083af4d798)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.datasets as Datasets
import torchvision.transforms as transforms
import torch.nn.functional as F
import torchvision.models as models
import torchvision.utils as vutils

import os
import random
import numpy as np
import math
from IPython.display import clear_output
import matplotlib.pyplot as plt
from PIL import Image
from tqdm.notebook import trange, tqdm

In [None]:
batchSize = 64

# Define learning rate
lr = 1e-4

# Number of Training epochs
nepoch = 10

# Dataset location
root = "../../datasets"

# Scale for the added image noise
noise_scale = 0.3

In [None]:
use_cuda = torch.cuda.is_available()
gpu_indx  = 0
device = torch.device(gpu_indx if use_cuda else "cpu")

<h3>Create an MNIST dataset and dataloader</h3>

In [None]:
# Define our transform
# We'll upsample the images to 32x32 as it's easier to contruct our network
transform = transforms.Compose([
            transforms.Resize(32),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5])])

train_set = Datasets.MNIST(root=root, train=True, transform=transform, download=True)
train_loader = DataLoader(train_set, batch_size=batchSize,shuffle=True, num_workers=4)

test_set = Datasets.MNIST(root=root, train=False, transform=transform, download=True)
test_loader = DataLoader(test_set, batch_size=batchSize, shuffle=False, num_workers=4)

<h3>Transpose Convolution</h3>
The AE model introduces a new layer-type the "Transpose convolution" (sometimes called "Deconvolution")<br>
The transpose convolution is a "learnable upsampling" method and is essentially the opposite of a convolution! We take a single feature (pixel) in our feature map and replicate it and multiply by a kernel, any overlapping sections are added together. The easiest way to understand them is with the following animation (where the blue square is the input and green is the output).
<img src="https://miro.medium.com/max/986/1*yoQ62ckovnGYV2vSIq9q4g.gif" width="750" align="center">

[Blog: Transposed Convolutions explained](https://medium.com/apache-mxnet/transposed-convolutions-explained-with-ms-excel-52d13030c7e8)<br>
[Blog: Deconvolution and Checkerboard Artifacts](https://distill.pub/2016/deconv-checkerboard/)

## AE Network

In [None]:
# We split up our network into two parts, the Encoder and the Decoder
class Encoder(nn.Module):
    def __init__(self, channels, ch=32, z=32):
        super(Encoder, self).__init__()
        # Create the Encoder layers
        # Use Conv2d layers to downsample!
        
        self.conv1 = # To Do kernel - channels X ch X 4 x 4, stride 2
        self.bn1 = # To Do Batch-Norm
        
        self.conv2 = # To Do kernel - ch X ch*2 X 4 X 4, stride 2
        self.bn2 = # To Do Batch-Norm
        
        self.conv3 = # To Do kernel - ch*2 X ch*4 X 4 x 4, stride 2
        self.bn3 = # To Do Batch-Norm

        self.conv_out = # To Do kernel - ch*4 X z X 4 x 4, stride 1
        
    def forward(self, x):
        x = # To Do Conv1, Batch-Norm1, Relu
        x = # To Do Conv2, Batch-Norm2, Relu
        x = # To Do Conv3, Batch-Norm3, Relu

        return # To Do Conv out
    
class Decoder(nn.Module):
    def __init__(self, channels, ch = 32, z = 32):
        super(Decoder, self).__init__()
        # Create the Decoder layers
        # Use ConvTranspose2d layers to upsample!
        
        self.conv1 = # To Do kernel - z X ch*4 4 X 4, stride 2
        self.bn1 = # To Do Batch-Norm
        
        self.conv2 =  # To Do kernel - ch*4 X ch*2 X 4 x 4, stride 2
        self.bn2 = # To Do Batch-Norm
        
        self.conv3 =  # To Do kernel - ch*2 X ch X 4 x 4, stride 2
        self.bn3 = # To Do Batch-Norm
        
        self.conv4 =  # To Do kernel - ch X channels X 4 x 4, stride 2

    def forward(self, x):
        x = # To Do Conv1, Batch-Norm1, Relu
        x = # To Do Conv2, Batch-Norm2, Relu
        x = # To Do Conv3, Batch-Norm3, Relu
        x = # To Do Conv4, tanh

        return x
    
class AE(nn.Module):
    def __init__(self, channel_in, ch=16, z=32):
        super(AE, self).__init__()
        self.encoder = Encoder(channels=channel_in, ch=ch, z=z)
        self.decoder = Decoder(channels=channel_in, ch=ch, z=z)

    def forward(self, x):
        encoding = self.encoder(x)
        x = self.decoder(encoding)
        return x, encoding

<h3>Visualize our data</h3>

In [None]:
# Get a test image
dataiter = iter(test_loader)
test_images = dataiter.next()[0]
# View the shape
test_images.shape

In [None]:
# Visualize the data!!!
plt.figure(figsize = (20,10))
out = vutils.make_grid(test_images[0:8], normalize=True)
_ = plt.imshow(out.numpy().transpose((1, 2, 0)))

### De-noising Autoencoder
While an Autoencoder can be used to simply compress the input into a lower-dimentional space lets also see how we can use it to remove some noise from an image!<br>
We're going to simulate some [salt-and-pepper noise!](https://en.wikipedia.org/wiki/Salt-and-pepper_noise)

In [None]:
# Visualize the data!!!
plt.figure(figsize = (20, 10))
random_sample = (torch.bernoulli((1 - noise_scale) * torch.ones_like(test_images)) * 2) - 1
noisy_test_img = random_sample * test_images

out = vutils.make_grid(noisy_test_img[0:8], normalize=True)
_ = plt.imshow(out.numpy().transpose((1, 2, 0)))

<h3>Create Network and Optimizer</h3>

In [None]:
# The size of the Latent Vector
latent_size = 128

# Create our network
ae_net = AE(channel_in=1, z=latent_size).to(device)

# Setup optimizer
optimizer = optim.Adam(ae_net.parameters(), lr=lr)

# MSE loss for reconstruction!
loss_func = nn.MSELoss()

loss_log = []
train_loss = 0

<h4>Network output</h4>

In [None]:
# Pass through a test image to make sure everything is working
recon_data, encoding = ae_net(test_images.to(device))

# View the Latent vector shape
encoding.shape

<h2>Start training!</h2>

In [None]:
pbar = trange(0, nepoch, leave=False, desc="Epoch")    
for epoch in pbar:
    pbar.set_postfix_str('Loss: %.4f' % train_loss)
    for i, data in enumerate(tqdm(train_loader, leave=False, desc="Training")):

        image = data[0].to(device)
        
        # Create the noisy data!
        random_sample = (torch.bernoulli((1 - noise_scale) * torch.ones_like(image)) * 2) - 1
        noisy_img = random_sample * image
        
        # Forward pass the image in the data tuple
        recon_data, _ = ae_net(noisy_img)
        
        # Calculate the MSE loss
        loss = loss_func(recon_data, image)
        
        # Log the loss
        loss_log.append(loss.item())
        train_loss = loss.item()
        
        # Take a training step
        ae_net.zero_grad()
        loss.backward()
        optimizer.step()

## Results!

In [None]:
# Plot the loss over time
_ = plt.plot(loss_log)
_ = plt.title("MSE Loss")

In [None]:
# Ground Truth
plt.figure(figsize = (20,10))
out = vutils.make_grid(test_images[0:8], normalize=True)
_ = plt.imshow(out.numpy().transpose((1, 2, 0)))

In [None]:
# Noisy Input
plt.figure(figsize = (20,10))
out = vutils.make_grid(noisy_test_img[0:8], normalize=True)
_ = plt.imshow(out.numpy().transpose((1, 2, 0)))

In [None]:
# Reconstruction
plt.figure(figsize = (20,10))
recon_data, _ = ae_net(noisy_test_img.to(device))
out = vutils.make_grid(recon_data.detach().cpu()[0:8], normalize=True)
_ = plt.imshow(out.numpy().transpose((1, 2, 0)))