In [None]:
import torch
import torch.nn as nn
import torchvision.models as models
from transformers import DeiTForImageClassification
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from google.colab import drive
import numpy as np

In [6]:
IMAGE_SIZE = 224
BATCH_SIZE = 2
NUM_WORKERS = 2
MEAN = [0.5, 0.5, 0.5]
STD = [0.5, 0.5, 0.5]

In [7]:
train_transforms = transforms.Compose([
    transforms.RandomResizedCrop(IMAGE_SIZE, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(), 
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(15),                  
    transforms.ColorJitter(brightness=0.2, contrast=0.2), 
    transforms.ToTensor(),                                    
    transforms.Normalize(MEAN, STD)                            
])

val_transforms = transforms.Compose([
    transforms.Resize(IMAGE_SIZE),
    transforms.CenterCrop(IMAGE_SIZE), 
    transforms.ToTensor(),
    transforms.Normalize(MEAN, STD)
])

In [None]:
drive.mount('/content/drive')
train_dataset = datasets.ImageFolder(root='/content/drive/MyDrive/dataset-dapa/train/', transform=train_transforms)
val_dataset   = datasets.ImageFolder(root='/content/drive/MyDrive/dataset-dapa/val/',   transform=val_transforms)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True
)


val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True
)

In [None]:
if __name__ == "__main__":
    images, labels = next(iter(train_loader))
    print(f"Batch shape: {images.shape}") 
    print(f"Labels shape: {labels.shape}")

model = DeiTForImageClassification.from_pretrained(
    "facebook/deit-base-distilled-patch16-224",
    num_labels=9,
    ignore_mismatched_sizes=True
)

model.deit.requires_grad_(False)

model.classifier = nn.Sequential(
    nn.Dropout(0.3),
    nn.Linear(model.classifier.in_features, 9)
)

Batch shape: torch.Size([2, 3, 224, 224])
Labels shape: torch.Size([2])


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
Some weights of DeiTForImageClassification were not initialized from the model checkpoint at facebook/deit-base-distilled-patch16-224 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('device: ', device)
model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.classifier.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.1, patience=3,
)

device:  cuda


In [None]:
def train_one_epoch(model, loader, optimizer, criterion, device):
    model.train()
    running_loss, correct, total = 0.0, 0, 0
    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()

        outputs = model(images)
        logits = outputs.logits

        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        preds = logits.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += images.size(0)

    return running_loss / total, correct / total

def validate(model, loader, criterion, device):
    model.eval()
    val_loss, correct, total = 0.0, 0, 0
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            logits = outputs.logits

            loss = criterion(logits, labels)
            val_loss += loss.item() * images.size(0)
            preds = logits.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += images.size(0)

    return val_loss / total, correct / total

In [None]:
num_epochs = 30
best_val_loss = float('inf')

for epoch in range(num_epochs):
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion, device)
    val_loss, val_acc     = validate(model, val_loader, criterion, device)

    print(f"Epoch {epoch+1}/{num_epochs}: "
          f"Train loss {train_loss:.4f}, acc {train_acc:.4f} | "
          f"Val   loss {val_loss:.4f}, acc {val_acc:.4f}")

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), "DeiT.pth")

Epoch 1/30: Train loss 0.6044, acc 0.7966 | Val   loss 0.2977, acc 0.9091
Epoch 2/30: Train loss 0.3566, acc 0.8774 | Val   loss 0.2357, acc 0.9182
Epoch 3/30: Train loss 0.3035, acc 0.8904 | Val   loss 0.1917, acc 0.9323
Epoch 4/30: Train loss 0.2747, acc 0.9006 | Val   loss 0.1862, acc 0.9333
Epoch 5/30: Train loss 0.2806, acc 0.9038 | Val   loss 0.1720, acc 0.9424
Epoch 6/30: Train loss 0.2700, acc 0.9032 | Val   loss 0.1555, acc 0.9424
Epoch 7/30: Train loss 0.2874, acc 0.9040 | Val   loss 0.1701, acc 0.9384
Epoch 8/30: Train loss 0.2700, acc 0.9055 | Val   loss 0.1493, acc 0.9404
Epoch 9/30: Train loss 0.2603, acc 0.9066 | Val   loss 0.1611, acc 0.9354
Epoch 10/30: Train loss 0.2630, acc 0.9094 | Val   loss 0.1658, acc 0.9394
Epoch 11/30: Train loss 0.2705, acc 0.9049 | Val   loss 0.1222, acc 0.9566
Epoch 12/30: Train loss 0.2719, acc 0.9051 | Val   loss 0.1427, acc 0.9505
Epoch 13/30: Train loss 0.2744, acc 0.9083 | Val   loss 0.1569, acc 0.9455
Epoch 14/30: Train loss 0.2749, ac

