# ðŸŒ¿ Plant Disease Detection using Vision Transformers (ViT)

This notebook demonstrates how to train a Vision Transformer (ViT) model to detect plant diseases using the PlantVillage dataset from Kaggle.

In [None]:

!pip install transformers datasets torchvision timm


In [None]:

from transformers import ViTForImageClassification, AutoImageProcessor, TrainingArguments, Trainer
from torchvision import transforms
from torch.utils.data import Dataset
import torch
from PIL import Image
import os, glob


In [None]:

class PlantDataset(Dataset):
    def __init__(self, image_dir, label_map, transform):
        self.image_paths = glob.glob(os.path.join(image_dir, "*", "*.jpg"))
        self.labels = [label_map[os.path.basename(os.path.dirname(p))] for p in self.image_paths]
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert("RGB")
        return {
            "pixel_values": self.transform(image),
            "labels": torch.tensor(self.labels[idx])
        }


In [None]:

feature_extractor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std),
])


In [None]:

# Define label map based on your dataset folders
label_map = {
    "Apple___Black_rot": 0,
    "Apple___Scab": 1,
    "Apple___healthy": 2,
    "Corn___Cercospora_leaf_spot Gray_leaf_spot": 3,
    "Corn___Common_rust": 4,
    "Corn___healthy": 5,
    # Add more as needed
}

train_dir = "/path/to/PlantVillage/train"
test_dir = "/path/to/PlantVillage/test"

train_dataset = PlantDataset(train_dir, label_map, transform)
test_dataset = PlantDataset(test_dir, label_map, transform)


In [None]:

model = ViTForImageClassification.from_pretrained(
    "google/vit-base-patch16-224-in21k",
    num_labels=len(label_map)
)


In [None]:

training_args = TrainingArguments(
    output_dir="./vit-plant-disease",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=5,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=10,
    load_best_model_at_end=True
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    tokenizer=feature_extractor,
)


In [None]:

trainer.train()


In [None]:

metrics = trainer.evaluate()
print(metrics)
