In [2]:
import torch
from PIL import Image
import open_clip
import numpy as np
model, _, preprocess = open_clip.create_model_and_transforms('ViT-bigG-14', pretrained='laion2b_s39b_b160k')
tokenizer = open_clip.get_tokenizer('ViT-bigG-14')

In [7]:
# open_clip.list_pretrained()
import numpy as np
model.eval()
context_length = model.context_length
vocab_size = model.vocab_size

print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")
print("Context length:", context_length)
print("Vocab size:", vocab_size)

Model parameters: -1,755,400,191
Context length: 77
Vocab size: 49408


In [4]:
image = preprocess(Image.open("./data/images/ROCO_00001.jpg")).unsqueeze(0)
text = tokenizer(["MRI", "head", "CT", "cat"])

with torch.no_grad(), torch.cuda.amp.autocast():
    image_features = model.encode_image(image)
    text_features = model.encode_text(text)
    image_features /= image_features.norm(dim=-1, keepdim=True)
    text_features /= text_features.norm(dim=-1, keepdim=True)

    text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)

print("Label probs:", text_probs)  # prints: [[1., 0., 0.]]

Label probs: tensor([[9.9972e-01, 2.4901e-04, 2.6577e-05, 4.2612e-09]])


In [14]:
import pandas as pd
df = pd.read_csv('./data/roco_validation.csv', sep='\t')
img_key = 'filepath'
print(df.keys())
images = df[img_key].tolist()
print(images[:5])
caption_key = 'title'
captions = df[caption_key].tolist()

Index(['filepath', 'caption'], dtype='object')
['E:\\Work\\roco-dataset\\data\\validation\\radiology\\images\\ROCO_00020.jpg', 'E:\\Work\\roco-dataset\\data\\validation\\radiology\\images\\ROCO_00027.jpg', 'E:\\Work\\roco-dataset\\data\\validation\\radiology\\images\\ROCO_00059.jpg', 'E:\\Work\\roco-dataset\\data\\validation\\radiology\\images\\ROCO_00062.jpg', 'E:\\Work\\roco-dataset\\data\\validation\\radiology\\images\\ROCO_00068.jpg']


In [15]:
import open_clip
open_clip.list_pretrained()

[('RN50', 'openai'),
 ('RN50', 'yfcc15m'),
 ('RN50', 'cc12m'),
 ('RN50-quickgelu', 'openai'),
 ('RN50-quickgelu', 'yfcc15m'),
 ('RN50-quickgelu', 'cc12m'),
 ('RN101', 'openai'),
 ('RN101', 'yfcc15m'),
 ('RN101-quickgelu', 'openai'),
 ('RN101-quickgelu', 'yfcc15m'),
 ('RN50x4', 'openai'),
 ('RN50x16', 'openai'),
 ('RN50x64', 'openai'),
 ('ViT-B-32', 'openai'),
 ('ViT-B-32', 'laion400m_e31'),
 ('ViT-B-32', 'laion400m_e32'),
 ('ViT-B-32', 'laion2b_e16'),
 ('ViT-B-32', 'laion2b_s34b_b79k'),
 ('ViT-B-32', 'datacomp_m_s128m_b4k'),
 ('ViT-B-32', 'commonpool_m_clip_s128m_b4k'),
 ('ViT-B-32', 'commonpool_m_laion_s128m_b4k'),
 ('ViT-B-32', 'commonpool_m_image_s128m_b4k'),
 ('ViT-B-32', 'commonpool_m_text_s128m_b4k'),
 ('ViT-B-32', 'commonpool_m_basic_s128m_b4k'),
 ('ViT-B-32', 'commonpool_m_s128m_b4k'),
 ('ViT-B-32', 'datacomp_s_s13m_b4k'),
 ('ViT-B-32', 'commonpool_s_clip_s13m_b4k'),
 ('ViT-B-32', 'commonpool_s_laion_s13m_b4k'),
 ('ViT-B-32', 'commonpool_s_image_s13m_b4k'),
 ('ViT-B-32', 'commo