In [None]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from transformers import ViTForImageClassification
from sklearn.metrics import classification_report, confusion_matrix
from google.colab import drive

drive.mount('/content/drive')

IMAGE_SIZE = 224
BATCH_SIZE = 8
NUM_WORKERS = 2
NUM_CLASSES = 9
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", DEVICE)

val_transforms = transforms.Compose([
    transforms.Resize(IMAGE_SIZE),
    transforms.CenterCrop(IMAGE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

val_dataset = datasets.ImageFolder('/content/drive/MyDrive/dataset-dapa/test/', transform=val_transforms)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

model = ViTForImageClassification.from_pretrained(
    "google/vit-base-patch16-224",
    num_labels=NUM_CLASSES,
    ignore_mismatched_sizes=True
)

import torch.nn as nn
model.classifier = nn.Sequential(
    nn.Dropout(0.3),
    nn.Linear(model.classifier.in_features, NUM_CLASSES)
)

model.load_state_dict(torch.load('/content/drive/MyDrive/models/ViT-15epoch.pth', map_location=DEVICE))
model.to(DEVICE)
model.eval()

def evaluate_model(model, loader, device, class_names):
    correct = 0
    total = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, preds = torch.max(outputs.logits, 1)

            total += labels.size(0)
            correct += (preds == labels).sum().item()

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    acc = 100 * correct / total
    print(f"\nOverall Accuracy: {acc:.2f}%\n")

    report = classification_report(all_labels, all_preds, target_names=class_names, output_dict=True)
    cm = confusion_matrix(all_labels, all_preds)
    per_class_acc = cm.diagonal() / cm.sum(axis=1)

    print(f"{'Class':<25} {'Precision':<10} {'Recall':<10} {'F1-Score':<10} {'Class Acc':<10}")
    print("-" * 70)
    for idx, class_name in enumerate(class_names):
        cls = report[class_name]
        print(f"{class_name:<25} {cls['precision']:<10.2f} {cls['recall']:<10.2f} {cls['f1-score']:<10.2f} {per_class_acc[idx]*100:<10.2f}")

evaluate_model(model, val_loader, DEVICE, val_dataset.classes)

Mounted at /content/drive
Device: cpu


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/69.7k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([9]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([9, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



Overall Accuracy: 94.57%

Class                     Precision  Recall     F1-Score   Class Acc 
----------------------------------------------------------------------
algal_spot                0.95       0.98       0.97       98.24     
brown_blight              0.93       0.84       0.88       84.33     
gray_blight               0.88       0.94       0.91       93.87     
healthy                   0.93       0.99       0.96       98.67     
helopeltis                0.98       0.95       0.97       95.33     
red-rust                  0.91       0.83       0.87       83.33     
red-spider-infested       1.00       1.00       1.00       100.00    
red_spot                  1.00       0.96       0.98       95.93     
white-spot                1.00       1.00       1.00       100.00    
