In [1]:
import sys
sys.path.append("../..")
from aips import get_engine
from IPython.display import display, HTML
from pyspark.sql import SparkSession
import ipywidgets as widgets
from PIL import Image
import pickle
import requests
import numpy
import torch
import clip
from io import BytesIO

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

engine = get_engine()
spark = SparkSession.builder.appName("AIPS").getOrCreate()

In [2]:
![ ! -d 'tmdb' ] && git clone --depth 1 https://github.com/ai-powered-search/tmdb.git
! cd tmdb && git pull
! cd tmdb && mkdir -p '../../../data/tmdb/' && tar -xvf movies_with_image_embeddings.tgz -C '../../../data/tmdb/'

Already up to date.
movies_with_image_embeddings.pickle


In [3]:
%run ../ch10/1.setup-the-movie-db.ipynb

Wiping "tmdb" collection
Creating "tmdb" collection
Status: Success
Adding LTR QParser for tmdb collection
Adding LTR Doc Transformer for tmdb collection
../../data/judgments.tgz already exists
../../data/movies.tgz already exists
Successfully written 65616 documents


## Listing 15.14

In [4]:
from aips.spark import create_view_from_collection
from aips.spark.dataframe import from_sql
from pyspark.sql.functions import split, regexp_replace, col

def normalize_embedding(embedding):
    return numpy.divide(embedding,
      numpy.linalg.norm(embedding,axis=0)).tolist()

def read(cache_name):
    cache_file_name = f"../../data/tmdb/{cache_name}.pickle"
    with open(cache_file_name, "rb") as fd:
        return pickle.load(fd)

def tmdb_with_embeddings_dataframe():
    movies = read("movies_with_image_embeddings")
    embeddings = movies["image_embeddings"]
    normalized_embeddings = [normalize_embedding(e) for e in embeddings]
    movies_dataframe = spark.createDataFrame(
        zip(movies["movie_ids"], movies["titles"], 
            movies["image_ids"], normalized_embeddings),
        schema=["movie_id", "title", "image_id", "image_embedding"])
    return movies_dataframe

def tmdb_lexical_embeddings_dataframe():
    lexical_tmdb_collection = engine.get_collection("tmdb")
    create_view_from_collection(lexical_tmdb_collection, "tmdb_lexical")
    embeddings_tmdb_collection = engine.get_collection("tmdb_with_embeddings")
    create_view_from_collection(embeddings_tmdb_collection, "tmdb_with_embeddings")
    
    joined_collection_sql = f"""    
    SELECT lexical.id id, embeddings.movie_id movie_id, lexical.title title, lexical.overview overview, embeddings.image_embedding, embeddings.image_id
    FROM tmdb_with_embeddings embeddings
    INNER JOIN (SELECT DISTINCT image_id from (SELECT movie_id, MIN(image_id) image_id from tmdb_with_embeddings GROUP BY movie_id ORDER BY image_id ASC)) distinct_images on embeddings.image_id = distinct_images.image_id
    INNER JOIN tmdb_lexical lexical ON lexical.id = embeddings.movie_id
    ORDER by lexical.id asc
    """
    #collection = engine.create_collection("tmdb_lexical_plus_embeddings")
    joined_dataframe = from_sql(joined_collection_sql) 
    joined_dataframe = joined_dataframe.withColumn("image_embedding", split(regexp_replace(col("image_embedding"),"\"",""), ","))
    return joined_dataframe
    
def create_embedding_indexes():
    embeddings_dataframe = tmdb_with_embeddings_dataframe()
    embeddings_collection = engine.create_collection("tmdb_with_embeddings")
    embeddings_collection.write(embeddings_dataframe)
    
    lexical_embeddings = tmdb_lexical_embeddings_dataframe()
    lexical_collection = engine.create_collection("tmdb_lexical_plus_embeddings")
    lexical_collection.write(lexical_embeddings)

In [5]:
create_embedding_indexes()

Wiping "tmdb_with_embeddings" collection
Creating "tmdb_with_embeddings" collection
Status: Success
Successfully written 7549 documents
Wiping "tmdb_lexical_plus_embeddings" collection
Creating "tmdb_lexical_plus_embeddings" collection
Status: Success
Successfully written 908 documents


## Listing 15.15

In [6]:
def load_image(full_path, log=False):   
    try:
        if full_path.startswith("http"):
            response = requests.get(full_path)
            image = Image.open(BytesIO(response.content))
        else:
            image = Image.open(full_path)
        if log: print("File Found")
        return image
    except:
        if log: print(f"No Image Available {full_path}")
        return []

