In [13]:
import tensorflow as tf
from transformers import ViTFeatureExtractor, TFViTForImageClassification
import tensorflow_datasets as tfds
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt



In [5]:
# Load a pretrained Vision Transformer model from Hugging Face (TensorFlow)
model_name = "google/vit-base-patch16-224"
model = TFViTForImageClassification.from_pretrained(model_name)

imagenette = tfds.load("imagenette/160px", split="train", as_supervised=False, batch_size=None)

All PyTorch model weights were used when initializing TFViTForImageClassification.

All the weights of TFViTForImageClassification were initialized from the PyTorch model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFViTForImageClassification for predictions without further training.


In [34]:

# Load and preprocess the image
def ViT_preprocessor(x):
    feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
    x = x['image']
    x = tf.image.resize(x, (224, 224)) / 255.0
    x = feature_extractor(images=x, return_tensors="tf")
    x = x['pixel_values']
    return x
inputs = []
for i in imagenette.take(5):
    x = ViT_preprocessor(i)
    inputs.append(x)
    
inputs = tf.concat(inputs, axis=0)
inputs.shape

2024-09-16 21:24:26.144999: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.


TensorShape([5, 3, 224, 224])

In [40]:
# Perform inference
outputs = model(inputs)
logits = outputs.logits

# Get the predicted class
predicted_class = tf.argmax(logits, axis=-1).numpy()
print(f"Predicted class ID: {predicted_class}")

# Get the class label (optional)
class_names = model.config.id2label
print(f"Predicted class: {[class_names[p] for p in predicted_class]}")


Predicted class ID: [111 644 644 644 644]
Predicted class: ['nematode, nematode worm, roundworm', 'matchstick', 'matchstick', 'matchstick', 'matchstick']
