In [None]:
import torch
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split

import torchvision
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms

import copy
import os

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
train_transform = transforms.Compose([
    
    # Random color jittering
    transforms.ColorJitter(0.1, 0.1, 0.1, 0.1),

    # Random rotation
    transforms.RandomRotation(20),

    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], 
                        [0.229, 0.224, 0.225])
])

test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], 
                        [0.229, 0.224, 0.225])
])

# Change the path to your prefered dataset
train_data_dir = "dataset/orange blocked unblocked/train"
test_data_dir = "dataset/orange blocked unblocked/test"
path = "models/full finetuned models"

In [None]:
os.makedirs(path, exist_ok=True)

In [None]:
train_dataset = datasets.ImageFolder(train_data_dir, transform=train_transform)
test_dataset = datasets.ImageFolder(test_data_dir, transform=test_transform)

train_labels = [label for _, label in train_dataset.samples]
test_labels = [label for _, label in test_dataset.samples]

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=True, num_workers=0)

In [None]:
train_loader.dataset.class_to_idx

In [None]:
def fine_tune(model, train_dataloaders, test_dataloader, criterion, optimizer, num_epochs=20, save_epochs=[]):
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    best_test_acc = 0.0

    for epoch in range(num_epochs):
        print('-' * 20)
        print('Epoch {}/{}'.format(epoch+1, num_epochs))

        model.train()  # Set model to training mode
        running_loss = 0.0
        running_corrects = 0

        # Iterate over data.
        for inputs, labels in train_dataloaders:
            inputs = inputs.to(device)
            labels = labels.to(device)

            # Zero the parameter gradients
            optimizer.zero_grad()

            with torch.set_grad_enabled(True):
                outputs = model(inputs)
                loss = criterion(outputs, labels)

                _, preds = torch.max(outputs, 1)

                loss.backward()
                optimizer.step()
            
            # Statistics
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)

        epoch_loss = running_loss / len(train_dataloaders.dataset)
        epoch_acc = running_corrects.double() / len(train_dataloaders.dataset)

        # Evaluate on test set
        model.eval()
        test_running_loss = 0.0
        test_running_corrects = 0

        for inputs, labels in test_dataloader:
            inputs = inputs.to(device)
            labels = labels.to(device)

            with torch.no_grad():
                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                loss = criterion(outputs, labels)

            test_running_loss += loss.item() * inputs.size(0)
            test_running_corrects += torch.sum(preds == labels.data)
        
        test_epoch_loss = test_running_loss / len(test_dataloader.dataset)
        test_epoch_acc = test_running_corrects.double() / len(test_dataloader.dataset)

        if epoch_acc > best_acc:
            best_acc = epoch_acc

        # Deep copy the model
        if test_epoch_acc > best_test_acc:
            best_test_acc = test_epoch_acc
            best_model_wts = copy.deepcopy(model.state_dict())

        print('Train Loss: {:.4f} Acc: {:.4f}'.format(epoch_loss, epoch_acc * 100))
        print('Test Loss: {:.4f} Acc: {:.4f}'.format(test_epoch_loss, test_epoch_acc * 100))

        if epoch+1 in save_epochs:
            print(f'Saving model at epoch {epoch+1}')
            best_weighted_model = model.eval()
            best_weighted_model.load_state_dict(best_model_wts)
            
            torch.save(best_weighted_model.state_dict(), f'{path}/alexnet_finetuned_{epoch+1}_epochs.pth')
            model.train()
        

    print('Best training Acc: {:4f}'.format(best_acc * 100))
    print('Best test Acc: {:4f}'.format(best_test_acc * 100))

    # Load best model weights
    model.load_state_dict(best_model_wts)
    return model

In [None]:
def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False

In [None]:
model = models.alexnet(weights="IMAGENET1K_V1")
# set_parameter_requires_grad(model, True)  # For last-layer fine-tuning
model.classifier[6] = torch.nn.Linear(model.classifier[6].in_features, 2)

In [None]:
def count_parameters(model):
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f'Total parameters: {total_params}')
    print(f'Trainable parameters: {trainable_params}')

count_parameters(model)

In [None]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

model_ft = fine_tune(
                model.to(device), 
                train_loader, 
                test_loader, 
                criterion, 
                optimizer, 
                num_epochs=100, 
                save_epochs=[2, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60, 65, 70, 75, 80, 85, 90, 95, 100]
            )