<a href="https://colab.research.google.com/github/wandb/edu/blob/main/lightning/autoencoder/autoencoder-mnist.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<img src="https://i.imgur.com/gb6B4ig.png" width="400" alt="Weights & Biases" />

# Autoencoder Networks for MNIST

Sources: 
- [Full Article](https://wandb.ai/ayush-thakur/lit-ae/reports/Autoencoder-An-Excersise--VmlldzoxMDIwNjgz)
- [GitHub](https://github.com/wandb/edu/blob/main/lightning/autoencoder/autoencoder-mnist.ipynb)
- [Wanb Project](https://wandb.ai/ayush-thakur/lit-ae/runs/2dsrj96p/overview)

In [None]:
# %%capture
# !pip install pytorch_lightning torchviz wandb

# repo_url = "https://raw.githubusercontent.com/wandb/edu/main/"
# utils_path = "lightning/utils.py"

# Download a util file of helper methods for this notebook
# !curl {repo_url + utils_path} > utils.py

import math
import os

import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
import torchvision
import wandb
import torchmetrics


import utils

In [None]:
class LitAE(utils.LoggedLitModule):
  """Generic autoencoder class for PyTorch Lightning.
  
  Also includes some under-the-hood Weights & Biases logging.

  Provide an encoder and decoder (both pl.LightningModules)
  and a config with information about the optimizer,
  and this class will create the autoencoder, set up the optimizers,
  and add output quality metric tracking via the Peak Signal-to-Noise Ratio.
  """

  def __init__(self, encoder, decoder, config):
    super().__init__()
    self.encoder = encoder
    self.decoder = decoder
    
    self.loss = config["loss"]
    self.optimizer = config["optimizer"]
    self.optimizer_params = config["optimizer.params"]

    # quality metric: peak signal-to-noise ratio, in decibels
    from torchmetrics.image import PeakSignalNoiseRatio
    psnr = PeakSignalNoiseRatio()
    self.training_metrics = torch.nn.ModuleList([psnr.clone()])
    self.validation_metrics = torch.nn.ModuleList([psnr.clone()])

  def forward(self, x):  # produce outputs
    return self.decoder(self.encoder(x))

  def configure_optimizers(self):  # ⚡: setup for .fit
    return self.optimizer(self.parameters(), **self.optimizer_params)

## Fully-Connected Encoder and Decoder

In [None]:
class EncoderFC(pl.LightningModule):
  """Fully-connected/torch.nn.Linear encoder.

  Pass in any configuration hyperparameters via the config argument.
  Applies resizing to inputs via AdapativeAvgPool2d so that it
  can work on images of varying size. Produces an encoding of the image
  as a vector with dimension config["encoding_dim"].
  """
  
  def __init__(self, config):
    super().__init__()
    self.config = config
    self.resize_layer = torch.nn.AdaptiveAvgPool2d(output_size=self.config["target_size"])
    self.flat_input_size = get_flat_size(self.config["target_size"])
    self.layers = torch.nn.Sequential(
        # add modules here
        torch.nn.Linear(self.flat_input_size, 128),
        self.config["activation"](),
        torch.nn.Linear(128, 64),
        self.config["activation"](),
        torch.nn.Linear(64, self.config["encoding_dim"]),
    )

  def forward(self, x):
    x = self.resize_layer(x)
    x = torch.flatten(x, start_dim=1)  # flatten all except batch dimension

    for layer in self.layers:
      x = layer(x)

    return x


class DecoderFC(pl.LightningModule):
  """Fully-connected/torch.nn.Linear decoder.

  Pass in any configuration hyperparameters via the config argument.
  Applies resizing to outputs via AdapativeAvgPool2d so that it
  can work on images of varying size. Consumes an encoding of the image
  as a vector with dimension config["encoding_dim"].
  Applies resizing to inputs so that it can work on images of varying size.
  """

  def __init__(self, config):
    super().__init__()
    self.config = config
    self.flat_output_size = get_flat_size(self.config["target_size"])
    self.layers = torch.nn.Sequential(
      # add modules here                                  
      torch.nn.Linear(self.config["encoding_dim"], 64),
      self.config["activation"](),
      torch.nn.Linear(64, 128),
      self.config["activation"](),
      torch.nn.Linear(128, self.flat_output_size, self.config["activation"]()),
      self.config["activation"](),
    )
    self.resize_layer = torch.nn.AdaptiveAvgPool2d(output_size=self.config["image_size"])

  def forward(self, x):
    for layer in self.layers:
      x = layer(x)

    x = torch.reshape(x, get_new_dims(x))  # reverse of flatten
    x = self.resize_layer(x)
    return x


def get_new_dims(x):
  assert len(x.shape) == 2, "expects a batch of vectors"
  batch, length = x.shape
  rows = int(math.sqrt(length))
  new_dims = (x.shape[0], 1, rows, -1)

  return new_dims


def get_flat_size(image_size):
  return np.prod(image_size)

## Convolutional Encoder and Decoder

In [None]:
class EncoderConv(pl.LightningModule):
  """Convolutional encoder.

  Pass in any configuration hyperparameters via the config argument.
  Applies resizing to inputs via AdapativeAvgPool2d so that it
  can work on images of varying size. Produces an encoding of the image
  as a 3d Tensor with channel dimension config["encoding_dim"].
  """

  def __init__(self, config):
    super().__init__()
    self.config = config
    self.resize_layer = torch.nn.AdaptiveAvgPool2d(output_size=self.config["image_size"])
    self.layers = torch.nn.Sequential(
      # add modules here
      torch.nn.Conv2d(1, 32, kernel_size=5),
      self.config["activation"](),
      torch.nn.Conv2d(32, self.config["encoding_dim"], kernel_size=3),
      self.config["activation"](),
      torch.nn.MaxPool2d(kernel_size=2),
    )

  def forward(self, x):
    x = self.resize_layer(x)

    for layer in self.layers:
      x = layer(x)
    return x


class DecoderConv(pl.LightningModule):
  """Convolutional decoder.

  Pass in any configuration hyperparameters via the config argument.
  Applies resizing to outpputs via AdapativeAvgPool2d so that it
  can work on images of varying size. Consumes an encoding of the image
  as a 3d Tensor with channel dimension config["encoding_dim"].
  """

  def __init__(self, config):
    super().__init__()
    self.config = config
    self.layers = torch.nn.Sequential(
        # add modules here
        torch.nn.ConvTranspose2d(self.config["encoding_dim"], 32, kernel_size=3, stride=2),
        self.config["activation"](),
        torch.nn.ConvTranspose2d(32, 1, kernel_size=5),
        self.config["activation"](),
    )
    self.resize_layer = torch.nn.AdaptiveAvgPool2d(output_size=self.config["image_size"])

  def forward(self, x):
    for layer in self.layers:
      x = layer(x)

    x = self.resize_layer(x)
    return x

## Training

To run training, execute the cell below.
You can configure the network and training procedure
by changing the values of the `config` dictionary.

Use the value of `erase` to switch tasks:
when `erase` is `True`,
a random portion of the input (but not the output!)
is erased before being fed to the network,
which makes the task a form of
[image in-painting](https://heartbeat.fritz.ai/guide-to-image-inpainting-using-machine-learning-to-edit-and-correct-defects-in-photos-3c1b0e13bbd0).
When it is `False`,
the input is unaltered,
and the task is a vanilla reconstruction task.

In between training runs,
especially runs that crashed,
you may wish to restart the notebook
and re-run the preceding cells
to get rid of accumulated state
(`Runtime > Restart runtime`).

In [None]:
###
# Setup Hyperparameters, Data, and Model
###


config = {  # dictionary of configuration hyperparameters
  "batch_size": 32,  # number of examples in a single batch
  "max_epochs": 10,  # number of times to pass over the whole dataset
  "image_size": (28, 28),  # size of images in this dataset
  "target_size": (28, 28),  # size of resized images fed to network
  "encoding_dim": 16,  # size/channel count of encoding of input
  "loss": torch.nn.MSELoss(),  # loss function
  "erase": False,  # set to False to deactivate input erasing, True to activate
  "activation": torch.nn.ReLU,  # activation function class (instantiated later)
  "optimizer": torch.optim.Adam,  # optimizer class (instantiated later)
  "optimizer.params":  # dict of hyperparameters for optimizer
    {"lr": 0.0001,  # learning rate to scale gradients
     "weight_decay": 0}  # if non-zero, reduce weights each batch
}

# 📝 if activated erases part of the image on each load
eraser = torchvision.transforms.RandomErasing(
    p=config["erase"], scale=[0.1, 0.2], ratio=[0.3, 3.3], value=0.13)
                                     
# 📸 set up the dataset of images
dmodule = utils.AutoEncoderMNISTDataModule(
    batch_size=config["batch_size"],
    transforms=eraser)
dmodule.prepare_data()
dmodule.setup()

# grab samples to log outputs on
samples = next(iter(dmodule.val_dataloader()))

# 🥅 instantiate the network
encoder = EncoderFC(config)
decoder = DecoderFC(config)
ae = LitAE(encoder, decoder, config)

###
# Train the model
###


with wandb.init(project="lit-ae", entity="wandb", config=config) as run:
  # 👀 watch the gradients and parameters, log to Weights & Biases
  wandb.watch(ae)

  image_logger = utils.ImageLogCallback(samples)  # logs inputs and outputs to Weights & Biases
  # logs the input and output weights to Weights & Biases
  filter_logger = utils.FilterLogCallback(  # for details see Challenge exercise below
    image_size=config["image_size"], log_input=True, log_output=True)

  # 👟 configure Trainer 
  trainer = pl.Trainer(
                      logger=pl.loggers.WandbLogger(
                        log_model=True, save_code=True),  # log to Weights & Biases
                      max_epochs=config["max_epochs"], log_every_n_steps=1,
                      callbacks=[image_logger, filter_logger],
                      enable_progress_bar=50)  # log to Weights & Biases
                      
  # 🏃‍♀️ run the Trainer on the model
  trainer.fit(ae, dmodule)

### Exercises

The cell above will output links to Weights & Biases dashboards
where you can review the training process and the final resulting model.

These dashboards will be useful in working through the exercises below.

#### 1. Choosing an Output Activation

The default configuration uses a
[`ReLU` activation](https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html#torch.nn.ReLU)
on all layers.
This activation has a high probability of outputting exactly zero.
Review the logged outputs of your neural network
in the Weights & Biases interface and look
for issues caused by this choice of activation.
If you notice any, try correcting them by using a different activation:
[`LeakyReLU`](https://pytorch.org/docs/stable/generated/torch.nn.LeakyReLU.html#torch.nn.LeakyReLU),
[`Sigmoid`](https://pytorch.org/docs/stable/generated/torch.nn.Sigmoid.html#torch.nn.Sigmoid), or
[`SiLU`](https://pytorch.org/docs/stable/generated/torch.nn.SiLU.html#torch.nn.SiLU).

> _Note_: Outputs are typically quite different
from the internal parts of the network,
like activation values.
For example, here they are always between 0 and 1,
to match the data they are compared against,
while activation values have no such restriction.
Therefore, it's not necessary or even advisable to use
the same activation for the output layer as for the hidden layers.
This will require you to slightly change the network code above!

#### 2. Change Hyperparameters

Better results can be obtained by tweaking the hyperparameters.
Try out different values in the `config`
and in the model definition and see what happens
to your results.

Here are some suggestions:
- Does increasing `batch_size` from `32` help or hurt?
What happens to the runtime if you decrease it to `1`?
What happens if you increase it to the maximum of `50_000`?
- If you increase `max_epochs` from `10` and train for longer,
does the model get better or worse?
- What happens when you increase the size of the model? You can do this by
increasing the value of `encoding_dim`,
increasing the number of hidden layers,
or increasing the hidden layer sizes
(set by the first two arguments of `Linear` and `Conv2D`).

#### 3. Convolutional Autoencoders

Try out the convolutional version of the autoencoder
(`encoder = EncoderConv` and `decoder = DecoderConv`).
Compare its performance on the usual autoencoder task (`erase=0.`)
and on the in-filling task (`erase=1.`)
to that of the fully-connected network.
Why do you think the convolutional network struggles with the in-filling task?

#### **Challenge**: Regularization and Learned Weights
In addition to logging the inputs, outputs, and metrics,
the training run also logs the "filters" for the network --
the first and last weights, which are applied directly to the inputs
and directly produce the outputs, respectively.

These filters can be interpreted as images,
helping us see what the network is looking for in the inputs
and using to construct the outputs.

In a well-trained neural network,
these filters would look like the "units" from which our inputs are built:
tiny pen-strokes or patches of brightness and/or darkness.

With the default settings, the learned filters don't look like that at all.
They either look like [white noise](https://en.wikipedia.org/wiki/White_noise),
like the static pattern that appears on a detuned television screen,
or they look like entire digits, memorized from the training set.

Getting the weights to converge to good filters is challenging.
You might use any or all of the following techniques:
- Increase `max_epochs` and continue training
long after the loss has stopped decreasing
- Add
[DropOut](https://pytorch.org/docs/stable/generated/torch.nn.Dropout.html)
([DropOut2d](https://pytorch.org/docs/stable/generated/torch.nn.Dropout2d.html#torch.nn.Dropout2d)
for convolutional networks)
after the activation layers.
Use a drop probability between `0.1` and `0.5`.
- Increase the `weight_decay` parameter of the optimizer,
trying values between `1e-12` and `1e-2`.
Weight decay is similar to ridge regression
or $\ell_2$-regularization from traditional ML
([read more here](https://towardsdatascience.com/this-thing-called-weight-decay-a7cd4bcfccab)).

#### **Challenge**: Skip Connections

A common technique for improving performance in computer vision models is adding
[skip connections](https://theaisummer.com/skip-connections/) --
transformations that "skip over" intervening layers.
These "shortcuts" allow information to flow more smoothly through the network
and stabilize training -- enabling more choices of optimizer,
layer size, and nonlinearity to reach good performance.

In a small autoencoder like this one, we might write one skip connection
from the input to the hidden layer
and another from the hidden layer to the output.
Try adding these to `EncoderConv` and `DecoderConv`.

> _Note_: In the style of the encoder and decoder modules above,
a module with a skip connection might have a `.forward` method like this one:
```python
def forward(self, xs):
  skip = self.skip(xs)
  for layer in self.layers:
    xs = layer(xs)
  return xs + skip
```
where `.skip` is a `torch.nn.Linear` or `torch.nn.Conv2D` layer. Notice that `xs` and `skip` are added together, and so need to be the same shape!

> _Note_: In the approach popularized by
[Residual Networks](https://towardsdatascience.com/an-overview-of-resnet-and-its-variants-5281e2f56035),
the skip connection doesn't transform the input at all
(akin, in the example above, to using `torch.nn.Identity` for `skip`).
This allows for very efficient and stable gradient flow
in networks of extreme depth,
but requires that the skipped-over layers don't change the shape.