To carry out local inference of base VIT pth file 

In [9]:
import torch
from PIL import Image
from torchvision import transforms
from transformers import ViTForImageClassification
import io

# Class mapping for FairFace
fairface_classes = [
    "White", "Black", "Latino_Hispanic", "East Asian",
    "Southeast Asian", "Indian", "Middle Eastern"
]

# Preprocessing (same as training)
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load model
num_classes = 7
model = ViTForImageClassification.from_pretrained(
    'google/vit-base-patch16-224',
    num_labels=num_classes,
    ignore_mismatched_sizes=True
)
model.load_state_dict(torch.load("vit_fairface_best.pth", map_location="cpu"))
model.eval()
model = model.to("cuda" if torch.cuda.is_available() else "cpu")

def infer_image_from_path(image_path):
    image = Image.open(image_path).convert('RGB')
    input_tensor = transform(image).unsqueeze(0)
    input_tensor = input_tensor.to("cuda" if torch.cuda.is_available() else "cpu")
    with torch.no_grad():
        logits = model(input_tensor).logits
        probs = torch.softmax(logits, dim=1)
        top5_probs, top5_indices = probs.topk(5, dim=1)
        top5_probs = top5_probs.cpu().numpy().flatten()
        top5_indices = top5_indices.cpu().numpy().flatten()
        top5_classes = [fairface_classes[i] for i in top5_indices]
    return top5_classes, top5_probs

# Example usage:
image_path = "asian.jpg"  # <-- Replace with your local image path
top5_classes, top5_probs = infer_image_from_path(image_path)
print("Top-1 Prediction:", top5_classes[0], f"({top5_probs[0]*100:.2f}%)")
print("Top-5 Predictions:")
for cls, prob in zip(top5_classes, top5_probs):
    print(f"  {cls}: {prob*100:.2f}%")

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([7]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([7, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Top-1 Prediction: White (91.29%)
Top-5 Predictions:
  White: 91.29%
  Middle Eastern: 4.95%
  Indian: 2.81%
  East Asian: 0.72%
  Black: 0.17%