def movie_search(query_embedding, limit=8):
    collection = engine.get_collection("tmdb_with_embeddings")
    request = {
        "query": query_embedding,
        "query_fields": ["image_embedding"],
        "return_fields": ["movie_id", "title", "image_id", "score"],
        "limit": limit,
        "quantization_size": "FLOAT32"}
    return collection.search(**request)
    
def encode_text(text, normalize=True):
    text = clip.tokenize([text]).to(device)    
    text_features = model.encode_text(text)
    embedding = text_features.tolist()[0] 
    if normalize:
        embedding = normalize_embedding(embedding)
    return embedding
    
def encode_image(image_file, normalize=True):
    image = load_image(image_file)
    inputs = preprocess(image).unsqueeze(0).to(device)
    embedding = model.encode_image(inputs).tolist()[0]
    if normalize:
        embedding = normalize_embedding(embedding)
    return embedding

def encode_text_and_image(text_query, image_file):    
    text_embedding = encode_text(text_query, True)
    image_embedding = encode_image(image_file, True)  
    return numpy.average((normalize_embedding(
        [text_embedding, image_embedding])), axis=0).tolist()

## Listing 15.16

In [7]:
def get_html(search_results, show_fields=True, display_header=None):
    css = """
      <style type="text/css">
        .results { 
          margin-top: 15px; 
          display: flex; 
          flex-wrap: wrap; 
          justify-content: space-evenly; }
        .field { font-size: 24px; position: relative; float: left; }
        .title { font-size:32px; font-weight:bold; max-width:450px; word-wrap:break-all; line-height:32px; display: table-cell; vertical-align: bottom;}
        .results .result { height: 250px; margin-bottom: 5px; }
        .fields {height: 100%; }
      </style>"""
    
    header = ""
    if display_header: 
        header = f"<div class='field' style='width:100%; font-size:32px; margin-bottom:20px'>{display_header}</div>"
    
    results_html = ""
    for movie in search_results["docs"]:
        image_file = f"http://image.tmdb.org/t/p/w780/{movie['image_id']}.jpg"
        movie_link = f"https://www.themoviedb.org/movie/{movie['movie_id']}"
        img_html = f"<img title='{movie['title']}' class='result' src='{image_file}'>"
        if show_fields: results_html += f"<div class='fields'><div class='title' style='height:65px;'>{movie['title']}</div><div class='field'>( score: {movie['score']} )</div>"
        results_html += f"<div style='clear:left'><a class='title' href='{movie_link}' target='_blank'>{img_html}</a></div>"
        if show_fields: results_html += "</div>"
    return f"{css}{header}<div class='results' style='clear:left'>{results_html}</div>"
   
def display_results(search_results, show_fields=True, display_header=None):    
    output = widgets.Output()
    with output:
        display(HTML(get_html(search_results, show_fields, display_header))) 
    display(widgets.HBox(layout=widgets.Layout(justify_content="center")), output)   

def search_and_display(text_query="", image_query=None):
    if image_query:
        if text_query:
            query_embedding = encode_text_and_image(text_query, image_query)
        else:
            query_embedding = encode_image(image_query)
    else:
        query_embedding = encode_text(text_query)
    display_results(movie_search(query_embedding), show_fields=False)

# Figure 15.5

In [8]:
search_and_display(text_query="singing in the rain")

HBox(layout=Layout(justify_content='center'))

Output()

In [9]:
search_and_display(text_query="superhero flying")

HBox(layout=Layout(justify_content='center'))

Output()

# Figure 15.6

In [10]:
search_and_display(text_query="superheroes flying")

HBox(layout=Layout(justify_content='center'))

Output()

# Figure 15.7

In [11]:
search_and_display(image_query="delorean-query.jpg")

HBox(layout=Layout(justify_content='center'))

Output()

# Figure 15.8

In [12]:
search_and_display(text_query="superhero", image_query="delorean-query.jpg")

HBox(layout=Layout(justify_content='center'))

Output()

## Listing 15.17

In [13]:
self = collection = engine.get_collection("tmdb_lexical_plus_embeddings") #temporary until functions are moved into collection

In [14]:
def reciprocal_rank_fusion(k, *search_results):
    scores = {}
    for ranked_docs in search_results:
        for rank, doc in enumerate(ranked_docs, 1):
            scores[doc["id"]] = scores.get(doc["id"], 0)  + (1.0 / (k + rank))
    sorted_scores = dict(sorted(scores.items(), key=lambda x: x[1], reverse=True))
    return sorted_scores

