In [None]:
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from sklearn.metrics import classification_report, confusion_matrix
from google.colab import drive
import numpy as np

In [None]:
IMAGE_SIZE = 299
BATCH_SIZE = 8
NUM_WORKERS = 2
MEAN = [0.485, 0.456, 0.406]
STD  = [0.229, 0.224, 0.225]

In [None]:
train_transforms = transforms.Compose([
    transforms.RandomResizedCrop(IMAGE_SIZE, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize(MEAN, STD)
])

val_transforms = transforms.Compose([
    transforms.Resize(IMAGE_SIZE),
    transforms.CenterCrop(IMAGE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(MEAN, STD)
])

In [None]:
drive.mount('/content/drive')
train_dataset = datasets.ImageFolder(root='/content/drive/MyDrive/dataset-dapa/train/', transform=train_transforms)
val_dataset   = datasets.ImageFolder(root='/content/drive/MyDrive/dataset-dapa/val/',   transform=val_transforms)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True
)


val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True
)

In [None]:
if __name__ == "__main__":
    images, labels = next(iter(train_loader))
    print(f"Batch shape: {images.shape}")
    print(f"Labels shape: {labels.shape}")

model = models.inception_v3(weights=models.Inception_V3_Weights.IMAGENET1K_V1, aux_logits=True)
for param in model.parameters():
    param.requires_grad = False

num_classes = 9

classifier = nn.Sequential(
    nn.Dropout(0.3),
    nn.Linear(model.fc.in_features, num_classes)
)
model.fc = classifier

for param in model.parameters():
    param.requires_grad = True

if model.aux_logits:
    model.AuxLogits.fc = nn.Linear(model.AuxLogits.fc.in_features, num_classes)
    for param in model.AuxLogits.parameters():
        param.requires_grad = True

Batch shape: torch.Size([8, 3, 299, 299])
Labels shape: torch.Size([8])


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

