In [13]:
import torch
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, random_split
from transformers import pipeline, ViTForImageClassification, ViTImageProcessor, ViTFeatureExtractor, Trainer, TrainingArguments


In [4]:
class LimitedImageFolder(ImageFolder):
    def __init__(self, root, transform=None, target_transform=None, limit_per_class=150):
        super(LimitedImageFolder, self).__init__(root, transform, target_transform)
        self.limit_per_class = limit_per_class
        self._limit_dataset()

    def _limit_dataset(self):
        class_counts = dict()
        new_samples = []
        for sample, target in self.samples:
            if target not in class_counts:
                class_counts[target] = 0
            if class_counts[target] < self.limit_per_class:
                new_samples.append((sample, target))
                class_counts[target] += 1
        self.samples = new_samples


In [5]:
data_root = "./data"

# Set the batch size and number of training epochs
batch_size = 8
num_epochs = 5

# Load the image data
image_transforms = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize images to the desired input size of the model
    transforms.ToTensor(),  # Convert images to tensors
])

dataset = LimitedImageFolder(data_root, transform=image_transforms)

In [7]:
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

In [9]:
label2id = {}
id2label = {}

for i, class_name in enumerate(dataset.classes):
    label2id[class_name] = str(i)
    id2label[str(i)] = class_name

In [10]:
class ImageClassificationCollator:
    def __init__(self, feature_extractor):
        self.feature_extractor = feature_extractor
 
    def __call__(self, batch):
        encodings = self.feature_extractor([x[0] for x in batch], return_tensors='pt')
        encodings['labels'] = torch.tensor([x[1] for x in batch], dtype=torch.long)
        return encodings 

In [12]:

# Load the pre-trained ViT model and its feature extractor
model_name = "google/vit-base-patch16-224"
model = ViTForImageClassification.from_pretrained(
    model_name,
    num_labels=len(label2id),
    label2id=label2id,
    id2label=id2label,
    ignore_mismatched_sizes=True
    )

image_processor = ViTImageProcessor.from_pretrained(model_name)

# Define the training arguments
training_args = TrainingArguments(
    output_dir="./results",
    num_train_epochs=num_epochs,
    per_device_train_batch_size=batch_size,
    save_steps=500,
    save_total_limit=2,
)

# Define the Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,

)

# Fine-tune the model
trainer.train()

# Save the fine-tuned model
trainer.save_model("./fine_tuned_model_2")

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([2, 768]) in the model instantiated
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([2]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


  0%|          | 0/150 [00:00<?, ?it/s]

{'train_runtime': 1498.9521, 'train_samples_per_second': 0.801, 'train_steps_per_second': 0.1, 'train_loss': 0.7241793823242187, 'epoch': 5.0}


In [15]:
# Set the path to the saved fine-tuned model
model_path = "./fine_tuned_model"

# Load the fine-tuned model and feature extractor
model = ViTForImageClassification.from_pretrained(model_path)
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
image_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")

# Define the image classification pipeline
image_classifier = pipeline("image-classification", feature_extractor=feature_extractor, image_processor=image_processor, model=model)

def perform_inference(image_path: str):
    results = image_classifier(image_path)
    # Get the predicted label and its corresponding class name
    predicted_label = results[0]['label']
    predicted_score = results[0]['score']
    return (predicted_label, predicted_score)

# Set the path to the new image for inference
yellow_example = "./data/yellow/zJ5A5CdlnTo.jpg"
other_example = "./data/other/14sLvv_ykmE.jpg"
# Perform inference
(label, score) = perform_inference(yellow_example)
# Print the predicted label and score
print("Predicted Label:", label)
print("Predicted Score:", score)


Predicted Label: other
Predicted Score: 0.6961593627929688


In [17]:
perform_inference(yellow_example)

('other', 0.6961593627929688)