In [None]:
# Importing Dependencies
import os
import json
import requests
import torch
import torch.nn.functional as F
from typing import List, Generator
from bson.binary import Binary
from pymongo import MongoClient
from dotenv import load_dotenv
from pymongo.server_api import ServerApi
from IPython.display import display, Image
import nltk
from nltk.corpus import stopwords
import tempfile
import cv2  # Replacing PIL with a faster library
from multiprocessing import Pool  # Importing Pool for multiprocessing
from concurrent.futures import ThreadPoolExecutor, as_completed  # For multithreading

In [None]:
import cv2
import numpy as np
import matplotlib.pyplot as plt

In [None]:
# Download the NLTK stopwords
nltk.download('stopwords')
stop_words = set(stopwords.words('english'))

In [None]:
# Loading Environment Variables
load_dotenv()
username = os.getenv("MONGO_USERNAME")
password = os.getenv("MONGO_PASSWORD")

In [None]:
# Username and password for MongoDB Atlas
uri = f'mongodb+srv://{username}:{password}@sample-images.qlezbxu.mongodb.net/?retryWrites=true&w=majority'

# Create a new client and connect to the server
client = MongoClient(uri, server_api=ServerApi('1'))

In [None]:
# Send a ping to confirm a successful connection
try:
    client.admin.command('ping')
    print("Pinged your deployment. You successfully connected to MongoDB!")
except Exception as e:
    print(e)

In [None]:
# Create a database and a collection within the database
db = client['sample-images']
embeddings_collection = db['demo']

In [None]:
# A directory of your images
image_directory = "D:\AIML\Scene-Sense\sample_images"

In [None]:
def chunker(seq: list, size: int) -> Generator:
    """Yield chunks of data from a larger list."""
    if size <= 0:
        raise ValueError("Size must be a positive integer")
    for pos in range(0, len(seq), size):
        yield seq[pos:pos + size]

In [None]:
def prepare_files(directory: str) -> Generator:
    accepted_extensions = [".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".webp"]

    # Get list of filenames from directory
    filenames = [f for f in os.listdir(directory) if any(f.endswith(ext) for ext in accepted_extensions)]

    # Fetch existing image paths from database
    image_paths_in_db = set(doc['image_path'] for doc in embeddings_collection.find({}, {'image_path': 1}))

    # Using multithreading to speed up the processing of the files
    with ThreadPoolExecutor() as executor:
        futures = []
        for filename in filenames:
            futures.append(executor.submit(process_image_file, filename, image_paths_in_db, directory))
        for future in as_completed(futures):
            result = future.result()
            if result is not None:
                yield result

In [None]:
def process_image_file(filename: str, image_paths_in_db: set, directory: str):
    image_path = os.path.join(directory, filename)
    if image_path not in image_paths_in_db:
        try:
            image = cv2.imread(image_path)
            resized_image = cv2.resize(image, (224, 224))
            temp_file_path = tempfile.mktemp(suffix=".png")
            cv2.imwrite(temp_file_path, resized_image)
            return image_path, temp_file_path
        except Exception as e:
            print(f"Error processing image '{image_path}': {str(e)}")
    return None

In [None]:
# Use ThreadPoolExecutor for send_images
def send_images(directory: str, batch_size: int = 1000):
    with ThreadPoolExecutor() as executor:
        futures = []
        files_chunks = list(chunker(list(prepare_files(directory)), batch_size))
        for chunk_index, files_chunk in enumerate(files_chunks, 1):
            futures.append(executor.submit(process_image_chunk, chunk_index, files_chunk, len(files_chunks)))
        for future in as_completed(futures):
            print(future.result())

In [None]:
def process_image_chunk(chunk_index, files_chunk, total_chunks):
    files = []
    file_objects = []  # List to hold file objects
    for image_path, temp_file_path in files_chunk:
        file_obj = open(temp_file_path, 'rb')  # Store the file object in a variable
        file_objects.append(file_obj)  # Add the file object to the list
        files.append(('images', (image_path, file_obj, 'image/png')))
    print(f"Processing chunk {chunk_index}/{total_chunks} - {len(files_chunk)} images...")
    response = requests.post('http://148.113.143.16:9999/image_embeddings/', files=files)
    if response.status_code == 200:
        embeddings = response.json()['image_embeddings']
        # Change the documents to include binary image data instead of the image path
        documents = [{'image': Binary(open(temp_file_path, 'rb').read()), 'embedding': embedding} 
                     for (image_path, temp_file_path), embedding in zip(files_chunk, embeddings)]
        try:
            embeddings_collection.insert_many(documents)
            print(f"Chunk {chunk_index}/{total_chunks} processed successfully!")
        except errors.InvalidDocument as e:
            print(f"Failed to store the images of chunk {chunk_index}/{total_chunks} due to their sizes: {e}")
    else:
        print(f'Error while processing chunk {chunk_index}/{total_chunks}: {response.text}')
    # Close all open files
    for file_obj in file_objects:
        file_obj.close()
    # Now it's safe to delete the temp files
    for _, temp_file_path in files_chunk:
        os.remove(temp_file_path)

In [None]:
send_images(image_directory)

In [None]:
def send_text(prompt):
    data = {"prompt": prompt}
    response = requests.post('http://148.113.143.16:9999/text_embeddings/', json=data)
    if response.status_code == 200:
        return response.json()
    else:
        print(f'Error while sending text: {response.text}')
        return None

In [None]:
def get_similar_images(text_embedding: str, similarity_threshold: float = 0.22):
    # Processing text embeddings
    text_embedding = torch.tensor(text_embedding)

    # Get all embeddings from the database
    documents = embeddings_collection.find()

    # Compute similarities with all image embeddings and get all that are above the threshold
    similar_images = []
    for document in documents:
        image_data = document['image']
        image_embeddings = torch.tensor(document['embedding'])
        similarity = F.cosine_similarity(text_embedding, image_embeddings)
        if similarity.item() > similarity_threshold:
            similar_images.append((similarity.item(), image_data))

    # Sort by similarity score
    similar_images = sorted(similar_images, key=lambda x: x[0], reverse=True)

    # Return image paths
    for sim, img_data in similar_images:
        nparr = np.fromstring(img_data, np.uint8)
        img_np = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
        plt.figure(figsize=(5,5))
        plt.imshow(cv2.cvtColor(img_np, cv2.COLOR_BGR2RGB))
        plt.title(f"Similarity: {sim}")
        plt.show()

In [None]:
def compute_similarity_threshold(prompt: str, base_threshold: float = 0.23, increment: float = 0.02) -> float:
    words = [word for word in prompt.split() if word not in stop_words]
    num_words = len(words)
    return base_threshold + num_words * increment

In [None]:
# Send a text prompt to the FastAPI server and store the returned embedding
prompt = "car"
text_embedding = send_text(prompt)
text_embedding = text_embedding.get('text_embedding')

In [None]:
# compute the similarity threshold based on the number of words in the prompt
similarity_threshold = compute_similarity_threshold(prompt)

In [None]:
# Get similar images from the database
get_similar_images(text_embedding, similarity_threshold)

In [None]:
import urllib.request
from PIL import Image

In [None]:
# Retrieving the resource located at the URL
# and storing it in the file name a.png
url = "https://prodigalbookable.blob.core.windows.net/prodigalbookable/0013d4c8-5a4d-40a5-98b9-2445a55a4750.png"
urllib.request.urlretrieve(url, "geeksforgeeks.png")