In [None]:
import torch
import clip
from PIL import Image

# Load the CLIP model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('ViT-B/32', device=device)

# Load the image and preprocess it
image = Image.open("example_image.jpg")
image_input = preprocess(image).unsqueeze(0).to(device)

# Prepare the caption prompt
caption_prompt = "a photo of a cat sitting on a table"

# Encode the caption prompt using CLIP
caption_input = clip.tokenize(caption_prompt).to(device)
caption_features = model.encode_text(caption_input)

# Calculate the cosine similarity between the image and caption embeddings
image_features = model.encode_image(image_input)
similarity = (100.0 * image_features @ caption_features.T).softmax(dim=-1)

# Get the most likely caption for the image
captions = clip.tokenize(["a cat sitting on a table"]).to(device)
caption_features = model.encode_text(captions)
logits_per_image, logits_per_text = model(image_input, caption_features)
probs = logits_per_image.softmax(dim=-1).cpu().numpy()
most_likely_caption = " ".join([clip.decode(c.item()) for c in captions[probs.argmax()]])

print("Most likely caption:", most_likely_caption)
