In [None]:
from transformers import ViTFeatureExtractor, ViTForImageClassification
import torch
import json
from tqdm import tqdm
from PIL import Image as PILImage
# search by background
from concurrent.futures import ThreadPoolExecutor
import numpy as np
import faiss

model = ViTForImageClassification.from_pretrained("D:\image_search_engine_ai-end/google_vit_base/")
feature_extractor = ViTFeatureExtractor.from_pretrained("D:\image_search_engine_ai-end/google_vit_base/")

with open('../sources/metadata/image_database.json', 'r') as f:
    image_database = json.load(f)



In [13]:
all_images = (
    image_database[0]['images'] + 
    image_database[1]['images'] + 
    image_database[2]['images'] + 
    image_database[3]['images']
    )

In [22]:
for image_info in tqdm(all_images[:100], desc="Processing Images", unit='image'):
    try:
        image_path = image_info['path']
        inputs = feature_extractor(images=PILImage.open(image_path), return_tensors="pt")
        outputs = model(**inputs)
        predicted_class = outputs.logits.argmax(-1).item()
        image_info['background_class'] = model.config.id2label[predicted_class]
    except Exception as e:
        print(f"Error extracting background classes from {image_path} : {str(e)}")

Processing Images:   0%|          | 0/100 [00:00<?, ?image/s]

Processing Images:  51%|█████     | 51/100 [00:25<00:15,  3.21image/s]

Error extracting background classes from D:\image_search_engine_ai-end\folder\camera\000054.jpg : [Errno 2] No such file or directory: 'D:\\image_search_engine_ai-end\\folder\\camera\\000054.jpg'


Processing Images:  86%|████████▌ | 86/100 [00:36<00:04,  3.22image/s]

Error extracting background classes from D:\image_search_engine_ai-end\folder\camera\000092.jpg : [Errno 2] No such file or directory: 'D:\\image_search_engine_ai-end\\folder\\camera\\000092.jpg'


Processing Images: 100%|██████████| 100/100 [00:40<00:00,  2.47image/s]


In [None]:
from concurrent.futures import ThreadPoolExecutor

def process_single_image(image_info, feature_extractor, model):
    progress_bar.update(1)
    try:
        image_path = image_info['path']
        inputs = feature_extractor(images=PILImage.open(image_path), return_tensors="pt")
        outputs = model(**inputs)
        predicted_class = outputs.logits.argmax(-1).item()
        image_info['background_class'] = model.config.id2label[predicted_class]
    except Exception as e:
        image_info['background_class'] = "non"
        
with ThreadPoolExecutor(max_workers=4) as executor:
    futures = {
        executor.submit(process_single_image, img, feature_extractor, model): img 
        for img in all_images
    }
    progress_bar = tqdm(total=len(futures), desc="Processing images", unit="image")
    
progress_bar.close()

Processing images: 100%|█████████▉| 1717/1721 [11:09<00:01,  2.57image/s]


In [None]:
with open('../sources/metadata/image_database.json', 'w') as f:
    json.dump(image_database, f, indent=4)

In [None]:
class_indices = {} 
for idx, img_info in enumerate(all_images):
    class_name = img_info['background_class']
    if class_name not in class_indices:
        class_indices[class_name] = []
    class_indices[class_name].append(idx)

with open('../sources/metadata/image_background_pair.json', 'w') as f:
            json.dump(class_indices, f, indent=4)

In [None]:
# get background options
background_categories = class_indices

In [None]:
# filter all image by background class
class_name = 'gown'
output = [all_images[i]["path"] for i in class_indices[class_name]]
for i in output:
    print(i)


In [None]:
# search image by background
def search_similar(query_embedding, class_name=None, k=10):
        query_embedding = np.array(
            query_embedding).astype('float32').reshape(1, -1)
            
        indices = class_indices[class_name]
        class_embeddings = np.array(
            [all_images[i]["feature"] for i in indices]).astype('float32')
        
        temp_index = faiss.IndexFlatL2(512)
        temp_index.add(class_embeddings)
        
        _, I = temp_index.search(query_embedding, k=min(k, len(indices)))
        return [all_images[indices[i]] for i in I[0]]



query_img_info = all_images[0]
output = search_similar(
    query_img_info['feature'],
    class_name="gown",
    k=10
)
for i in output:
    print(i['path'])