#To carry out local inference of swin-base model with test images

In [5]:
import torch
from PIL import Image
from torchvision import transforms
from transformers import AutoModelForImageClassification, AutoImageProcessor

# Class mapping for FairFace
fairface_classes = [
    "White", "Black", "Latino_Hispanic", "East Asian",
    "Southeast Asian", "Indian", "Middle Eastern"
]

# Load processor for correct normalization
processor = AutoImageProcessor.from_pretrained("microsoft/swinv2-base-patch4-window16-256")
mean = processor.image_mean
std = processor.image_std

# Preprocessing (use processor's normalization and correct size)
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std)
])

# Load model
num_classes = 7
model = AutoModelForImageClassification.from_pretrained(
    'microsoft/swinv2-base-patch4-window16-256',
    num_labels=num_classes,
    ignore_mismatched_sizes=True
)
model.load_state_dict(torch.load("swinv2_fairface_best.pth", map_location="cpu"))
model.eval()
model = model.to("cuda" if torch.cuda.is_available() else "cpu")

def infer_image_from_path_swin(image_path):
    image = Image.open(image_path).convert('RGB')
    input_tensor = transform(image).unsqueeze(0)
    input_tensor = input_tensor.to("cuda" if torch.cuda.is_available() else "cpu")
    with torch.no_grad():
        logits = model(input_tensor).logits
        probs = torch.softmax(logits, dim=1)
        top5_probs, top5_indices = probs.topk(5, dim=1)
        top5_probs = top5_probs.cpu().numpy().flatten()
        top5_indices = top5_indices.cpu().numpy().flatten()
        top5_classes = [fairface_classes[i] for i in top5_indices]
    return top5_classes, top5_probs

# Example usage:
image_path = "dark.jpg"  # <-- Replace with your local image path
top5_classes, top5_probs = infer_image_from_path_swin(image_path)
print("Top-1 Prediction:", top5_classes[0], f"({top5_probs[0]*100:.2f}%)")
print("Top-5 Predictions:")
for cls, prob in zip(top5_classes, top5_probs):
    print(f"  {cls}: {prob*100:.2f}%")

Some weights of Swinv2ForImageClassification were not initialized from the model checkpoint at microsoft/swinv2-base-patch4-window16-256 and are newly initialized because the shapes did not match:
- classifier.weight: found shape torch.Size([1000, 1024]) in the checkpoint and torch.Size([7, 1024]) in the model instantiated
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([7]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Top-1 Prediction: Latino_Hispanic (99.88%)
Top-5 Predictions:
  Latino_Hispanic: 99.88%
  Black: 0.07%
  Indian: 0.04%
  East Asian: 0.01%
  White: 0.00%
