In [1]:
import sys
sys.path.append("../..")

from aips import get_engine
from IPython.display import display, HTML
from pyspark.sql import SparkSession
import ipywidgets as widgets
from PIL import Image
import pickle
import requests
import numpy
import torch
import clip
from io import BytesIO

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

engine = get_engine()
spark = SparkSession.builder.appName("AIPS").getOrCreate()

In [2]:
![ ! -d 'tmdb' ] && git clone --depth 1 https://github.com/ai-powered-search/tmdb.git
! cd tmdb && git pull
! cd tmdb && mkdir -p '../../../data/tmdb/' && tar -xvf movies_with_image_embeddings.tgz -C '../../../data/tmdb/'

Already up to date.
movies_with_image_embeddings.pickle


## Listing 15.14

In [3]:
def normalize_embedding(embedding):
    return numpy.divide(embedding,
      numpy.linalg.norm(embedding,axis=0)).tolist()

def read(cache_name):
    cache_file_name = f"../../data/tmdb/{cache_name}.pickle"
    with open(cache_file_name, "rb") as fd:
        return pickle.load(fd)

def generate_tmdb_with_embeddings_index():
    movies = read("movies_with_image_embeddings")
    embeddings = movies["image_embeddings"]
    normalized_embeddings = [ normalize_embedding(embedding) for embedding in embeddings ]
    collection = engine.create_collection("tmdb_with_embeddings")
    movies_dataframe = spark.createDataFrame(
        zip(movies["movie_ids"], movies["titles"], 
            movies["image_ids"], normalized_embeddings),
        schema=["movie_id", "title", "image_id", "image_embedding"])
    collection.write(movies_dataframe)

In [4]:
generate_tmdb_with_embeddings_index()

Wiping "tmdb_with_embeddings" collection
Creating "tmdb_with_embeddings" collection
Status: Success
Successfully written 7549 documents


## Listing 15.15

In [4]:
def load_image(full_path, log=False):   
    try:
        if full_path.startswith("http"):
            response = requests.get(full_path)
            image = Image.open(BytesIO(response.content))
        else:
            image = Image.open(full_path)
        if log: print("File Found")
        return image
    except:
        if log: print(f"No Image Available {full_path}")
        return []      

def movie_search(query_embedding, limit=8):
    collection = engine.get_collection("tmdb_with_embeddings")
    request = {
        "query_vector": query_embedding,
        "query_field": "image_embedding",
        "limit": limit,
        "quantization_size": "FLOAT32"}
    return collection.vector_search(**request)
    
def normalize_embedding(embedding):
    return numpy.divide(embedding,
      numpy.linalg.norm(embedding,axis=0)).tolist()

def encode_text(text):
    text = clip.tokenize([text]).to(device)    
    text_features = model.encode_text(text)
    embedding = text_features.tolist()[0] 
    normalized_embedding = normalize_embedding(embedding)
    return embedding
    
def encode_image(image_file):
    image = load_image(image_file)
    inputs = preprocess(image).unsqueeze(0).to(device)
    embedding = model.encode_image(inputs).tolist()[0]
    normalized_embedding = normalize_embedding(embedding)
    return embedding

def encode_text_and_image(text_query, image_file):    
    text_embedding = encode_text(text_query)
    image_embedding = encode_image(image_file)  
    return numpy.average((normalize_embedding(
        [text_embedding, image_embedding])), axis=0).tolist()

## Listing 15.16

In [5]:
def get_html(movies_documents):
    css = """
      <style type="text/css">
        .results { 
          margin-top: 15px; 
          display: flex; 
          flex-wrap: wrap; 
          justify-content: space-evenly; }
        .results .result { height: 250px; margin-bottom: 5px; }
      </style>"""
    
    results_html = ""
    for movie in movies_documents:
        image_file = f"http://image.tmdb.org/t/p/w780/{movie['image_id']}.jpg"
        movie_link = f"https://www.themoviedb.org/movie/{movie['movie_id']}"
        img_html = f"<img title='{movie['title']}' class='result' src='{image_file}'>"
        results_html += f"<a href='{movie_link}' target='_blank'>{img_html}</a>"
    return f"{css}<div class='results'>{results_html}</div>"
   
