In [12]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from tqdm import tqdm

In [13]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [14]:
weights=models.VGG19_Weights.IMAGENET1K_V1
vgg_model = models.vgg19(weights=weights)
preprocess = weights.transforms()

In [15]:
train_dataset = datasets.ImageFolder(root='./plantdoc_ViT_training/cropped_images_dataset/train', transform=preprocess)
val_dataset = datasets.ImageFolder(root='./plantdoc_ViT_training/cropped_images_dataset/validation', transform=preprocess)


In [16]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)


In [17]:
num_classes = len(train_dataset.classes)
vgg_model.classifier[6] = nn.Linear(4096, num_classes)


In [18]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(vgg_model.parameters(), lr=0.001, momentum=0.9)


In [19]:
vgg_model.to(device)

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padd

In [20]:
# Training loop
num_epochs = 10

for epoch in range(num_epochs):
    vgg_model.train()
    running_loss = 0.0
    correct_train = 0
    total_train = 0
    with tqdm(total=len(train_loader), desc=f'Epoch {epoch + 1}/{num_epochs}', unit='batch') as pbar:
        for batch_idx, (inputs, labels) in enumerate(train_loader):
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = vgg_model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            _, predicted = outputs.max(1)
            total_train += labels.size(0)
            correct_train += predicted.eq(labels).sum().item()

            # Update progress bar
            pbar.set_postfix({'Training Loss': running_loss / (batch_idx + 1)})
            pbar.update()

        # Calculate training accuracy
        train_accuracy = 100.0 * correct_train / total_train

        # Calculate validation loss
        vgg_model.eval()
        val_loss = 0.0
        correct_val = 0
        total_val = 0
        with torch.no_grad():
            for val_inputs, val_labels in val_loader:
                val_inputs, val_labels = val_inputs.to(device), val_labels.to(device)
                val_outputs = vgg_model(val_inputs)
                val_loss += criterion(val_outputs, val_labels).item()

                _, predicted_val = val_outputs.max(1)
                total_val += val_labels.size(0)
                correct_val += predicted_val.eq(val_labels).sum().item()

        # Calculate validation accuracy
        val_accuracy = 100.0 * correct_val / total_val

        # Print training and validation metrics
        print(f'Epoch {epoch + 1}/{num_epochs}, Training Loss: {running_loss / len(train_loader)}, Training Accuracy: {train_accuracy}%, Validation Loss: {val_loss / len(val_loader)}, Validation Accuracy: {val_accuracy}%')

print("Training complete!")

Epoch 1/10: 100%|██████| 232/232 [01:26<00:00,  2.69batch/s, Training Loss=2.06]


Epoch 1/10, Training Loss: 2.0603473068311295, Training Accuracy: 39.2919875692474%, Validation Loss: 2.0141468706883883, Validation Accuracy: 38.44884488448845%


Epoch 2/10: 100%|██████| 232/232 [01:26<00:00,  2.67batch/s, Training Loss=1.31]


Epoch 2/10, Training Loss: 1.3117775634444993, Training Accuracy: 59.50547223348196%, Validation Loss: 1.9078307089052702, Validation Accuracy: 44.38943894389439%


Epoch 3/10: 100%|██████| 232/232 [01:27<00:00,  2.66batch/s, Training Loss=1.03]


Epoch 3/10, Training Loss: 1.0283687077462673, Training Accuracy: 67.2611809214971%, Validation Loss: 1.5992875444261652, Validation Accuracy: 51.98019801980198%


Epoch 4/10: 100%|█████| 232/232 [01:27<00:00,  2.65batch/s, Training Loss=0.783]


Epoch 4/10, Training Loss: 0.7829092769273396, Training Accuracy: 74.81421429536549%, Validation Loss: 1.4700637287215184, Validation Accuracy: 56.93069306930693%


Epoch 5/10: 100%|█████| 232/232 [01:27<00:00,  2.65batch/s, Training Loss=0.611]


Epoch 5/10, Training Loss: 0.6112436749490684, Training Accuracy: 79.57032833400892%, Validation Loss: 1.5600852182036953, Validation Accuracy: 59.9009900990099%


Epoch 6/10: 100%|█████| 232/232 [01:27<00:00,  2.65batch/s, Training Loss=0.506]


Epoch 6/10, Training Loss: 0.5059605732174783, Training Accuracy: 82.9752736116741%, Validation Loss: 1.8313346599277698, Validation Accuracy: 48.34983498349835%


Epoch 7/10: 100%|█████| 232/232 [01:27<00:00,  2.65batch/s, Training Loss=0.391]


Epoch 7/10, Training Loss: 0.39065042070659073, Training Accuracy: 87.00175651938928%, Validation Loss: 1.4855206266633774, Validation Accuracy: 62.211221122112214%


Epoch 8/10: 100%|█████| 232/232 [01:27<00:00,  2.65batch/s, Training Loss=0.286]


Epoch 8/10, Training Loss: 0.285964515317103, Training Accuracy: 90.27158492095663%, Validation Loss: 1.5546088110851615, Validation Accuracy: 62.54125412541254%


Epoch 9/10: 100%|█████| 232/232 [01:27<00:00,  2.64batch/s, Training Loss=0.264]


Epoch 9/10, Training Loss: 0.26436083816826855, Training Accuracy: 91.28496149169031%, Validation Loss: 2.012016293249632, Validation Accuracy: 51.65016501650165%


Epoch 10/10: 100%|████| 232/232 [01:27<00:00,  2.64batch/s, Training Loss=0.217]

Epoch 10/10, Training Loss: 0.21664040779759144, Training Accuracy: 92.94689906769355%, Validation Loss: 1.66814203403498, Validation Accuracy: 63.53135313531353%
Training complete!





In [21]:
test_dataset = datasets.ImageFolder(root='./plantdoc_ViT_training/cropped_images_dataset/test', transform=preprocess)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)

# Calculate test loss
vgg_model.eval()
test_loss = 0.0
correct_test = 0
total_test = 0
with torch.no_grad():
    for test_inputs, test_labels in test_loader:
        test_inputs, test_labels = test_inputs.to(device), test_labels.to(device)
        test_outputs = vgg_model(test_inputs)
        test_loss += criterion(test_outputs, test_labels).item()

        _, predicted_test = test_outputs.max(1)
        total_test += test_labels.size(0)
        correct_test += predicted_test.eq(test_labels).sum().item()

# Calculate test accuracy
test_accuracy = 100.0 * correct_test / total_test
print(test_accuracy)

57.59368836291913


In [22]:
torch.save(vgg_model.state_dict(), './models/vgg19_model.pth')