In [1]:
import sys
sys.path.append("../..")
from aips import get_engine
from pyspark.sql import SparkSession
import pickle 
import numpy 
import torch
import clip
import time
import pandas
import random
import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

engine = get_engine()
spark = SparkSession.builder.appName("AIPS").getOrCreate()

100%|███████████████████████████████████████| 338M/338M [00:19<00:00, 18.0MiB/s]


In [2]:
![ ! -d 'tmdb' ] && git clone --depth 1 https://github.com/ai-powered-search/tmdb.git
! cd tmdb && git pull
! cd tmdb && mkdir -p '../../../data/tmdb/' && tar -xvf movies_with_image_embeddings.tgz -C '../../../data/tmdb/'

Cloning into 'tmdb'...
remote: Enumerating objects: 7, done.[K
remote: Counting objects: 100% (7/7), done.[K
remote: Compressing objects: 100% (7/7), done.[K
remote: Total 7 (delta 0), reused 6 (delta 0), pack-reused 0[K
Receiving objects: 100% (7/7), 103.98 MiB | 18.46 MiB/s, done.
Already up to date.
movies_with_image_embeddings.pickle


In [3]:
def normalize_embedding(embedding):
    return numpy.divide(embedding,
      numpy.linalg.norm(embedding,axis=0)).tolist()

def read(cache_name):
    cache_file_name = f"../../data/tmdb/{cache_name}.pickle"
    with open(cache_file_name, "rb") as fd:
        return pickle.load(fd)

def quantize(embeddings):
    embeddings = numpy.array(embeddings)
    quantized_embeddings = numpy.zeros_like(embeddings, dtype=numpy.int8)
    quantized_embeddings[embeddings > 0] = 1
    return quantized_embeddings.tolist()

def tmdb_with_embeddings_dataframe():
    movies = read("movies_with_image_embeddings")
    embeddings = movies["image_embeddings"]
    normalized_embeddings = [normalize_embedding(e) for e in embeddings]
    quantized_embeddings = [quantize(e) for e in normalized_embeddings]
    movie_dataframe = spark.createDataFrame(
        zip(movies["movie_ids"], movies["titles"], 
            movies["image_ids"], normalized_embeddings,
            quantized_embeddings),
        schema=["movie_id", "title", "image_id", "image_embedding",
                "image_binary_embedding"])
    return movie_dataframe
    
def encode_text(text):
    text = clip.tokenize([text]).to(device)
    text_features = model.encode_text(text).tolist()[0]
    return numpy.array(normalize_embedding(text_features))

In [4]:
movie_dataframe = tmdb_with_embeddings_dataframe()
embeddings_collection = engine.create_collection("tmdb_with_embeddings")
embeddings_collection.write(movie_dataframe)

Wiping "tmdb_with_embeddings" collection
Creating "tmdb_with_embeddings" collection
Status: Success
Successfully written 7549 documents


In [5]:
def column_list(dataframe, column):
    return numpy.array(dataframe.select(column).rdd.flatMap(lambda x: x).collect())

def sort_titles(scores, movies, limit=25):
    titles = column_list(movies, "title").tolist()
    binary_results = numpy.argsort(scores)[-limit:][::-1]
    ranked = [titles[id] for id in binary_results]
    return list(dict.fromkeys(ranked))

def numpy_rankings(query, limit=20):
    start_dotprod = time.time()

    embeddings = column_list(movie_dataframe, "image_embedding")
    query_embedding = encode_text(query)
    dot_prod_scores = numpy.dot(embeddings, query_embedding)

    stop_dotprod = time.time(); start_binary = time.time()

    quantized_embeddings = column_list(movie_dataframe, "image_binary_embedding")
    quantized_query = numpy.array(quantize(query_embedding)) 
    binary_scores = 1536 - numpy.logical_xor(quantized_embeddings,
                                             quantized_query).sum(axis=1)
    
    stop_binary = time.time()
    
    binary_results = sort_titles(binary_scores, movie_dataframe)
    full_results = sort_titles(dot_prod_scores, movie_dataframe)
    return {"binary_query_time": stop_binary - start_binary,
            "full_query_time": stop_dotprod - start_dotprod,
            "recall": len(set(full_results).intersection(set(binary_results))) / len(set(binary_results)),
            "binary_results": binary_results,
            "full_results": full_results}

In [6]:
def only_titles(response):
    return [d["title"] for d in response["docs"]]

def base_search_request(query_vector, field, quantization_size):
    return {"query": query_vector,
            "query_fields": [field],
            "return_fields": ["movie_id", "title", "score"],
            "limit": 25,
            "k": 1000,
            "quantization_size": quantization_size}

def engine_rankings(query, log=False):
    collection = engine.get_collection("tmdb_with_embeddings")
    query_embedding = encode_text(query)    
    quantized_query = numpy.zeros_like(query_embedding, dtype=numpy.int8)
    quantized_query[query_embedding > 0] = 1

    binary_request = base_search_request(quantized_query.tolist(),
                                         "image_binary_embedding",
                                         "BINARY")
    start_dotprod = time.time()    
    binary_results = only_titles(collection.search(**binary_request))
    stop_dotprod = time.time()

    reranked_request = binary_request
    reranked_request["rerank_query"] = {
        "query": query_embedding.tolist(),
        "query_fields": ["image_embedding"],
        "k": 100,
        "rerank_quantity": 100,
        "quantization_size": "FLOAT32"}
    
    if log: print(json.dumps(reranked_request, indent=2))        
    start_reranked = time.time()    
    full_results = only_titles(collection.search(**reranked_request))
    stop_reranked = time.time()
    return {"binary_query_time": stop_reranked - start_reranked,
            "full_query_time": stop_dotprod - start_dotprod,
            "recall": len(set(full_results).intersection(set(binary_results))) / len(set(binary_results)),
            "binary_results": binary_results,
            "full_results": full_results}

In [7]:
query = "The Hobbit"
engine_scores = engine_rankings(query)
numpy_scores = numpy_rankings(query)
results = pandas.DataFrame(zip(engine_scores["binary_results"], numpy_scores["binary_results"],
                          engine_scores["full_results"], numpy_scores["full_results"]),
                          columns=["quantized solr", "quantized numpy",
                                   "dotprod solr", "dotprod numpy"])
print(f"Search engine binary search time: {engine_scores['binary_query_time']}")
print(f"Search engine full search time: {engine_scores['full_query_time']}")
print(f"Numpy binary search time: {numpy_scores['binary_query_time']}")
print(f"Numpy full search time: {numpy_scores['full_query_time']}")
results

Search engine binary search time: 0.04891395568847656
Search engine full search time: 0.08017492294311523
Numpy binary search time: 0.9986176490783691
Numpy full search time: 1.0766046047210693


Unnamed: 0,quantized solr,quantized numpy,dotprod solr,dotprod numpy
0,The Hobbit: The Desolation of Smaug,The Hobbit: The Desolation of Smaug,The Lord of the Rings: The Fellowship of the Ring,The Lord of the Rings: The Fellowship of the Ring
1,The Lord of the Rings: The Fellowship of the Ring,The Lord of the Rings: The Fellowship of the Ring,The Hobbit: An Unexpected Journey,The Hobbit: An Unexpected Journey
2,The Hobbit: The Desolation of Smaug,The Hobbit: The Battle of the Five Armies,The Princess Bride,The Princess Bride
3,Klaus,Klaus,The Hobbit: The Battle of the Five Armies,The Hobbit: The Battle of the Five Armies
4,The Hobbit: The Battle of the Five Armies,The Goonies,The Hobbit: The Battle of the Five Armies,The Hobbit: The Desolation of Smaug
5,The Goonies,The Hobbit: An Unexpected Journey,The Hobbit: An Unexpected Journey,The Lord of the Rings: The Two Towers
6,The Hobbit: The Desolation of Smaug,Labyrinth,The Lord of the Rings: The Fellowship of the Ring,The Lord of the Rings: The Return of the King
7,The Hobbit: An Unexpected Journey,The Lord of the Rings: The Return of the King,The Hobbit: The Desolation of Smaug,Guardians of the Galaxy Vol. 2
8,The Hobbit: The Battle of the Five Armies,Frozen II,The Lord of the Rings: The Fellowship of the Ring,The Last Samurai


In [8]:
random.seed(1234)

titles = column_list(movie_dataframe, "title")
random.shuffle(titles)

def mean_accuracy(f):
    return numpy.mean([f(q)["recall"] for q in tqdm.tqdm(titles[:25])])

print(f"Average quantized recall for numpy: {mean_accuracy(numpy_rankings)}")
print(f"Average quantized recall for engine: {mean_accuracy(engine_rankings)}")

100%|██████████| 25/25 [01:01<00:00,  2.44s/it]


Average quantized recall for numpy: 0.3250562238049114


100%|██████████| 25/25 [00:01<00:00, 16.52it/s]

Average quantized recall for engine: 0.5765688793490651



