<a href="https://colab.research.google.com/github/rodrigorochag/Transfer-Learning-CNN/blob/main/transfer_learning_pytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import os, time
from tempfile import TemporaryDirectory

import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt

import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader, random_split

In [2]:
# Install the Kaggle library
!pip install -q kaggle

# Set Kaggle credentials
import json
import os

# Replace the following with your Kaggle username and API key

kaggle_info = {
    "username": "user_kaggle",
    "key": "api_key"
}
# Save Kaggle credentials to a JSON file
os.makedirs("/root/.kaggle", exist_ok=True)
with open("/root/.kaggle/kaggle.json", "w") as file:
    json.dump(kaggle_info, file)

# Change the permissions of the file
!chmod 600 /root/.kaggle/kaggle.json

# Download the dataset
!kaggle datasets download -d misrakahmed/vegetable-image-dataset
# Unzip the dataset
!unzip -q /content/vegetable-image-dataset.zip


Dataset URL: https://www.kaggle.com/datasets/misrakahmed/vegetable-image-dataset
License(s): CC-BY-SA-4.0
Downloading vegetable-image-dataset.zip to /content
100% 533M/534M [00:30<00:00, 20.1MB/s]
100% 534M/534M [00:30<00:00, 18.3MB/s]


In [4]:
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize((224,224)),
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize((224,224)),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

In [5]:
dataset = datasets.ImageFolder('/content/Vegetable Images/train')

train_size = int(0.8 * len(dataset))  # 80% for training
val_size = len(dataset) - train_size  # 20% for validation

train_dataset, val_dataset = random_split(dataset, [train_size, val_size])


train_dataset.dataset.transform = data_transforms['train']
val_dataset.dataset.transform = data_transforms['val']

In [6]:
dataloaders = {
    name: torch.utils.data.DataLoader(
        dataset,
        batch_size=4,
        shuffle=True,
        num_workers=4
    ) for name, dataset in zip(["train", "val"], [train_dataset, val_dataset])
}

class_names = train_dataset.dataset.classes
dataset_sizes = { name: len(dataset) for name, dataset in zip(["train", "val"], [train_dataset, val_dataset]) }
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

class_names, dataset_sizes, device



(['Bean',
  'Bitter_Gourd',
  'Bottle_Gourd',
  'Brinjal',
  'Broccoli',
  'Cabbage',
  'Capsicum',
  'Carrot',
  'Cauliflower',
  'Cucumber',
  'Papaya',
  'Potato',
  'Pumpkin',
  'Radish',
  'Tomato'],
 {'train': 12000, 'val': 3000},
 device(type='cuda', index=0))

In [7]:
def train_epoch(model, dataloader, criterion, optimizer, device):
    model.train()  # Set model to training mode
    running_loss = 0.0
    running_corrects = 0

    # Iterate over data.
    for inputs, labels in dataloader:
        inputs = inputs.to(device)
        labels = labels.to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward
        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        loss = criterion(outputs, labels)

        # backward + optimize
        loss.backward()
        optimizer.step()

        # statistics
        running_loss += loss.item() * inputs.size(0)
        running_corrects += torch.sum(preds == labels.data)

    return running_loss, running_corrects


def validate_epoch(model, dataloader, criterion, device):
    model.eval()   # Set model to evaluate mode
    running_loss = 0.0
    running_corrects = 0

    # Iterate over data.
    for inputs, labels in dataloader:
        inputs = inputs.to(device)
        labels = labels.to(device)

        # forward
        with torch.no_grad():
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            loss = criterion(outputs, labels)

        # statistics
        running_loss += loss.item() * inputs.size(0)
        running_corrects += torch.sum(preds == labels.data)

    return running_loss, running_corrects


