In [1]:
from aips import get_engine
from aips.spark import create_view_from_collection, get_spark_session
from aips.spark.dataframe import from_sql
from pyspark.sql.window import Window
from pyspark.sql.functions import col, split, regexp_replace
from IPython.display import display, HTML
import ipywidgets as widgets
from PIL import Image
import pickle
import requests
import numpy
import torch
import clip
from io import BytesIO
import aips.indexer
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

from aips import set_engine; set_engine("opensearch")
engine = get_engine()
spark = get_spark_session()
aips.indexer.build_collection(engine, "tmdb")
aips.indexer.download_data_files("movies_with_image_embeddings")        

'data/tmdb/movies_with_image_embeddings.pickle'

## Listing 15.14

In [2]:
def normalize_embedding(embedding):
    return numpy.divide(embedding,
      numpy.linalg.norm(embedding,axis=0)).tolist()

def read(cache_name):
    cache_file_name = f"data/tmdb/{cache_name}.pickle"
    with open(cache_file_name, "rb") as fd:
        return pickle.load(fd)

def tmdb_with_embeddings_dataframe():
    movies = read("movies_with_image_embeddings")
    embeddings = movies["image_embeddings"]
    normalized_embeddings = [normalize_embedding(e) for e in embeddings]
    movies_dataframe = spark.createDataFrame(
        zip(movies["movie_ids"], movies["titles"], 
            movies["image_ids"], normalized_embeddings),
        schema=["movie_id", "title", "image_id", "image_embedding"])
    return movies_dataframe

embeddings_dataframe = tmdb_with_embeddings_dataframe()
embeddings_collection = engine.create_collection("tmdb_with_embeddings")
embeddings_collection.write(embeddings_dataframe)

Wiping "tmdb_with_embeddings" collection
Creating "tmdb_with_embeddings" collection


7549

<Response [200]>
{'_shards': {'total': 2, 'successful': 1, 'failed': 0}}
Successfully written 7549 documents


### Setting up the lexical embedding index


In [3]:
def tmdb_lexical_embeddings_dataframe():
    lexical_tmdb_collection = engine.get_collection("tmdb")
    create_view_from_collection(lexical_tmdb_collection, "tmdb_lexical")
    embeddings_tmdb_collection = engine.get_collection("tmdb_with_embeddings")
    create_view_from_collection(embeddings_tmdb_collection, "tmdb_with_embeddings")
    
    dataframe = from_sql("SELECT id, title FROM tmdb_lexical")
    dataframe.printSchema()
    dataframe.show(1)

    dataframe = from_sql("SELECT * FROM tmdb_with_embeddings inner join tmdb_lexical on tmdb_lexical.id = tmdb_with_embeddings.movie_id")
    dataframe.printSchema()
    for d in dataframe.head(2):
        display(d)

    
    joined_collection_sql = f"""    
    SELECT lexical.id id, embeddings.movie_id movie_id, lexical.title title, lexical.overview overview, embeddings.image_embedding, embeddings.image_id
    FROM tmdb_with_embeddings embeddings
    INNER JOIN (SELECT DISTINCT image_id from (SELECT movie_id, MIN(image_id) image_id from tmdb_with_embeddings GROUP BY movie_id ORDER BY image_id ASC)) distinct_images on embeddings.image_id = distinct_images.image_id
    INNER JOIN tmdb_lexical lexical ON lexical.id = embeddings.movie_id
    ORDER by lexical.id asc"""

    joined_dataframe = from_sql(joined_collection_sql) 
    #joined_dataframe = joined_dataframe.withColumn("image_embedding", split(regexp_replace(col("image_embedding"),"\"",""), ","))
    
    joined_dataframe.printSchema()
    joined_dataframe.show(3)
    return joined_dataframe

lexical_embeddings = tmdb_lexical_embeddings_dataframe()

lexical_collection = engine.create_collection("tmdb_lexical_plus_embeddings")
lexical_collection.write(lexical_embeddings)

root
 |-- id: string (nullable = true)
 |-- title: string (nullable = true)

+------+--------------------+
|    id|               title|
+------+--------------------+
|250114|Nick Offerman: Am...|
+------+--------------------+
only showing top 1 row

