In [3]:
import torch
from PIL import Image
import open_clip

from matplotlib.pyplot import figure, imshow, axis
from matplotlib.image import imread

torch.set_printoptions(precision=3, sci_mode=False)

In [4]:
def showImagesHorizontally(list_of_files):
    fig = figure()
    number_of_files = len(list_of_files)
    for i in range(number_of_files):
        a=fig.add_subplot(1,number_of_files,i+1)
        image = imread(list_of_files[i])
        # imshow(image, cmap='Greys_r')
        imshow(image)
        axis('off')

# Load Model

In [5]:
open_clip.list_models()

['coca_base',
 'coca_roberta-ViT-B-32',
 'coca_ViT-B-32',
 'coca_ViT-L-14',
 'convnext_base',
 'convnext_base_w',
 'convnext_base_w_320',
 'convnext_large',
 'convnext_large_d',
 'convnext_large_d_320',
 'convnext_small',
 'convnext_tiny',
 'convnext_xlarge',
 'convnext_xxlarge',
 'convnext_xxlarge_320',
 'EVA01-g-14',
 'EVA01-g-14-plus',
 'EVA02-B-16',
 'EVA02-E-14',
 'EVA02-E-14-plus',
 'EVA02-L-14',
 'EVA02-L-14-336',
 'mt5-base-ViT-B-32',
 'mt5-xl-ViT-H-14',
 'nllb-clip-base',
 'nllb-clip-base-siglip',
 'nllb-clip-large',
 'nllb-clip-large-siglip',
 'RN50',
 'RN50-quickgelu',
 'RN50x4',
 'RN50x16',
 'RN50x64',
 'RN101',
 'RN101-quickgelu',
 'roberta-ViT-B-32',
 'swin_base_patch4_window7_224',
 'ViT-B-16',
 'ViT-B-16-plus',
 'ViT-B-16-plus-240',
 'ViT-B-16-quickgelu',
 'ViT-B-16-SigLIP',
 'ViT-B-16-SigLIP-256',
 'ViT-B-16-SigLIP-384',
 'ViT-B-16-SigLIP-512',
 'ViT-B-16-SigLIP-i18n-256',
 'ViT-B-32',
 'ViT-B-32-256',
 'ViT-B-32-plus-256',
 'ViT-B-32-quickgelu',
 'ViT-bigG-14',
 'ViT-

In [6]:
model_b, _, preprocess_b = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k')
tokenizer_b = open_clip.get_tokenizer('ViT-B-32')

In [7]:
model_l, _, preprocess_l = open_clip.create_model_and_transforms('ViT-L-14-336', pretrained='openai')
tokenizer_l = open_clip.get_tokenizer('ViT-B-32')

In [8]:
text_b = tokenizer_b(["a diagram", "a lobster", "a cat"])
text_l = tokenizer_l(["a diagram", "a lobster", "a cat"])

# Show Images

<div style="display: flex; flex-direction: row;">
    <img src="./experiment_data/lobster_grayscale_0.png" style="width: 200px; height: auto; margin-right: 10px;">
    <img src="./experiment_data/lobster_blue_0.png" style="width: 200px; height: auto; margin-right: 10px;">
    <img src="./experiment_data/lobster_red_0.png" style="width: 200px; height: auto; margin-right: 10px;">
</div>

In [9]:
lobster_0_files = [
    "./experiment_data/lobster_grayscale_0.png",
    "./experiment_data/lobster_blue_0.png",
    "./experiment_data/lobster_red_0.png"]

# showImagesHorizontally(lobster_0_files)

In [10]:
for filepath in lobster_0_files:
    image = preprocess_b(Image.open(filepath)).unsqueeze(0)
    with torch.no_grad(), torch.cuda.amp.autocast():
        image_features = model_b.encode_image(image)
        text_features = model_b.encode_text(text_b)
        image_features /= image_features.norm(dim=-1, keepdim=True)
        text_features /= text_features.norm(dim=-1, keepdim=True)
        
        dist = image_features @ text_features.T
        text_probs = (100.0 * dist).softmax(dim=-1)
    
    print(f"Label probs: {text_probs}, Cosine distance: {dist}")

Label probs: tensor([[    0.015,     0.985,     0.000]]), Cosine distance: tensor([[0.226, 0.268, 0.169]])
Label probs: tensor([[    0.003,     0.997,     0.000]]), Cosine distance: tensor([[0.232, 0.292, 0.157]])
Label probs: tensor([[    0.000,     1.000,     0.000]]), Cosine distance: tensor([[0.235, 0.319, 0.164]])