def train_model(model, criterion, optimizer, scheduler, num_epochs=3):
    since = time.time()

    # Create a temporary directory to save training checkpoints
    with TemporaryDirectory() as tempdir:
        best_model_params_path = os.path.join(tempdir, 'best_model_params.pt')
        torch.save(model.state_dict(), best_model_params_path)
        best_acc = 0.0
        history = []

        for epoch in range(num_epochs):
            # Each epoch has a training and validation phase
            train_loss, train_corrects = train_epoch(model, dataloaders['train'], criterion, optimizer, device)
            scheduler.step()
            val_loss, val_corrects = validate_epoch(model, dataloaders['val'], criterion, device)

            train_loss /= dataset_sizes['train']
            train_acc = train_corrects.double() / dataset_sizes['train']
            val_loss /= dataset_sizes['val']
            val_acc = val_corrects.double() / dataset_sizes['val']

            history.append([train_acc, val_acc, train_loss, val_loss])
            print(f'Epoch {epoch}/{num_epochs - 1}: '
                  f'Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f}, '
                  f'Val Loss: {val_loss:.4f}, Acc: {val_acc:.4f}')

            # deep copy the model
            if val_acc > best_acc:
                best_acc = val_acc
                torch.save(model.state_dict(), best_model_params_path)

        time_elapsed = time.time() - since
        print(f'\nTraining complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
        print(f'Best val Acc: {best_acc:.4f}')

        # load best model weights
        model.load_state_dict(torch.load(best_model_params_path))

    return model, history

In [8]:
model_orig = torchvision.models.convnext_base(weights='IMAGENET1K_V1')
model_orig.classifier[2] = torch.nn.Linear(
    in_features=model_orig.classifier[2].in_features,
    out_features=len(class_names)
)
model_orig = model_orig.to(device)

Downloading: "https://download.pytorch.org/models/convnext_base-6075fbad.pth" to /root/.cache/torch/hub/checkpoints/convnext_base-6075fbad.pth
100%|██████████| 338M/338M [00:02<00:00, 129MB/s]


In [9]:
model_ft = models.convnext_base(weights='IMAGENET1K_V1')
model_ft.classifier[2] = torch.nn.Linear(
    in_features=model_ft.classifier[2].in_features,
    out_features=len(class_names)
)
model_ft = model_ft.to(device)

criterion = nn.CrossEntropyLoss()
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.0005)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=6, gamma=0.1)

In [None]:
model_ft, model_ft_history = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, num_epochs=13)

In [None]:
model_conv = torchvision.models.convnext_base(weights='IMAGENET1K_V1')
for param in model_conv.parameters(): param.requires_grad = False
model_conv.classifier[2] = torch.nn.Linear(
    in_features=model_conv.classifier[2].in_features,
    out_features=len(class_names)
)
model_conv = model_conv.to(device)

criterion = nn.CrossEntropyLoss()
optimizer_conv = optim.SGD(model_conv.classifier[2].parameters(), lr=0.0005)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=6, gamma=0.1)

In [None]:
model_conv, model_conv_history = train_model(model_conv, criterion, optimizer_conv, exp_lr_scheduler, num_epochs=13)

In [None]:
model_ft_history = pd.DataFrame(model_ft_history, columns=["Train Accuracy", "Validation Accuracy", "Train Loss", "Validation Loss"])
model_conv_history = pd.DataFrame(model_conv_history, columns=["Train Accuracy", "Validation Accuracy", "Train Loss", "Validation Loss"])

In [None]:
num_of_subplots = len(model_ft_history.columns)

fig, axs = plt.subplots(nrows=2, ncols=2, sharex=True)

for i, title, ax in zip(range(num_of_subplots), model_ft_history.columns, [i for x in axs for i in x]):
    ax.plot(model_ft_history.iloc[:, i], label='Full fine-tuning')
    ax.plot(model_conv_history.iloc[:, i], label='Partial fine-tuning')
    ax.set_xticks(range(13))
    ax.set_xticklabels(("1", " ", "3", " ", "5", " ", "7", " ", "9", " ", "11", " ", "13"))
    ax.xaxis.set_ticks_position('none')
    ax.yaxis.set_ticks_position('none')
    if i > 1:  ax.set_xlabel('Epoch')
    if i == 0: ax.set_ylabel('Accuracy')
    if i == 2: ax.set_ylabel('Loss')
    ax.set_title(title)
    ax.legend()

plt.tight_layout()
plt.show()

In [None]:
def visualize_model_predictions(model, img_path):
    was_training = model.training
    model.eval()

    img = Image.open(img_path)
    img = data_transforms['val'](img)
    img = img.unsqueeze(0)
    img = img.to(device)

    with torch.no_grad():
        outputs = model(img)
        _, preds = torch.max(outputs, 1)

        model.train(mode=was_training)
        return outputs

In [None]:
fig, [ax1, ax2, ax3] = plt.subplots(nrows=1, ncols=3, figsize=(15, 3))

img = cv2.imread("/content/Vegetable Images/train/Bean/0026.jpg")
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
ax1.imshow(img)
ax1.set_title("Test Image")

original_preds = visualize_model_predictions(model_orig, "/content/Vegetable Images/train/Bean/0026.jpg")
ax2 = pd.Series(original_preds.data.cpu().numpy().tolist()[0]).plot(kind='bar', ax=ax2)
ax2.set_xticklabels(class_names, rotation=60)
ax2.axhline(y=0, color='green', linestyle='--')
ax2.set_title("Predictions before fine-tuning")

finetuned_preds = visualize_model_predictions(model_ft, "/content/Vegetable Images/train/Bean/0026.jpg")
pd.Series(finetuned_preds.data.cpu().numpy().tolist()[0]).plot(kind='bar', ax=ax3)
ax3.set_xticklabels(class_names, rotation=60)
ax3.axhline(y=0, color='green', linestyle='--')
ax3.set_title("Predictions after fine-tuning")

plt.show()