In [None]:
from datasets import load_dataset

ds = load_dataset("huggan/wikiart", split="train")
print(ds)



Resolving data files:   0%|          | 0/72 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/45 [00:00<?, ?it/s]

Dataset({
    features: ['image', 'artist', 'genre', 'style'],
    num_rows: 81444
})


In [2]:
split_ds = ds.train_test_split(test_size=0.2, seed=42)
prepared_ds = {
    "train": split_ds["train"],
    "test": split_ds["test"]
}

In [4]:
from transformers import ViTImageProcessor

model_name_or_path = 'google/vit-base-patch16-224-in21k'
processor = ViTImageProcessor.from_pretrained(model_name_or_path)

In [5]:
image = prepared_ds['train'][400]['image']
print(image)

<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=1921x1382 at 0x7F79A3500EC0>


In [6]:
processor(image, return_tensors='pt')

{'pixel_values': tensor([[[[0.3725, 0.5451, 0.5216,  ..., 0.5451, 0.5529, 0.5373],
          [0.4745, 0.5686, 0.5529,  ..., 0.5529, 0.5686, 0.5529],
          [0.4824, 0.5451, 0.5608,  ..., 0.5451, 0.5608, 0.5216],
          ...,
          [0.3804, 0.3569, 0.3725,  ..., 0.4667, 0.3255, 0.4118],
          [0.3020, 0.3569, 0.3490,  ..., 0.5529, 0.4667, 0.4745],
          [0.2627, 0.3098, 0.3098,  ..., 0.6314, 0.5529, 0.5451]],

         [[0.4353, 0.6471, 0.5843,  ..., 0.6078, 0.6078, 0.5608],
          [0.5451, 0.6941, 0.6314,  ..., 0.6235, 0.6235, 0.5765],
          [0.5843, 0.6863, 0.6627,  ..., 0.6078, 0.6078, 0.5451],
          ...,
          [0.4588, 0.4588, 0.4667,  ..., 0.4980, 0.3647, 0.4588],
          [0.3804, 0.4510, 0.4275,  ..., 0.5608, 0.4824, 0.4902],
          [0.3333, 0.4039, 0.3961,  ..., 0.6314, 0.5686, 0.5608]],

         [[0.4118, 0.6157, 0.5843,  ..., 0.4275, 0.4275, 0.3961],
          [0.5451, 0.6706, 0.6471,  ..., 0.4353, 0.4431, 0.4118],
          [0.5843, 0.6784

In [7]:
print(prepared_ds["train"].column_names)
print(prepared_ds["train"].features)
labels = prepared_ds["train"].features["style"].names
print("Art Period Labels:", labels)



['image', 'artist', 'genre', 'style']
{'image': Image(mode=None, decode=True, id=None), 'artist': ClassLabel(names=['Unknown Artist', 'boris-kustodiev', 'camille-pissarro', 'childe-hassam', 'claude-monet', 'edgar-degas', 'eugene-boudin', 'gustave-dore', 'ilya-repin', 'ivan-aivazovsky', 'ivan-shishkin', 'john-singer-sargent', 'marc-chagall', 'martiros-saryan', 'nicholas-roerich', 'pablo-picasso', 'paul-cezanne', 'pierre-auguste-renoir', 'pyotr-konchalovsky', 'raphael-kirchner', 'rembrandt', 'salvador-dali', 'vincent-van-gogh', 'hieronymus-bosch', 'leonardo-da-vinci', 'albrecht-durer', 'edouard-cortes', 'sam-francis', 'juan-gris', 'lucas-cranach-the-elder', 'paul-gauguin', 'konstantin-makovsky', 'egon-schiele', 'thomas-eakins', 'gustave-moreau', 'francisco-goya', 'edvard-munch', 'henri-matisse', 'fra-angelico', 'maxime-maufra', 'jan-matejko', 'mstislav-dobuzhinsky', 'alfred-sisley', 'mary-cassatt', 'gustave-loiseau', 'fernando-botero', 'zinaida-serebriakova', 'georges-seurat', 'isaac-lev

In [17]:
def process_example(example):
    inputs = processor(example['image'], return_tensors='pt')
    inputs['labels']  = example['style']

    return inputs

In [None]:

def transform(example):
    inputs = processor(example['image'], return_tensors='pt')
    inputs['pixel_values'] = inputs['pixel_values'].squeeze(0)
    inputs['labels'] = example['style']
    return inputs


In [30]:
train_ds = prepared_ds["train"].with_transform(transform)
test_ds = prepared_ds["test"].with_transform(transform)

In [20]:
import torch

def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['labels'] for x in batch])
    }

In [21]:
import numpy as np

import evaluate
metric = evaluate.load("accuracy")
def compute_metrics(p):
    return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)

In [22]:
from transformers import ViTForImageClassification

labels = prepared_ds['train'].features["style"].names

model = ViTForImageClassification.from_pretrained(
    model_name_or_path,
    num_labels=len(labels),
    id2label={str(i): c for i, c in enumerate(labels)},
    label2id={c: str(i) for i, c in enumerate(labels)}
)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [23]:
from transformers import TrainingArguments

training_args = TrainingArguments(
  output_dir="./vit-base-beans",
  per_device_train_batch_size=16,
  evaluation_strategy="steps",
  num_train_epochs=4,
  fp16=True,
  save_steps=100,
  eval_steps=100,
  logging_steps=10,
  learning_rate=2e-4,
  save_total_limit=2,
  remove_unused_columns=False,
  push_to_hub=False,
  report_to='tensorboard',
  load_best_model_at_end=True,
)



In [31]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=train_ds,
    eval_dataset=test_ds,
    tokenizer=processor,
)


  trainer = Trainer(


In [32]:
train_results = trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

Step,Training Loss,Validation Loss


KeyboardInterrupt: 

In [None]:
metrics = trainer.evaluate(prepared_ds['validation'])
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)