def display_results(search_results):    
    output = widgets.Output()
    with output:
        display(HTML(get_html(search_results["docs"]))) 
    display(widgets.HBox(layout=widgets.Layout(justify_content="center")), output)   

def search_and_display(text_query="", image_query=None):
    if image_query:
        if text_query:
            query_embedding = encode_text_and_image(text_query, image_query)
        else:
            query_embedding = encode_image(image_query)
    else:
        query_embedding = encode_text(text_query)
    display_results(movie_search(query_embedding))

# Figure 15.5

In [6]:
search_and_display(text_query="singing in the rain")

HBox(layout=Layout(justify_content='center'))

Output()

In [7]:
search_and_display(text_query="superhero flying")

HBox(layout=Layout(justify_content='center'))

Output()

# Figure 15.6

In [8]:
search_and_display(text_query="superheroes flying")

HBox(layout=Layout(justify_content='center'))

Output()

# Figure 15.7

In [9]:
search_and_display(image_query="delorean-query.jpg")

HBox(layout=Layout(justify_content='center'))

Output()

# Figure 15.8

In [10]:
search_and_display(text_query="superhero", image_query="delorean-query.jpg")

HBox(layout=Layout(justify_content='center'))

Output()

# Listing 15.17

In [12]:
from aips.spark import create_view_from_collection
from aips.spark.dataframe import from_sql
from pyspark.sql.functions import split, regexp_replace, col

def fix(broken_embedding_src):
    return f"array(replace({broken_embedding_src}, '\"', ''))"

#Create tmdb collection with text + signals
def combine_tmdb_lexical_and_embeddings_collections():
    lexical_tmdb_collection = engine.get_collection("tmdb")
    create_view_from_collection(lexical_tmdb_collection, "tmdb_lexical")
    embeddings_tmdb_collection = engine.get_collection("tmdb_with_embeddings")
    create_view_from_collection(embeddings_tmdb_collection, "tmdb_with_embeddings")
    
    joined_collection_sql = f"""
    SELECT lexical.*, embeddings.image_embedding, embeddings.image_id, embeddings.movie_id
    FROM tmdb_lexical lexical RIGHT JOIN tmdb_with_embeddings embeddings ON lexical.id = embeddings.movie_id    
    """
    
    collection = engine.create_collection("tmdb_lexical_plus_embeddings")
    joined_dataframe = from_sql(joined_collection_sql) 
    joined_dataframe = joined_dataframe.withColumn("image_embedding", split(regexp_replace(col("image_embedding"),"\"",""), ","))
    collection.write(joined_dataframe)
    collection.commit()
    return collection
    
collection = combine_tmdb_lexical_and_embeddings_collections()

Wiping "tmdb_lexical_plus_embeddings" collection
Creating "tmdb_lexical_plus_embeddings" collection
Status: Success
Successfully written 7549 documents


In [11]:
def get_html(movies_documents):
    css = """
      <style type="text/css">
        .results { 
          margin-top: 15px; 
          display: flex; 
          flex-wrap: wrap; 
          justify-content: space-evenly; }
        .results .result { height: 250px; margin-bottom: 5px; }
      </style>"""
    
    results_html = ""
    for movie in movies_documents:
        movie_title = f"{movie['title']}"
        image_file = f"http://image.tmdb.org/t/p/w780/{movie['image_id']}.jpg"
        movie_link = f"https://www.themoviedb.org/movie/{movie['movie_id']}"
        img_html = f"<img title='{movie['title']}' class='result' src='{image_file}'>"
        results_html += f"<div>{movie_title}<br/>(score: {movie['score']})<br/><a href='{movie_link}' target='_blank'>{img_html}</a></div>"
    return f"{css}<div class='results'>{results_html}</div>"
   
