In [35]:
from PIL import Image
import requests
from torchvision import transforms, models
from torch.utils.data import DataLoader
import torch
from transformers import ViTForImageClassification, ViTImageProcessor

In [36]:
import sys
import os
sys.path.append(os.path.abspath('../'))


In [131]:
def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x[0] for x in batch]).reshape(len(batch), 3, 224, 224),
        'labels': torch.tensor([x[1] for x in batch])
    }   

In [125]:
from utils.load_dataset import PlantVillageDataset
model_name_or_path = 'google/vit-base-patch16-224-in21k'
path_dataset = '../Plant_leave_diseases_dataset_without_augmentation'
model = ViTImageProcessor.from_pretrained(model_name_or_path)
transform = transforms.Compose([
     lambda x: processor(images=x, return_tensors="pt")['pixel_values']
     ])

train_dataset = PlantVillageDataset(root_dir=path_dataset, train=True, transform=transform)
val_dataset = PlantVillageDataset(root_dir=path_dataset, train=False, transform=transform)


In [133]:
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)

In [137]:
from transformers import ViTForImageClassification
num_classes = len(train_dataset.label_to_idx)
# Initialize model  
model = ViTForImageClassification.from_pretrained(
    'google/vit-base-patch16-224-in21k',
    num_labels=num_classes,
    id2label={str(v): k for k, v in train_dataset.label_to_idx.items()},
    label2id=train_dataset.label_to_idx
)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [139]:
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)

# put model on GPU, if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
epochs = 6
criterion = torch.nn.CrossEntropyLoss()

for epoch in range(epochs):
    model.train()
    train_loss = 0.0
    for batch in train_loader:
        # put batch on device
        batch = {k:v.to(device) for k,v in batch.items()}
        # forward pass
        outputs = model(**batch)
        loss = outputs.loss
        
        train_loss += loss.item()
        
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

    print("Loss after epoch {epoch}:", train_loss/len(train_loader))
    
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for batch in val_loader:
            # put batch on device
            batch = {k:v.to(device) for k,v in batch.items()}
            
            # forward pass
            outputs = model(**batch)
            loss = outputs.logits
            
            val_loss += torch.sum(loss).detach().cpu().numpy()
            break  
    print("Validation loss after epoch {epoch}:", val_loss/len(val_loader))

16
Loss after epoch {epoch}: 0.0021036590239386717
0


RuntimeError: a Tensor with 1248 elements cannot be converted to Scalar

In [145]:
model.eval()
val_loss = 0.0
with torch.no_grad():
    for batch in val_loader:
        # put batch on device
        batch = {k:v.to(device) for k,v in batch.items()}
        
        # forward pass
        outputs = model(**batch)
        loss = outputs.logits
        
        val_loss += torch.sum(loss).detach().cpu().numpy()
        break  
print("Validation loss after epoch {epoch}:", val_loss/len(val_loader))

0
Validation loss after epoch {epoch}: -0.0048491526


In [146]:
torch.save(model.state_dict(), './XviT.pth')