Inception3(
  (Conv2d_1a_3x3): BasicConv2d(
    (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
    (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (Conv2d_2a_3x3): BasicConv2d(
    (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (Conv2d_2b_3x3): BasicConv2d(
    (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (maxpool1): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  (Conv2d_3b_1x1): BasicConv2d(
    (conv): Conv2d(64, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(80, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (Conv2d_4a_3x3): BasicConv2d(
    (conv): Conv2d(80, 192, kernel_size=(3, 3), stri

In [None]:
criterion = nn.CrossEntropyLoss()
params_to_optimize = list(model.fc.parameters())
if model.aux_logits:
    params_to_optimize += list(model.AuxLogits.parameters())

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [None]:
def train_one_epoch(model, loader, optimizer, criterion, device):
    model.train()
    running_loss, correct, total = 0.0, 0, 0
    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()

        outputs = model(images)
        if isinstance(outputs, tuple):
            main_output, aux_output = outputs
            loss = criterion(main_output, labels) + 0.4 * criterion(aux_output, labels)
            preds = main_output.argmax(dim=1)
        else:
            loss = criterion(outputs, labels)
            preds = outputs.argmax(dim=1)

        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        correct += (preds == labels).sum().item()
        total += images.size(0)
    return running_loss / total, correct / total


def validate(model, loader, criterion, device):
    model.eval()
    val_loss, correct, total = 0.0, 0, 0
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            if isinstance(outputs, tuple):
                outputs = outputs[0]

            loss = criterion(outputs, labels)
            val_loss += loss.item() * images.size(0)
            preds = outputs.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += images.size(0)
    return val_loss / total, correct / total

In [None]:
num_epochs = 30
best_val_loss = float('inf')

for epoch in range(num_epochs):
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion, device)
    val_loss, val_acc     = validate(model, val_loader, criterion, device)

    print(f"Epoch {epoch+1}/{num_epochs}: "
          f"Train loss {train_loss:.4f}, acc {train_acc:.4f} | "
          f"Val   loss {val_loss:.4f}, acc {val_acc:.4f}")

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), "InceptionV3-new.pth")

Epoch 1/30: Train loss 1.4806, acc 0.6433 | Val   loss 0.5925, acc 0.7778
Epoch 2/30: Train loss 0.9723, acc 0.7661 | Val   loss 0.3695, acc 0.8889
Epoch 3/30: Train loss 0.7985, acc 0.8150 | Val   loss 0.3555, acc 0.8990
Epoch 4/30: Train loss 0.7138, acc 0.8312 | Val   loss 0.2899, acc 0.9111
Epoch 5/30: Train loss 0.6094, acc 0.8519 | Val   loss 0.2485, acc 0.9313
Epoch 6/30: Train loss 0.5786, acc 0.8647 | Val   loss 0.2404, acc 0.9232
Epoch 7/30: Train loss 0.5465, acc 0.8740 | Val   loss 0.3593, acc 0.8747
Epoch 8/30: Train loss 0.4960, acc 0.8930 | Val   loss 0.2646, acc 0.8949
Epoch 9/30: Train loss 0.4700, acc 0.8926 | Val   loss 0.2282, acc 0.9152
Epoch 10/30: Train loss 0.4470, acc 0.8990 | Val   loss 0.2079, acc 0.9263
Epoch 11/30: Train loss 0.4128, acc 0.9064 | Val   loss 0.1565, acc 0.9515
Epoch 12/30: Train loss 0.3677, acc 0.9118 | Val   loss 0.2497, acc 0.9152
Epoch 13/30: Train loss 0.3560, acc 0.9166 | Val   loss 0.2164, acc 0.9283
Epoch 14/30: Train loss 0.3638, ac

In [None]:
import os

save_dir = '/content/drive/MyDrive/models'
os.makedirs(save_dir, exist_ok=True)

save_path = os.path.join(save_dir, 'InceptionV3-new.pth')

torch.save(model.state_dict(), save_path)
print(f"Model saved at: {save_path}")

Model saved at: /content/drive/MyDrive/models/InceptionV3-new.pth


In [None]:
test_transforms = transforms.Compose([
    transforms.Resize(IMAGE_SIZE),
    transforms.CenterCrop(IMAGE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(MEAN, STD)
])

In [None]:
test_dataset = datasets.ImageFolder(
    root='/content/drive/MyDrive/dataset-dapa/test/',
    transform=test_transforms
)

In [None]:
test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True
)

In [None]:
def evaluate_on_test(model, test_loader, device, class_names):
    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            outputs = model(images)

            if isinstance(outputs, tuple):
                outputs = outputs[0]

            preds = outputs.argmax(dim=1).cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(labels.numpy())

    print(classification_report(
        all_labels,
        all_preds,
        target_names=class_names,
        digits=2,
        zero_division=0
    ))


In [None]:
class_names = test_dataset.classes
evaluate_on_test(model, test_loader, device, class_names)

                     precision    recall  f1-score   support

         algal_spot       0.97      0.98      0.97       170
       brown_blight       0.96      0.98      0.97       134
        gray_blight       0.96      0.96      0.96       163
            healthy       0.94      0.99      0.96       150
         helopeltis       0.96      0.97      0.96       150
           red-rust       0.76      0.79      0.78        24
red-spider-infested       1.00      0.95      0.98        21
           red_spot       1.00      0.91      0.95       172
         white-spot       0.92      1.00      0.96        11

           accuracy                           0.96       995
          macro avg       0.94      0.95      0.94       995
       weighted avg       0.96      0.96      0.96       995



In [None]:
def per_class_accuracy(model, loader, device, class_names):
    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for images, labels in loader:
            images = images.to(device)
            outputs = model(images)

            if isinstance(outputs, tuple):
                outputs = outputs[0]

            preds = outputs.argmax(dim=1).cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(labels.numpy())

    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)

    cm = confusion_matrix(all_labels, all_preds)
    print(f"{'Class':<25} {'Accuracy (%)':>12}")
    print("-" * 40)
    for i, class_name in enumerate(class_names):
        correct = cm[i, i]
        total = cm[i].sum()
        acc = 100.0 * correct / total if total > 0 else 0.0
        print(f"{class_name:<25} {acc:>12.2f}")


In [None]:
class_names = test_dataset.classes
per_class_accuracy(model, test_loader, device, class_names)

Class                     Accuracy (%)
----------------------------------------
algal_spot                       97.65
brown_blight                     97.76
gray_blight                      96.32
healthy                          98.67
helopeltis                       96.67
red-rust                         79.17
red-spider-infested              95.24
red_spot                         90.70
white-spot                      100.00
