In [None]:
! pip install ftfy regex tqdm
! pip install git+https://github.com/openai/CLIP.git

In [None]:
import glob

import torch
import numpy as np
import pandas as pd
from tqdm import tqdm


In [None]:

import clip
clip.available_models()

In [None]:

text_chunk_size = 1000
img_chunk_size = 1024

model, preprocess = clip.load("ViT-B/32")
model.cuda().eval()
input_resolution = model.visual.input_resolution
context_length = model.context_length
vocab_size = model.vocab_size

print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")
print("Input resolution:", input_resolution)
print("Context length:", context_length)
print("Vocab size:", vocab_size)

In [None]:

class_names = pd.read_csv('{csv path}/classes_in_imagenet.csv')['class_name'] # https://github.com/mf1024/ImageNet-datasets-downloader/blob/master/classes_in_imagenet.csv
text_descriptions = [f"This is a photo of a {label}" for label in class_names]
text_features = []

# use chunk
with torch.no_grad():
    for text_chunk in tqdm([text_descriptions[i:i + text_chunk_size] for i in range(0, len(text_descriptions), text_chunk_size)]):
      text_tokens = clip.tokenize(text_chunk).cuda()
      tmp = model.encode_text(text_tokens).float()
      tmp /= tmp.norm(dim=-1, keepdim=True)
      text_features.append(tmp)
text_features = torch.cat(tuple(text_features),dim=0)

In [None]:
original_images = []
images = []

img_paths = glob.glob('{images path}')
preds = []

for i, img_path in tqdm(enumerate(img_paths), total=len(img_paths)):
    # train 이미지 불러오기
    image = Image.open(img_path).convert("RGB")
    images.append(preprocess(image))
    if (i+1) % img_chunk_size == 0:
        image_input = torch.tensor(np.stack(images)).cuda()
        with torch.no_grad():
            image_features = model.encode_image(image_input).float()
            image_features /= image_features.norm(dim=-1, keepdim=True)

            text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
            top_probs, top_labels = text_probs.cpu().topk(1, dim=-1)

            preds += top_labels.flatten().tolist()
        images = []