In [37]:
import warnings
warnings.filterwarnings('ignore')

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from glob import glob
import os

from transformers import ViTFeatureExtractor
from transformers import Trainer
from transformers import ViTForImageClassification
from transformers import TrainingArguments
from datasets import load_metric

import torch
from torch.utils.data import Dataset

seed = 42

In [2]:
def seed_everything(seed=42):
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    
seed_everything()

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Currently using "{device.upper()}" device.')

In [4]:
train_path = r'../input/100-bird-species/train'
valid_path = r'../input/100-bird-species/valid'
test_path = r'../input/100-bird-species/test'

In [156]:
train_paths = sorted([str(p) for p in glob(train_path + '/*/*.jpg')])
train_labels = [path.split('/')[-2].replace('  ', ' ') for path in train_paths]

valid_paths = sorted([str(p) for p in glob(valid_path + '/*/*.jpg')])
valid_labels = [path.split('/')[-2].replace('  ', ' ') for path in valid_paths]

test_paths = sorted([str(p) for p in glob(test_path + '/*/*.jpg')])
test_labels = [path.split('/')[-2].replace('  ', ' ') for path in test_paths]

id_2_label = {i: str(c) for i, c in enumerate(sorted(set(train_labels)))}
label_2_id = {str(c): i for i, c in enumerate(sorted(set(train_labels)))}

In [6]:
model_name_or_path = 'google/vit-base-patch16-224-in21k'
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name_or_path)

In [157]:
class BirdDataset(Dataset):
    def __init__(self, paths, labels):
        self.paths = paths
        self.labels = labels
        
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, ix):
        image = Image.open(self.paths[ix])
        label = self.labels[ix]
        label = label_2_id[label]
        input_ = {'image': image, 'label': label}
        
        return self.transforms(input_)
        
    def transforms(self, inputs):
        outputs = feature_extractor(inputs['image'], return_tensors='pt')
        outputs['labels'] = inputs['label']
        return outputs
    
def collate_fn(batch):
    return {
            'pixel_values': torch.cat([x['pixel_values'] for x in batch]),
            'labels': torch.tensor([x['labels'] for x in batch])
                }

In [158]:
train_ds = BirdDataset(train_paths, train_labels)
valid_ds = BirdDataset(valid_paths, valid_labels)

![ViT Architecture](https://miro.medium.com/max/700/0*YRDqyaLnCJscrYWV)

In [159]:
model = ViTForImageClassification.from_pretrained(
                                                  model_name_or_path,
                                                  num_labels=len(set(train_labels)),
                                                  id2label=id_2_label,
                                                  label2id=label_2_id,
                                                 )

In [123]:
metric = load_metric("accuracy")
def compute_metrics(p):
    return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)

In [160]:
training_args = TrainingArguments(
  output_dir="./vit-birds",
  per_device_train_batch_size=16,
  evaluation_strategy="steps",
  num_train_epochs=4,
  fp16=True,
  save_steps=3000,
  eval_steps=3000,
  logging_steps=10,
  learning_rate=2e-4,
  save_total_limit=2,
  remove_unused_columns=False,
  push_to_hub=False,
  report_to='tensorboard',
  load_best_model_at_end=True,
)

In [161]:
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=train_ds,
    eval_dataset=valid_ds,
    tokenizer=feature_extractor,
)

In [None]:
train_results = trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

In [None]:
metrics = trainer.evaluate(valid_ds)
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

### [0.992 Accuracy on train results](https://www.kaggle.com/code/pankratozzi/pytorch-vit-hf-birds-classification)