In [None]:
import torch
from PIL import Image
import open_clip

# if you're using a silicon Mac, you might see faster performance using mps, otherwise set MAC = False
MAC = True
device = "mps" if MAC else "cpu"

# model initialization
model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k', device=device)
model.eval()
tokenizer = open_clip.get_tokenizer('ViT-B-32')

In [None]:
# put in any image you want
image_fp = "elephant.webp"
image = preprocess(Image.open(image_fp)).unsqueeze(0).to(device)

# change the classification options to whatever you want
cls_options = [
    "an elephant", 
    "a dog", 
    "a cat"
]
text = tokenizer(cls_options).to(device)

with torch.no_grad(), torch.cuda.amp.autocast():
    image_features = model.encode_image(image)
    text_features = model.encode_text(text)
    image_features /= image_features.norm(dim=-1, keepdim=True)
    text_features /= text_features.norm(dim=-1, keepdim=True)

    text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)

for cls, prob in zip(cls_options, text_probs[0]):
    print(f"{cls}: {prob:.2f}")