# Vision Transformer (ViT) Image Classification

This notebook fine-tunes a pretrained Vision Transformer (ViT) model from [Hugging Face](https://huggingface.co/models?search=vit) on the local image dataset contained in the repository. It also demonstrates how to run inference with the fine-tuned model.

In [None]:
!pip install -q transformers datasets evaluate accelerate torchvision pillow
!curl -L -o ./neu-yolo.zip https://www.kaggle.com/api/v1/datasets/download/zymzym/neu-yolo
!unzip -q neu-yolo.zip
!rm neu-yolo.zip

In [None]:
from pathlib import Path

import numpy as np
from datasets import Dataset, DatasetDict, Features, ClassLabel, Image as HFImage
import evaluate
from PIL import Image

import torch
from torchvision import transforms
from transformers import (
    AutoImageProcessor,
    ViTForImageClassification,
    TrainingArguments,
    Trainer,
)


In [None]:
# Paths to the training and validation folders
project_dir = Path.cwd()
train_dir = project_dir / 'train'
valid_dir = project_dir / 'valid'

assert train_dir.exists(), f'Training directory not found: {train_dir}'
assert valid_dir.exists(), f'Validation directory not found: {valid_dir}'

train_split_dir = train_dir / 'train'
test_split_dir = valid_dir / 'valid'

split_dirs = {'train': train_split_dir, 'test': test_split_dir}
for split_name, split_dir in split_dirs.items():
    images_dir = split_dir / 'images'
    labels_dir = split_dir / 'labels'
    assert images_dir.exists(), f"Missing {split_name} images directory: {images_dir}"
    assert labels_dir.exists(), f"Missing {split_name} labels directory: {labels_dir}"


In [None]:
# Load datasets from the YOLO-style folders
class_id_to_name = {
    0: 'crazing',
    1: 'inclusion',
    2: 'patches',
    3: 'pitted_surface',
    4: 'rolled_in_scale',
    5: 'scratches',
}
label_names = [class_id_to_name[idx] for idx in sorted(class_id_to_name)]
features = Features(
    {
        'image': HFImage(),
        'label': ClassLabel(names=label_names),
    }
)


def load_split(split_root):
    images_dir = split_root / 'images'
    labels_dir = split_root / 'labels'

    image_paths = []
    label_ids = []

    for label_path in sorted(labels_dir.glob('*.txt')):
        with label_path.open() as f:
            first_line = f.readline().strip()
        if not first_line:
            raise ValueError(f"Empty label file: {label_path}")

        class_id = int(first_line.split()[0])
        image_candidates = sorted(images_dir.glob(f"{label_path.stem}.*"))
        if not image_candidates:
            raise FileNotFoundError(
                f"No image found for label file {label_path}"
            )

        image_paths.append(str(image_candidates[0]))
        label_ids.append(class_id)

    return Dataset.from_dict(
        {'image': image_paths, 'label': label_ids}, features=features
    )


train_dataset_raw = load_split(train_split_dir)
test_dataset = load_split(test_split_dir)
train_split = train_dataset_raw.train_test_split(
    test_size=0.2, seed=42, shuffle=True
)

dataset = DatasetDict(
    {
        'train': train_split['train'],
        'validation': train_split['test'],
        'test': test_dataset,
    }
)

num_labels = len(label_names)
label2id = {label: idx for idx, label in enumerate(label_names)}
id2label = {idx: label for label, idx in label2id.items()}

print('Labels:', label_names)
print('Train examples:', len(dataset['train']))
print('Validation examples:', len(dataset['validation']))
print('Test examples:', len(dataset['test']))


In [None]:
# Load pretrained image processor and define transforms
checkpoint = 'google/vit-base-patch16-224-in21k'
image_processor = AutoImageProcessor.from_pretrained(checkpoint)

size = image_processor.size['height']
normalize = transforms.Normalize(mean=image_processor.image_mean, std=image_processor.image_std)

train_transforms = transforms.Compose([
    transforms.Resize((size, size)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    normalize,
])

valid_transforms = transforms.Compose([
    transforms.Resize((size, size)),
    transforms.ToTensor(),
    normalize,
])


def preprocess_train(examples):
    examples['pixel_values'] = [train_transforms(image.convert('RGB')) for image in examples['image']]
    return examples


def preprocess_eval(examples):
    examples['pixel_values'] = [valid_transforms(image.convert('RGB')) for image in examples['image']]
    return examples


train_dataset = dataset['train'].map(
    preprocess_train,
    batched=True,
    remove_columns=['image'],
    desc='Applying train transforms',
    load_from_cache_file=False,
)

eval_dataset = dataset['validation'].map(
    preprocess_eval,
    batched=True,
    remove_columns=['image'],
    desc='Applying validation transforms',
    load_from_cache_file=False,
)

test_dataset = dataset['test'].map(
    preprocess_eval,
    batched=True,
    remove_columns=['image'],
    desc='Applying test transforms',
    load_from_cache_file=False,
)


In [None]:
metric = evaluate.load('accuracy')

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

In [None]:
model = ViTForImageClassification.from_pretrained(
    checkpoint,
    num_labels=num_labels,
    label2id=label2id,
    id2label=id2label,
    ignore_mismatched_sizes=True,
)

In [None]:
training_args = TrainingArguments(
    output_dir='vit-output',
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=5,
    learning_rate=5e-5,
    weight_decay=0.01,
    eval_strategy='epoch',
    save_strategy='epoch',
    load_best_model_at_end=True,
    metric_for_best_model='accuracy',
    logging_steps=10,
    remove_unused_columns=False,
    push_to_hub=False,
    fp16=torch.cuda.is_available(),
    report_to='none',
)

In [None]:
def collate_fn(batch):
    pixel_values = torch.stack([
        example['pixel_values'] if isinstance(example['pixel_values'], torch.Tensor)
        else torch.tensor(example['pixel_values'])
        for example in batch
    ])
    labels = torch.tensor([example['label'] for example in batch], dtype=torch.long)
    return {'pixel_values': pixel_values, 'labels': labels}

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=image_processor,
    compute_metrics=compute_metrics,
    data_collator=collate_fn,
)

train_result = trainer.train()
trainer.save_model()
trainer.log_metrics('train', train_result.metrics)
trainer.save_metrics('train', train_result.metrics)
trainer.save_state()

metrics = trainer.evaluate()
trainer.log_metrics('eval', metrics)
trainer.save_metrics('eval', metrics)

test_metrics = trainer.evaluate(eval_dataset=test_dataset)
trainer.log_metrics('test', test_metrics)
trainer.save_metrics('test', test_metrics)


In [None]:
# Inference on a single image
image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.gif'}
image_path = next(
    path for path in train_dir.rglob('*')
    if path.is_file() and path.suffix.lower() in image_extensions
)
image = Image.open(image_path).convert('RGB')
inputs = image_processor(images=image, return_tensors='pt')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = trainer.model.to(device=device, dtype=torch.float32)
inputs = {k: v.to(device) for k, v in inputs.items()}
model.eval()
with torch.no_grad():
    outputs = model(**inputs)
    probs = outputs.logits.softmax(dim=-1).squeeze()

predicted_idx = int(probs.argmax())
predicted_label = id2label[predicted_idx]
confidence = float(probs[predicted_idx])

print(f'Image: {image_path.name}')
print(f'Predicted label: {predicted_label} (confidence {confidence:.2%})')


> **Tip:** Training a ViT model can be computationally expensive. If you are running on limited hardware, consider reducing the number of epochs or using a smaller checkpoint such as `google/vit-base-patch16-224`.