#Image retrieval with prompts with MetaCLIP

## Installing dependencies

In [None]:
! pip install transformers torch faiss-gpu datasets loguru

## Loading the model

In [1]:
import torch
from PIL import Image
from datasets import load_dataset
from transformers import AutoProcessor, AutoModelForZeroShotImageClassification, AutoTokenizer
from tqdm import tqdm
#Only needed in Google colab
import torch._dynamo
torch._dynamo.config.suppress_errors = True
import faiss
import numpy as np

#Define device (GPU or CPU)
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")


  from .autonotebook import tqdm as notebook_tqdm


In [2]:

#Load CLIP model, processor and tokenizer
# /root/.cache/huggingface/hub/models--facebook--metaclip-b16-fullcc2.5b
embedding_size = 512
processor = AutoProcessor.from_pretrained("facebook/metaclip-b16-fullcc2.5b")
model = AutoModelForZeroShotImageClassification.from_pretrained("facebook/metaclip-b16-fullcc2.5b",  torch_dtype=torch.float16).to(device)
model = torch.compile(model)
tokenizer = AutoTokenizer.from_pretrained("facebook/metaclip-b16-fullcc2.5b")

In [28]:
embedding_size = 768
processor = AutoProcessor.from_pretrained("facebook/metaclip-l14-fullcc2.5b")
model = AutoModelForZeroShotImageClassification.from_pretrained("facebook/metaclip-l14-fullcc2.5b",  torch_dtype=torch.float16).to(device)
model = torch.compile(model)
tokenizer = AutoTokenizer.from_pretrained("facebook/metaclip-l14-fullcc2.5b")

## Loading the dataset

## Extracting features of image

In [29]:
#Add a vector to FAISS index
def add_vector_to_index(embedding, index):
    #convert embedding to numpy
    vector = embedding.detach().cpu().numpy()
    #Convert to float32 numpy
    vector = np.float32(vector)
    #Normalize vector
    faiss.normalize_L2(vector)
    #Add to index
    index.add(vector)

#Extract features of a given image
def extract_features_clip(image):
    with torch.no_grad():
        inputs = processor(images=image, return_tensors="pt").to(device)
        image_features = model.get_image_features(**inputs)
        return image_features

In [38]:
from pathlib import Path

#FAISS index
index = faiss.IndexFlatL2(embedding_size)

file_dir = '/opt/product/sd_test/test_datas/screw_cls'

import glob
# imgs = os.listdir(file_dir)
imgs = glob.glob(file_dir + '/**/*.bmp')
file_with_idx = {}
cls_with_idx = {}

#Process the dataset to extract all features and store in index
idx = 0 
for image_path in tqdm(imgs):
    image = Image.open(str(image_path))
    clip_features = extract_features_clip(image)
    print(clip_features.shape)
    add_vector_to_index(clip_features,index)
    file_with_idx[idx] = image_path
    cls = Path(image_path).parent.stem.split('_')[1]
    cls_with_idx[idx] = cls
    idx = idx + 1 

#Write index locally. Not needed after but can be useful for future retrieval
faiss.write_index(index,"metaclip.index")

  7%|▋         | 6/91 [00:00<00:03, 25.87it/s]

torch.Size([1, 768])
torch.Size([1, 768])
torch.Size([1, 768])
torch.Size([1, 768])
torch.Size([1, 768])
torch.Size([1, 768])


 15%|█▌        | 14/91 [00:00<00:02, 32.84it/s]

torch.Size([1, 768])
torch.Size([1, 768])
torch.Size([1, 768])
torch.Size([1, 768])
torch.Size([1, 768])
torch.Size([1, 768])
torch.Size([1, 768])
torch.Size([1, 768])


 20%|█▉        | 18/91 [00:00<00:02, 32.34it/s]

torch.Size([1, 768])
torch.Size([1, 768])
torch.Size([1, 768])
torch.Size([1, 768])
torch.Size([1, 768])
torch.Size([1, 768])


 29%|██▊       | 26/91 [00:00<00:02, 27.97it/s]

torch.Size([1, 768])
torch.Size([1, 768])
torch.Size([1, 768])
torch.Size([1, 768])
torch.Size([1, 768])
torch.Size([1, 768])


 36%|███▋      | 33/91 [00:01<00:01, 30.79it/s]

torch.Size([1, 768])
torch.Size([1, 768])
torch.Size([1, 768])
torch.Size([1, 768])
torch.Size([1, 768])
torch.Size([1, 768])
torch.Size([1, 768])
torch.Size([1, 768])


 45%|████▌     | 41/91 [00:01<00:01, 34.05it/s]

