In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds
from transformers import ViTFeatureExtractor, TFAutoModelForImageClassification
import numpy as np

# Load CIFAR-10 and resize to 224x224
def preprocess(example):
    image = tf.image.resize(example['image'], [224, 224])
    image = tf.cast(image, tf.float32) / 255.0  # Normalize to [0,1]
    return image, example['label']

train_ds = tfds.load("cifar10", split="train", as_supervised=False)
test_ds = tfds.load("cifar10", split="test", as_supervised=False)

train_ds = train_ds.map(preprocess).batch(32).prefetch(tf.data.AUTOTUNE)
test_ds = test_ds.map(preprocess).batch(32).prefetch(tf.data.AUTOTUNE)

# Load ViT model and feature extractor
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
model = TFAutoModelForImageClassification.from_pretrained(
    "google/vit-base-patch16-224", num_labels=10
)

# Compile model
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=["accuracy"]
)

# Train
model.fit(train_ds, epochs=3)

# Evaluate
test_loss, test_acc = model.evaluate(test_ds)
print(f"Test Accuracy: {test_acc * 100:.2f}%")
