In [3]:
import torch

# Parameters

In [None]:
BATCH_SIZE = 5
PATH_TO_DATA = "../sample_data/"
N_EPOCHS = 10
LEARNING_RATE = 1e-3 # this equals to 0.001
WEIGHT_DECAY = 1e-3 # this equals to 0.001


# 1. Prepare Data

In [None]:
from torch.utils.data import DataLoader
from classification.dataset import PokemonDataset, img_channel_padding_collate

In [None]:
tr_dataset = PokemonDataset(PATH_TO_DATA)
tr_dataloader = DataLoader(
    tr_dataset,
    batch_size=BATCH_SIZE,
    collate_fn=img_channel_padding_collate,
)

Next, we should define a validation dataset:

In [None]:
val_dataset = PokemonDataset(PATH_TO_DATA)
val_dataloader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    collate_fn=img_channel_padding_collate,
)

# 2. Model Design

In [None]:
from torch.nn import Module, Sequential, Linear, ReLU
from torch import Tensor # Only for type-annotations

Layers: BatchNorm, ReLU, Linear layers

In [None]:
class ImageClassifier(Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(self, x: Tensor) -> Tensor:
        return x

# 3. Training: optimizers and backpropagation

First, lets pick an optimizer and loss function:

In [None]:
# Define a model:
model = ImageClassifier()

# Define loss:
loss_fn = torch.nn.CrossEntropyLoss()

# Define the optimizer, you need to attach your model(s) parameters
# to an optimizer object that will be responsible for updating your model.
optimizer = torch.optim.AdamW(
    model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY
)

Note that PyTorch has 2 different Adam optimizers, Adam is ... and AdamW is ...

Now we can run the training loop:

In [None]:
for epoch in range(N_EPOCHS):
    tr_losses = []
    for input_data, target_data in tr_dataloader:
        pred_data = model(input_data)
        tr_loss = loss_fn(pred_data, target_data)

        # First step is to clear the calculated gradients from the previous iteration.
        optimizer.zero_grad()

        # Loss should be 1-value tensor and is a leaf in the gradient tree,
        # Based on the gradient functions attached to the tensors (grad_fn), automatic
        # backpropagation will calculate all gradients in the background.
        tr_loss.backward()

        # Optimizer object is responsible for updating model weights and biases,
        # after the gradients are calculated.
        optimizer.step()

        # tensor.detach() function breaks the gradient tree and returns the tensor only including the data !
        # Thats why you are only allowed to do that after you finish the gradient calculations.
        tr_losses.append(tr_loss.detach())
    
    # tensor.item() returns the one and only attached value here: e.g., tensor((0.00135)) -> 0.00135
    avg_tr_loss = torch.stack(tr_losses).mean().item()

    # After a successful training iteration, we can run a validation loop to monitor model's performance:
    # Since there will be no backpropagation on the validation step, it is unnecessary to store gradients.
    # Therefore we use torch.no_grad to disable all gradient functions during validation for memory efficiency.
    with torch.no_grad():
        val_losses = []
        for input_data, target_data in val_dataloader:
            pred_data = model(input_data)
            val_loss = loss_fn(pred_data, target_data)
            val_losses.append(val_loss.detach())
        avg_val_loss = torch.stack(val_losses).mean().item()

    print(f"Epoch: {epoch}/{N_EPOCHS}, Avg.Tr.Loss: {avg_tr_loss}, Avg.Val.Loss: {avg_val_loss}")

# 4. Testing and evaluation

It is crucial that your testing set is completely disjoint from your training and validation sets, plus it should have the correct representation of the use cases of your model in real life.

In [None]:
test_dataset = PokemonDataset(PATH_TO_DATA)
# Collect test results one by one (no-batch), because they will be used to calculate metrics.
test_dataloader = DataLoader(
    test_dataset,
    batch_size=1,
    collate_fn=img_channel_padding_collate,
)

Testing loop:

In [None]:
with torch.no_grad():
    test_preds = []
    test_logits = []
    test_labels = []
    for input_data, target_data in test_dataloader:
        pred_data = model(input_data)
        test_logits.append(test_logits)
        pred_class = torch.argmax(test_logits)
        test_preds.append(pred_class)
        test_labels.append(target_data)


Here lets take a look at the basic classification metrics: accuracy and f1-score:

We need to visualize our results for better interpretation:

In [24]:
import matplotlib.pyplot as plt

Then we can plot a ROC-AUC curve to investigate the model confidence on its predictions.

Lastly, the multi-class confusion matrix is crucial to see the model behaviour.

# 5. Model inference

Now use your model in production ! Dont forget to apply the preprocessing steps if there is any.