In [15]:
def get_display_header(lexical_search=None, vector_search=None):
    if lexical_search and vector_search:
        return "Hybrid Results:<br/>" \
               + f"  --Lexical Query: {lexical_search['query']}<br/>" \
               + f"  --Vector Query: [{vector_search['query'][0]}, {vector_search['query'][1]}, ... ]<br/>"
    elif lexical_search:
        return f"Lexical Query: {lexical_search['query']}<br/>"
    elif vector_search:
        return f"Vector Query: [{vector_search['query'][0]}, {vector_search['query'][1]}, ... ]<br/>"
    else:
        return None

# Listing 15.18

In [16]:
def base_search():
    over_request_limit = 15
    return {
            "return_fields": ["id", "title", "id", "image_id", "movie_id", "score", "image_embedding"],
            "limit": over_request_limit,
            "order_by": [("score", "desc"), ("title", "asc")]
    }

def lexical_search_from_text_query(query_text):
    return { "query": query_text, 
             "query_fields": ["title","overview"],
             "default_operator": "OR",
             **base_search() }

def vector_search_from_embedding(query_embedding):
    return { "query": query_embedding, 
             "query_fields": ["image_embedding"],
             "quantization_size": "FLOAT32",
            **base_search() }


def display_lexical_search_results(query_text):
    collection = engine.get_collection("tmdb_lexical_plus_embeddings")
    lexical_search = lexical_search_from_text_query(query_text)
    lexical_search_results = collection.search(**lexical_search)
    
    display_results(lexical_search_results, display_header= \
                    get_display_header(lexical_search=lexical_search))
    
def display_vector_search_results(query_text):
    collection = engine.get_collection("tmdb_lexical_plus_embeddings")
    query_embedding = encode_text(query_text)
    vector_search = vector_search_from_embedding(query_embedding)
    vector_search_results = collection.search(**vector_search)

    display_results(vector_search_results, display_header= \
                    get_display_header(vector_search=vector_search))

query = '"' + "singin' in the rain" + '"'
display_lexical_search_results(query)
display_vector_search_results(query)

HBox(layout=Layout(justify_content='center'))

Output()

HBox(layout=Layout(justify_content='center'))

Output()

In [17]:
def merge_hybrid_results(hybrid_search_scores, *search_results):
    merged_results = {}
    for ranked_docs in search_results:
        for doc in ranked_docs:
            doc_id = doc["id"]
            if doc_id in hybrid_search_scores: #only process docs we need
                merged_result = merged_results.get(doc_id, {})
                merged_results[doc_id] = { **doc, **merged_result }
    scored_results = []
    for id in hybrid_search_scores:
        scored_results.append({ **merged_results[id], "score":hybrid_search_scores[id] })

    return {"docs": scored_results }

In [18]:
def reciprocal_rank_fusion_hybrid_search(searches=[], limit=None, algorithm_params={}):
    k = 60
    if "k" in algorithm_params: k = algorithm_params["k"]

    search_results = []
    for search in searches:
        search_results.append(self.search(**search)["docs"])

    hybrid_search_scores = reciprocal_rank_fusion(k, *search_results)      
    scored_docs = merge_hybrid_results( \
        hybrid_search_scores, *search_results)

    if limit and limit < len(scored_docs["docs"]): 
        scored_docs["docs"] = scored_docs["docs"][:limit] 
        
    return scored_docs

In [19]:
def hybrid_search(searches=[], limit=None, algorithm="rrf", algorithm_params={}):
    hybrid_search_results = None
    match algorithm:
        case "rrf":
            hybrid_search_results = reciprocal_rank_fusion_hybrid_search(
                                      searches=searches, limit=limit, 
                                      algorithm_params=algorithm_params)
        case "rerank":
            pass #need rerank implemented on collection.search(...)
    return hybrid_search_results

## Listing 15.19

In [20]:
def display_hybrid_search_results(text_query, limit=8):
    lexical_search = lexical_search_from_text_query(text_query)
    vector_search = vector_search_from_embedding(encode_text(text_query))
    hybrid_search_results = hybrid_search([lexical_search, vector_search], limit=limit)
    
    display_results(hybrid_search_results, display_header=\
                      get_display_header(lexical_search, vector_search))
    
display_hybrid_search_results(query)

HBox(layout=Layout(justify_content='center'))

Output()

In [21]:
query="the hobbit"
display_lexical_search_results(query)
display_vector_search_results(query)
display_hybrid_search_results(query)

HBox(layout=Layout(justify_content='center'))

Output()

HBox(layout=Layout(justify_content='center'))

Output()

HBox(layout=Layout(justify_content='center'))

Output()