In [None]:
import os
import torch
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, random_split
from transformers import ViTForImageClassification, ViTFeatureExtractor
from tqdm import tqdm

# Define parameters
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_classes = 5  # Number of classes
batch_size = 32
num_epochs = 20
lr = 1e-4
save_path = "vit_model.pth"  # Path to save the trained model

# Load ViT model and feature extractor
model_name = 'google/vit-base-patch16-224-in21k'
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
model = ViTForImageClassification.from_pretrained(model_name, num_labels=num_classes).to(device)

# Load and preprocess data
data_dir = '/content/drive/MyDrive/pupil_images/gaussian_filtered_images/gaussian_filtered_images'

# Define transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Create ImageFolder dataset
dataset = ImageFolder(root=data_dir, transform=transform)

# Split data into train and test sets
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Define loss function and optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

# Training loop
for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0
    correct = 0
    total = 0
    for images, labels in tqdm(train_loader, desc=f'Epoch {epoch + 1}/{num_epochs}', unit='batch'):
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(pixel_values=images, labels=labels)
        loss = outputs.loss
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = torch.max(outputs.logits, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    train_accuracy = 100 * correct / total
    print(f"Train Loss: {train_loss / len(train_loader):.4f} | Train Accuracy: {train_accuracy:.2f}%")

# Testing loop
model.eval()
test_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
    for images, labels in tqdm(test_loader, desc='Testing', unit='batch'):
        images, labels = images.to(device), labels.to(device)

        outputs = model(pixel_values=images, labels=labels)
        loss = outputs.loss

        test_loss += loss.item()
        _, predicted = torch.max(outputs.logits, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

test_accuracy = 100 * correct / total
print(f"Test Loss: {test_loss / len(test_loader):.4f} | Test Accuracy: {test_accuracy:.2f}%")

# Save the trained model
torch.save(model.state_dict(), save_path)
print(f"Model saved to {save_path}")


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


preprocessor_config.json:   0%|          | 0.00/160 [00:00<?, ?B/s]



config.json:   0%|          | 0.00/502 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Epoch 1/20: 100%|██████████| 92/92 [16:37<00:00, 10.84s/batch]


Train Loss: 0.8087 | Train Accuracy: 71.72%


Epoch 2/20: 100%|██████████| 92/92 [01:52<00:00,  1.22s/batch]


Train Loss: 0.6009 | Train Accuracy: 78.42%


Epoch 3/20: 100%|██████████| 92/92 [01:51<00:00,  1.21s/batch]


Train Loss: 0.4817 | Train Accuracy: 82.55%


Epoch 4/20: 100%|██████████| 92/92 [01:51<00:00,  1.21s/batch]


Train Loss: 0.3937 | Train Accuracy: 86.37%


Epoch 5/20: 100%|██████████| 92/92 [01:51<00:00,  1.21s/batch]


Train Loss: 0.2626 | Train Accuracy: 91.60%


Epoch 6/20: 100%|██████████| 92/92 [01:51<00:00,  1.21s/batch]


Train Loss: 0.1962 | Train Accuracy: 94.67%


Epoch 7/20: 100%|██████████| 92/92 [01:51<00:00,  1.21s/batch]


Train Loss: 0.1571 | Train Accuracy: 95.32%


Epoch 8/20: 100%|██████████| 92/92 [01:51<00:00,  1.21s/batch]


Train Loss: 0.1345 | Train Accuracy: 95.83%


Epoch 9/20: 100%|██████████| 92/92 [01:51<00:00,  1.21s/batch]


Train Loss: 0.0989 | Train Accuracy: 97.23%


Epoch 10/20: 100%|██████████| 92/92 [01:51<00:00,  1.21s/batch]


Train Loss: 0.0922 | Train Accuracy: 97.30%


Epoch 11/20: 100%|██████████| 92/92 [01:51<00:00,  1.21s/batch]


Train Loss: 0.0736 | Train Accuracy: 98.09%


Epoch 12/20: 100%|██████████| 92/92 [01:51<00:00,  1.21s/batch]


Train Loss: 0.0618 | Train Accuracy: 98.29%


Epoch 13/20: 100%|██████████| 92/92 [01:51<00:00,  1.21s/batch]


Train Loss: 0.0434 | Train Accuracy: 98.84%


Epoch 14/20: 100%|██████████| 92/92 [01:51<00:00,  1.21s/batch]


Train Loss: 0.0358 | Train Accuracy: 99.01%


Epoch 15/20: 100%|██████████| 92/92 [01:51<00:00,  1.21s/batch]


Train Loss: 0.0319 | Train Accuracy: 98.91%


Epoch 16/20: 100%|██████████| 92/92 [01:51<00:00,  1.21s/batch]


Train Loss: 0.0324 | Train Accuracy: 98.84%


Epoch 17/20: 100%|██████████| 92/92 [01:51<00:00,  1.21s/batch]


Train Loss: 0.0277 | Train Accuracy: 98.87%


Epoch 18/20: 100%|██████████| 92/92 [01:51<00:00,  1.21s/batch]


Train Loss: 0.0343 | Train Accuracy: 98.50%


Epoch 19/20: 100%|██████████| 92/92 [01:51<00:00,  1.21s/batch]


Train Loss: 0.1054 | Train Accuracy: 96.45%


Epoch 20/20: 100%|██████████| 92/92 [01:51<00:00,  1.21s/batch]


Train Loss: 0.0634 | Train Accuracy: 97.54%


Testing: 100%|██████████| 23/23 [03:38<00:00,  9.49s/batch]


Test Loss: 0.8866 | Test Accuracy: 81.31%
Model saved to vit_model.pth
