In [2]:
# ! pip install -r requirements.txt
# ! pre-commit install
# ! pre-commit run --all-files

In [3]:
# !jupyter nbconvert --to script .\VGG_transfer_learning.ipynb

In [None]:
import io
import logging
import sys
from datetime import datetime
from typing import Literal


def configure_logging(mode: Literal["a", "w"]) -> logging.Logger:
    if hasattr(__builtins__, "__IPYTHON__"):  # execution in Jupyter
        sys.stdout = io.TextIOWrapper(sys.stdout, encoding="utf-8")
    else:  # execution in CPython
        sys.stdout.reconfigure(encoding="utf-8")  # type: ignore

    logging.basicConfig(
        level=logging.DEBUG,
        format="%(asctime)s [%(levelname)s] %(threadName)s %(processName)s %(message)s",
        handlers=[
            logging.FileHandler(
                filename=f"veggie-net.log", mode=mode, encoding="utf-8"
            ),
            logging.StreamHandler(),
        ],
    )
    return logging.getLogger(__name__)

In [5]:
def disablePILDecompressionBombError() -> None:
    from PIL import Image, ImageFile

    Image.MAX_IMAGE_PIXELS = None  # Prevents PIL DecompressionBombError
    ImageFile.LOAD_TRUNCATED_IMAGES = True

In [6]:
# Needs to be here at top as on Windows the function is otherwise non pickleable
def worker_init_fn(worker_id: int) -> None:
    disablePILDecompressionBombError()
    logger = configure_logging("a")
    logger.info(f"Initialized worker {worker_id}")

In [7]:
# Setup logging before any torch module is imported
logger = configure_logging("w")

In [8]:
import os
import random
from types import SimpleNamespace

import numpy as np
import torch
import torch_directml
from ignite.metrics import Accuracy, Fbeta, Precision, Recall
from torch import nn, optim
from torch.utils.data import DataLoader, RandomSampler, random_split
from torch.utils.tensorboard import SummaryWriter
from torchinfo import summary
from torchvision.datasets import ImageFolder
from torchvision.models import VGG, VGG19_Weights, vgg19

  from torch.distributed.optim import ZeroRedundancyOptimizer


In [9]:
BATCH_SIZE = 8
EPOCHS = 12
TARGET_SET = "Original dataset"  # target folder inside dataset (archive.zip)

In [10]:
custom_seed = np.random.randint(2141403747)
random.seed(custom_seed)  # apparently seed must be set at several places
torch.manual_seed(custom_seed)

<torch._C.Generator at 0x264bdc58af0>

In [11]:
def configure_device() -> torch.device:
    return (
        torch.device("cuda")
        if torch.cuda.is_available()
        else torch_directml.device(torch_directml.default_device())
        if torch_directml.is_available()
        else torch.device("cpu")
    )

In [12]:
device = configure_device()
device

device(type='privateuseone', index=0)

In [13]:
dataset_full = ImageFolder(
    root=os.path.join(".", "archive", TARGET_SET),
    transform=VGG19_Weights.IMAGENET1K_V1.transforms(),  # Img inference transformation piepeline: https://pytorch.org/vision/main/models/generated/torchvision.models.vgg19.html
)

data_sets = SimpleNamespace(
    **dict(
        zip(
            ["train", "validation", "test"],
            random_split(dataset=dataset_full, lengths=[0.6, 0.3, 0.1]),
        )
    )
)

data_loader_kwargs = {
    "batch_size": BATCH_SIZE,
    "persistent_workers": True,
    # "num_workers": max(
    #     2, (os.cpu_count() or 0) // 2
    # ),  # too many workers may cause python to crash because of RAM usage
    "num_workers": 8,
    "shuffle": True,
    # "in_order": False, # Only introduced in PyTorch v2.6
    "drop_last": True,
    "worker_init_fn": worker_init_fn,
}

if "cuda" in device.type.lower():
    data_loader_kwargs | {"pin_memory": True, "pin_memory_device": device.type}


data_loaders = SimpleNamespace(
    train=DataLoader(data_sets.train, **data_loader_kwargs),
    validation=DataLoader(data_sets.validation, **data_loader_kwargs),
    test=DataLoader(data_sets.test, **data_loader_kwargs),
)

