In [1]:
import torch
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, random_split
from transformers import pipeline, AutoModelForImageClassification, Trainer, TrainingArguments
from datasets import load_dataset

In [2]:
dataset = load_dataset("imagefolder", data_dir="./data/", split="train")

Resolving data files:   0%|          | 0/2523 [00:00<?, ?it/s]

Found cached dataset imagefolder (/Users/ryanmiville/.cache/huggingface/datasets/imagefolder/default-cc8a6490479c0dd7/0.0.0/37fbb85cc714a338bea574ac6c7d0b5be5aff46c1862c1989b20e0771199e93f)


In [3]:
dataset = dataset.train_test_split(test_size=0.2)


In [4]:
labels = dataset["train"].features["label"].names
label2id = {}
id2label = {}

for i, label in enumerate(labels):
    label2id[label] = str(i)
    id2label[str(i)] = label

In [5]:
from transformers import AutoImageProcessor

model_name = "google/vit-base-patch16-224-in21k"
image_processor = AutoImageProcessor.from_pretrained(model_name)

In [17]:
from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor

normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
size = (
    image_processor.size["shortest_edge"]
    if "shortest_edge" in image_processor.size
    else (image_processor.size["height"], image_processor.size["width"])
)
_transforms = Compose([RandomResizedCrop(size), ToTensor(), normalize])


In [18]:
def transforms(examples):
    examples["pixel_values"] = [_transforms(img.convert("RGB")) for img in examples["image"]]
    del examples["image"]
    return examples


In [14]:
from transformers import DefaultDataCollator

data_collator = DefaultDataCollator()


In [19]:
dataset = dataset.with_transform(transforms)

In [7]:
import evaluate

accuracy = evaluate.load("accuracy")


In [8]:
import numpy as np

def compute_metrics(eval_pred):
	predictions, labels = eval_pred
	predictions = np.argmax(predictions, axis=1)
	return accuracy.compute(predictions=predictions, references=labels)

In [9]:
model = AutoModelForImageClassification.from_pretrained(
    model_name,
		num_labels=len(labels),
    label2id=label2id,
    id2label=id2label,
)

Some weights of the model checkpoint at google/vit-base-patch16-224-in21k were not used when initializing ViTForImageClassification: ['pooler.dense.bias', 'pooler.dense.weight']
- This IS expected if you are initializing ViTForImageClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTForImageClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [11]:
training_args = TrainingArguments(
    output_dir="flannel_model",
    remove_unused_columns=False,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=16,
    gradient_accumulation_steps=4,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    warmup_ratio=0.1,
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
)


In [20]:
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    tokenizer=image_processor,
    compute_metrics=compute_metrics,
)

In [21]:
trainer.train()



  0%|          | 0/93 [00:00<?, ?it/s]

{'loss': 0.5437, 'learning_rate': 5e-05, 'epoch': 0.31}
{'loss': 0.2452, 'learning_rate': 4.3975903614457834e-05, 'epoch': 0.63}
{'loss': 0.1282, 'learning_rate': 3.7951807228915666e-05, 'epoch': 0.94}


  0%|          | 0/32 [00:00<?, ?it/s]

{'eval_loss': 0.10317094624042511, 'eval_accuracy': 0.9841584158415841, 'eval_runtime': 189.6832, 'eval_samples_per_second': 2.662, 'eval_steps_per_second': 0.169, 'epoch': 0.98}
{'loss': 0.1124, 'learning_rate': 3.192771084337349e-05, 'epoch': 1.26}
{'loss': 0.0637, 'learning_rate': 2.5903614457831325e-05, 'epoch': 1.57}
{'loss': 0.0532, 'learning_rate': 1.9879518072289157e-05, 'epoch': 1.89}


  0%|          | 0/32 [00:00<?, ?it/s]

{'eval_loss': 0.07660099864006042, 'eval_accuracy': 0.9841584158415841, 'eval_runtime': 189.5913, 'eval_samples_per_second': 2.664, 'eval_steps_per_second': 0.169, 'epoch': 1.98}
{'loss': 0.0354, 'learning_rate': 1.3855421686746989e-05, 'epoch': 2.2}
{'loss': 0.0515, 'learning_rate': 7.83132530120482e-06, 'epoch': 2.52}
{'loss': 0.032, 'learning_rate': 1.8072289156626506e-06, 'epoch': 2.83}


  0%|          | 0/32 [00:00<?, ?it/s]

{'eval_loss': 0.06595849990844727, 'eval_accuracy': 0.9861386138613861, 'eval_runtime': 195.0077, 'eval_samples_per_second': 2.59, 'eval_steps_per_second': 0.164, 'epoch': 2.93}
{'train_runtime': 7539.2373, 'train_samples_per_second': 0.803, 'train_steps_per_second': 0.012, 'train_loss': 0.13687530560519107, 'epoch': 2.93}


TrainOutput(global_step=93, training_loss=0.13687530560519107, metrics={'train_runtime': 7539.2373, 'train_samples_per_second': 0.803, 'train_steps_per_second': 0.012, 'train_loss': 0.13687530560519107, 'epoch': 2.93})