In [None]:
import itertools
import pickle

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm.notebook import tqdm

from pt_runner.cnn import CheckpointHandler, DataHandlerPT, EarlyStopper, calc_metrices


In [None]:
# New run
NEW_RUN = True
DT_REF = None

# Resuming
# NEW_RUN = False
# DT_REF = "2025-05-28_12-35"

In [None]:
RANDOM_STATE = 0

In [None]:
with open("mnist_small.pickle", "rb") as file:
    data = pickle.load(file)

In [None]:
_X = data["_X"].astype(np.float64)
_Y = data["_Y"].astype(np.int32)
print(_X.shape)
print(_X.dtype)
print(_Y.shape)
print(_Y.dtype)

In [None]:
data_handler = DataHandlerPT(_X=_X, _Y=_Y)

In [None]:
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.max_pool = nn.MaxPool2d(2)
        self.relu = nn.ReLU()
        self.adaptive_pool = nn.AdaptiveAvgPool2d((4, 4))
        self.fc1 = nn.Linear(32 * 4 * 4, num_classes)

    def forward(self, X):
        X = self.conv1(X)
        X = self.relu(X)
        X = self.max_pool(X)
        X = self.conv2(X)
        X = self.relu(X)
        X = self.max_pool(X)
        X = self.adaptive_pool(X)
        X = X.view(X.shape[0], -1)
        X = self.fc1(X)
        return X


model = SimpleCNN(num_classes=10)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min", patience=5)
loss_fn = nn.CrossEntropyLoss()

In [None]:
from torchinfo import summary

input_size = (100, 1, 32, 32)  # (batch_size, channels, height, width)
summary(model, input_size=input_size)

`tensorboard --logdir=src/T03_cnn/runs`


In [None]:
n_epochs = 100  # number of epochs to run
batch_size = 10  # size of each batch
validation_interval = 4  # Evaluate every 10 epochs
log_name = "C1"

# Save/load
cph = CheckpointHandler()
cph.make_dir("./checkpoints")
if NEW_RUN:
    dt = cph.get_dt()
    log_dir = f"runs/{dt}"
    save_path = f"./checkpoints/{dt}.pth"
    epoch_start = 0
else:
    log_dir = f"runs/{DT_REF}"
    load_path = f"./checkpoints/{DT_REF}.pth"
    save_path = load_path
    model, optimizer, epoch, val_loss = cph.load(
        load_path=load_path, model=model, optimizer=optimizer
    )
    epoch_start = epoch
    print(f"Resuming from epoch: {epoch_start}")

epoch_end = epoch_start + n_epochs

# Initialize Components
early_stopper = EarlyStopper(patience=5)
writer = SummaryWriter(log_dir=log_dir, purge_step=epoch_start)

# Data
data_handler.split_and_scale(test_size=0.2, val_size=0.1, random_state=RANDOM_STATE)
ds_train = data_handler.get_train()
ds_val = data_handler.get_val()
loader_train = DataLoader(ds_train, batch_size=batch_size, shuffle=True)
loader_val = DataLoader(ds_val, batch_size=batch_size, shuffle=False)

# Main loop
for epoch in tqdm(
    range(epoch_start, epoch_end), initial=epoch_start, desc="Epoch", total=n_epochs
):
    # Training Phase
    model.train()
    epoch_train_loss = 0.0
    epoch_train_f1 = 0.0
    logit_arr = []
    label_arr = []

    for X_batch, Y_batch in loader_train:
        optimizer.zero_grad()
        Y_pred = model(X_batch)
        loss = loss_fn(Y_pred, Y_batch.view(-1))
        # Backward pass
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)  # Gradient clipping
        # Update weights
        optimizer.step()
        # Multiplies the average loss per sample by the number of
        # samples in the batch to get the total loss for this batch.
        epoch_train_loss += loss.item() * X_batch.size(0)
        logit_arr.append(Y_pred)
        label_arr.append(Y_batch)

    avg_train_loss = epoch_train_loss / len(loader_train.dataset)

    logits = torch.concat(logit_arr, dim=0)
    labels = torch.concat(label_arr, dim=0)
    metrices, _, _ = calc_metrices(logits=logits, labels=labels.view(-1))
    avg_train_f1 = metrices["weighted avg"]["f1-score"]

    # Validation Phase
    if epoch % validation_interval == 0 or epoch == epoch_start:
        model.eval()
        val_loss = 0.0
        logit_arr = []
        label_arr = []
        with torch.no_grad():
            for X_val, Y_val in loader_val:
                Y_pred = model(X_val)
                val_loss += loss_fn(Y_pred, Y_val.view(-1)).item() * X_val.size(0)
                logit_arr.append(Y_pred)
                label_arr.append(Y_val)

        avg_val_loss = val_loss / len(loader_val.dataset)

        logits = torch.concat(logit_arr, dim=0)
        labels = torch.concat(label_arr, dim=0)
        metrices, _, _ = calc_metrices(logits=logits, labels=labels.view(-1))
        avg_val_f1 = metrices["weighted avg"]["f1-score"]

        scheduler.step(avg_val_loss)

        # Early Stopping and Checkpoint
        es = early_stopper(avg_val_loss)
        if es["best_loss"]:
            cph.save(
                save_path=save_path,
                model=model,
                optimizer=optimizer,
                val_loss=avg_val_loss,
                epoch=epoch,
            )
            print("Save model @ epoch:", epoch)
        if es["early_stop"]:
            print("Stopped at epoch:", epoch)
            break

    writer.add_scalars(
        log_name, {"loss/train": avg_train_loss, "loss/val": avg_val_loss}, epoch
    )
    writer.add_scalars(
        log_name, {"f1/train": avg_train_f1, "f1/val": avg_val_f1}, epoch
    )


In [None]:
model.eval()
with torch.no_grad():
    X_val, Y_val = ds_val[:]
    test_pred = model(X_val)
    final_loss = loss_fn(test_pred, Y_val.view(-1))
    print(f"Val loss: {final_loss:.4f}")

In [None]:
metrices, Y_pred_labels, Y_true_labels = calc_metrices(
    logits=test_pred, labels=Y_val.view(-1), isPrint=True
)

In [None]:
# Find misclassification
loc = Y_pred_labels != Y_true_labels
print(f"Missclassification: {loc.sum()} out of {loc.shape[0]}")

In [None]:
fig, axes2D = plt.subplots(3, 5, figsize=(12, 8))
axes = list(itertools.chain.from_iterable(axes2D))  # Flatten 2D list
X_val_miss = X_val[loc]
Y_val_miss = Y_true_labels[loc]
Y_pred_miss = Y_pred_labels[loc]
for idx, ax in enumerate(axes):
    if idx < loc.sum():
        ax.imshow(X_val_miss[idx].view(28, 28), cmap="gray")
        ax.set_title(f"True={Y_val_miss[idx]}, Pred={Y_pred_miss[idx]}")
    else:
        ax.axis("off")  # Hide unused axes
    ax.set_xticks([])
    ax.set_yticks([])

fig.tight_layout()
plt.show()