In [1]:
import torch
import torchvision.models as models
import torch.nn as nn
import torchvision.transforms as transform
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from sklearn.metrics import precision_score, classification_report

In [2]:
transform = transform.Compose([
    transform.Resize(224),
    transform.CenterCrop(224),
    transform.ToTensor(),
    transform.Normalize(mean=[0.485, 0.456, 0.406],
                 std=[0.229, 0.224, 0.225]),
])

In [3]:
test_dataset = ImageFolder('../../dataset-dapa/Test/', transform=transform)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False, num_workers=0)

In [None]:
model = models.mobilenet_v3_small(pretrained=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('is cuda available?', torch.cuda.is_available())

state_dict = torch.load('best_mobilenetv3_teadiseases.pth', map_location=device)
num_classes = 9
classifier = nn.Sequential(
    nn.Dropout(0.3),
    nn.Linear(model.classifier[0].in_features, num_classes)
)
model.classifier = classifier

model.load_state_dict(state_dict, strict=False)
model = model.to(device)
model.eval()

  state_dict = torch.load('best_mobilenetv3_teadiseases.pth', map_location=device)


is cuda available? True


MobileNetV3(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
      (2): Hardswish()
    )
    (1): InvertedResidual(
      (block): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=16, bias=False)
          (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
        )
        (1): SqueezeExcitation(
          (avgpool): AdaptiveAvgPool2d(output_size=1)
          (fc1): Conv2d(16, 8, kernel_size=(1, 1), stride=(1, 1))
          (fc2): Conv2d(8, 16, kernel_size=(1, 1), stride=(1, 1))
          (activation): ReLU()
          (scale_activation): Hardsigmoid()
        )
        (2): Conv2dNormActivation(
          (0): Conv2d(16, 16, kernel_size=(1, 1), 

In [5]:
correct = 0
total = 0
all_preds = []
all_labels = []

model.eval()
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)

        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        all_preds.extend(predicted.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

accuracy = 100 * correct / total
print(f"Test Accuracy: {accuracy:.2f}%\n")

report = classification_report(all_labels, all_preds, target_names=test_dataset.classes, output_dict=True, zero_division=0)

print(f"{'Class':<25} {'Precision':<10} {'Recall':<10} {'F1-Score':<10} {'Accuracy':<10}")
print("-" * 65)
for class_name in test_dataset.classes:
    cls_report = report[class_name]
    precision = cls_report['precision']
    recall = cls_report['recall']
    f1 = cls_report['f1-score']
    support = cls_report['support']
    acc = (recall * support) / support
    print(f"{class_name:<25} {precision:<10.2f} {recall:<10.2f} {f1:<10.2f} {acc*100:<10.2f}")

Test Accuracy: 86.53%

Class                     Precision  Recall     F1-Score   Accuracy  
-----------------------------------------------------------------
algal_spot                0.90       0.85       0.87       85.29     
brown_blight              0.87       0.78       0.82       78.36     
gray_blight               0.77       0.85       0.81       84.66     
healthy                   0.86       0.91       0.88       90.67     
helopeltis                0.91       0.87       0.89       87.33     
red-rust                  0.68       0.79       0.73       79.17     
red-spider-infested       1.00       0.95       0.98       95.24     
red_spot                  0.92       0.91       0.92       91.28     
white-spot                0.83       0.91       0.87       90.91     
