In [1]:
import torch
import torchvision.models as models
import torchvision.transforms as transform
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from sklearn.metrics import classification_report
import torch.nn as nn

In [2]:
transform = transform.Compose([
    transform.Resize(224),
    transform.CenterCrop(224),
    transform.ToTensor(),
    transform.Normalize(mean=[0.485, 0.456, 0.406],
                 std=[0.229, 0.224, 0.225]),
])

In [3]:
test_dir = '../../dataset-dapa/test/'

test_dataset = ImageFolder(test_dir, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False, num_workers=0)

In [None]:
model = models.densenet201(weights=models.DenseNet201_Weights.IMAGENET1K_V1)
features = model.features
features.requires_grad_(False) 

num_classes = len(test_dataset.classes)
model.classifier = nn.Sequential(
    nn.Dropout(0.3),
    nn.Linear(model.classifier.in_features, num_classes)
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.load_state_dict(torch.load('best_densenet201.pth', map_location=device))
print('is cuda available?', torch.cuda.is_available())
model = model.to(device)


is cuda available? True


In [5]:
correct = 0
total = 0
all_preds = []
all_labels = []

model.eval()
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)

        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        all_preds.extend(predicted.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

accuracy = 100 * correct / total
print(f"Test Accuracy: {accuracy:.2f}%\n")

report = classification_report(all_labels, all_preds, target_names=test_dataset.classes, output_dict=True)

print(f"{'Class':<25} {'Precision':<10} {'Recall':<10} {'F1-Score':<10} {'Accuracy':<10}")
print("-" * 65)
for class_name in test_dataset.classes:
    cls_report = report[class_name]
    precision = cls_report['precision']
    recall = cls_report['recall']
    f1 = cls_report['f1-score']
    support = cls_report['support']
    acc = (recall * support) / support
    print(f"{class_name:<25} {precision:<10.2f} {recall:<10.2f} {f1:<10.2f} {acc*100:<10.2f}")

Test Accuracy: 90.45%

Class                     Precision  Recall     F1-Score   Accuracy  
-----------------------------------------------------------------
algal_spot                0.95       0.95       0.95       95.29     
brown_blight              0.82       0.92       0.87       91.79     
gray_blight               0.90       0.83       0.86       82.82     
healthy                   0.97       0.81       0.88       80.67     
helopeltis                0.97       0.93       0.95       92.67     
red-rust                  0.83       0.83       0.83       83.33     
red-spider-infested       1.00       1.00       1.00       100.00    
red_spot                  0.84       0.98       0.91       98.26     
white-spot                1.00       0.91       0.95       90.91     
