Maybe instead of a traditional MLP classifier a contrastive classifier with text embeddings of label classes as priors would work well?
Let's see how the text embeddings for 1000 imagenet classes. We'll collect the embeddings then reduce the dimensions and plot them in 3D space. Hopefully, each embedding will land somewhere distinct and similar classes will embed close to each other.

**Result:** If you zoom into a section of the cluster and hover over some data, it's pretty clear that the text embeddings are logical - similar objects produce similar embeddings and therefore end up in similar locations in space.

In [None]:
import requests
from PIL import Image
from transformers import AutoProcessor, OwlViTForObjectDetection
import notebook_helper 
import torch 
import plotly.express as px
with open("assets/imagenet_classes.txt") as f:
    # Imagenet has a heirarchy of subclasses that's comma separated, main class is first thing before comma
    labels = [label.split(",")[0] for label in f.read().split("\n")]

processor = AutoProcessor.from_pretrained("google/owlvit-base-patch32")
model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32")

url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
inputs = processor(text=[labels], images=image, return_tensors="pt")
with torch.no_grad():
    output = model(**inputs).text_embeds.squeeze(0).numpy()

reduced = notebook_helper.get_reduced(output, 3)
fig = notebook_helper.make_plot_3d(reduced, hover_labels=labels)
display(fig)