In [1]:
import torch
from transformers import ViTFeatureExtractor, ViTForImageClassification, TrainingArguments, Trainer
from datasets import load_dataset
from evaluate import load
import numpy as np
from transformers import DefaultDataCollator
from torchvision import transforms

from accelerate import Accelerator

In [2]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(device)

cuda


In [3]:
# establish which LLM to train on
model_name = 'google/vit-base-patch16-224-in21k'
run_name = model_name + '_run0'

In [4]:
# build dataset
food_dataset = load_dataset("food101")

Found cached dataset food101 (/home/ref2156/.cache/huggingface/datasets/food101/default/0.0.0/7cebe41a80fb2da3f08fcbef769c8874073a86346f7fb96dc0847d4dfc318295)


  0%|          | 0/2 [00:00<?, ?it/s]

In [5]:
labels = food_dataset["train"].features["label"].names
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
    label2id[label] = str(i)
    id2label[str(i)] = label

In [6]:
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
model = ViTForImageClassification.from_pretrained(model_name,
                                                  num_labels=len(labels),
                                                  id2label=id2label,
                                                  label2id=label2id).to(device)

Some weights of the model checkpoint at google/vit-base-patch16-224-in21k were not used when initializing ViTForImageClassification: ['pooler.dense.bias', 'pooler.dense.weight']
- This IS expected if you are initializing ViTForImageClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTForImageClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [7]:
# define fine-tuning hyper parameters
epochs = 8
per_dev_batch_size = 16
output_dir = './vit'
lr = 5e-5

In [8]:
data_collator = DefaultDataCollator()

In [14]:
# define image transformations
normalize = transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
size = (feature_extractor.size["shortest_edge"]
        if "shortest_edge" in feature_extractor.size
        else (feature_extractor.size["height"], feature_extractor.size["width"])
        )
img_transforms = transforms.Compose([transforms.RandomResizedCrop(size), transforms.ToTensor(), normalize])


# define function to prepare dataset for huggingface implementation
def transform_data(examples):
    examples["pixel_values"] = [img_transforms(img.convert("RGB")) for img in examples["image"]]
    del examples["image"]
    return examples

In [15]:
food_dataset = food_dataset.with_transform(transform_data)

In [16]:
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return accuracy.compute(predictions=predictions, references=labels)

In [17]:
training_args = TrainingArguments(
    output_dir=output_dir + run_name,
    remove_unused_columns=False,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=lr,
    per_device_train_batch_size=per_dev_batch_size,
    gradient_accumulation_steps=4,
    per_device_eval_batch_size=per_dev_batch_size,
    num_train_epochs=epochs,
    warmup_ratio=0.1,
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=food_dataset["train"],
    eval_dataset=food_dataset["validation"],
    tokenizer=feature_extractor,
    compute_metrics=compute_metrics,
)

In [18]:
trainer.train()



Epoch,Training Loss,Validation Loss
