## Imports

In [None]:
from opensearchpy import OpenSearch
from IPython.display import display, HTML
import imgkit
import shutil
import os
from query import search

import sys
import fasttext
from query import QueryClassification

sys.path.append(os.path.abspath(os.path.join("..")))
from week3.transform_query import transform_query

## Setup

In [None]:
host = "localhost"
port = 9200
base_url = "http://{}:{}/".format(host, port)
ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname("__file__"), '..'))
FILE_DIR = os.path.abspath(os.path.dirname("__file__"))
IMAGE_OUTPUT_DIR = os.path.join(ROOT_DIR, "week4/out")
client = OpenSearch(
    hosts=[{"host": host, "port": port}],
    http_compress=True,  # enables gzip compression for request bodies
    use_ssl=False,
)

query_classification_model_file_path = os.path.join(ROOT_DIR, "datasets/fasttext/labeled_queries_model.bin")
query_classification_model = fasttext.load_model(
    query_classification_model_file_path
)

## Output

In [None]:
def render_comparision(user_query, render_to_img):
    base_query_response, _not_used, _not_used = search(
        client, user_query, source=source
    )
    vector_search_response, _not_used, _not_used = search(
        client, user_query, source=source, size=size, vector_search=True
    )
    vector_search_with_query_classification_response, _not_used, _not_used = search(
        client,
        user_query,
        source=source,
        query_classification=query_classification,
        vector_search=True,
    )

    def hit_to_html(hit):
        hit_source = hit["_source"]
        return f"""<div style="display: flex; align-items:center; max-width:100%; oveflow:hidden; padding: 10px;">
                        <img style="max-width: 200px; max-height: 80px; margin-right:10px;" src="{hit_source["image"][0]}"/>
                        <span style="font-size: 16px; color: black;">{hit_source["name"][0]}</span>
                    </div>"""

    base_query_html = "\n".join(
        map(lambda hit: hit_to_html(hit), base_query_response["hits"]["hits"])
    )
    vector_search_html = "\n".join(
        map(lambda hit: hit_to_html(hit), vector_search_response["hits"]["hits"])
    )
    vector_search_with_query_classification_html = "\n".join(
        map(
            lambda hit: hit_to_html(hit),
            vector_search_with_query_classification_response["hits"]["hits"],
        )
    )

    html = f"""<head>
        <style>
        .container {{background: white;}}
        h4 {{font-size: 18px; color: black; background: white;}}
        h3 {{font-size: 20px; color: black; text-align:center; border-bottom:1px solid;}}
        section {{display: inline-block; width:33%; border-right:1px dashed;}}
        </style>
    </head>
    <body>
        <div class="container">
            <h4>Search results for: "{user_query}"</h4>
            <div>
                <section>
                    <div><h3>Base query</h3>
                    {base_query_html}
                </section>
                <section>
                    <h3>Vector search</h3>
                    {vector_search_html}
                </section>
                <section>
                    <h3>Vector with query classification</h3>
                    {vector_search_with_query_classification_html}
                </section>
            </div>
        </div>
    </body>
    """
    if render_to_img:
        imgkit.from_string(
            html,
            f"{user_query}.jpg",
            options={
                "format": "jpeg",
            },
        )
        shutil.move(
            os.path.join(FILE_DIR, f"{user_query}.jpg"),
            os.path.join(IMAGE_OUTPUT_DIR, f"{user_query}.jpg"),
        )
    else:
        display(HTML(html))


# config
size = 10
source = ["name", "shortDescription", "image"]
test_queries = ["Ipad", "Touchpad", "camera", "Bed"]

query_classification = QueryClassification(
    model=query_classification_model,
    threshold=0.5,
    label_prefix="__label__",
    transform_query=transform_query,
    prediction_count=5,
)

for user_query in test_queries:
    render_comparision(user_query=user_query, render_to_img=True)