In [14]:
def configure_model(device) -> VGG:
    model = vgg19(
        weights=VGG19_Weights.DEFAULT
    )  # model initialization with pretrained default weights

    logger.info("MODEL INIT INFO:")
    logger.info(summary(model))  # use torchinfo for parameter info

    model.features.requires_grad_(False)  # freeze all the blocks of model
    model.avgpool.requires_grad_(False)
    model.classifier.requires_grad_(False)

    model.classifier[-1] = nn.Linear(
        in_features=4096, out_features=6, bias=True
    )  # replace the last classification layer
    model.classifier[-1].requires_grad_(True)  # unfreeze new classification layer

    logger.info("MODEL MODIFIED INFO:")
    logger.info(summary(model))  # most params should be non trainable

    model = model.to(device)  # send model with forzen layers (hopefully) to GPU

    return model

In [15]:
model = configure_model(device=device)

2025-02-17 00:06:17,169 [INFO] MainThread MainProcess MODEL INIT INFO:
Layer (type:depth-idx)                   Param #
VGG                                      --
├─Sequential: 1-1                        --
│    └─Conv2d: 2-1                       1,792
│    └─ReLU: 2-2                         --
│    └─Conv2d: 2-3                       36,928
│    └─ReLU: 2-4                         --
│    └─MaxPool2d: 2-5                    --
│    └─Conv2d: 2-6                       73,856
│    └─ReLU: 2-7                         --
│    └─Conv2d: 2-8                       147,584
│    └─ReLU: 2-9                         --
│    └─MaxPool2d: 2-10                   --
│    └─Conv2d: 2-11                      295,168
│    └─ReLU: 2-12                        --
│    └─Conv2d: 2-13                      590,080
│    └─ReLU: 2-14                        --
│    └─Conv2d: 2-15                      590,080
│    └─ReLU: 2-16                        --
│    └─Conv2d: 2-17                      590,080
│    └─R

In [16]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(
    filter(
        lambda model_parameter: model_parameter.requires_grad, model.parameters()
    ),  # optimize just the non frozen parameters
    lr=0.001,  # start learning rate
    momentum=0.9,
)
scheduler = optim.lr_scheduler.StepLR(
    optimizer, step_size=10, gamma=0.1
)  # adjust learning rate per epoch

In [None]:
def train_one_epoch(epoch: int, summary_writer: SummaryWriter) -> float:
    logger.info(f"Starting training epoch: {epoch}")
    running_loss_training = 0.0
    last_loss = 0.0

    for i, data in enumerate(data_loaders.train):
        logger.debug(f"iteration: {i}\tepoch-test: {epoch}")
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()  # zero the parameter gradients
        outputs = model(inputs)  # Make predictions for this batch

        loss = criterion(outputs, labels)  # Compute the loss and its gradients
        loss.backward()  # backpropagation

        optimizer.step()  # Adjust learning weights

        running_loss_training += loss.item()

        if i % BATCH_SIZE == 0:
            last_loss = running_loss_training / (i + 1)  # loss per mini-batch
            tb_x = epoch * len(data_loaders.train) + i + 1

            summary_writer.add_scalar("Loss/train", last_loss, tb_x)

    scheduler.step()  # Adjust learning rate
    summary_writer.flush()

    return last_loss

In [18]:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
summary_writer = SummaryWriter(log_dir=f"runs/veggie_trainer_{timestamp}")
model_path_best_epoch = ""

precision = Precision(device=device)
recall = Recall(device=device)
accuracy = Accuracy(device=device)
f1 = Fbeta(beta=1.0, precision=precision, recall=recall, device=device)


def train():
    best_loss_validation = float("inf")

    for epoch in range(EPOCHS + 1):
        model.train(True)
        avg_loss_train = train_one_epoch(epoch, summary_writer)
        running_loss_test = 0.0
        avg_loss_validation = 0.0

        model.eval()  # Set the model to evaluation mode, disabling dropout and using population statistics for batch normalization

        logger.info(f"Starting validation epoch: {epoch}")
        with (
            torch.no_grad()
        ):  # Disable gradient computation and reduce memory consumption.
            for i, data in enumerate(data_loaders.validation):
                logger.debug(f"iteration: {i}\tepoch-validation: {epoch}")
                inputs, labels = data
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                running_loss_test += loss.item()

                # Log the running loss averaged per batch
                avg_loss_validation = running_loss_test / (i + 1)

                precision.update((outputs, labels))
                recall.update((outputs, labels))
                accuracy.update((outputs, labels))
                f1.update((outputs, labels))

        logger.info(f"LOSS train {avg_loss_train} vs. validation {avg_loss_validation}")

        summary_writer.add_scalar("Loss/validation", avg_loss_validation, epoch + 1)
        summary_writer.add_scalar(
            "Precision/validation", precision.compute().mean().item(), epoch + 1
        )
        summary_writer.add_scalar(
            "Recall/validation", recall.compute().mean().item(), epoch + 1
        )
        summary_writer.add_scalar("Accuracy/validation", accuracy.compute(), epoch + 1)
        summary_writer.add_scalar("F1/validation", f1.compute(), epoch + 1)

        precision.reset()
        recall.reset()
        accuracy.reset()
        f1.reset()
        summary_writer.flush()

        # Track best performance, and save the model's state
        if avg_loss_validation < best_loss_validation:
            best_loss_validation = avg_loss_validation

            global model_path_best_epoch
            model_path_best_epoch = os.path.join(
                ".", f"veggie-net-{timestamp}-{epoch}.pth"
            )
            torch.save(model.state_dict(), model_path_best_epoch)

            model_path_previous_epoch = os.path.join(
                ".", f"veggie-net-{timestamp}-{epoch - 1}.pth"
            )
            if os.path.exists(model_path_previous_epoch):
                os.remove(model_path_previous_epoch)

    logger.info("Finished Training")

