# CNN Autoencoder on mnist

The digits can be reconstructed to a pretty decent degree, with the same weird effect of overfitting a single image. The training seems to have phases where the model parameters / loss plateau before improving. Also the reconstructed images are a little bit blurry, but subjectively not as much as in the [fashion mnist example used in the lectures](https://github.com/fastai/course22p2/blob/master/nbs/08_autoencoder.ipynb).

## References

* fastai 2022 / 2023 course part II:
    * [notebook 8](https://github.com/fastai/course22p2/blob/master/nbs/08_autoencoder.ipynb)
    * [lesson 15](https://course.fast.ai/Lessons/lesson15.html)

## Setup

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import random
import re
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch
import torch.nn as nn
import torch.optim as optim
import tqdm
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from torch.optim import SGD
from torch.utils.data import DataLoader, Dataset

import random_neural_net_models.autoencoder_fastai2022 as ae
import random_neural_net_models.convolution_lecun1990 as conv_lecun1990
import random_neural_net_models.data as rnnm_data
import random_neural_net_models.learner as rnnm_learner
import random_neural_net_models.losses as rnnm_losses
import random_neural_net_models.telemetry as telemetry
import random_neural_net_models.utils as utils

sns.set_theme()

In [None]:
DO_OVERFITTING_ONLY = True

In [None]:
mnist = fetch_openml("mnist_784", version=1, cache=True, parser="auto")

Setting seeds

In [None]:
utils.make_deterministic(42)

Getting device

In [None]:
device = utils.get_device()
device

In [None]:
X = mnist["data"]
y = mnist["target"]
X.shape, y.shape

Selecting a few images to overfit on

In [None]:
n = 1
X0, y0 = X.iloc[:n], y.iloc[:n]
X0.shape

## Defining dataset and dataloader

In [None]:
ds = conv_lecun1990.DigitsDataset(X0, y0)

In [None]:
item = ds[0]
plt.imshow(item[0], cmap="gray", origin="upper")
plt.title(f"Label: {item[1]}")
plt.axis("off")
plt.tight_layout()

defining a dataloader

In [None]:
batch_size = 1
dataloader = DataLoader(ds, batch_size=batch_size, shuffle=False)

In [None]:
item[0].shape

## overfitting

defining the model

In [None]:
model = ae.CNNAutoEncoder()
model = telemetry.ModelTelemetry(
    model,
    loss_names=("total",),
    activations_name_patterns=(".*act.*",),
    gradients_name_patterns=(".*conv\d$",),
    parameters_name_patterns=(".*conv\d$",),
)
model.double()
model.to(device);

In [None]:
opt = SGD(
    model.parameters(),
    lr=0.1,
)

In [None]:
loss_func = nn.MSELoss()

In [None]:
_iter = 0

training loop

In [None]:
n_epochs = 1_000

model.train()
for epoch in tqdm.tqdm(range(n_epochs), desc="Epochs", total=n_epochs):
    for i, (xb, _) in enumerate(dataloader):
        xb = xb.to(device)
        x_pred = model(xb)

        loss = loss_func(x_pred, xb)

        opt.zero_grad()
        loss.backward()
        opt.step()

        model.loss_history_train(loss, _iter)
        model.parameter_history(_iter)

        _iter += 1

print("Done!")

plotting gradients

In [None]:
model.draw_gradient_stats(yscale="log", figsize=(12, 20))

plotting activations

In [None]:
model.draw_activation_stats(yscale="log")

drawing histograms of the weights and biases across training iterations

In [None]:
model.draw_parameter_stats()

plotting losses

In [None]:
model.draw_loss_history_train()

inference over samples

In [None]:
train_features, _ = next(iter(dataloader))

In [None]:
model.eval();

inspecting predictions

In [None]:
train_features = train_features.to(device)
preds = model(train_features)
preds[0, :5, :5]

In [None]:
x_pred = preds.detach().cpu().numpy()
x_pred[0, :3, :5]

In [None]:
img = train_features[0].cpu()
img_pred = x_pred[0]
fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(10, 5))
ax = axs[0]
ax.imshow(img, cmap="gray")
ax.set_title("Input image")
ax.axis("off")
ax = axs[1]
ax.imshow(img_pred, cmap="gray")
ax.set_title("Reconstructed image")
ax.axis("off")
plt.show()

In [None]:
model.clean_hooks()

## overfitting with `Learner`

In [None]:
ds_train = rnnm_data.MNISTDatasetWithLabels(X0, y0)
dl_train = DataLoader(
    ds_train,
    batch_size=1,
    collate_fn=rnnm_data.collate_mnist_dataset_to_block_with_labels,
    shuffle=True,
)
next(iter(dl_train))

In [None]:
model = ae.CNNAutoEncoder2()

In [None]:
n_epochs = 1_000
lr = 1e-2

optimizer = optim.Adam(model.parameters(), lr=lr)

loss = rnnm_losses.MSELossMNISTAutoencoder()
save_dir = Path("./models")

loss_callback = rnnm_learner.TrainLossCallback()
activations_callback = rnnm_learner.TrainActivationsCallback(
    every_n=100, max_depth_search=4, name_patterns=(".*act.*",)
)
gradients_callback = rnnm_learner.TrainGradientsCallback(
    every_n=100, max_depth_search=4, name_patterns=(".*conv\d$",)
)
parameters_callback = rnnm_learner.TrainParametersCallback(
    every_n=100, max_depth_search=4, name_patterns=(".*conv\d$",)
)

callbacks = [
    loss_callback,
    activations_callback,
    gradients_callback,
    parameters_callback,
]

learner = rnnm_learner.Learner(
    model,
    optimizer,
    loss,
    callbacks=callbacks,
    save_dir=save_dir,
    device=device,
)

In [None]:
lr_find_callback = rnnm_learner.LRFinderCallback(1e-5, 10, 100)

learner.find_learning_rate(
    dl_train, n_epochs=200, lr_find_callback=lr_find_callback
)

In [None]:
lr_find_callback.plot()

In [None]:
lr = 1e-2
scheduler = optim.lr_scheduler.OneCycleLR(
    optimizer=optimizer,
    max_lr=lr,
    epochs=n_epochs,
    steps_per_epoch=len(dl_train),
)
scheduler_callback = rnnm_learner.EveryBatchSchedulerCallback(scheduler)
learner.update_callback(scheduler_callback)

In [None]:
learner.fit(dl_train, n_epochs=n_epochs)

In [None]:
loss_callback.plot()

In [None]:
x_pred = learner.predict(dl_train)

In [None]:
img = next(iter(dl_train)).image[0]
img_pred = x_pred[0]
fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(10, 5))
ax = axs[0]
ax.imshow(img, cmap="gray")
ax.set_title("Input image")
ax.axis("off")
ax = axs[1]
ax.imshow(img_pred, cmap="gray")
ax.set_title("Reconstructed image")
ax.axis("off")
plt.show()

In [None]:
parameters_callback.plot()

In [None]:
gradients_callback.plot()

In [None]:
activations_callback.plot()

In [None]:
if DO_OVERFITTING_ONLY:
    raise SystemExit("Skipping training beyond overfitting.")

## Reproducing 10 digits

In [None]:
def draw_pair(img: torch.Tensor, img_pred: torch.Tensor):
    fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(10, 5))
    ax = axs[0]
    ax.imshow(img, cmap="gray")
    ax.set_title("Input image")
    ax.axis("off")
    ax = axs[1]
    ax.imshow(img_pred, cmap="gray")
    ax.set_title("Reconstructed image")
    ax.axis("off")
    plt.show()