In [None]:
save_path = '/content/drive/MyDrive/DeiT.pth'

torch.save(model.state_dict(), save_path)
print("Model saved successfully to Google Drive!")

Model saved successfully to Google Drive!


In [None]:
import os

save_dir = '/content/drive/MyDrive/models'
os.makedirs(save_dir, exist_ok=True)

save_path = os.path.join(save_dir, 'DeiT.pth')

torch.save(model.state_dict(), save_path)
print(f"Model saved at: {save_path}")

Model saved at: /content/drive/MyDrive/models/DeiT.pth


In [None]:
from sklearn.metrics import classification_report, confusion_matrix
import numpy as np

correct = 0
total = 0
all_preds = []
all_labels = []

model.eval()
with torch.no_grad():
    for images, labels in val_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.logits, 1)

        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        all_preds.extend(predicted.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

accuracy = 100 * correct / total
print(f"\nOverall Accuracy: {accuracy:.2f}%\n")

report = classification_report(all_labels, all_preds, target_names=val_dataset.classes, output_dict=True)

cm = confusion_matrix(all_labels, all_preds)
per_class_accuracy = cm.diagonal() / cm.sum(axis=1)

print(f"{'Class':<25} {'Precision':<10} {'Recall':<10} {'F1-Score':<10} {'Class Acc':<10}")
print("-" * 70)
for idx, class_name in enumerate(val_dataset.classes):
    cls_report = report[class_name]
    precision = cls_report['precision']
    recall = cls_report['recall']
    f1 = cls_report['f1-score']
    acc = per_class_accuracy[idx]
    print(f"{class_name:<25} {precision:<10.2f} {recall:<10.2f} {f1:<10.2f} {acc*100:<10.2f}")



Overall Accuracy: 93.64%

Class                     Precision  Recall     F1-Score   Class Acc 
----------------------------------------------------------------------
algal_spot                0.99       0.96       0.97       95.86     
brown_blight              0.97       0.80       0.88       79.70     
gray_blight               0.85       0.98       0.91       97.55     
healthy                   0.99       0.88       0.93       88.00     
helopeltis                0.91       0.99       0.95       98.67     
red-rust                  1.00       0.96       0.98       95.65     
red-spider-infested       1.00       1.00       1.00       100.00    
red_spot                  0.93       0.98       0.95       97.66     
white-spot                0.91       1.00       0.95       100.00    


In [None]:
model.load_state_dict(torch.load('/content/drive/MyDrive/DeiT.pth'))
model.to(device)

class_names = val_dataset.classes

evaluate_model(model, val_loader, device, class_names=class_names)

Classification Report:

                     precision    recall  f1-score   support

         algal_spot     0.9878    0.9586    0.9730       169
       brown_blight     0.9725    0.7970    0.8760       133
        gray_blight     0.8503    0.9755    0.9086       163
            healthy     0.9851    0.8800    0.9296       150
         helopeltis     0.9080    0.9867    0.9457       150
           red-rust     1.0000    0.9565    0.9778        23
red-spider-infested     1.0000    1.0000    1.0000        21
           red_spot     0.9330    0.9766    0.9543       171
         white-spot     0.9091    1.0000    0.9524        10

           accuracy                         0.9364       990
          macro avg     0.9495    0.9479    0.9464       990
       weighted avg     0.9409    0.9364    0.9359       990



{'algal_spot': {'precision': 0.9878048780487805,
  'recall': 0.9585798816568047,
  'f1-score': 0.972972972972973,
  'support': 169.0},
 'brown_blight': {'precision': 0.9724770642201835,
  'recall': 0.7969924812030075,
  'f1-score': 0.8760330578512396,
  'support': 133.0},
 'gray_blight': {'precision': 0.8502673796791443,
  'recall': 0.9754601226993865,
  'f1-score': 0.9085714285714286,
  'support': 163.0},
 'healthy': {'precision': 0.9850746268656716,
  'recall': 0.88,
  'f1-score': 0.9295774647887324,
  'support': 150.0},
 'helopeltis': {'precision': 0.9079754601226994,
  'recall': 0.9866666666666667,
  'f1-score': 0.9456869009584664,
  'support': 150.0},
 'red-rust': {'precision': 1.0,
  'recall': 0.9565217391304348,
  'f1-score': 0.9777777777777777,
  'support': 23.0},
 'red-spider-infested': {'precision': 1.0,
  'recall': 1.0,
  'f1-score': 1.0,
  'support': 21.0},
 'red_spot': {'precision': 0.9329608938547486,
  'recall': 0.9766081871345029,
  'f1-score': 0.9542857142857143,
  'su