In [None]:
import clip
import os, sys
import numpy as np
import torch
import torch.hub
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from PIL import Image
sys.path.append('../../..')

In [None]:
def categories():
	txt_path = os.path.join('../../../interrogate', 'flavors.txt')
	with open(txt_path, 'r', encoding='utf-8') as f:
		return [line.strip() for line in f.readlines()]
c = categories()

In [None]:


device = 'cuda'
clip_model_name = 'ViT-L/14'
clip_models_path = '/f/stablediffusion/stable-diffusion-webui/models/clip-interrogator'
blip_image_eval_size = 384
dtype = torch.float32

# from modules/interrogate.py
def load_clip_model():
	import clip
	model, preprocess = clip.load(clip_model_name)
	model.eval()
	model = model.to(device)
	return model, preprocess

def preprocess_img(preprocess, pil_image):
		clip_image = preprocess(pil_image).unsqueeze(0).type(dtype).to(device)
		return clip_image

def encode_image(clip_model, clip_image):
	image_features = clip_model.encode_image(clip_image).type(dtype)
	return image_features

def similarity(text_array, text_features, image_features, top_count=1):
        similarity = torch.zeros((1, len(text_array))).to(device)
        for i in range(image_features.shape[0]):
            similarity += (100.0 * image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1)
        similarity /= image_features.shape[0]

        top_probs, top_labels = similarity.cpu().topk(top_count, dim=-1)
        return [(text_array[top_labels[0][i].numpy()], (top_probs[0][i].numpy()*100)) for i in range(top_count)] 

def torch_gc():
	with torch.cuda.device('cuda:0'):
		torch.cuda.empty_cache()
		torch.cuda.ipc_collect()

def rank(clip_model, image_features, text_array, top_count=1):
	top_count = min(top_count, len(text_array))
	text_tokens = clip.tokenize(list(text_array), truncate=True).to(device)
	text_features = clip_model.encode_text(text_tokens).type(dtype)
	text_features /= text_features.norm(dim=-1, keepdim=True)

	similarity = torch.zeros((1, len(text_array))).to(device)
	for i in range(image_features.shape[0]):
		similarity += (100.0 * image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1)
	similarity /= image_features.shape[0]

	top_probs, top_labels = similarity.cpu().topk(top_count, dim=-1)
	return [(text_array[top_labels[0][i].numpy()], (top_probs[0][i].numpy()*100)) for i in range(top_count)]




In [None]:
model, preprocess = load_clip_model()
clip_model = model


In [None]:
img_10 = Image.open('images/10.png')
img_10_features = encode_image(model, preprocess_img(preprocess, img_10))

img_30 = Image.open('images/30.png')
img_30_features = encode_image(model, preprocess_img(preprocess, img_30))

# %%markdown
# ![title](images/10.png)
# ![title](images/30.png)


In [None]:
clip_model = model
device = 'cuda'
dtype = torch.float32

text = 'a photo of a cat wearing a pink hat on a blue rug'
text_array = text.split(' ')
text_array = c
text_tokens = clip.tokenize(text, truncate=True).to(device)
text_features = clip_model.encode_text(text_tokens).type(dtype)

with torch.no_grad():
	print(similarity(text_array, text_features, img_10_features, len(text_array)))

text_features_single  = clip_model.encode_text(text_tokens).type(dtype)

text_concat = ', '.join([text, text, text, text])
text_tokens = clip.tokenize(text_concat, truncate=True).to(device)


text_token_list = [clip.tokenize(x, truncate=True).to(device) for x in text_array]
empty_token = clip.tokenize('', truncate=True).to(device)
text_feature_list = [clip_model.encode_text(t).type(dtype) for t in text_token_list]


In [None]:
t = torch.zeros_like(text_features).to(torch.float32)
torch.not_equal(text_features, t)

In [None]:
[x.shape for x in text_feature_list]