# ResNet classifier on MNIST

## References

* fastai 2022 / 2023 course part II:
    * [notebook 13](https://github.com/fastai/course22p2/blob/master/nbs/13_resnet.ipynb)
    * [lesson 18](https://course.fast.ai/Lessons/lesson18.html)
* He et al. 2015, [Deep Residual Learning for Image Recognition](http://arxiv.org/abs/1512.03385)

## Setup

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import random
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch
import torch.nn.functional as F
import torch.optim as optim
import torchinfo
import tqdm
from sklearn import metrics
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from torch.optim import SGD, Adam
from torch.utils.data import DataLoader

import random_neural_net_models.convolution_lecun1990 as conv_lecun1990
import random_neural_net_models.data as rnnm_data
import random_neural_net_models.learner as rnnm_learner
import random_neural_net_models.losses as rnnm_losses
import random_neural_net_models.resnet as resnet
import random_neural_net_models.telemetry as telemetry
import random_neural_net_models.utils as utils

sns.set_theme()

In [None]:
DO_OVERFITTING_ONLY = True

In [None]:
mnist = fetch_openml("mnist_784", version=1, cache=True, parser="auto")

Setting seeds

In [None]:
utils.make_deterministic(42)

Getting device

In [None]:
device = utils.get_device()
device

In [None]:
X = mnist["data"]
y = mnist["target"]
X.shape, y.shape

Selecting a few images to overfit on (limiting to the number 5)

In [None]:
n0 = 32
n1 = 1_000
is_5 = y == "5"
X0, y0 = X.loc[is_5].iloc[:n0], y.loc[is_5].iloc[:n0]
X1, y1 = X.loc[is_5].iloc[n0 : n1 + n0], y.loc[is_5].iloc[n0 : n0 + n1]
X0.shape, X1.shape

## Defining dataset and dataloader

In [None]:
ds = conv_lecun1990.DigitsDataset(X0, y0)
ds_test = conv_lecun1990.DigitsDataset(X1, y1)

In [None]:
item = ds[0]
plt.imshow(item[0], cmap="gray", origin="upper")
plt.title(f"Label: {item[1]}")
plt.axis("off")
plt.tight_layout()

defining a dataloader

In [None]:
batch_size = n0
dataloader = DataLoader(ds, batch_size=batch_size, shuffle=False)
dataloader_test = DataLoader(ds_test, batch_size=500, shuffle=False)

## overfitting

defining the model

In [None]:
model = resnet.ResNet(
    nfs=(
        8,
        16,
        32,
        64,
    ),  # 128,),#256),
)
model = telemetry.ModelTelemetry(
    model,
    parameters_every_n=10,
    activations_every_n=1,
    gradients_every_n=10,
    activations_name_patterns=(".*act.*",),
    gradients_name_patterns=(r".*conv\d", ".*lin"),
    parameters_name_patterns=(
        r".*conv\d",
        ".*lin",
    ),
)
model.double()
model.to(device);

ResBlock(1, 8, stride=1, ks=3): [1, 8, 28, 28]
ResBlock(8, 16, stride=2, ks=3): [1, 16, 14, 14]
ResBlock(16, 32, stride=2, ks=3): [1, 32, 7, 7]
ResBlock(32, 64, stride=2, ks=3): [1, 64, 4, 4]
ResBlock(64, 128, stride=2, ks=3): [1, 128, 2, 2]
ResBlock(128, 256, stride=2, ks=3): [1, 256, 1, 1]

In [None]:
# m = resnet.ResBlock(128, 256, stride=2, ks=3)
# res = torchinfo.summary(m, input_size=(1, 128, 2, 2))

In [None]:
# res.summary_list[0].output_size

In [None]:
torchinfo.summary(model, input_size=(batch_size, 28, 28), dtypes=[torch.double])

In [None]:
# opt = SGD(
#     model.parameters(),
#     lr=0.1,
# )
opt = Adam(model.parameters(), lr=3e-2, eps=1e-5)
opt

In [None]:
_iter = 0

training loop

In [None]:
n_epochs = 100

model.train()
for epoch in tqdm.tqdm(range(n_epochs), desc="Epochs", total=n_epochs):
    for i, (xb, yb) in enumerate(dataloader):
        xb = xb.to(device)
        yb = yb.to(device)
        y_pred = model(xb)

        loss = F.cross_entropy(y_pred, yb)

        opt.zero_grad()
        loss.backward()
        opt.step()

        model.loss_history_train(loss, _iter)
        model.parameter_history(_iter)

        _iter += 1

print("Done!")

plotting gradients

In [None]:
model.draw_gradient_stats(yscale="log", figsize=(12, 20))

plotting activations

In [None]:
model.draw_activation_stats(yscale="log")

plotting losses

In [None]:
model.draw_loss_history_train()

In [None]:
model.draw_loss_history_test(yscale="log")

drawing histograms of the weights and biases across training iterations

In [None]:
model.draw_parameter_stats()

In [None]:
train_features, train_labels = next(iter(dataloader))

In [None]:
model.eval();

inspecting predictions

In [None]:
train_features = train_features.to(device)
pred_probs = model(train_features)
pred_probs[:3, :]

In [None]:
y_pred = pred_probs.cpu().detach().numpy().argmax(axis=1)
y_pred

In [None]:
train_labels

In [None]:
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].cpu()  # .reshape((28,28))
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.axis("off")
plt.show()
print(f"Label: {label}, pred: {y_pred[0]}")

computing test set performance

In [None]:
ys_pred = []
ys_true = []
for test_features, test_labels in dataloader_test:
    test_features = test_features.to(device)
    pred_probs = model(test_features)

    y_pred = pred_probs.to("cpu").detach().numpy().argmax(axis=1)

    ys_true.append(test_labels.numpy())
    ys_pred.append(y_pred)


ys_true = np.concatenate(ys_true)
ys_pred = np.concatenate(ys_pred)

In [None]:
ys_true[:3], ys_pred[:3]

In [None]:
accuracy = metrics.accuracy_score(ys_true, ys_pred)
error_rate = 1 - accuracy
print(f"* Accuracy: {accuracy:.2%}")
print(f"* Error rate: {error_rate:.2%}")

* Accuracy: 100.00%
* Error rate: 0.00%

In [None]:
ax = metrics.ConfusionMatrixDisplay.from_predictions(ys_true, ys_pred)
plt.axis("off")
plt.show()

In [None]:
model.clean_hooks()

## overfitting with `Learner`

In [None]:
X0.shape, y0.shape

In [None]:
ds_train = rnnm_data.MNISTDatasetWithLabels(X0, y0, one_hot=False)
dl_train = DataLoader(
    ds_train,
    batch_size=n0,
    collate_fn=rnnm_data.collate_mnist_dataset_to_block_with_labels,
    shuffle=True,
)
next(iter(dl_train)).label

In [None]:
model = resnet.ResNet2(
    nfs=(
        8,
        16,
        32,
        64,
    ),  # 128,),#256),
)

In [None]:
n_epochs = 100
lr = 1e-1
# optimizer = optim.SGD(model.parameters(), lr=lr)
optimizer = optim.Adam(model.parameters(), lr=lr)
# scheduler = optim.lr_scheduler.OneCycleLR(
#     optimizer=optimizer,
#     max_lr=lr,
#     epochs=n_epochs,
#     steps_per_epoch=len(dl_train),
# )
loss = rnnm_losses.CrossEntropyMNIST()
save_dir = Path("./models")

loss_callback = rnnm_learner.TrainLossCallback()
activations_callback = rnnm_learner.TrainActivationsCallback(
    every_n=100, max_depth_search=4, name_patterns=(".*act.*",)
)
gradients_callback = rnnm_learner.TrainGradientsCallback(
    every_n=100,
    max_depth_search=4,
    name_patterns=(r".*conv\d", ".*lin"),
)
parameters_callback = rnnm_learner.TrainParametersCallback(
    every_n=100,
    max_depth_search=4,
    name_patterns=(
        r".*conv\d",
        ".*lin",
    ),
)

# scheduler_callback = rnnm_learner.EveryBatchSchedulerCallback(scheduler)
callbacks = [
    loss_callback,
    activations_callback,
    gradients_callback,
    parameters_callback,
    # scheduler_callback,
]

lr_find_callback = rnnm_learner.LRFinderCallback(1e-5, 10, 100)

learner = rnnm_learner.Learner(
    model,
    optimizer,
    loss,
    callbacks=callbacks,
    save_dir=save_dir,
    device=device,
)

In [None]:
learner.find_learning_rate(
    dl_train, n_epochs=200, lr_find_callback=lr_find_callback
)

In [None]:
lr_find_callback.plot()

In [None]:
scheduler = optim.lr_scheduler.OneCycleLR(
    optimizer=optimizer,
    max_lr=lr,
    epochs=n_epochs,
    steps_per_epoch=len(dl_train),
)
scheduler_callback = rnnm_learner.EveryBatchSchedulerCallback(scheduler)
learner.update_callback(scheduler_callback)

In [None]:
learner.fit(dl_train, n_epochs=n_epochs)

In [None]:
loss_callback.plot()

In [None]:
y_prob = learner.predict(dl_train)
y_prob.argmax(dim=1)

In [None]:
y0.values

In [None]:
parameters_callback.plot()

In [None]:
gradients_callback.plot()

In [None]:
activations_callback.plot()

In [None]:
if DO_OVERFITTING_ONLY:
    raise SystemExit("Skipping training beyond overfitting.")

## Reproducing 10 digits

In [None]:
X0, X2, y0, y2 = train_test_split(X, y, test_size=0.2, random_state=42)
X0, X1, y0, y1 = train_test_split(X0, y0, test_size=0.2, random_state=42)

In [None]:
ds = conv_lecun1990.DigitsDataset(X0, y0)
ds_valid = conv_lecun1990.DigitsDataset(X1, y1)
ds_test = conv_lecun1990.DigitsDataset(X2, y2)

In [None]:
batch_size = 256
dataloader = DataLoader(ds, batch_size=batch_size, shuffle=True, drop_last=True)
dataloader_valid = DataLoader(
    ds_valid, batch_size=500, shuffle=False, drop_last=True
)
dataloader_test = DataLoader(
    ds_test, batch_size=500, shuffle=False, drop_last=True
)

defining the model

In [None]:
model = resnet.ResNet(
    nfs=(8, 16),  # ,32,64,), #128,), # 256
)
model = telemetry.ModelTelemetry(
    model,
    parameters_every_n=100,
    activations_every_n=100,
    gradients_every_n=100,
    activations_name_patterns=(".*act.*",),
    gradients_name_patterns=(r".*conv\d", ".*lin"),
    parameters_name_patterns=(
        r".*conv\d",
        ".*lin",
    ),
)
model.double()
model.to(device);

In [None]:
# opt = SGD(
#     model.parameters(),
#     lr=0.1,
# )
opt = Adam(model.parameters(), lr=3e-2, eps=1e-5)
opt

In [None]:
_iter = 0

training loop

In [None]:
n_epochs = 2

model.train()
for epoch in tqdm.tqdm(range(n_epochs), desc="Epochs", total=n_epochs):
    for xb, yb in dataloader:
        xb = xb.to(device)
        yb = yb.to(device)
        y_pred = model(xb)

        loss = F.cross_entropy(y_pred, yb)

        opt.zero_grad()
        loss.backward()
        opt.step()

        model.loss_history_train(loss, _iter)
        model.parameter_history(_iter)

        _iter += 1

    # compute validation loss
    with torch.no_grad():
        model.eval()
        ys_pred, ys_true = [], []
        for xb, yb in dataloader_valid:
            xb = xb.to(device)
            yb = yb.to(device)

            y_pred = model(xb)
            ys_pred.append(y_pred)
            ys_true.append(yb)

        y_true = torch.cat(ys_true, dim=0)
        y_pred = torch.cat(ys_pred, dim=0)
        loss = F.cross_entropy(y_pred, y_true)
        model.loss_history_test(loss, _iter)

        model.train()

print("Done!")

plotting gradients

In [None]:
model.draw_gradient_stats(yscale="log", figsize=(12, 20))

plotting activations

In [None]:
model.draw_activation_stats(yscale="log")

plotting losses

In [None]:
model.draw_loss_history_train()

In [None]:
model.draw_loss_history_test(yscale="log")

drawing histograms of the weights and biases across training iterations

In [None]:
model.draw_parameter_stats()

In [None]:
test_features, test_labels = next(iter(dataloader_test))

In [None]:
test_features = test_features.to(device)
pred_probs = model(test_features)
pred_probs[:3, :]

In [None]:
y_pred = pred_probs.cpu().detach().numpy().argmax(axis=1)
y_pred[:3]

In [None]:
test_labels[:3]

In [None]:
print(f"Feature batch shape: {test_features.size()}")
print(f"Labels batch shape: {test_labels.size()}")
img = test_features[0].cpu()  # .reshape((28,28))
label = test_labels[0]
plt.imshow(img, cmap="gray")
plt.axis("off")
plt.show()
print(f"Label: {label}, pred: {y_pred[0]}")

computing test set performance

In [None]:
ys_pred = []
ys_true = []
for test_features, test_labels in dataloader_test:
    test_features = test_features.to(device)
    pred_probs = model(test_features)

    y_pred = pred_probs.to("cpu").detach().numpy().argmax(axis=1)

    ys_true.append(test_labels.numpy())
    ys_pred.append(y_pred)


ys_true = np.concatenate(ys_true)
ys_pred = np.concatenate(ys_pred)

In [None]:
ys_true[:3], ys_pred[:3]

In [None]:
accuracy = metrics.accuracy_score(ys_true, ys_pred)
error_rate = 1 - accuracy
print(f"* Accuracy: {accuracy:.2%}")
print(f"* Error rate: {error_rate:.2%}")

* Accuracy: 98.69%
* Error rate: 1.31%

In [None]:
ax = metrics.ConfusionMatrixDisplay.from_predictions(ys_true, ys_pred)
plt.axis("off")
plt.show()

In [None]:
model.clean_hooks()

## using `Learner`

In [None]:
ds_train = rnnm_data.MNISTDatasetWithLabels(X0, y0, one_hot=False)
ds_valid = rnnm_data.MNISTDatasetWithLabels(X1, y1, one_hot=False)
ds_test = rnnm_data.MNISTDatasetWithLabels(X2, y2, one_hot=False)

In [None]:
ds_train[0];

In [None]:
dl_train = DataLoader(
    ds_train,
    batch_size=256,
    collate_fn=rnnm_data.collate_mnist_dataset_to_block_with_labels,
    shuffle=True,
    drop_last=True,
)
dl_valid = DataLoader(
    ds_valid,
    batch_size=500,
    collate_fn=rnnm_data.collate_mnist_dataset_to_block_with_labels,
    shuffle=False,
    drop_last=True,
)
dl_test = DataLoader(
    ds_test,
    batch_size=500,
    collate_fn=rnnm_data.collate_mnist_dataset_to_block_with_labels,
    shuffle=False,
    drop_last=True,
)

In [None]:
next(iter(dl_train))

In [None]:
model = resnet.ResNet2(
    nfs=(
        8,
        16,
        # 32,
        # 64,
    ),  # 128,),#256),
)

In [None]:
n_epochs = 2
lr = 1e-3
# optimizer = optim.SGD(model.parameters(), lr=lr)  # , momentum=1e-3
optimizer = optim.Adam(model.parameters(), lr=lr)

loss = rnnm_losses.CrossEntropyMNIST()
save_dir = Path("./models")

loss_callback = rnnm_learner.TrainLossCallback()
activations_callback = rnnm_learner.TrainActivationsCallback(
    every_n=100, max_depth_search=4, name_patterns=(r".*act.*",)
)
gradients_callback = rnnm_learner.TrainGradientsCallback(
    every_n=100, max_depth_search=4, name_patterns=(r".*conv\d", ".*lin")
)
parameters_callback = rnnm_learner.TrainParametersCallback(
    every_n=100,
    max_depth_search=4,
    name_patterns=(
        r".*conv\d",
        ".*lin",
    ),
)
early_stopping_callback = rnnm_learner.EarlyStoppingCallback(patience=3)

callbacks = [
    loss_callback,
    activations_callback,
    gradients_callback,
    parameters_callback,
    # early_stopping_callback,
]

lr_find_callback = rnnm_learner.LRFinderCallback(1e-5, 10, 100)

learner = rnnm_learner.Learner(
    model,
    optimizer,
    loss,
    callbacks=callbacks,
    save_dir=save_dir,
    device=device,
    show_epoch_progress=True,
)

In [None]:
learner.find_learning_rate(
    dl_train, n_epochs=2, lr_find_callback=lr_find_callback
)

In [None]:
lr_find_callback.plot()

In [None]:
lr = 1e-3
scheduler = optim.lr_scheduler.OneCycleLR(
    optimizer=optimizer,
    max_lr=lr,
    epochs=n_epochs,
    steps_per_epoch=len(dl_train),
)
scheduler_callback = rnnm_learner.EveryBatchSchedulerCallback(scheduler)
learner.update_callback(scheduler_callback)

In [None]:
learner.fit(dl_train, n_epochs=n_epochs, dataloader_valid=dl_valid)

In [None]:
loss_callback.plot(window=100)

In [None]:
losses_valid = loss_callback.get_losses_valid()
losses_valid

In [None]:
parameters_callback.plot()

In [None]:
gradients_callback.plot()

In [None]:
activations_callback.plot()

computing test set performance

In [None]:
y_prob = learner.predict(dl_test)

In [None]:
ys_pred = y_prob.argmax(dim=1).numpy()
ys_pred

In [None]:
ys_true = np.array([int(v) for v in y2.values])
ys_true

In [None]:
accuracy = metrics.accuracy_score(ys_true, ys_pred)
error_rate = 1 - accuracy
print(f"* Accuracy: {accuracy:.2%}")
print(f"* Error rate: {error_rate:.2%}")

* Accuracy: 97.46%
* Error rate: 2.54%

In [None]:
ax = metrics.ConfusionMatrixDisplay.from_predictions(ys_true, ys_pred)
plt.axis("off")
plt.show()