def display_results(search_results):    
    output = widgets.Output()
    with output:
        display(HTML(get_html(search_results["docs"]))) 
    display(widgets.HBox(layout=widgets.Layout(justify_content="center")), output)  

In [74]:
collection = engine.get_collection("tmdb_lexical_plus_embeddings")

query = "singing in the rain"
limit = 9

lexical_query = query
lexical_search = {
        "query": lexical_query,
        "query_fields": ["title","overview"],
        "return_fields": ["id", "title", "id", "image_id", "movie_id", "score"],
        "limit": limit,
        "query_parser": "edismax"
}
lexical_search_results = collection.search(**lexical_search) 

query_embedding = encode_text(query)
vector_search = {
        "query_vector": query_embedding,
        "query_field": "image_embedding",
        "return_fields": ["id", "title", "image_id", "movie_id", "score"],
        "limit": limit,
        "quantization_size": "FLOAT32"}
vector_search_results = collection.vector_search(**vector_search)

print(f"Lexical Query: {lexical_query}")
display_results(lexical_search_results)

print(f"Vector Query: {query_embedding[0:3]} ... {query_embedding[-3:]}")
display_results(vector_search_results)

Lexical Query: singing in the rain


HBox(layout=Layout(justify_content='center'))

Output()

Vector Query: [0.3824406564235687, 0.2455785572528839, -0.29758134484291077] ... [-0.32294002175331116, 0.47960010170936584, -0.7555630803108215]


HBox(layout=Layout(justify_content='center'))

Output()

In [100]:
from collections import Counter

self = collection
def hybrid_search(lexical_search_args, vector_search_args, algorithm={"name": "rrf", "k": 60}, limit=10):
    hybrid_search_results = None
    match algorithm.get("name"):
        case "rrf":
            k = 60
            if algorithm["k"]: k = algorithm["k"]
            lexical_search_results = self.search(**lexical_search_args)
            vector_search_results = self.vector_search(**vector_search_args)
            hybrid_search_scores = reciprocal_rank_fusion(k, 
                                       lexical_search_results["docs"], 
                                       vector_search_results["docs"])
            
            lexical_fields = {item["id"]: item for item in lexical_search_results["docs"]}
            vector_fields = {item["id"]: item for item in vector_search_results["docs"]}
            
            merged_search_docs = sorted([
                dict(lexical_fields[id], score=hybrid_search_scores[id]) \
                if id in lexical_fields \
                else dict(vector_fields[id], score=hybrid_search_scores[id]) \
                for id in hybrid_search_scores], key=lambda x: x["score"], reverse=True)
            
            #sorted(orig_list, key=lambda x: x.count, reverse=True)
                        
            hybrid_search_results = {"docs": merged_search_docs }
        case "rerank_lexical_with_vector":
            pass #need rerank implemented on coll
    return hybrid_search_results

def reciprocal_rank_fusion(k, *search_results):
    rrf_scores = Counter()
    for ranked_docs in search_results:
        rank = 0
        for doc in ranked_docs:
            rank += 1
            rrf_scores[doc["id"]] = rrf_scores[doc["id"]] + ( 1.0 / ( k + rank ) )
    return dict(rrf_scores)
    
# where
# k is a ranking constant
# q is a query in the set of queries
# d is a document in the result set of q
# result(q) is the result set of q
# rank( result(q), d ) is d's rank within the result(q) starting from 1

In [101]:
hybrid_search_results = hybrid_search(lexical_search, vector_search, algorithm={"name": "rrf", "k":60})
print(f"Lexical Query: {lexical_search['query']}")
print(f"Lexical Query: {vector_search['query_vector'][0:3]} ... {vector_search['query_vector'][-3:]}")
display_results(hybrid_search_results)


Lexical Query: singing in the rain
Lexical Query: [0.3824406564235687, 0.2455785572528839, -0.29758134484291077] ... [-0.32294002175331116, 0.47960010170936584, -0.7555630803108215]


HBox(layout=Layout(justify_content='center'))

Output()