In [12]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from tqdm import tqdm

In [13]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
weights=models.VGG16_Weights.IMAGENET1K_V1
vgg_model = models.vgg16(weights=weights)
preprocess = weights.transforms()

In [15]:
train_dataset = datasets.ImageFolder(root='./plantdoc_ViT_training/cropped_images_dataset/train', transform=preprocess)
val_dataset = datasets.ImageFolder(root='./plantdoc_ViT_training/cropped_images_dataset/validation', transform=preprocess)


In [16]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)


In [17]:
num_classes = len(train_dataset.classes)
vgg_model.classifier[6] = nn.Linear(4096, num_classes)


In [18]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(vgg_model.parameters(), lr=0.001, momentum=0.9)


In [19]:
vgg_model.to(device)

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1

In [23]:
# Training loop
num_epochs = 1

for epoch in range(num_epochs):
    vgg_model.train()
    running_loss = 0.0
    correct_train = 0
    total_train = 0
    with tqdm(total=len(train_loader), desc=f'Epoch {epoch + 1}/{num_epochs}', unit='batch') as pbar:
        for batch_idx, (inputs, labels) in enumerate(train_loader):
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = vgg_model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            _, predicted = outputs.max(1)
            total_train += labels.size(0)
            correct_train += predicted.eq(labels).sum().item()

            # Update progress bar
            pbar.set_postfix({'Training Loss': running_loss / (batch_idx + 1)})
            pbar.update()

        # Calculate training accuracy
        train_accuracy = 100.0 * correct_train / total_train

        # Calculate validation loss
        vgg_model.eval()
        val_loss = 0.0
        correct_val = 0
        total_val = 0
        with torch.no_grad():
            for val_inputs, val_labels in val_loader:
                val_inputs, val_labels = val_inputs.to(device), val_labels.to(device)
                val_outputs = vgg_model(val_inputs)
                val_loss += criterion(val_outputs, val_labels).item()

                _, predicted_val = val_outputs.max(1)
                total_val += val_labels.size(0)
                correct_val += predicted_val.eq(val_labels).sum().item()

        # Calculate validation accuracy
        val_accuracy = 100.0 * correct_val / total_val

        # Print training and validation metrics
        print(f'Epoch {epoch + 1}/{num_epochs}, Training Loss: {running_loss / len(train_loader)}, Training Accuracy: {train_accuracy}%, Validation Loss: {val_loss / len(val_loader)}, Validation Accuracy: {val_accuracy}%')

print("Training complete!")

Epoch 1/1: 100%|██████| 232/232 [01:15<00:00,  3.09batch/s, Training Loss=0.138]

Epoch 1/1, Training Loss: 0.13848815952901375, Training Accuracy: 95.43304958789352%, Validation Loss: 1.8443398976600485, Validation Accuracy: 65.84158415841584%
Training complete!





In [24]:
test_dataset = datasets.ImageFolder(root='./plantdoc_ViT_training/cropped_images_dataset/test', transform=preprocess)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)

# Calculate test loss
vgg_model.eval()
test_loss = 0.0
correct_test = 0
total_test = 0
with torch.no_grad():
    for test_inputs, test_labels in test_loader:
        test_inputs, test_labels = test_inputs.to(device), test_labels.to(device)
        test_outputs = vgg_model(test_inputs)
        test_loss += criterion(test_outputs, test_labels).item()

        _, predicted_test = test_outputs.max(1)
        total_test += test_labels.size(0)
        correct_test += predicted_test.eq(test_labels).sum().item()

# Calculate test accuracy
test_accuracy = 100.0 * correct_test / total_test
print(test_accuracy)

59.467455621301774


In [25]:
torch.save(vgg_model.state_dict(), './models/vgg16_model.pth')