In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from tqdm import tqdm

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [7]:
weights=models.ViT_B_16_Weights.IMAGENET1K_SWAG_E2E_V1
vit_model = models.vit_b_16(weights=weights)
preprocess = weights.transforms()

In [8]:
train_dataset = datasets.ImageFolder(root='./plantdoc_ViT_training/cropped_images_dataset/train', transform=preprocess)
val_dataset = datasets.ImageFolder(root='./plantdoc_ViT_training/cropped_images_dataset/validation', transform=preprocess)


In [9]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=4)


In [11]:
# Get the mapping between class numbers and class labels
class_to_idx = train_dataset.class_to_idx
class_to_idx

{'Apple Scab Leaf': 0,
 'Apple leaf': 1,
 'Apple rust leaf': 2,
 'Bell_pepper leaf': 3,
 'Bell_pepper leaf spot': 4,
 'Blueberry leaf': 5,
 'Cherry leaf': 6,
 'Corn Gray leaf spot': 7,
 'Corn leaf blight': 8,
 'Corn rust leaf': 9,
 'Peach leaf': 10,
 'Potato leaf': 11,
 'Potato leaf early blight': 12,
 'Potato leaf late blight': 13,
 'Raspberry leaf': 14,
 'Soyabean leaf': 15,
 'Squash Powdery mildew leaf': 16,
 'Strawberry leaf': 17,
 'Tomato Early blight leaf': 18,
 'Tomato Septoria leaf spot': 19,
 'Tomato leaf': 20,
 'Tomato leaf bacterial spot': 21,
 'Tomato leaf late blight': 22,
 'Tomato leaf mosaic virus': 23,
 'Tomato leaf yellow virus': 24,
 'Tomato mold leaf': 25,
 'Tomato two spotted spider mites leaf': 26,
 'grape leaf': 27,
 'grape leaf black rot': 28}

In [13]:
num_classes = len(train_dataset.classes)
vit_model.heads[-1] = nn.Linear(vit_model.heads[-1].in_features, num_classes)

In [14]:
vit_model.load_state_dict(torch.load('./models/vitb16_SWAGE2E_model.pth'))

<All keys matched successfully>

In [15]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(vit_model.parameters(), lr=0.001, momentum=0.9)


In [16]:
vit_model.to(device)

VisionTransformer(
  (conv_proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
  (encoder): Encoder(
    (dropout): Dropout(p=0.0, inplace=False)
    (layers): Sequential(
      (encoder_layer_0): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
        (dropout): Dropout(p=0.0, inplace=False)
        (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU(approximate='none')
          (2): Dropout(p=0.0, inplace=False)
          (3): Linear(in_features=3072, out_features=768, bias=True)
          (4): Dropout(p=0.0, inplace=False)
        )
      )
      (encoder_layer_1): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_a

In [17]:
# Training loop
num_epochs = 1

for epoch in range(num_epochs):
    vit_model.train()
    running_loss = 0.0
    correct_train = 0
    total_train = 0
    with tqdm(total=len(train_loader), desc=f'Epoch {epoch + 1}/{num_epochs}', unit='batch') as pbar:
        for batch_idx, (inputs, labels) in enumerate(train_loader):
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = vit_model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            _, predicted = outputs.max(1)
            total_train += labels.size(0)
            correct_train += predicted.eq(labels).sum().item()

            # Update progress bar
            pbar.set_postfix({'Training Loss': running_loss / (batch_idx + 1)})
            pbar.update()

        # Calculate training accuracy
        train_accuracy = 100.0 * correct_train / total_train

        # Calculate validation loss
        vit_model.eval()
        val_loss = 0.0
        correct_val = 0
        total_val = 0
        with torch.no_grad():
            for val_inputs, val_labels in val_loader:
                val_inputs, val_labels = val_inputs.to(device), val_labels.to(device)
                val_outputs = vit_model(val_inputs)
                val_loss += criterion(val_outputs, val_labels).item()

                _, predicted_val = val_outputs.max(1)
                total_val += val_labels.size(0)
                correct_val += predicted_val.eq(val_labels).sum().item()

        # Calculate validation accuracy
        val_accuracy = 100.0 * correct_val / total_val

        # Print training and validation metrics
        print(f'Epoch {epoch + 1}/{num_epochs}, Training Loss: {running_loss / len(train_loader)}, Training Accuracy: {train_accuracy}%, Validation Loss: {val_loss / len(val_loader)}, Validation Accuracy: {val_accuracy}%')

print("Training complete!")

Epoch 1/1: 100%|████| 463/463 [07:30<00:00,  1.03batch/s, Training Loss=0.00242]

Epoch 1/1, Training Loss: 0.002416566229755633, Training Accuracy: 99.97297662478043%, Validation Loss: 1.285548043681238, Validation Accuracy: 73.76237623762377%
Training complete!





In [21]:
test_dataset = datasets.ImageFolder(root='./plantdoc_ViT_training/cropped_images_dataset/test', transform=preprocess)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=4)

# Calculate test loss
vit_model.eval()
test_loss = 0.0
correct_test = 0
total_test = 0
with torch.no_grad():
    for test_inputs, test_labels in test_loader:
        test_inputs, test_labels = test_inputs.to(device), test_labels.to(device)
        test_outputs = vit_model(test_inputs)
        test_loss += criterion(test_outputs, test_labels).item()

        _, predicted_test = test_outputs.max(1)
        total_test += test_labels.size(0)
        correct_test += predicted_test.eq(test_labels).sum().item()

# Calculate test accuracy
test_accuracy = 100.0 * correct_test / total_test
print(test_accuracy)

73.07692307692308


In [22]:
torch.save(vit_model.state_dict(), './models/vitb16_SWAGE2E_model.pth')