In [1]:
from transformers import ViTModel
from transformers.modeling_outputs import SequenceClassifierOutput
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, models, transforms
from datasets import load_dataset, load_metric
import os

## Define the Model

In [14]:
class ViTForImageClassification(nn.Module):
    def __init__(self, num_classes=2):
        super().__init__()
        self.vit = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(self.vit.config.hidden_size, num_classes)
        self.num_classes = num_classes

## ChatGPT

In [2]:
# Data augmentation and normalization for training
# Just normalization for validation
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

data_dir = 'data/hymenoptera_data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                          data_transforms[x])
                  for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
                                             shuffle=True, num_workers=4)
              for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes

In [3]:
print(dataloaders)
print(dataset_sizes)
print(class_names)

{'train': <torch.utils.data.dataloader.DataLoader object at 0x2a1071ca0>, 'val': <torch.utils.data.dataloader.DataLoader object at 0x2a1071400>}
{'train': 244, 'val': 153}
['ants', 'bees']


In [5]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from transformers import ViTFeatureExtractor, ViTForImageClassification, AdamW

# Step 1: Prepare your custom dataset
# Assuming you have prepared your custom dataset and DataLoader, adapt as needed


# Step 2: Choose a pre-trained ViT model
model_name = "google/vit-base-patch16-224-in21k"  # Example pre-trained ViT model
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
model = ViTForImageClassification.from_pretrained(model_name)

# Step 3: Modify the pre-trained model for fine-tuning
# Replace classification head
num_classes = 10  # Example: number of classes in your custom dataset
model.classifier = nn.Linear(model.config.hidden_size, num_classes)

# Step 4: Fine-tune the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

optimizer = AdamW(model.parameters(), lr=1e-5)  # Example optimizer
criterion = nn.CrossEntropyLoss()  # Example loss function

num_epochs = 5  # Example number of epochs
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for images, labels in dataloaders["train"]:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = model(images)
        loss = criterion(outputs.logits, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)

    epoch_loss = running_loss / len(dataloaders["train"])
    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {epoch_loss:.4f}")

# Step 5: Evaluate the fine-tuned model
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in dataloaders["val"]:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.logits, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = correct / total
print(f"Validation Accuracy: {accuracy:.4f}")

Some weights of the model checkpoint at google/vit-base-patch16-224-in21k were not used when initializing ViTForImageClassification: ['pooler.dense.bias', 'pooler.dense.weight']
- This IS expected if you are initializing ViTForImageClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTForImageClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch [1/5], Loss: 7.9350
Epoch [2/5], Loss: 4.9266
Epoch [3/5], Loss: 3.4581
Epoch [4/5], Loss: 2.7149
Epoch [5/5], Loss: 2.0886
Validation Accuracy: 0.9346