root
 |-- image_id: string (nullable = true)
 |-- movie_id: string (nullable = true)
 |-- title: string (nullable = true)
 |-- image_embedding: array (nullable = true)
 |    |-- element: string (containsNull = true)
 |-- cast: string (nullable = true)
 |-- directors: string (nullable = true)
 |-- genres: string (nullable = true)
 |-- id: string (nullable = true)
 |-- movie_image_ids: string (nullable = true)
 |-- overview: string (nullable = true)
 |-- poster_file: string (nullable = true)
 |-- poster_path: string (nullable = true)
 |-- release_date: timestamp (nullable = true)
 |-- release_year: double (nullable = true)
 |-- tagline: string (nullable = true)
 |-- title: string (nullable = true)
 |-- vote_average: float (nullable = true)
 

Py4JJavaError: An error occurred while calling o177.collectToPython.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 9.0 failed 1 times, most recent failure: Lost task 0.0 in stage 9.0 (TID 131) (6317ec753586 executor driver): java.lang.RuntimeException: Error while encoding: java.lang.RuntimeException: scala.collection.convert.Wrappers$JListWrapper is not a valid external type for schema of string
if (assertnotnull(input[0, org.apache.spark.sql.Row, true]).isNullAt) null else staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, validateexternaltype(getexternalrowfield(assertnotnull(input[0, org.apache.spark.sql.Row, true]), 0, cast), StringType, true), true, false, true) AS cast#110
if (assertnotnull(input[0, org.apache.spark.sql.Row, true]).isNullAt) null else staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, validateexternaltype(getexternalrowfield(assertnotnull(input[0, org.apache.spark.sql.Row, true]), 1, directors), StringType, true), true, false, true) AS directors#111
if (assertnotnull(input[0, org.apache.spark.sql.Row, true]).isNullAt) null else staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, validateexternaltype(getexternalrowfield(assertnotnull(input[0, org.apache.spark.sql.Row, true]), 2, genres), StringType, true), true, false, true) AS genres#112
if (assertnotnull(input[0, org.apache.spark.sql.Row, true]).isNullAt) null else staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, validateexternaltype(getexternalrowfield(assertnotnull(input[0, org.apache.spark.sql.Row, true]), 3, id), StringType, true), true, false, true) AS id#113
if (assertnotnull(input[0, org.apache.spark.sql.Row, true]).isNullAt) null else staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, validateexternaltype(getexternalrowfield(assertnotnull(input[0, org.apache.spark.sql.Row, true]), 4, movie_image_ids), StringType, true), true, false, true) AS movie_image_ids#114
if (assertnotnull(input[0, org.apache.spark.sql.Row, true]).isNullAt) null else staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, validateexternaltype(getexternalrowfield(assertnotnull(input[0, org.apache.spark.sql.Row, true]), 5, overview), StringType, true), true, false, true) AS overview#115
if (assertnotnull(input[0, org.apache.spark.sql.Row, true]).isNullAt) null else staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, validateexternaltype(getexternalrowfield(assertnotnull(input[0, org.apache.spark.sql.Row, true]), 6, poster_file), StringType, true), true, false, true) AS poster_file#116
if (assertnotnull(input[0, org.apache.spark.sql.Row, true]).isNullAt) null else staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, validateexternaltype(getexternalrowfield(assertnotnull(input[0, org.apache.spark.sql.Row, true]), 7, poster_path), StringType, true), true, false, true) AS poster_path#117
if (assertnotnull(input[0, org.apache.spark.sql.Row, true]).isNullAt) null else staticinvoke(class org.apache.spark.sql.catalyst.util.DateTimeUtils$, TimestampType, anyToMicros, validateexternaltype(getexternalrowfield(assertnotnull(input[0, org.apache.spark.sql.Row, true]), 8, release_date), TimestampType, true), true, false, true) AS release_date#118
if (assertnotnull(input[0, org.apache.spark.sql.Row, true]).isNullAt) null else validateexternaltype(getexternalrowfield(assertnotnull(input[0, org.apache.spark.sql.Row, true]), 9, release_year), DoubleType, true) AS release_year#119
if (assertnotnull(input[0, org.apache.spark.sql.Row, true]).isNullAt) null else staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, validateexternaltype(getexternalrowfield(assertnotnull(input[0, org.apache.spark.sql.Row, true]), 10, tagline), StringType, true), true, false, true) AS tagline#120
if (assertnotnull(input[0, org.apache.spark.sql.Row, true]).isNullAt) null else staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, validateexternaltype(getexternalrowfield(assertnotnull(input[0, org.apache.spark.sql.Row, true]), 11, title), StringType, true), true, false, true) AS title#121
if (assertnotnull(input[0, org.apache.spark.sql.Row, true]).isNullAt) null else validateexternaltype(getexternalrowfield(assertnotnull(input[0, org.apache.spark.sql.Row, true]), 12, vote_average), FloatType, true) AS vote_average#122
if (assertnotnull(input[0, org.apache.spark.sql.Row, true]).isNullAt) null else validateexternaltype(getexternalrowfield(assertnotnull(input[0, org.apache.spark.sql.Row, true]), 13, vote_count), LongType, true) AS vote_count#123L
	at org.apache.spark.sql.errors.QueryExecutionErrors$.expressionEncodingError(QueryExecutionErrors.scala:1237)
	at org.apache.spark.sql.catalyst.encoders.ExpressionEncoder$Serializer.apply(ExpressionEncoder.scala:210)
	at org.apache.spark.sql.catalyst.encoders.ExpressionEncoder$Serializer.apply(ExpressionEncoder.scala:193)
	at scala.collection.Iterator$$anon$10.next(Iterator.scala:461)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage2.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:760)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at org.apache.spark.shuffle.sort.BypassMergeSortShuffleWriter.write(BypassMergeSortShuffleWriter.java:140)
	at org.apache.spark.shuffle.ShuffleWriteProcessor.write(ShuffleWriteProcessor.scala:59)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:99)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:52)
	at org.apache.spark.scheduler.Task.run(Task.scala:136)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:548)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1504)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:551)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1136)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:635)
	at java.base/java.lang.Thread.run(Thread.java:833)
