In [None]:
import torch
import evaluate
import numpy as np
from datasets import load_dataset
from transformers import AutoImageProcessor, AutoModelForImageClassification

dataset = load_dataset("cifar10", split="test[:1000]")
model_name = "nateraw/vit-base-patch16-224-cifar10"
model = AutoModelForImageClassification.from_pretrained(model_name)
image_processor = AutoImageProcessor.from_pretrained(model_name)

model.eval()
batch_size = 32
logits, labels = [], []

for i in range(0, len(dataset), batch_size):
    batch = dataset[i : i + batch_size]

    with torch.no_grad():
        inputs = image_processor(batch["img"], return_tensors="pt")
        outputs = model(inputs["pixel_values"])

    logits.extend(outputs.logits.cpu().numpy())
    labels.extend(batch["label"])

preds = np.argmax(logits, axis=-1)

metric = evaluate.load("accuracy")
accuracy = metric.compute(predictions=preds, references=labels)
print(accuracy)

metric = evaluate.load("f1")
f1 = metric.compute(predictions=preds, references=labels, average="macro")
print(f1)

comparison = evaluate.load("exact_match", module_type="comparison")
exact_match = comparison.compute(predictions1=preds, predictions2=labels)
print(exact_match)

measurement = evaluate.load("label_distribution", module_type="measurement")
distribution = measurement.compute(data=labels)
print(distribution)
