In [1]:
import torch

from torchmetrics.classification import MulticlassAccuracy

In [2]:
from utils_v2 import Flowers102Classifier, plot_training_runs, get_train_val_test_loader, FineTuneType, TrainingRun



TypeError: __init__() got an unexpected keyword argument 'scale'

In [3]:
train_loader, train_val_loader, validation_loader, test_loader = get_train_val_test_loader(mixup = True)

TRAINING SIZE: 816
VALIDATION SIZE: 204
TRAINING SIZE: 6149


In [None]:
def transfer_learning_on_backbone(backbones, feature_extract_epochs, fine_tune_epochs):
    """Run transfer learning on multiple backbones for this classification task.
    The choice of the backbone (pre-trained model) is a hyper-parameter.

    We perform transfer-learning in 2 steps:
    1. Feature extraction, which is run for feature_extract_epochs, and
    2. Fine-tuning, which is run for fine_tune_epochs.

    We save the model with the best validation accuracy after every epoch.
    """
    device = "cuda" if torch.cuda.is_available() else "cpu"
    # Let's train the last classification later of the pre-trained model with the
    # specified backbone on the Flowers 102 dataset.

    training_runs = {}
    for backbone in backbones:
        best_cp_path = f'{backbone}_Flowers102_best.pt'
        print(f"Running feature extraction on a {backbone} backbone for {feature_extract_epochs} epochs.\n")
        fc = Flowers102Classifier(backbone=backbone)
        fc.to(device)

        optimizer = torch.optim.Adam(fc.parameters())
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=6, gamma=0.3)
        accuracy = MulticlassAccuracy(num_classes=102, average='micro').to(device)

        # First freeze all the weights except for the newly added Linear layer.
        fc.fine_tune(FineTuneType.NEW_LAYERS)

        best_test_accuracy = 0.0
        training_run = TrainingRun()
        training_runs[backbone] = training_run

        fc.train_multiple_epochs_and_save_best_checkpoint(
            train_loader,
            train_val_loader,
            validation_loader,
            accuracy,
            optimizer,
            scheduler,
            feature_extract_epochs,
            best_cp_path,
            training_run,
        )

        print(f"Done with feature extraction for {backbone}-based model. Ran for {feature_extract_epochs} epochs.")

        best_val_accuracy = fc.get_metrics("val")['accuracy']
        print(f"[{backbone}] Best val accuracy after feature extraction is {best_val_accuracy}\n")
        print(f"Running fine-tuning for {fine_tune_epochs} epochs")

        # Set all the parameters to be trainable.
        fc.fine_tune(FineTuneType.ALL)

        optimizer = torch.optim.Adam(fc.get_optimizer_params(), lr=1e-8)
        # Every 2 steps reduce the LR to 70% of the previous value.
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.7)

        fc.train_multiple_epochs_and_save_best_checkpoint(
            train_loader,
            train_val_loader,
            validation_loader,
            accuracy,
            optimizer,
            scheduler,
            fine_tune_epochs,
            best_cp_path,
            training_run,
        )
        print("-------------------------------------------------------------------------")
        print(f"Accuracy of {backbone}-based pre-trained model with last layer replaced.")
        fc.eval()
        fc.evaluate(test_loader, accuracy, 0, "Val")
        print("-------------------------------------------------------------------------")


    # end for (backbone)

    # Now plot the training runs.
    plot_training_runs(training_runs)


# end def
backbones = ["resnet18"]
transfer_learning_on_backbone(backbones, feature_extract_epochs=16, fine_tune_epochs=8)

Running feature extraction on a resnet18 backbone for 16 epochs.

[1] Train Loss: 4.72572
[1] Train Loss: 3.76083, Accuracy: 0.23197
[1] Val Loss: 4.23376, Accuracy: 0.08185
Current valdation accuracy 8.18 is better than previous best of 0.00. Saving checkpoint.
[2] Train Loss: 3.49054
[2] Train Loss: 2.72579, Accuracy: 0.67668
[2] Val Loss: 3.38485, Accuracy: 0.42857
Current valdation accuracy 42.86 is better than previous best of 8.18. Saving checkpoint.
[3] Train Loss: 2.57728
[3] Train Loss: 1.93658, Accuracy: 0.87139
[3] Val Loss: 2.77192, Accuracy: 0.54762
Current valdation accuracy 54.76 is better than previous best of 42.86. Saving checkpoint.
[4] Train Loss: 1.93020
[4] Train Loss: 1.38395, Accuracy: 0.94111
[4] Val Loss: 2.32633, Accuracy: 0.63839
Current valdation accuracy 63.84 is better than previous best of 54.76. Saving checkpoint.
[5] Train Loss: 1.42808
[5] Train Loss: 0.98940, Accuracy: 0.96635
[5] Val Loss: 1.98740, Accuracy: 0.69792
Current valdation accuracy 69.79 