In [1]:
import os
import re
import requests
import numpy as np
from PIL import Image
from IPython.display import display
import warnings
warnings.filterwarnings("ignore")

In [2]:
from transformers import CLIPModel, T5Tokenizer

In [None]:
# use the customized image preprocessor. Do not use the HF imagepreprocessor, otherwise the results will be different
from clip_processor import image_preprocess

In [None]:
os.makedirs('checkpoints', exist_ok=True)
# download the checkpoints using `wget`, we use vecapdfn_clip_h14_336 as an example
!wget https://docs-assets.developer.apple.com/ml-research/models/veclip/vecapdfn_clip_h14_336.zip -P checkpoints/
!unzip checkpoints/vecapdfn_clip_h14_336.zip -d checkpoints/

In [None]:
MODEL_DIR = "checkpoints/vecapdfn_clip_h14_336"

# load tokenizer and model
# Note: The T5 tokenizer does not enforce a fixed maximum input length. Therefore, during usage, 
# if any warnings related to sequence length exceedance appear, they can generally be ignored.
tokenizer = T5Tokenizer.from_pretrained("t5-base")
print(f"Loading model {MODEL_DIR} ...")
model = CLIPModel.from_pretrained(MODEL_DIR)

In [6]:
# text model
texts = ["a photo of car", "a photo of two cats"]
text_inputs = tokenizer(texts, return_tensors="pt", padding=True)

text_outputs = model.text_model(**text_inputs)

In [None]:
text_outputs.last_hidden_state.shape

In [None]:
# default image crop size
crop_size = 224
match = re.search(r'clip_h\d+_(\d+)', MODEL_DIR)
if match:
    crop_size = int(match.group(1))

# vision model
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
pixel_values = image_preprocess([np.asarray(image)], crop_size=crop_size)

vision_outputs = model.vision_model(pixel_values=pixel_values)

In [None]:
vision_outputs.last_hidden_state.shape

In [None]:
# text-vision model
outputs = model(**text_inputs, pixel_values=pixel_values)
logits_per_image = outputs.logits_per_image
probs = logits_per_image.softmax(dim=1)
print(probs)

In [None]:
display(image)
for prob, text in zip(probs[0], texts):
    # Format and print the message
    print("Probability for '{}' is {:.2%}".format(text, prob))