In [None]:
import torch
import torch.nn as nn
import torch.utils as utils
import torchvision.datasets as dataset
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
%matplotlib inline

In [None]:
mnist_train = dataset.MNIST("./", train=True, transform=transforms.ToTensor(),  download=True)
mnist_val  = dataset.MNIST("./", train=False, transform=transforms.ToTensor(), download=True)

batch_size = 100

train_loader = torch.utils.data.DataLoader(dataset = mnist_train,
                                           batch_size = batch_size,
                                           shuffle = True)

val_loader = torch.utils.data.DataLoader(dataset = mnist_val,
                                         batch_size = batch_size,
                                         shuffle = False)

Let's explore the MNIST dataset

In [None]:
fig, axs = plt.subplots(1, 6, figsize=(20,5))
np.vectorize(lambda ax:ax.axis('off'))(axs)

for i in range(6):
    axs[i].imshow(mnist_train[i][0].squeeze().numpy(), cmap = 'gray')
    axs[i].set_title("Label: {}".format(mnist_train[i][1]))

**[TO DO]** Create a funtion which samples gaussian noise with a given mean and standard deviation. This will be used to corrupt the input

In [None]:
def gaussian_noise(mean, std, shape):
    noise = torch.randn(shape) * std + mean 
    return noise

**[To DO]** Build a Convolutional Autoencoder

- The Encoder is composed of three convolutional layers, each separated by a ReLU activation function. These layers are configured with output channels (or number of filters) in the following sequence: [32, 32, 64].
- The Decoder consists of three convolutional layers with a ReLU activation funtion in-between. However, the number of output channels here is the reverse of the Encoder's configuration. Additionally, the Decoder employs [Pixel Shuffle](https://pytorch.org/docs/stable/generated/torch.nn.PixelShuffle.html) after each of the first two convolutional layers to upsample the input (therefore, make sure to predict $s^2$ channels for **each** output channel, which are to be allocated by Pixel Shuffle.) Use an upscale factor $s = 2$ for each of the two Pixel Shuffle Layers. The final convolutional layer in the Decoder reduces the output channels to match the number of channels in the target image. It is essential to ensure that the Decoder's output generates an image that matches the input image in shape, since we use a mean squared error reconstruction term.

In [None]:
class AutoEncoder(nn.Module):
    def __init__(self):
        super(AutoEncoder,self).__init__()
        # 32, 32, 64
        self.encoder = nn.Sequential(nn.Conv2d(1, 32, 3, stride=2, padding=1), #  batch, 32, 14x14
                                     nn.ReLU(),                                #  batch, 32, 14x14
                                     nn.Conv2d(32, 32, 3, stride=2, padding=1),# batch,  32, 7 x 7
                                     nn.ReLU(),                                # batch,  32, 7 x 7
                                     nn.Conv2d(32, 64, 3, stride=1, padding=1)) # batch,  64, 7 x 7
        
        # 64, 32, 32
        upscale_factor = 2
        self.decoder = nn.Sequential(nn.Conv2d(64, 32 * (upscale_factor ** 2), 3, stride=1, padding=1),  # (batch, 32, 7, 7)
                                     nn.ReLU(),
                                     nn.PixelShuffle(upscale_factor),  # (batch, 32, 14x14)
                                     nn.Conv2d(32, 32 * (upscale_factor ** 2), 3, stride=1, padding=1), 
                                     nn.ReLU(),
                                     nn.PixelShuffle(upscale_factor),  # (batch, 32, 28,28),
                                     nn.Conv2d(32, 1, 3, stride=1, padding=1))
                                 
    def forward(self,x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x 

In [None]:
def plot_losses(train_loss, val_loss):

    epochs = range(1, len(train_loss) + 1)

    plt.plot(epochs, train_loss, '--', label='Training loss')
    plt.plot(epochs, val_loss, 'r', label='Validation loss')
    plt.title('Training and validation loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()

    plt.show()

In [None]:
learning_rate = 0.001
epochs = 5

model = AutoEncoder().to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
# Pay attention that for the validation set we are generating noise once, to do a consistent calculation of the validation loss 
val_noise = gaussian_noise(mean=0, std=0.5, shape=(batch_size, 1, 28, 28)).to(device)

**[TO DO]** Fill in the missing parts. These include:

- Constructing the training input image 
- Calcualting the loss
- Constructing the validation input image

In [None]:
training_losses = []
validation_losses = []

for i in range(epochs):
    
    model.train()
    epoch_loss_train = 0
    epoch_loss_val = 0
    
    for image,_ in train_loader:
        # Pay attention that for the training set we are generating noise at for every batch.
        input_image = image.to(device)
        train_noise = gaussian_noise(mean = 0, std = 0.5, shape = input_image.shape).to(device)
        noisy_input = input_image + train_noise

        optimizer.zero_grad()
        outputs = model(noisy_input)

        loss = criterion(outputs, input_image)
        loss.backward()
        optimizer.step()
        epoch_loss_train += loss.item()
        
    train_loss = epoch_loss_train / len(train_loader)
    training_losses.append(train_loss)
    
    model.eval()
    
    for val_image, _ in val_loader:
        
        input_image = val_image.to(device)
        noisy_input = input_image + val_noise
        
        outputs = model(noisy_input)
        val_loss = criterion(outputs,input_image)
        epoch_loss_val += val_loss.item()
        
    val_loss = epoch_loss_val / len(val_loader)
    validation_losses.append(val_loss)
              
                
    print("Training loss:",epoch_loss_train / len(train_loader),
          " Validation loss:", epoch_loss_val / len(val_loader))

In [None]:
plot_losses(train_loss=training_losses, val_loss=validation_losses)

Visualize the output. Take a random sample from the validation set, corrupt it with noise and run the noisy image through the model. Does the model remove noise from this image? Rememeber to move the channels of the PyTorch Tensor to the end (as what numpy expects), move your Tensor to the CPU and convert it to numpy! 

In [None]:
model.eval()
images, _ = next(iter(val_loader))  

In [None]:
index = 3
input_image = images[index].unsqueeze(0).to(device)
noisy_image = input_image + val_noise
result = model(noisy_image)

original = images[index].permute(1,2,0).detach().cpu().numpy()
noisy_image =  noisy_image[0].permute(1,2,0).detach().cpu().numpy()
result = result[0].permute(1,2,0).detach().cpu().numpy()

fig, axs = plt.subplots(1, 4, figsize=(20,5))
np.vectorize(lambda ax:ax.axis('off'))(axs)
axs[0].imshow(original, cmap = 'gray')
axs[1].imshow(noisy_image, cmap = 'gray')
axs[2].imshow(result, cmap = 'gray')