In [None]:
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from datasets import load_dataset, load_metric
from transformers import ViTImageProcessor, ViTForImageClassification
from PIL import Image
from transformers import TrainingArguments
from transformers import Trainer

In [None]:
device = torch.device('mps')
device

In [None]:
dataset_train = load_dataset('cifar10',split='train[:10%]', verification_mode='no_checks', cache_dir='/Users/ykamoji/Documents/ImageDatabase/cifar10/train')
dataset_test = load_dataset('cifar10', split='test[:10%]', verification_mode='no_checks', cache_dir='/Users/ykamoji/Documents/ImageDatabase/cifar10/test')

In [None]:
print(dataset_train)
print(dataset_test)
num_classes = len(set(dataset_train['label']))
labels = dataset_train.features['label']
num_classes, labels

In [None]:
dataset_train[0]['img']

In [None]:
dataset_train[0]['label'], labels.names[dataset_train[0]['label']]

In [None]:
model_name = 'google/vit-base-patch16-224-in21k'
feature_extractor = ViTImageProcessor.from_pretrained(model_name, cache_dir='models/')

In [None]:
feature_extractor

In [None]:
print(dataset_train[0]['img'].size)
example = feature_extractor(dataset_train[0]['img'], return_tensors='pt')
print(example['pixel_values'].shape)

In [None]:
def preprocess(batchImage):
    inputs = feature_extractor(batchImage['img'], return_tensors='pt')
    inputs['label'] = batchImage['label']
    return inputs

In [None]:
prepared_train = dataset_train.with_transform(preprocess)
prepared_test = dataset_test.with_transform(preprocess)

In [None]:
def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['label'] for x in batch])
    }

In [None]:
metric = load_metric("accuracy", cache_dir='metrics/', trust_remote_code=True)
def compute_metrics(p):
    return metric.compute(
        predictions=np.argmax(p.predictions, axis=1),
        references=p.label_ids
    )

In [None]:
training_args = TrainingArguments(
  output_dir="cifar",
  per_device_train_batch_size=16,
  evaluation_strategy="steps",
  num_train_epochs=4,
  save_steps=100,
  eval_steps=100,
  logging_steps=10,
  learning_rate=2e-4,
  save_total_limit=2,
  remove_unused_columns=False,
  push_to_hub=False,
  load_best_model_at_end=True
)

In [None]:
labels = dataset_train.features['label'].names

model = ViTForImageClassification.from_pretrained(model_name, num_labels=len(labels), cache_dir='models/')

In [None]:
model.to(device)

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=prepared_train,
    eval_dataset=prepared_test,
    tokenizer=feature_extractor,
)

In [None]:
train_results = trainer.train()

In [None]:
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
# save the trainer state
trainer.save_state()

In [None]:
metrics = trainer.evaluate(prepared_test)
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

In [None]:
image = dataset_test["img"][0].resize((200,200))
image

In [None]:
# extract the actual label of the first image of the testing dataset
actual_label = dataset_test["label"][0]

labels = dataset_test.features['label']
actual_label, labels.names[actual_label]

In [None]:
model_name_or_path = 'LaCarnevali/vit-cifar10'
model_finetuned = ViTForImageClassification.from_pretrained(model_name_or_path)
# import features
feature_extractor_finetuned = ViTImageProcessor.from_pretrained(model_name_or_path)

In [None]:
inputs = feature_extractor_finetuned(image, return_tensors="pt")

with torch.no_grad():
    logits = model_finetuned(**inputs).logits

In [None]:
predicted_label = logits.argmax(-1).item()
labels = dataset_test.features['label']
labels.names[predicted_label]