In [2]:
from PIL import Image
import requests
from torchvision import transforms, models
from torch.utils.data import DataLoader
import torch
from transformers import ViTForImageClassification, ViTImageProcessor, AutoImageProcessor

In [3]:
import sys
import os
sys.path.append(os.path.abspath('../'))

In [4]:
def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x[0] for x in batch]).reshape(len(batch), 3, 224, 224),
        'labels': torch.tensor([x[1] for x in batch])
    }   

In [5]:
from utils.load_dataset import PlantVillageDataset

model_name_or_path = 'google/vit-base-patch16-224-in21k'
path_dataset = '../Plant_leave_diseases_dataset_without_augmentation'

processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
model = ViTImageProcessor.from_pretrained(model_name_or_path)
transform = transforms.Compose([
     lambda x: processor(images=x, return_tensors="pt")['pixel_values']
     ])

train_dataset = PlantVillageDataset(root_dir=path_dataset, train=True, transform=transform)
val_dataset = PlantVillageDataset(root_dir=path_dataset, train=False, transform=transform)


Fast image processor class <class 'transformers.models.vit.image_processing_vit_fast.ViTImageProcessorFast'> is available for this model. Using slow image processor class. To use the fast image processor class set `use_fast=True`.


In [6]:
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)

In [7]:
from transformers import ViTForImageClassification
num_classes = len(train_dataset.label_to_idx)
# Initialize model  
model = ViTForImageClassification.from_pretrained(
    'google/vit-base-patch16-224-in21k',
    num_labels=num_classes,
    id2label={str(v): k for k, v in train_dataset.label_to_idx.items()},
    label2id=train_dataset.label_to_idx
)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [8]:
from tqdm import tqdm
# Device
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)
criterion = torch.nn.CrossEntropyLoss()
epochs = 6
max_steps = 100

for epoch in range(epochs):
    model.train()
    train_loss = 0.0
    
    # Progress bar for steps instead of full dataset
    train_loader_tqdm = tqdm(range(max_steps), desc=f"Epoch {epoch+1}/{epochs}", leave=True)
    
    for i, batch in zip(train_loader_tqdm, train_loader):
        # Move batch to device
        batch = {k: v.to(device) for k, v in batch.items()}
        
        # Forward pass
        outputs = model(**batch)
        loss = outputs.loss
        train_loss += loss.item()
        
        # Backward pass and optimization
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        # Update tqdm progress bar
        train_loader_tqdm.set_postfix(loss=loss.item())

    avg_train_loss = train_loss / max_steps  # Normalize by max_steps
    print(f"Epoch {epoch+1} - Training Loss: {avg_train_loss:.4f}")
    
    model.eval()
    val_loss = 0.0
    
    # Progress bar for validation steps
    val_loader_tqdm = tqdm(range(max_steps), desc=f"Validating {epoch+1}/{epochs}", leave=True)

    with torch.no_grad():
        for i, batch in zip(val_loader_tqdm, val_loader):
            # Move batch to device
            batch = {k: v.to(device) for k, v in batch.items()}
            
            # Forward pass
            outputs = model(**batch)
            loss = criterion(outputs.logits, batch["labels"])
            val_loss += loss.item()
            
            # Update tqdm progress bar
            val_loader_tqdm.set_postfix(loss=loss.item())

    avg_val_loss = val_loss / max_steps  # Normalize by max_steps
    print(f"Epoch {epoch+1} - Validation Loss: {avg_val_loss:.4f}")


Epoch 1/6: 100%|██████████| 100/100 [02:21<00:00,  1.41s/it, loss=2.07]


Epoch 1 - Training Loss: 2.6456


Validating 1/6: 100%|██████████| 100/100 [00:53<00:00,  1.87it/s, loss=3.07]


Epoch 1 - Validation Loss: 2.5237


Epoch 2/6: 100%|██████████| 100/100 [02:20<00:00,  1.41s/it, loss=1.13]


Epoch 2 - Training Loss: 1.4610


Validating 2/6: 100%|██████████| 100/100 [00:53<00:00,  1.86it/s, loss=1.82]


Epoch 2 - Validation Loss: 1.6368


Epoch 3/6: 100%|██████████| 100/100 [02:21<00:00,  1.42s/it, loss=0.619]


Epoch 3 - Training Loss: 0.8899


Validating 3/6: 100%|██████████| 100/100 [00:53<00:00,  1.88it/s, loss=0.853]


Epoch 3 - Validation Loss: 1.1731


Epoch 4/6: 100%|██████████| 100/100 [02:19<00:00,  1.40s/it, loss=0.459]


Epoch 4 - Training Loss: 0.5982


Validating 4/6: 100%|██████████| 100/100 [00:53<00:00,  1.88it/s, loss=0.539]


Epoch 4 - Validation Loss: 0.7667


Epoch 5/6: 100%|██████████| 100/100 [02:20<00:00,  1.41s/it, loss=0.426]


Epoch 5 - Training Loss: 0.4211


Validating 5/6: 100%|██████████| 100/100 [00:53<00:00,  1.88it/s, loss=0.401]


Epoch 5 - Validation Loss: 0.5815


Epoch 6/6: 100%|██████████| 100/100 [02:21<00:00,  1.42s/it, loss=0.282]


Epoch 6 - Training Loss: 0.3215


Validating 6/6: 100%|██████████| 100/100 [00:53<00:00,  1.88it/s, loss=0.274]

Epoch 6 - Validation Loss: 0.3807





In [9]:
model.eval()
val_loss = 0.0
with torch.no_grad():
    for batch in val_loader:
        # put batch on device
        batch = {k:v.to(device) for k,v in batch.items()}
        
        # forward pass
        outputs = model(**batch)
        loss = criterion(outputs.logits, batch["labels"])
        
        val_loss += torch.sum(loss).detach().cpu().numpy()
        break  
print("Validation loss after epoch {epoch}:", val_loss/len(val_loader))

Validation loss after epoch {epoch}: 0.0002866196996960692


In [11]:
# Accuracy
from sklearn.metrics import accuracy_score
y_true = []
y_pred = []
with torch.no_grad():
    for batch in val_loader:
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        y_true.extend(batch["labels"].detach().cpu().numpy())
        y_pred.extend(torch.argmax(outputs.logits, axis=1).detach().cpu().numpy())

accuracy = accuracy_score(y_true, y_pred)
print(f"Accuracy: {accuracy:.4f}")

Accuracy: 0.9872


In [12]:
torch.save(model.state_dict(), './XviT.pth')