In [1]:
%%capture
!pip install super-gradients==3.1.0
!pip install streamlit
!pip install faiss-cpu
!npm install localtunnel@2.0.2
!pip install gdown
!gdown 1-S1NsJtOiJMAp5AlxnyBXs5lOSZz_o4L

In [6]:
%%writefile feature_extraction.py
    
# feature_extraction.py
import numpy as np
from transformers import CLIPProcessor, CLIPModel
from super_gradients.training import models


class FeatureExtractor:
    
    # Constructor
    def __init__(self):

        self.model = CLIPModel.from_pretrained("patrickjohncyh/fashion-clip")
        self.preprocess = CLIPProcessor.from_pretrained("patrickjohncyh/fashion-clip")
        
    # Method to extract image features
    def extract_image_features(self, img):
        
        inputs = self.preprocess.image_processor(images=img, return_tensors="pt")
        
        outputs = self.model.get_image_features(**inputs)
        
        features = outputs[0].detach().numpy().reshape(1,-1)
        
#         features = features/np.linalg.norm(features, ord=2, axis=1, keepdims=True)
        
        return features
    
    def extract_text_features(self, txt):
        inputs = self.preprocess.tokenizer(txt, return_tensors='pt')
        
        outputs = self.model.get_text_features(**inputs)
        
        features = outputs[0].detach().numpy().reshape(1,-1)
        
#         features = features/np.linalg.norm(features, ord=2, axis=1, keepdims=True)
        
        return features
    
    
class ClothesExtractor:
    def __init__(self, class_dict, model_path='average_model.pth'):
        self.model = best_model = models.get('yolo_nas_s',
                        num_classes=44,
                        checkpoint_path = model_path)
        self.outputs = {}
        self.class_dict = class_dict
        

    def extract_clothing_item(self, image):
        
        self.outputs['original_image'] = image
        # get prediction
        results = list(self.model.predict(image, conf=0.5))[0].prediction

        for label_idx in set(results.labels):
            highest_score = 0
            for i in np.where(results.labels==label_idx)[0]:
                XMIN, YMIN, XMAX, YMAX = results.bboxes_xyxy[i]
                confidence = results.confidence[i]

                if confidence > highest_score:
                    highest_score = confidence
                    image_slice = image[int(YMIN):int(YMAX), int(XMIN):int(XMAX)]
                    self.outputs[self.class_dict[label_idx]] = image_slice
            # Add more clothing items here as needed
        return self.outputs

Overwriting feature_extraction.py


In [7]:
%%writefile data_processing.py
# data_processing.py

import time
import collections
import glob
import os
import pickle
import numpy as np
import faiss
import streamlit as st


def load_npz(file_path):
    data = np.load(file_path, allow_pickle=True)
    return dict(data.items())

def load_dictionary(file_path):
    with open(file_path, 'rb') as file:
        return pickle.load(file)

@st.cache_data
def build_faiss_index(file_path):

    if file_path.split('.')[-1] =='npz':
        features_vector_dict = load_npz(file_path)
    elif file_path.split('.')[-1] =='pkl':
        features_vector_dict = load_dictionary(file_path)


        
#     print(list(features_vector_dict.values())[:3])

    d = list(features_vector_dict.values())[0].shape[0]
    index = faiss.IndexFlatL2(d)
    xb = np.array(list(features_vector_dict.values())).reshape(-1, d)
    index.add(xb)
    
    return index, [key.split('/')[-1] for key in features_vector_dict]

# Constants
VCTORS_PATHS = [
    "/kaggle/input/fashion-dataset-clip-embeddings/extracted_features_vector_all.npz",
    "/kaggle/input/fashion-dataset-clip-embeddings/original_features_vector_all.npz",
    "/kaggle/input/fashion-dataset-clip-embeddings/extracted_features_vector_all.pkl",
    "/kaggle/input/fashion-dataset-clip-embeddings/original_features_vector_all.pkl"
]

EXTRACTED_DIR = '/kaggle/input/fashion-dataset22/extracted_images'
ORIGINAL_DIR = '/kaggle/input/fashion-dataset22/original_images'
MAPPING_FILE_PATH = '/kaggle/input/fashion-dataset22/name_mapping.json'


NUM_ROWS = 2
NUM_COLS = 5
NUM_IMAGES = NUM_ROWS * NUM_COLS
IMAGES_PER_ROW = NUM_COLS
IMAGE_WIDTH = 200

# Global Variables
@st.cache_data
def get_paths_dict(EXTRACTED_DIR, ORIGINAL_DIR):
    cutted_image_paths_dict = {file.split('/')[-1]: file for file in glob.glob(os.path.join(EXTRACTED_DIR, "*/*.jpg"))}
    cutted_image_paths_dict = collections.OrderedDict(sorted(cutted_image_paths_dict.items()))

    original_image_paths_dict = {file:os.path.join(ORIGINAL_DIR,file) for file in os.listdir(ORIGINAL_DIR)}
    original_image_paths_dict = collections.OrderedDict(sorted(original_image_paths_dict.items()))

    return original_image_paths_dict, cutted_image_paths_dict