Caused by: java.lang.RuntimeException: scala.collection.convert.Wrappers$JListWrapper is not a valid external type for schema of string
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.StaticInvoke_1$(Unknown Source)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.writeFields_0_0$(Unknown Source)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.apply(Unknown Source)
	at org.apache.spark.sql.catalyst.encoders.ExpressionEncoder$Serializer.apply(ExpressionEncoder.scala:207)
	... 17 more

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2672)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2608)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2607)
	at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
	at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2607)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1182)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1182)
	at scala.Option.foreach(Option.scala:407)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1182)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2860)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2802)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2791)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
Caused by: java.lang.RuntimeException: Error while encoding: java.lang.RuntimeException: scala.collection.convert.Wrappers$JListWrapper is not a valid external type for schema of string
if (assertnotnull(input[0, org.apache.spark.sql.Row, true]).isNullAt) null else staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, validateexternaltype(getexternalrowfield(assertnotnull(input[0, org.apache.spark.sql.Row, true]), 0, cast), StringType, true), true, false, true) AS cast#110
if (assertnotnull(input[0, org.apache.spark.sql.Row, true]).isNullAt) null else staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, validateexternaltype(getexternalrowfield(assertnotnull(input[0, org.apache.spark.sql.Row, true]), 1, directors), StringType, true), true, false, true) AS directors#111
if (assertnotnull(input[0, org.apache.spark.sql.Row, true]).isNullAt) null else staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, validateexternaltype(getexternalrowfield(assertnotnull(input[0, org.apache.spark.sql.Row, true]), 2, genres), StringType, true), true, false, true) AS genres#112
if (assertnotnull(input[0, org.apache.spark.sql.Row, true]).isNullAt) null else staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, validateexternaltype(getexternalrowfield(assertnotnull(input[0, org.apache.spark.sql.Row, true]), 3, id), StringType, true), true, false, true) AS id#113
if (assertnotnull(input[0, org.apache.spark.sql.Row, true]).isNullAt) null else staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, validateexternaltype(getexternalrowfield(assertnotnull(input[0, org.apache.spark.sql.Row, true]), 4, movie_image_ids), StringType, true), true, false, true) AS movie_image_ids#114
if (assertnotnull(input[0, org.apache.spark.sql.Row, true]).isNullAt) null else staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, validateexternaltype(getexternalrowfield(assertnotnull(input[0, org.apache.spark.sql.Row, true]), 5, overview), StringType, true), true, false, true) AS overview#115
if (assertnotnull(input[0, org.apache.spark.sql.Row, true]).isNullAt) null else staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, validateexternaltype(getexternalrowfield(assertnotnull(input[0, org.apache.spark.sql.Row, true]), 6, poster_file), StringType, true), true, false, true) AS poster_file#116
if (assertnotnull(input[0, org.apache.spark.sql.Row, true]).isNullAt) null else staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, validateexternaltype(getexternalrowfield(assertnotnull(input[0, org.apache.spark.sql.Row, true]), 7, poster_path), StringType, true), true, false, true) AS poster_path#117
if (assertnotnull(input[0, org.apache.spark.sql.Row, true]).isNullAt) null else staticinvoke(class org.apache.spark.sql.catalyst.util.DateTimeUtils$, TimestampType, anyToMicros, validateexternaltype(getexternalrowfield(assertnotnull(input[0, org.apache.spark.sql.Row, true]), 8, release_date), TimestampType, true), true, false, true) AS release_date#118
if (assertnotnull(input[0, org.apache.spark.sql.Row, true]).isNullAt) null else validateexternaltype(getexternalrowfield(assertnotnull(input[0, org.apache.spark.sql.Row, true]), 9, release_year), DoubleType, true) AS release_year#119
if (assertnotnull(input[0, org.apache.spark.sql.Row, true]).isNullAt) null else staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, validateexternaltype(getexternalrowfield(assertnotnull(input[0, org.apache.spark.sql.Row, true]), 10, tagline), StringType, true), true, false, true) AS tagline#120
if (assertnotnull(input[0, org.apache.spark.sql.Row, true]).isNullAt) null else staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, validateexternaltype(getexternalrowfield(assertnotnull(input[0, org.apache.spark.sql.Row, true]), 11, title), StringType, true), true, false, true) AS title#121
if (assertnotnull(input[0, org.apache.spark.sql.Row, true]).isNullAt) null else validateexternaltype(getexternalrowfield(assertnotnull(input[0, org.apache.spark.sql.Row, true]), 12, vote_average), FloatType, true) AS vote_average#122
if (assertnotnull(input[0, org.apache.spark.sql.Row, true]).isNullAt) null else validateexternaltype(getexternalrowfield(assertnotnull(input[0, org.apache.spark.sql.Row, true]), 13, vote_count), LongType, true) AS vote_count#123L
	at org.apache.spark.sql.errors.QueryExecutionErrors$.expressionEncodingError(QueryExecutionErrors.scala:1237)
	at org.apache.spark.sql.catalyst.encoders.ExpressionEncoder$Serializer.apply(ExpressionEncoder.scala:210)
	at org.apache.spark.sql.catalyst.encoders.ExpressionEncoder$Serializer.apply(ExpressionEncoder.scala:193)
	at scala.collection.Iterator$$anon$10.next(Iterator.scala:461)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage2.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:760)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at org.apache.spark.shuffle.sort.BypassMergeSortShuffleWriter.write(BypassMergeSortShuffleWriter.java:140)
	at org.apache.spark.shuffle.ShuffleWriteProcessor.write(ShuffleWriteProcessor.scala:59)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:99)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:52)
	at org.apache.spark.scheduler.Task.run(Task.scala:136)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:548)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1504)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:551)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1136)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:635)
	at java.base/java.lang.Thread.run(Thread.java:833)
