Import section.

In [None]:
import gzip
import html
import os
from functools import lru_cache
from io import BytesIO

import ftfy
import ipyplot
import numpy as np
import regex as re
import requests
import PIL
import torch
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from PIL import Image

Load the CLIP model previously downloaded from [OpenAI](https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt). Replace the placeholder with the download location of the pre-trained model on your local machine.

In [None]:
model = torch.jit.load("<put the full path of the pre-trained CLIP model here>").cuda().eval()

Details about the model:

In [None]:
input_resolution = model.input_resolution.item()
context_length = model.context_length.item()
vocab_size = model.vocab_size.item()

print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")
print("Input resolution:", input_resolution)
print("Context length:", context_length)
print("Vocab size:", vocab_size)

Normalize the pixel intensity using the data set mean and standard deviation.

In [None]:
image_mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).cuda()
image_std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).cuda()

Define the image preprocessing: center-crop the input images so that their resolution is conformed to what the model expects.

In [None]:
preprocess = Compose([
    Resize(input_resolution, interpolation=Image.BICUBIC),
    CenterCrop(input_resolution),
    ToTensor()
])

Also, text need to be preprocessed. Let's use a case-insensitive tokenizer (to be preliminarily downloaded from [OpenAI](https://openaipublic.azureedge.net/clip/bpe_simple_vocab_16e6.txt.gz)). Replace the placeholder in the *__init__* function of the class *SimpleTokenizer* below with the download location of the zipped pre-trained model on your local machine.

In [None]:
@lru_cache()
def bytes_to_unicode():
    """
    Returns list of utf-8 byte and a corresponding list of unicode strings.
    The reversible bpe codes work on unicode strings.
    This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
    When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
    This is a signficant percentage of your normal, say, 32K bpe vocab.
    To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
    And avoids mapping to whitespace/control characters the bpe code barfs on.
    """
    bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
    cs = bs[:]
    n = 0
    for b in range(2**8):
        if b not in bs:
            bs.append(b)
            cs.append(2**8+n)
            n += 1
    cs = [chr(n) for n in cs]
    return dict(zip(bs, cs))


def get_pairs(word):
    """Return set of symbol pairs in a word.
    Word is represented as tuple of symbols (symbols being variable-length strings).
    """
    pairs = set()
    prev_char = word[0]
    for char in word[1:]:
        pairs.add((prev_char, char))
        prev_char = char
    return pairs


def basic_clean(text):
    text = ftfy.fix_text(text)
    text = html.unescape(html.unescape(text))
    return text.strip()


def whitespace_clean(text):
    text = re.sub(r'\s+', ' ', text)
    text = text.strip()
    return text


class SimpleTokenizer(object):
    def __init__(self, bpe_path: str = "<put the full path of the zipped tokenizer here>"):
        self.byte_encoder = bytes_to_unicode()
        self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
        merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
        merges = merges[1:49152-256-2+1]
        merges = [tuple(merge.split()) for merge in merges]
        vocab = list(bytes_to_unicode().values())
        vocab = vocab + [v+'</w>' for v in vocab]
        for merge in merges:
            vocab.append(''.join(merge))
        vocab.extend(['<|startoftext|>', '<|endoftext|>'])
        self.encoder = dict(zip(vocab, range(len(vocab))))
        self.decoder = {v: k for k, v in self.encoder.items()}
        self.bpe_ranks = dict(zip(merges, range(len(merges))))
        self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
        self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)

    def bpe(self, token):
        if token in self.cache:
            return self.cache[token]
        word = tuple(token[:-1]) + ( token[-1] + '</w>',)
        pairs = get_pairs(word)

        if not pairs:
            return token+'</w>'

        while True:
            bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
            if bigram not in self.bpe_ranks:
                break
            first, second = bigram
            new_word = []
            i = 0
            while i < len(word):
                try:
                    j = word.index(first, i)
                    new_word.extend(word[i:j])
                    i = j
                except:
                    new_word.extend(word[i:])
                    break

                if word[i] == first and i < len(word)-1 and word[i+1] == second:
                    new_word.append(first+second)
                    i += 2
                else:
                    new_word.append(word[i])
                    i += 1
            new_word = tuple(new_word)
            word = new_word
            if len(word) == 1:
                break
            else:
                pairs = get_pairs(word)
        word = ' '.join(word)
        self.cache[token] = word
        return word

    def encode(self, text):
        bpe_tokens = []
        text = whitespace_clean(basic_clean(text)).lower()
        for token in re.findall(self.pat, text):
            token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
            bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
        return bpe_tokens

    def decode(self, tokens):
        text = ''.join([self.decoder[token] for token in tokens])
        text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
        return text


Retrieve some local images. Set the *path* variable with a local directory path.

In [None]:
path = '<put the full path of the directory containing the images here>'

files = os.listdir(path)

all_images =[]
for f in files:
    image = PIL.Image.open(path + '\\' + f).convert("RGB")
    all_images.append(image)

print("Total images retrived: ", len(all_images))

Plot the retrived images within the notebook using [ipyplot](https://github.com/karolzak/ipyplot).

In [None]:
max_images = 105
img_width = 130
ipyplot.plot_images(all_images, max_images=max_images, img_width=img_width)

Preprocess the retrived images.

In [None]:
images = [preprocess(single_image) for single_image in all_images]

image_input = torch.tensor(np.stack(images)).cuda()
image_input -= image_mean[:, None, None]
image_input /= image_std[:, None, None]
with torch.no_grad():
    image_features = model.encode_image(image_input).float()
image_features /= image_features.norm(dim=-1, keepdim=True)

Create the tokenizer and then define function two functions for semantic search, one to extract the text features from the search sentence and the second one to get top N images having semantic similarity.

In [None]:
tokenizer = SimpleTokenizer()

'''
Given a sentence, this fuctions returns its text features. 
'''
def get_text_features_from_search_phrase(phrase):
    sentence_tokens = [tokenizer.encode("%s "%(phrase) + "<|endoftext|>")]
    text_input = torch.zeros(len(sentence_tokens), model.context_length, dtype=torch.long)
    for idx, tokens in enumerate(sentence_tokens):
        text_input[idx, :len(tokens)] = torch.tensor(tokens)
    
    text_input = text_input.cuda()
    with torch.no_grad():
        features = model.encode_text(text_input).float()
        features /= features.norm(dim=-1, keepdim=True)

    return features

'''
Given a list of images and their similarity list, this function returns 
the similarity scores for the top images (along with the images themselves).
'''
def get_top_semantic_similarity(similarity_list, image_list, max_number_of_results):
    results = zip(range(len(similarity_list)), similarity_list)
    results = sorted(results, key=lambda x: x[1],reverse= True)
    top_images = []
    scores=[]
    for index,score in results[:max_number_of_results]:
        scores.append(score)
        top_images.append(image_list[index])
    return scores, top_images

Set a search phrase.

In [None]:
search_phrase = "set a search sentence here"

Extract the text features from the search phrase.

In [None]:
text_features = get_text_features_from_search_phrase(search_phrase)

Calculate similarities.

In [None]:
similarity = text_features.cpu().numpy() @ image_features.cpu().numpy().T

Get the top images having semantic similarity.

In [None]:
similarity = similarity[0]
max_number_of_results = 3
scores, imgs = get_top_semantic_similarity(similarity, all_images, max_number_of_results)

Print the scores and plot the top images having semantic similarity.

In [None]:
print ("Similarity scores:", scores)
ipyplot.plot_images(imgs, img_width=300)