In [1]:
# Importing Dependencies
import os
import json
import requests
import torch
import torch.nn.functional as F
from typing import List, Generator
from pymongo import MongoClient
from IPython.display import display, Image
import nltk
from nltk.corpus import stopwords
import tempfile
import cv2  # Replacing PIL with a faster library
from multiprocessing import Pool  # Importing Pool for multiprocessing
from concurrent.futures import ThreadPoolExecutor, as_completed  # For multithreading

In [2]:
# Download the NLTK stopwords
nltk.download('stopwords')
stop_words = set(stopwords.words('english'))

[nltk_data] Downloading package stopwords to
[nltk_data]     C:\Users\Cypher\AppData\Roaming\nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [3]:
# Create a MongoDB client
client = MongoClient('mongodb://localhost:27017')  # replace with your connection string

# Create a database and a collection within the database
db = client['scene-sense']
embeddings_collection = db['embeddings']

# Create an index on the 'embedding' field
embeddings_collection.create_index("embedding")

'embedding_1'

In [4]:
def chunker(seq: list, size: int) -> Generator:
    """Yield chunks of data from a larger list."""
    if size <= 0:
        raise ValueError("Size must be a positive integer")
    for pos in range(0, len(seq), size):
        yield seq[pos:pos + size]

In [5]:
def prepare_files(directory: str) -> Generator:
    accepted_extensions = [".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".webp"]

    # Get list of filenames from directory
    filenames = [f for f in os.listdir(directory) if any(f.endswith(ext) for ext in accepted_extensions)]

    # Fetch existing image paths from database
    image_paths_in_db = set(doc['image_path'] for doc in embeddings_collection.find({}, {'image_path': 1}))

    # Using multithreading to speed up the processing of the files
    with ThreadPoolExecutor() as executor:
        futures = []
        for filename in filenames:
            futures.append(executor.submit(process_image_file, filename, image_paths_in_db, directory))
        for future in as_completed(futures):
            result = future.result()
            if result is not None:
                yield result

In [6]:
def process_image_file(filename: str, image_paths_in_db: set, directory: str):
    image_path = os.path.join(directory, filename)
    if image_path not in image_paths_in_db:
        try:
            image = cv2.imread(image_path)
            resized_image = cv2.resize(image, (224, 224))
            temp_file_path = tempfile.mktemp(suffix=".png")
            cv2.imwrite(temp_file_path, resized_image)
            return image_path, temp_file_path
        except Exception as e:
            print(f"Error processing image '{image_path}': {str(e)}")
    return None

In [7]:
# Use ThreadPoolExecutor for send_images
def send_images(directory: str, batch_size: int = 100):
    with ThreadPoolExecutor() as executor:
        futures = []
        files_chunks = list(chunker(list(prepare_files(directory)), batch_size))
        for chunk_index, files_chunk in enumerate(files_chunks, 1):
            futures.append(executor.submit(process_image_chunk, chunk_index, files_chunk, len(files_chunks)))
        for future in as_completed(futures):
            print(future.result())

In [8]:
def process_image_chunk(chunk_index, files_chunk, total_chunks):
    files = []
    file_objects = []  # List to hold file objects
    for image_path, temp_file_path in files_chunk:
        file_obj = open(temp_file_path, 'rb')  # Store the file object in a variable
        file_objects.append(file_obj)  # Add the file object to the list
        files.append(('images', (image_path, file_obj, 'image/png')))
    print(f"Processing chunk {chunk_index}/{total_chunks} - {len(files_chunk)} images...")
    response = requests.post('http://148.113.143.16:9999/image_embeddings/', files=files)
    if response.status_code == 200:
        embeddings = response.json()['image_embeddings']
        documents = [{'image_path': image_path, 'embedding': embedding} for (image_path, _), embedding in zip(files_chunk, embeddings)]
        embeddings_collection.insert_many(documents)
        print(f"Chunk {chunk_index}/{total_chunks} processed successfully!")
    else:
        print(f'Error while processing chunk {chunk_index}/{total_chunks}: {response.text}')
    # Close all open files
    for file_obj in file_objects:
        file_obj.close()
    # Now it's safe to delete the files
    for _, temp_file_path in files_chunk:
        os.remove(temp_file_path)

In [9]:
def send_text(prompt):
    data = {"prompt": prompt}
    response = requests.post('http://148.113.143.16:9999/text_embeddings/', json=data)
    if response.status_code == 200:
        return response.json()
    else:
        print(f'Error while sending text: {response.text}')
        return None

In [10]:
# Search for similar images given a query in DB
def get_similar_images(text_embedding: str, similarity_threshold: float = 0.22) -> List[str]:
    # Processing text embeddings
    text_embedding = torch.tensor(text_embedding)
    # Get all embeddings from the database
    documents = embeddings_collection.find()
    # Compute similarities with all image embeddings and get all that are above the threshold
    similar_images = []
    for document in documents:
        image_path = document['image_path']
        image_embeddings = torch.tensor(document['embedding'])
        similarity = F.cosine_similarity(text_embedding, image_embeddings)
        if similarity.item() > similarity_threshold:
            similar_images.append((similarity.item(), image_path))
    # Sort by similarity score
    similar_images = sorted(similar_images, key=lambda x: x[0], reverse=True)
    # Return image paths
    return [path for sim, path in similar_images]

In [11]:
def compute_similarity_threshold(prompt: str, base_threshold: float = 0.23, increment: float = 0.015) -> float:
    words = [word for word in prompt.split() if word not in stop_words]
    num_words = len(words)
    return base_threshold + num_words * increment

In [12]:
send_images('D:\AIML\gimmick\images2000')

Error processing image 'D:\AIML\gimmick\images2000\↻𓂃 𝐄𝐔𝐏𝐇𝐎𝐑𝐈𝐀 ━━ 𝐓𝐨𝐤𝐲𝐨 𝐑𝐞𝐯𝐞𝐧𝐠𝐞𝐫𝐬, 𝖨𝗆𝖺𝗀𝗂𝗇𝖾𝗌.jpeg': OpenCV(4.7.0) D:\a\opencv-python\opencv-python\opencv\modules\imgproc\src\resize.cpp:4062: error: (-215:Assertion failed) !ssize.empty() in function 'cv::resize'



In [13]:
# Send a text prompt to the FastAPI server and store the returned embedding
prompt = "sunset"
text_embedding = send_text(prompt)
text_embedding = text_embedding.get('text_embedding')

In [14]:
# compute the similarity threshold based on the number of words in the prompt
similarity_threshold = compute_similarity_threshold(prompt)

In [15]:
# Get similar images from the database
get_similar_images(text_embedding, similarity_threshold)

['D:\\AIML\\gimmick\\images2000\\photo-1686882846945-4a07366aba60.jpg',
 'D:\\AIML\\gimmick\\images2000\\photo-1687166431515-123c0419a89b.jpg',
 'D:\\AIML\\gimmick\\images2000\\photo-1687222771254-ddaa3381e619.jpg',
 'D:\\AIML\\gimmick\\images2000\\premium_photo-1677693662922-3a8f7abe5f2b.jpg',
 'D:\\AIML\\gimmick\\images2000\\photo-1687252827323-6bb790696569.jpg',
 'D:\\AIML\\gimmick\\images2000\\photo-1687436583975-1ea5a34becc2.jpg',
 'D:\\AIML\\gimmick\\images2000\\photo-1682687982204-f1a77dcc3067.jpg',
 'D:\\AIML\\gimmick\\images2000\\premium_photo-1670148434846-eab89d9dbdc0.jpg',
 'D:\\AIML\\gimmick\\images2000\\photo-1686900248731-5a492eb9a5e2.jpg',
 'D:\\AIML\\gimmick\\images2000\\photo-1687175452114-3a4fc0664df3.jpg',
 'D:\\AIML\\gimmick\\images2000\\photo-1687789903431-54cdeb56bb43.jpg',
 'D:\\AIML\\gimmick\\images2000\\photo-1503162894963-8c333941fb01.jpg',
 'D:\\AIML\\gimmick\\images2000\\premium_photo-1681140029699-ee43ba816894.jpg',
 'D:\\AIML\\gimmick\\images2000\\photo-1