Caused by: java.lang.RuntimeException: scala.collection.convert.Wrappers$JListWrapper is not a valid external type for schema of string
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.StaticInvoke_1$(Unknown Source)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.writeFields_0_0$(Unknown Source)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.apply(Unknown Source)
	at org.apache.spark.sql.catalyst.encoders.ExpressionEncoder$Serializer.apply(ExpressionEncoder.scala:207)
	... 17 more


## Listing 15.15

In [None]:
def load_image(full_path, log=False):   
    try:
        if full_path.startswith("http"):
            response = requests.get(full_path)
            image = Image.open(BytesIO(response.content))
        else:
            image = Image.open(full_path)
        if log: print("File Found")
        return image
    except:
        if log: print(f"No Image Available {full_path}")
        return []

def movie_search(query_embedding, limit=8):
    collection = engine.get_collection("tmdb_with_embeddings")
    request = {
        "query": query_embedding,
        "query_fields": ["image_embedding"],
        "return_fields": ["movie_id", "title",
                          "image_id", "score"],
        "limit": limit,
        "quantization_size": "FLOAT32"}
    response = collection.search(**request)
    return response
    
def encode_text(text, normalize=True):
    text = clip.tokenize([text]).to(device)    
    text_features = model.encode_text(text)
    embedding = text_features.tolist()[0] 
    if normalize:
        embedding = normalize_embedding(embedding)
    return embedding
    
