In [None]:
import os

from elasticsearch import Elasticsearch
from sentence_transformers import SentenceTransformer
from PIL import Image

In [None]:
ELASTICSEARCH_HOST = "http://elasticsearch:9200"

In [None]:
## Define Utility Functions

In [None]:
def get_elasticsearch(hosts=ELASTICSEARCH_HOST):
    return Elasticsearch(hosts=hosts)

In [None]:
def create_index(elasticsearch: Elasticsearch, index_name:str, mappings:dict=None, settings:dict=None, overwrite_existing=False):

    if settings is None:
        settings = {}
    if mappings is None:
        mappings = {}

    create_index_body = {
        "settings":settings,
        "mappings":mappings
    }

    if elasticsearch.indices.exists(index_name):
        if overwrite_existing:
            elasticsearch.indices.delete(index_name)
        else:
            raise Exception(f"Index {index_name} already exists")

    elasticsearch.indices.create(index=index_name, body=create_index_body)
            

In [None]:
embedding_model = SentenceTransformer('clip-ViT-B-32')
def generate_image_embedding(image:Image):
    return embedding_model.encode(image)

def generate_image_embedding_from_path(image_path:str):
    img = Image.open(image_path)
    return generate_image_embedding(img)


In [None]:
def insert_document(elasticsearch:Elasticsearch, index_name:str, document:dict):
    elasticsearch.index(index=index_name, document=document)

def bulk_insert_documents(elasticsearch:Elasticsearch, index_name:str, documents:list[dict], ):
    inserts = []
    for doc in documents:
        inserts.append({"index":{"_index":index_name}})
        inserts.append(doc)
    elasticsearch.bulk(body=inserts)

def image_search(image:Image, elasticsearch:Elasticsearch, index_name:str, image_embedding_field_name:str, k=3, num_candidates=100, size=3):

    image_embedding = generate_image_embedding(image)
    query_body = {
        "knn":{
            "field":image_embedding_field_name,
            "k":k,
            "num_candidates":num_candidates,
            "query_vector":image_embedding,
            "boost":100
        }
    }

    res = elasticsearch.search(index=index_name, body=body, size=size)
    return res