In [5]:
from huggingface_hub import login
login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.svâ€¦

In [6]:
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from datasets import load_dataset
from transformers import ViTForImageClassification, TrainingArguments, Trainer
from transformers import ViTImageProcessor
import evaluate
import numpy as np
import os


# Configs
DATASET_DIR = "D:/project files/steve pest and weed detection/dataset"
MODEL_NAME = "google/vit-base-patch16-224-in21k"
OUTPUT_DIR = "./vit-weed-pest-model"

# Load dataset from folder using Hugging Face Datasets
dataset = load_dataset("imagefolder", data_dir=DATASET_DIR)

# Split into 80% train / 20% test
split_dataset = dataset["train"].train_test_split(test_size=0.2, seed=42)
train_ds = split_dataset["train"]
test_ds = split_dataset["test"]

# Extract label info
labels = train_ds.features["label"].names
id2label = {i: label for i, label in enumerate(labels)}
label2id = {label: i for i, label in id2label.items()}
num_labels = len(labels)

# Load Hugging Face ViT image processor
image_processor = ViTImageProcessor.from_pretrained(MODEL_NAME)

# Define transforms
train_transforms = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.CenterCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.3, contrast=0.3),
    transforms.GaussianBlur(kernel_size=3),
    transforms.ToTensor(),
    transforms.Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
])

test_transforms = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
])

# Custom Dataset wrapper to apply transforms
class HFDataset(Dataset):
    def __init__(self, hf_dataset, transform=None):
        self.dataset = hf_dataset
        self.transform = transform

    def __getitem__(self, idx):
        item = self.dataset[idx]
        image = item["image"]  # PIL Image
        label = item["label"]
        if self.transform:
            image = self.transform(image)
        return {"pixel_values": image, "label": label}

    def __len__(self):
        return len(self.dataset)

# Wrap datasets with transforms
train_dataset = HFDataset(train_ds, transform=train_transforms)
test_dataset = HFDataset(test_ds, transform=test_transforms)

# Load ViT model
model = ViTForImageClassification.from_pretrained(
    MODEL_NAME,
    num_labels=num_labels,
    id2label=id2label,
    label2id=label2id
)

# Load accuracy metric
accuracy = evaluate.load("accuracy")

# Metric computation
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=-1)
    return accuracy.compute(predictions=preds, references=labels)

# Data collator
def collate_fn(batch):
    pixel_values = torch.stack([x["pixel_values"] for x in batch])
    labels = torch.tensor([x["label"] for x in batch])
    return {"pixel_values": pixel_values, "labels": labels}

# Training arguments
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    eval_strategy="epoch",
    save_strategy="epoch",
    num_train_epochs=10,
    learning_rate=1e-4,
    weight_decay=0.01,
    warmup_ratio=0.1,
    fp16=True,
    logging_dir="./logs",
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    save_total_limit=2,
    
    
)

# Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    compute_metrics=compute_metrics,
    data_collator=collate_fn,
    tokenizer=image_processor  # only used for logging; optional
)

# Train
trainer.train()

# Final evaluation
metrics = trainer.evaluate()
print(f"ðŸ“Š Final Accuracy: {metrics['eval_accuracy']*100:.2f}%")

# Save the final model
trainer.save_model(OUTPUT_DIR)
print(f"âœ… Model saved to: {OUTPUT_DIR}")




Resolving data files:   0%|          | 0/20112 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/5014 [00:00<?, ?it/s]

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  trainer = Trainer(


Epoch,Training Loss,Validation Loss,Accuracy
1,0.7158,0.546511,0.840418
2,0.3961,0.373409,0.874969
3,0.2975,0.333395,0.889883
4,0.2283,0.346098,0.886155
5,0.1546,0.2986,0.904052
6,0.122,0.303458,0.915237
7,0.0812,0.325454,0.913249
8,0.0595,0.345408,0.917226
9,0.0401,0.353958,0.918966
10,0.0285,0.360124,0.918469


ðŸ“Š Final Accuracy: 91.90%
âœ… Model saved to: ./vit-weed-pest-model


In [13]:
model.push_to_hub("sabari15/ViT-base16-fine-tuned-crop-disease-model", token="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")

README.md:   0%|          | 0.00/5.17k [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


model.safetensors:   0%|          | 0.00/343M [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/sabari15/ViT-base16-fine-tuned-crop-disease-model/commit/d7c1bc128619b2eb73aa6bbf342ab580d0d8b185', commit_message='Upload ViTForImageClassification', commit_description='', oid='d7c1bc128619b2eb73aa6bbf342ab580d0d8b185', pr_url=None, repo_url=RepoUrl('https://huggingface.co/sabari15/ViT-base16-fine-tuned-crop-disease-model', endpoint='https://huggingface.co', repo_type='model', repo_id='sabari15/ViT-base16-fine-tuned-crop-disease-model'), pr_revision=None, pr_num=None)

In [14]:
dataset_path = "D:/project files/steve pest and weed detection/dataset"
dataset = load_dataset("imagefolder", data_dir=dataset_path)

Resolving data files:   0%|          | 0/20112 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/5014 [00:00<?, ?it/s]

In [15]:
import numpy as np

# Assuming 'label' is the column with class indices
labels = dataset['train']['label']
unique_labels = np.unique(labels)
print(unique_labels)

[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21]


In [16]:
label_names = dataset['train'].features['label'].names
print(label_names)


['Cashew anthracnose', 'Cashew gumosis', 'Cashew healthy', 'Cashew leaf miner', 'Cashew red rust', 'Cassava bacterial blight', 'Cassava brown spot', 'Cassava green mite', 'Cassava healthy', 'Cassava mosaic', 'Maize fall armyworm', 'Maize grasshoper', 'Maize healthy', 'Maize leaf beetle', 'Maize leaf blight', 'Maize leaf spot', 'Maize streak virus', 'Tomato healthy', 'Tomato leaf blight', 'Tomato leaf curl', 'Tomato septoria leaf spot', 'Tomato verticulium wilt']
