In [2]:
import torch
import torch.nn as nn
import torchvision.models as models

from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
from sklearn.metrics import confusion_matrix, classification_report

In [1]:
NORMALIZE_MEAN = [0.485, 0.456, 0.406]
NORMALIZE_STD = [0.229, 0.224, 0.225]

IMAGE_SIZE = 224
BATCH_SIZE = 16

NUM_WORKERS = 2
NUM_CLASSES = 9

In [5]:
test_transforms = transforms.Compose([
    transforms.Resize(int(IMAGE_SIZE * 1.14)),
    transforms.CenterCrop(IMAGE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(NORMALIZE_MEAN, NORMALIZE_STD),
])

test_dataset_path = '../../../dataset/test'
test_dataset = datasets.ImageFolder(root=test_dataset_path, transform=test_transforms)
test_loader = DataLoader(
    test_dataset,
    shuffle=False,
    pin_memory=True,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
)

In [6]:
def test_model(model, loader, criterion, device):
    model.eval()
    test_loss, correct, total = 0.0, 0, 0

    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            test_loss += loss.item() * images.size(0)
            preds = outputs.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += images.size(0)

    avg_loss = test_loss/total
    avg_accuracy = correct/total

    return avg_loss, avg_accuracy

In [8]:
model = models.resnet50(weights=None)
criterion = nn.CrossEntropyLoss()

trained_model_path = "best_resnet50.pth"
model.load_state_dict(torch.load("best_resnet50.pth"))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [11]:
test_loss, test_acc = test_model(model, test_loader, criterion, device)
print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.4f}")

Test Loss: 0.3229, Test Accuracy: 0.8874


In [14]:
all_preds = []
all_labels = []

# create the confusion matrix and classification report
model.eval()
with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(device)
        outputs = model(images)
        preds = outputs.argmax(dim=1).cpu().numpy()
        all_preds.extend(preds)
        all_labels.extend(labels.numpy())

conf_matrix = confusion_matrix(all_labels, all_preds)
print("Confusion Matrix:\n", conf_matrix)
print("\nClassification Report:\n", 
    classification_report(all_labels, all_preds, target_names=test_dataset.classes)
)


Confusion Matrix:
 [[160   0   4   3   0   2   0   1   0]
 [  0 117  11   1   4   0   0   1   0]
 [  4  25 126   1   4   0   0   3   0]
 [  0   0   0 145   2   0   0   3   0]
 [  0   0   1  15 131   0   0   3   0]
 [  2   0   2   0   1  19   0   0   0]
 [  1   0   0   0   0   0  18   2   0]
 [  3   5   0   7   0   0   0 157   0]
 [  1   0   0   0   0   0   0   0  10]]

Classification Report:
                      precision    recall  f1-score   support

         algal_spot       0.94      0.94      0.94       170
       brown_blight       0.80      0.87      0.83       134
        gray_blight       0.88      0.77      0.82       163
            healthy       0.84      0.97      0.90       150
         helopeltis       0.92      0.87      0.90       150
           red-rust       0.90      0.79      0.84        24
red-spider-infested       1.00      0.86      0.92        21
           red_spot       0.92      0.91      0.92       172
         white-spot       1.00      0.91      0.95    