torch.Size([1, 768])
torch.Size([1, 768])
torch.Size([1, 768])
torch.Size([1, 768])
torch.Size([1, 768])
torch.Size([1, 768])
torch.Size([1, 768])
torch.Size([1, 768])


 54%|█████▍    | 49/91 [00:01<00:01, 36.30it/s]

torch.Size([1, 768])
torch.Size([1, 768])
torch.Size([1, 768])
torch.Size([1, 768])
torch.Size([1, 768])
torch.Size([1, 768])
torch.Size([1, 768])
torch.Size([1, 768])


 63%|██████▎   | 57/91 [00:01<00:00, 37.23it/s]

torch.Size([1, 768])
torch.Size([1, 768])
torch.Size([1, 768])
torch.Size([1, 768])
torch.Size([1, 768])
torch.Size([1, 768])
torch.Size([1, 768])
torch.Size([1, 768])


 71%|███████▏  | 65/91 [00:01<00:00, 37.05it/s]

torch.Size([1, 768])
torch.Size([1, 768])
torch.Size([1, 768])
torch.Size([1, 768])
torch.Size([1, 768])
torch.Size([1, 768])
torch.Size([1, 768])
torch.Size([1, 768])


 80%|████████  | 73/91 [00:02<00:00, 38.26it/s]

torch.Size([1, 768])
torch.Size([1, 768])
torch.Size([1, 768])
torch.Size([1, 768])
torch.Size([1, 768])
torch.Size([1, 768])
torch.Size([1, 768])
torch.Size([1, 768])


 89%|████████▉ | 81/91 [00:02<00:00, 38.22it/s]

torch.Size([1, 768])
torch.Size([1, 768])
torch.Size([1, 768])
torch.Size([1, 768])
torch.Size([1, 768])
torch.Size([1, 768])
torch.Size([1, 768])
torch.Size([1, 768])


 98%|█████████▊| 89/91 [00:02<00:00, 38.56it/s]

torch.Size([1, 768])
torch.Size([1, 768])
torch.Size([1, 768])
torch.Size([1, 768])
torch.Size([1, 768])
torch.Size([1, 768])
torch.Size([1, 768])
torch.Size([1, 768])


100%|██████████| 91/91 [00:02<00:00, 34.68it/s]

torch.Size([1, 768])





## Image retrieval with prompts

In [39]:
import os
from loguru import logger

test_img_dir = '/opt/product/sd_test/test_datas/test_screw'
img_paths = os.listdir(test_img_dir)


for img_path in img_paths:
    full_path = os.path.join(test_img_dir, img_path)
    img_cls = img_path.split('_')[0]
    input_image = Image.open(full_path)
    input_features = extract_features_clip(input_image)

    #Preprocess the vector before searching the FAISS index
    input_features_np = input_features.detach().cpu().numpy()
    input_features_np = np.float32(input_features_np)
    faiss.normalize_L2(input_features_np)

    #Search the top 5 images
    distances, indices = index.search(input_features_np, 1)
    best_match_idx = indices.flatten()[0]
    actual_cls = cls_with_idx[best_match_idx]
    logger.info(f'img_path {full_path}, match file {file_with_idx[best_match_idx]}, expect cls {img_cls}, actual cls {actual_cls}')
    ...

[32m2024-01-15 15:54:24.865[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m23[0m - [1mimg_path /opt/product/sd_test/test_datas/test_screw/3_1.bmp, match file /opt/product/sd_test/test_datas/screw_cls/lx_3/1.bmp, expect cls 3, actual cls 3[0m
[32m2024-01-15 15:54:24.904[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m23[0m - [1mimg_path /opt/product/sd_test/test_datas/test_screw/10_1.bmp, match file /opt/product/sd_test/test_datas/screw_cls/lx_10/2.bmp, expect cls 10, actual cls 10[0m


[32m2024-01-15 15:54:24.946[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m23[0m - [1mimg_path /opt/product/sd_test/test_datas/test_screw/7_1.bmp, match file /opt/product/sd_test/test_datas/screw_cls/lx_7/9.bmp, expect cls 7, actual cls 7[0m
[32m2024-01-15 15:54:24.983[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m23[0m - [1mimg_path /opt/product/sd_test/test_datas/test_screw/2_1.bmp, match file /opt/product/sd_test/test_datas/screw_cls/lx_2/4.bmp, expect cls 2, actual cls 2[0m
[32m2024-01-15 15:54:25.021[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m23[0m - [1mimg_path /opt/product/sd_test/test_datas/test_screw/6_1.bmp, match file /opt/product/sd_test/test_datas/screw_cls/lx_6/1.bmp, expect cls 6, actual cls 6[0m
[32m2024-01-15 15:54:25.058[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m23[0m - [1mimg_path /opt/product/sd_test/test_datas/test_screw/4_1.bmp, match file /opt/product/sd_test/test