def encode_image(image_file, normalize=True):
    image = load_image(image_file)
    inputs = preprocess(image).unsqueeze(0).to(device)
    embedding = model.encode_image(inputs).tolist()[0]
    if normalize:
        embedding = normalize_embedding(embedding)
    return embedding

def encode_text_and_image(text_query, image_file):    
    text_embedding = encode_text(text_query, False)
    image_embedding = encode_image(image_file, False)  
    return numpy.average((normalize_embedding(
        [text_embedding, image_embedding])), axis=0).tolist()

## Listing 15.16

In [None]:
def get_html(search_results, show_fields=True, display_header=None):
    css = """
      <style type="text/css">
        .results { 
          margin-top: 15px; 
          display: flex; 
          flex-wrap: wrap; 
          justify-content: space-evenly; }
        .field { font-size: 24px; position: relative; float: left; }
        .title { font-size:32px; font-weight:bold; max-width:450px; word-wrap:break-all; line-height:32px; display: table-cell; vertical-align: bottom;}
        .results .result { height: 250px; margin-bottom: 5px; }
        .fields {height: 100%; }
      </style>"""
    
    header = ""
    if display_header: 
        header = f"<div class='field' style='width:100%; font-size:32px; margin-bottom:20px'>{display_header}</div>"
    
    results_html = ""
    for movie in search_results["docs"]:
        image_file = f"http://image.tmdb.org/t/p/w780/{movie['image_id']}.jpg"
        movie_link = f"https://www.themoviedb.org/movie/{movie['movie_id']}"
        img_html = f"<img title='{movie['title']}' class='result' src='{image_file}'>"
        if show_fields: results_html += f"<div class='fields'><div class='title' style='height:65px;'>{movie['title']}</div><div class='field'>( score: {movie['score']} )</div>"
        results_html += f"<div style='clear:left'><a class='title' href='{movie_link}' target='_blank'>{img_html}</a></div>"
        if show_fields: results_html += "</div>"
    return f"{css}{header}<div class='results' style='clear:left'>{results_html}</div>"
   
def display_results(search_results, show_fields=True, display_header=None):    
    output = widgets.Output()
    with output:
        display(HTML(get_html(search_results, show_fields, display_header))) 
    display(widgets.HBox(layout=widgets.Layout(justify_content="center")), output)   

In [None]:
def search_and_display(text_query="", image_query=None):
    if image_query:
        if text_query:
            query_embedding = encode_text_and_image(text_query, image_query)
        else:
            query_embedding = encode_image(image_query)
    else:
        query_embedding = encode_text(text_query)
    display_results(movie_search(query_embedding), show_fields=False)

# Figure 15.5

In [None]:
search_and_display(text_query="singing in the rain")

HBox(layout=Layout(justify_content='center'))

Output()

In [None]:
search_and_display(text_query="superhero flying")

HBox(layout=Layout(justify_content='center'))

Output()

# Figure 15.6

In [None]:
search_and_display(text_query="superheroes flying")

HBox(layout=Layout(justify_content='center'))

Output()

# Figure 15.7

In [None]:
search_and_display(image_query="chapters/ch15/delorean-query.jpg")

HBox(layout=Layout(justify_content='center'))

Output()

# Figure 15.8

In [None]:
search_and_display(text_query="superhero", image_query="chapters/ch15/delorean-query.jpg")

HBox(layout=Layout(justify_content='center'))

Output()

# Listing 15.17

## Listing 15.17

In [None]:
# %load -s reciprocal_rank_fusion ../../engines/Collection.py
def reciprocal_rank_fusion(search_results, k = None):
    if k is None: k = 60
    scores = {}
    for ranked_docs in search_results:
        for rank, doc in enumerate(ranked_docs, 1):
            scores[doc["id"]] = scores.get(doc["id"], 0)  + (1.0 / (k + rank))
    sorted_scores = dict(sorted(scores.items(), key=lambda item: item[1], reverse=True))
    return sorted_scores


