In [None]:
!pip install evaluate

In [None]:
from datasets import load_dataset, DatasetDict
from transformers import ViTImageProcessor, ViTForImageClassification, TrainingArguments, Trainer
import torch
import numpy as np
from evaluate import load

In [None]:
# 1. Carregar o Dataset
dataset = load_dataset("cifar10")

In [None]:
dataset = DatasetDict({
    "train": dataset["train"].select(range(5000)),
    "test": dataset["test"].select(range(1000)),
})

In [None]:
# 2. Pré-processamento (Redimensionar imagens para o ViT)
model_name = "google/vit-base-patch16-224-in21k"
processor = ViTImageProcessor.from_pretrained(model_name)

In [None]:
def transform(example_batch):
    inputs = processor([x for x in example_batch['img']], return_tensors='pt')
    inputs['labels'] = example_batch['label']
    return inputs

In [None]:
prepared_ds = dataset.with_transform(transform)

In [None]:
# 3. Carregar o Modelo com as labels do CIFAR-10
labels = dataset['train'].features['label'].names
model = ViTForImageClassification.from_pretrained(
    model_name,
    num_labels=len(labels),
    id2label={str(i): c for i, c in enumerate(labels)},
    label2id={c: str(i) for i, c in enumerate(labels)}
)

In [None]:
# 4. Configurações de Treino
training_args = TrainingArguments(
    output_dir="./vit-cifar10",
    per_device_train_batch_size=8,
    eval_strategy="steps",
    num_train_epochs=3,
    save_steps=100,
    eval_steps=100,
    logging_steps=10,
    learning_rate=2e-5,
    save_total_limit=2,
    remove_unused_columns=False,
    push_to_hub=False,
    report_to="none",
    load_best_model_at_end=True,
)

In [None]:
# 5. Inicializar o Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=prepared_ds['train'],
    eval_dataset=prepared_ds['test'],
    tokenizer=processor,
)

In [None]:
# Rodar!
trainer.train()