original_image_paths_dict, cutted_image_paths_dict = get_paths_dict(EXTRACTED_DIR, ORIGINAL_DIR)

Overwriting data_processing.py


In [16]:
%%writefile app.py
# app.py


import streamlit as st
# Set the app name
st.set_page_config(page_title="Image Search Engine", layout="wide")

from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
from feature_extraction import FeatureExtractor, ClothesExtractor
from data_processing import build_faiss_index, original_image_paths_dict, cutted_image_paths_dict, VCTORS_PATHS, NUM_ROWS, NUM_COLS, NUM_IMAGES, IMAGES_PER_ROW, IMAGE_WIDTH



def retrieve_similar_images(query_embedding, index, num_results=10):
    _, closest_indices = index.search(query_embedding, num_results)
    return closest_indices

def make_grid(rows, cols):
    grid = [0] * rows
    for i in range(rows):
        with st.container():
            grid[i] = st.columns(cols)
    return grid

def app_layout():
    st.title("Image Search Engine")
    st.write("Search for similar images by uploading an image or entering a text query.")

    search_mode = st.radio("Search Mode", ("Image", "Text"))
    
    index_selection = st.radio("Select index", ("Original Images", "Cutted Images"))
        
    if index_selection == "Original Images":
        index, image_paths_lookup = original_index, original_image_paths_lookup
        image_paths_dict = original_image_paths_dict
    else:
        # Add logic for building the index for cutted images
        index, image_paths_lookup = cutted_index, cutted_image_paths_lookup
        image_paths_dict = cutted_image_paths_dict

    if search_mode == "Image":
        uploaded_image = st.file_uploader("Upload an image", type=['png', 'jpg', 'jpeg'])

        if uploaded_image is not None:
            queryImage = Image.open(uploaded_image)
            queryImage = queryImage.convert('RGB')
            if queryImage is not None:
                # Extract clothing items from the uploaded image
                
                extracted_items = extractor.extract_clothing_item(np.array(queryImage))
                
                num_rows = (len(extracted_items)//NUM_COLS) + 1
                

                # Display the extracted clothing items
#                 st.subheader("Uploaded Image")
#                 st.image(queryImage.resize((IMAGE_WIDTH, IMAGE_WIDTH)), use_column_width=True)

                st.subheader("Extracted Clothing Items")
                
                grid = make_grid(num_rows, NUM_COLS)
                for i in range(0, len(extracted_items), NUM_COLS):
                    st.write("\n")
                    images_row = list(extracted_items.items())[i: i + NUM_COLS]
                    
                    for j, (item_name, item_image) in enumerate(images_row):
                        grid[i // NUM_COLS][j % NUM_COLS].write(f"**{item_name}:**")
                        grid[i // NUM_COLS][j % NUM_COLS].image(item_image, width=IMAGE_WIDTH)

                # Let the user choose an item to use as the search image
                selected_item = st.selectbox("Select an item for similarity search", list(extracted_items.keys()) + ['All'])

                if selected_item !='All':
                    selected_image = Image.fromarray(extracted_items[selected_item])
                    queryFeature = resnet_feature_extractor.extract_image_features(selected_image)

                    if queryFeature is not None:
                        similar_indices = retrieve_similar_images(queryFeature, index)
                        top_10_indexes = [image_paths_lookup[idx] for idx in similar_indices[0]]
                        top_10_similar_imgs = [image_paths_dict[img] for img in top_10_indexes]

                        st.subheader("Similar Images")

                        grid = make_grid(NUM_ROWS, NUM_COLS)

                        for i in range(0, NUM_IMAGES, IMAGES_PER_ROW):
                            st.write("\n")
                            images_row = top_10_similar_imgs[i: i + IMAGES_PER_ROW]

                            for j, img_path in enumerate(images_row):
                                st.write('\t')
                                similar_image = Image.open(img_path)
                                similar_image = similar_image.resize((IMAGE_WIDTH, IMAGE_WIDTH))
                                grid[i // IMAGES_PER_ROW][j % IMAGES_PER_ROW].image(similar_image, width=IMAGE_WIDTH)

                        st.write("\n")
                        
                elif selected_item =='All':
                    for selected_image_item in list(extracted_items.keys()):
                        
                        selected_image = Image.fromarray(extracted_items[selected_image_item])
                        queryFeature = resnet_feature_extractor.extract_image_features(selected_image)   
                        st.subheader("Query Image")
                        st.image(selected_image.resize((IMAGE_WIDTH, IMAGE_WIDTH)))

                        if queryFeature is not None:
                            similar_indices = retrieve_similar_images(queryFeature, index)
                            top_10_indexes = [image_paths_lookup[idx] for idx in similar_indices[0]]
                            top_10_similar_imgs = [image_paths_dict[img] for img in top_10_indexes]

                            st.subheader("Similar Images")

                            grid = make_grid(NUM_ROWS, NUM_COLS)

                            for i in range(0, NUM_IMAGES, IMAGES_PER_ROW):
                                st.write("\n")
                                images_row = top_10_similar_imgs[i: i + IMAGES_PER_ROW]

                                for j, img_path in enumerate(images_row):
                                    st.write('\t')
                                    similar_image = Image.open(img_path)
                                    similar_image = similar_image.resize((IMAGE_WIDTH, IMAGE_WIDTH))
                                    grid[i // IMAGES_PER_ROW][j % IMAGES_PER_ROW].image(similar_image, width=IMAGE_WIDTH)

                            st.write("\n")

    
    elif search_mode == "Text":
        text_query = st.text_input("Enter a text query")
        if text_query:
        
            queryFeature_Text = resnet_feature_extractor.extract_text_features(text_query)
            
            if queryFeature_Text is not None:
                similar_indices = retrieve_similar_images(queryFeature_Text, index)
                top_10_indexes = [image_paths_lookup[idx] for idx in similar_indices[0]]
                top_10_similar_imgs = [image_paths_dict[img] for img in top_10_indexes]

                st.subheader("Text Query")
                st.write(text_query)

                st.subheader("Similar Images")

                grid = make_grid(NUM_ROWS, NUM_COLS)

                for i in range(0, NUM_IMAGES, IMAGES_PER_ROW):
                    st.write("\n")
                    images_row = top_10_similar_imgs[i: i + IMAGES_PER_ROW]

                    for j, img_path in enumerate(images_row):
                        st.write('\t')
                        similar_image = Image.open(img_path)
                        similar_image = similar_image.resize((IMAGE_WIDTH, IMAGE_WIDTH))
                        grid[i // IMAGES_PER_ROW][j % IMAGES_PER_ROW].image(similar_image, width=IMAGE_WIDTH)

                    st.write("\n")

if __name__ == "__main__":
    
    resnet_feature_extractor = FeatureExtractor()
    
    clothing_dict = {0: 'Suitcase', 1: 'Miniskirt', 2: 'Tie', 3: 'Luggage & bags', 4: 'Shoe', 5: 'Belt', 6: 'Outerwear', 7: 'Dress', 8: 'Earrings', 
                 9: 'Bracelet', 10: 'Necklace', 11: 'Brassiere', 12: 'Footwear', 13: 'Satchel', 14: 'Bowtie', 15: 'Top', 16: 'Pants', 17: 'Sunglasses', 
                 18: 'Swimwear', 19: 'Clothing', 20: 'Glove', 21: 'Skirt', 22: 'High heels', 23: 'Underpants', 24: 'Fedora', 25: 'Sun hat', 26: 'Sock', 
                 27: 'Wallet', 28: 'Scarf', 29: 'Watch', 30: 'Umbrella', 31: 'Glasses', 32: 'Boot', 33: 'Basket', 34: 'Backpack', 35: 'Bag', 36: 'Hat', 
                 37: 'Coat', 38: 'Sandal', 39: 'Shorts', 40: 'Jeans', 41: 'Shirt', 42: 'Handbag', 43: 'Jacket'}
    extractor = ClothesExtractor(class_dict=clothing_dict)
    
    original_index, original_image_paths_lookup = build_faiss_index(VCTORS_PATHS[3])
    cutted_index, cutted_image_paths_lookup = build_faiss_index(VCTORS_PATHS[2])
    app_layout()

Overwriting app.py


In [None]:
!streamlit run app.py & npx localtunnel --port 8501


Collecting usage statistics. To deactivate, set browser.gatherUsageStats to False.
[0m
your url is: https://five-apples-count.loca.lt
[0m
[34m[1m  You can now view your Streamlit app in your browser.[0m
[0m
[34m  Network URL: [0m[1mhttp://172.19.2.2:8501[0m
[34m  External URL: [0m[1mhttp://35.201.204.155:8501[0m
[0m
The console stream is logged into /root/sg_logs/console.log
[2023-06-22 10:57:28] INFO - crash_tips_setup.py - Crash tips is enabled. You can set your environment variable to CRASH_HANDLER=FALSE to disable it
[2023-06-22 10:57:35] INFO - loader.py - Loading faiss with AVX2 support.
[2023-06-22 10:57:35] INFO - loader.py - Successfully loaded faiss with AVX2 support.
caused by: ['/opt/conda/lib/python3.10/site-packages/tensorflow_io/python/ops/libtensorflow_io_plugins.so: undefined symbol: _ZN3tsl6StatusC1EN10tensorflow5error4CodeESt17basic_string_viewIcSt11char_traitsIcEENS_14SourceLocationE']
caused by: ['/opt/conda/lib/python3.10/site-packages/tensorflow_i