In [1]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split

transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                         std=[0.229, 0.224, 0.225]),
])

dataset = datasets.ImageFolder(root=r'C:\Users\Admin\OneDrive\Documents\colored_images', transform=transform)

train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)


In [2]:
from transformers import ViTForImageClassification

model = ViTForImageClassification.from_pretrained(
    'google/vit-base-patch16-224-in21k',
    num_labels=5
)


config.json:   0%|          | 0.00/502 [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [3]:
import torch
import torch.nn as nn
from torch.optim import AdamW

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = AdamW(model.parameters(), lr=3e-5)

In [4]:
def train_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0
    correct = 0
    total = 0
    
    for images, labels in dataloader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images).logits
        
        loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * images.size(0)
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
    
    epoch_loss = running_loss / total
    accuracy = correct / total
    return epoch_loss, accuracy


In [5]:
def eval_epoch(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images).logits
            
            loss = criterion(outputs, labels)
            running_loss += loss.item() * images.size(0)
            
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    
    epoch_loss = running_loss / total
    accuracy = correct / total
    return epoch_loss, accuracy


In [28]:
num_epochs = 5
best_val_acc = 0
for epoch in range(num_epochs):
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
    val_loss, val_acc = eval_epoch(model, val_loader, criterion, device)

    print(f"Epoch {epoch+1}/{num_epochs}:")
    print(f"  Train loss: {train_loss:.4f}, Train accuracy: {train_acc:.4f}")
    print(f"  Val loss: {val_loss:.4f}, Val accuracy: {val_acc:.4f}")

    model.eval()
    # with torch.no_grad():
    #     for images, labels in val_loader:
    #         images, labels = images.to(device), labels.to(device)
    #         outputs = model(images).logits
    #         _, preds = torch.max(outputs, 1)
    #         print("Validation labels sample:", labels[:5])
    #         print("Validation preds sample:", preds[:5])
    #         break

    if val_acc >= best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), 'best_vit_dr_model.pth')
        print("  Saved best model")




Epoch 1/5:
  Train loss: 0.2643, Train accuracy: 1.0000
  Val loss: 1.0850, Val accuracy: 0.0000
  Saved best model
Epoch 2/5:
  Train loss: 0.2514, Train accuracy: 1.0000
  Val loss: 1.0928, Val accuracy: 0.0000
  Saved best model
Epoch 3/5:
  Train loss: 0.2402, Train accuracy: 1.0000
  Val loss: 1.0989, Val accuracy: 0.0000
  Saved best model
Epoch 4/5:
  Train loss: 0.2297, Train accuracy: 1.0000
  Val loss: 1.1032, Val accuracy: 0.0000
  Saved best model
Epoch 5/5:
  Train loss: 0.2198, Train accuracy: 1.0000
  Val loss: 1.1064, Val accuracy: 0.0000
  Saved best model


In [33]:
from PIL import Image
import torch

model.load_state_dict(torch.load('best_vit_dr_model.pth'))
model.to(device)
model.eval()

def preprocess_image(image_path):
    image = Image.open(image_path).convert('RGB')
    image = transform(image)  
    return image.unsqueeze(0)  

image_tensor = preprocess_image(r'C:\Users\Admin\OneDrive\Documents\colored_images\p2.jpg').to(device)

with torch.no_grad():
    outputs = model(image_tensor).logits
    probs = torch.softmax(outputs, dim=1)
    predicted_class = torch.argmax(probs).item()

print(f"Predicted diabetic retinopathy class: {predicted_class}, Confidence: {probs[0, predicted_class]:.4f}")


Predicted diabetic retinopathy class: 1, Confidence: 0.3398