In [None]:
def get_display_header(lexical_request=None, vector_request=None):
    if lexical_request and vector_request:
        return "Hybrid Results:<br/>" \
               + f"  --Lexical Query: {lexical_request['query']}<br/>" \
               + f"  --Vector Query: [{vector_request['query'][0]}, {vector_request['query'][1]}, ... ]<br/>"
    elif lexical_request:
        return f"Lexical Query: {lexical_request['query']}<br/>"
    elif vector_request:
        return f"Vector Query: [{vector_request['query'][0]}, {vector_request['query'][1]}, ... ]<br/>"

In [None]:
collection = engine.get_collection("tmdb_lexical_plus_embeddings")

# Listing 15.18

In [None]:
over_request_limit = 15
base_query = {"return_fields": ["id", "title", "id", "image_id",
                                "movie_id", "score", "image_embedding"],
              "limit": over_request_limit,
              "order_by": [("score", "desc"), ("title", "asc")]}

def lexical_search_request(query_text):
    return {"query": query_text, 
            "query_fields": ["title", "overview"],   
            "default_operator": "OR",    
            **base_query}

def vector_search_request(query_embedding):
    return { "query": query_embedding, 
             "query_fields": ["image_embedding"],
             "quantization_size": "FLOAT32",
            **base_query}

def display_lexical_search_results(query_text):
    collection = engine.get_collection("tmdb_lexical_plus_embeddings")
    lexical_request = lexical_search_request(query_text)
    lexical_search_results = collection.search(**lexical_request)
    
    display_results(lexical_search_results, display_header= \
                    get_display_header(lexical_request=lexical_request))
    
def display_vector_search_results(query_text):
    collection = engine.get_collection("tmdb_lexical_plus_embeddings")
    query_embedding = encode_text(query_text)
    vector_request = vector_search_request(query_embedding)
    vector_search_results = collection.search(**vector_request)

    display_results(vector_search_results, display_header= \
                    get_display_header(vector_request=vector_request))

query = '"' + "singin' in the rain" + '"'
display_lexical_search_results(query)
display_vector_search_results(query)

ValueError: {'error': {'root_cause': [{'type': 'query_shard_exception', 'reason': 'No mapping found for [title] in order to sort on', 'index': 'tmdb_lexical_plus_embeddings', 'index_uuid': 'bmpdOsEMRnGG0mzcpEFvnQ'}], 'type': 'search_phase_execution_exception', 'reason': 'all shards failed', 'phase': 'query', 'grouped': True, 'failed_shards': [{'shard': 0, 'index': 'tmdb_lexical_plus_embeddings', 'node': 'Nm1R8o10QMyp92t1z6y5PA', 'reason': {'type': 'query_shard_exception', 'reason': 'No mapping found for [title] in order to sort on', 'index': 'tmdb_lexical_plus_embeddings', 'index_uuid': 'bmpdOsEMRnGG0mzcpEFvnQ'}}]}, 'status': 400}

## Listing 15.19

In [None]:
def display_hybrid_search_results(text_query, limit=10):
    lexical_request = lexical_search_request(text_query)
    vector_request = vector_search_request(encode_text(text_query))    
    hybrid_search_results = collection.hybrid_search(
          [lexical_request, vector_request], limit=10,
          algorithm="rrf", algorithm_params={"k": 60})
    display_results(hybrid_search_results,
                    display_header=get_display_header(lexical_request, vector_request))

## Listing 15.20

In [None]:
query = "the hobbit"
display_lexical_search_results(query)
display_vector_search_results(query)

HBox(layout=Layout(justify_content='center'))

Output()

HBox(layout=Layout(justify_content='center'))

Output()

In [None]:
display_hybrid_search_results(query)

HBox(layout=Layout(justify_content='center'))

Output()

## Listing 15.21

In [None]:
def lexical_vector_rerank(text_query, limit=10):
    lexical_request = lexical_search_request(text_query)
    vector_request = vector_search_request(encode_text(text_query))    
    hybrid_search_results = collection.hybrid_search(
        [lexical_request, vector_request],
        algorithm="lexical_vector_rerank", limit=limit)
    header = get_display_header(lexical_request, vector_request)
    display_results(hybrid_search_results, display_header=header)

In [None]:
lexical_vector_rerank("the hobbit")

HBox(layout=Layout(justify_content='center'))

Output()