def draw_n_pairs(
    input_features: torch.Tensor, x_pred: torch.Tensor, n: int = 5
):
    _n = min(n, len(input_features))
    print(f"Drawing {_n} pairs")
    for i in range(_n):
        img = input_features[i].cpu()
        img_pred = x_pred[i]
        draw_pair(img, img_pred)

In [None]:
X0, X2, y0, y2 = train_test_split(X, y, test_size=0.2, random_state=42)
X0, X1, y0, y1 = train_test_split(X, y, test_size=0.2, random_state=42)

In [None]:
ds = conv_lecun1990.DigitsDataset(X0, y0)
ds_valid = conv_lecun1990.DigitsDataset(X1, y1)
ds_test = conv_lecun1990.DigitsDataset(X2, y2)

In [None]:
batch_size = 256
dataloader = DataLoader(ds, batch_size=batch_size, shuffle=True)
dataloader_valid = DataLoader(ds_valid, batch_size=500, shuffle=False)
dataloader_test = DataLoader(ds_test, batch_size=500, shuffle=False)

defining the model

In [None]:
model = ae.CNNAutoEncoder()
model = telemetry.ModelTelemetry(
    model,
    loss_names=("total",),
    gradients_every_n=100,
    activations_every_n=100,
    parameters_every_n=100,
    activations_name_patterns=(".*act.*",),
    gradients_name_patterns=(".*conv\d",),
    parameters_name_patterns=(".*conv\d",),
)
model.double()
model.to(device);

In [None]:
opt = SGD(
    model.parameters(),
    lr=0.1,
)

In [None]:
loss_func = nn.MSELoss()

In [None]:
_iter = 0

training loop

In [None]:
n_epochs = 8