In [19]:
def test():
    logger.info("Running test")
    running_loss_test = 0.0
    avg_loss_test = 0.0

    with torch.no_grad():  # Disable gradient computation and reduce memory consumption.
        for i, data in enumerate(data_loaders.test):
            logger.debug(rf"iteration: {i}\of final test")
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            running_loss_test += loss.item()

            avg_loss_test = running_loss_test / (i + 1)

            precision.update((outputs, labels))
            recall.update((outputs, labels))
            accuracy.update((outputs, labels))
            f1.update((outputs, labels))

        summary_writer.add_scalar("Loss/test", avg_loss_test, EPOCHS)
        summary_writer.add_scalar(
            "Precision/test", precision.compute().mean().item(), EPOCHS
        )
        summary_writer.add_scalar("Recall/test", recall.compute().mean().item(), EPOCHS)
        summary_writer.add_scalar("Accuracy/test", accuracy.compute(), EPOCHS)
        summary_writer.add_scalar("F1/test", f1.compute(), EPOCHS)

        precision.reset()
        recall.reset()
        accuracy.reset()
        f1.reset()
        summary_writer.flush()

    logger.info(f"LOSS test {avg_loss_test}")
    logger.info("Test complete")

In [20]:
# # if __name__ == "__main__":
# train()
# model.load_state_dict(torch.load(model_path_best_epoch, map_location=device)) # load best performing model (might be from a previous epoch)
# test()

In [27]:
# # Load a previously trained model manually
device = configure_device()
model = configure_model(device=device)
model.load_state_dict(
    torch.load(
        os.path.join(".", "veggie-net-20250216_235626-1.pth"), weights_only=False
    )
)
model = model.to(device=device)
model.eval()

2025-02-17 00:34:42,571 [INFO] MainThread MainProcess MODEL INIT INFO:
Layer (type:depth-idx)                   Param #
VGG                                      --
├─Sequential: 1-1                        --
│    └─Conv2d: 2-1                       1,792
│    └─ReLU: 2-2                         --
│    └─Conv2d: 2-3                       36,928
│    └─ReLU: 2-4                         --
│    └─MaxPool2d: 2-5                    --
│    └─Conv2d: 2-6                       73,856
│    └─ReLU: 2-7                         --
│    └─Conv2d: 2-8                       147,584
│    └─ReLU: 2-9                         --
│    └─MaxPool2d: 2-10                   --
│    └─Conv2d: 2-11                      295,168
│    └─ReLU: 2-12                        --
│    └─Conv2d: 2-13                      590,080
│    └─ReLU: 2-14                        --
│    └─Conv2d: 2-15                      590,080
│    └─ReLU: 2-16                        --
│    └─Conv2d: 2-17                      590,080
│    └─R

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padd

In [None]:
# manual (visual) validation
# from torchvision.transforms import ToPILImage
# disablePILDecompressionBombError()

# idx_to_class = {v: k for k, v in dataset_full.class_to_idx.items()}

# sample, label = data_sets.validation[random.randint(0, len(data_sets.validation) - 1)]
# sample_batch = sample.unsqueeze(0).to(device)


# ToPILImage()(sample).show()
# f"Prediction: {idx_to_class[model(sample_batch).argmax().item()]} | Actual Label: {idx_to_class[label]}"

'Prediction: Eggplant | Actual Label: Eggplant'

In [None]:
# usage: https://pytorch.org/tutorials/recipes/recipes/tensorboard_with_pytorch.html
# ! tensorboard --logdir=./runs/
# go to http://localhost:6006