model.train()
for epoch in tqdm.tqdm(range(n_epochs), desc="Epochs", total=n_epochs):
    for i, (xb, _) in tqdm.tqdm(
        enumerate(dataloader), desc="Batches", total=len(dataloader)
    ):
        xb = xb.to(device)
        x_pred = model(xb)

        loss = loss_func(x_pred, xb)

        opt.zero_grad()
        loss.backward()
        opt.step()

        model.loss_history_train(loss, _iter)
        model.parameter_history(_iter)

        _iter += 1

    # compute validation loss
    with torch.no_grad():
        model.eval()
        xs_pred, xs_true = [], []
        for xb, _ in dataloader_test:
            xb = xb.to(device)

            x_pred = model(xb)
            xs_pred.append(x_pred)
            xs_true.append(xb)

        x_pred = torch.cat(xs_pred, dim=0)
        x_true = torch.cat(xs_true, dim=0)
        loss_test = loss_func(x_pred, x_true)
        model.loss_history_test(loss_test, _iter)

        model.train()

print("Done!")

plotting gradients

In [None]:
model.draw_gradient_stats(yscale="log", figsize=(12, 20))

plotting activations

In [None]:
model.draw_activation_stats(yscale="log")

drawing histograms of the weights and biases across training iterations

In [None]:
model.draw_parameter_stats()

plotting losses

In [None]:
model.draw_loss_history_train()

In [None]:
model.draw_loss_history_test()

In [None]:
test_features, _ = next(iter(dataloader_test))

In [None]:
model.eval();

inspecting predictions

In [None]:
test_features = test_features.to(device)
preds = model(test_features)
preds[0, :5, :5]

In [None]:
x_pred = preds.detach().cpu().numpy()
x_pred[0, :3, :5]

In [None]:
draw_n_pairs(test_features, x_pred, n=16)

In [None]:
model.clean_hooks()

## now with `Learner`

In [None]:
ds_train = rnnm_data.MNISTDatasetWithLabels(X0, y0)
ds_valid = rnnm_data.MNISTDatasetWithLabels(X1, y1)
ds_test = rnnm_data.MNISTDatasetWithLabels(X2, y2)

In [None]:
dl_train = DataLoader(
    ds_train,
    batch_size=256,
    collate_fn=rnnm_data.collate_mnist_dataset_to_block_with_labels,
    shuffle=True,
)
dl_valid = DataLoader(
    ds_valid,
    batch_size=500,
    collate_fn=rnnm_data.collate_mnist_dataset_to_block_with_labels,
    shuffle=False,
)
dl_test = DataLoader(
    ds_test,
    batch_size=500,
    collate_fn=rnnm_data.collate_mnist_dataset_to_block_with_labels,
    shuffle=False,
)

In [None]:
model = ae.CNNAutoEncoder2()

In [None]:
n_epochs = 8
lr = 1e-3

optimizer = optim.Adam(model.parameters(), lr=lr)

loss = rnnm_losses.MSELossMNISTAutoencoder()
save_dir = Path("./models")

loss_callback = rnnm_learner.TrainLossCallback()
activations_callback = rnnm_learner.TrainActivationsCallback(
    every_n=100, max_depth_search=4, name_patterns=(".*act.*",)
)
gradients_callback = rnnm_learner.TrainGradientsCallback(
    every_n=100, max_depth_search=4, name_patterns=(".*conv\d$",)
)
parameters_callback = rnnm_learner.TrainParametersCallback(
    every_n=100, max_depth_search=4, name_patterns=(".*conv\d$",)
)

callbacks = [
    loss_callback,
    activations_callback,
    gradients_callback,
    parameters_callback,
]


learner = rnnm_learner.Learner(
    model,
    optimizer,
    loss,
    callbacks=callbacks,
    save_dir=save_dir,
    device=device,
)

In [None]:
lr_find_callback = rnnm_learner.LRFinderCallback(1e-5, 10, 100)

learner.find_learning_rate(
    dl_train, n_epochs=2, lr_find_callback=lr_find_callback
)

In [None]:
lr_find_callback.plot()

In [None]:
lr = 1e-2
scheduler = optim.lr_scheduler.OneCycleLR(
    optimizer=optimizer,
    max_lr=lr,
    epochs=n_epochs,
    steps_per_epoch=len(dl_train),
)
scheduler_callback = rnnm_learner.EveryBatchSchedulerCallback(scheduler)
learner.update_callback(scheduler_callback)

In [None]:
learner.fit(dl_train, n_epochs=n_epochs)

In [None]:
loss_callback.plot()

In [None]:
parameters_callback.plot()

In [None]:
gradients_callback.plot()

In [None]:
activations_callback.plot()

In [None]:
x_pred = learner.predict(dl_test)

In [None]:
test_features = next(iter(dl_test))

In [None]:
draw_n_pairs(test_features.image, x_pred, n=16)