In [1]:
import sys
sys.path.append("../..")

from aips import get_engine
from IPython.display import display, HTML
from pyspark.sql import SparkSession
import ipywidgets as widgets
from PIL import Image
import pickle
import requests
import numpy
import torch
import clip
from io import BytesIO

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

engine = get_engine()
spark = SparkSession.builder.appName("AIPS").getOrCreate()

100%|████████████████████████████████████████| 338M/338M [11:02<00:00, 534kiB/s]


In [2]:
![ ! -d 'tmdb' ] && git clone --depth 1 https://github.com/ai-powered-search/tmdb.git
! cd tmdb && git pull
! cd tmdb && mkdir -p '../../../data/tmdb/' && tar -xvf movies_with_image_embeddings.tgz -C '../../../data/tmdb/'

Cloning into 'tmdb'...
remote: Enumerating objects: 7, done.[K
remote: Counting objects: 100% (7/7), done.[K
remote: Compressing objects: 100% (7/7), done.[K
remote: Total 7 (delta 0), reused 6 (delta 0), pack-reused 0[K
Receiving objects: 100% (7/7), 103.98 MiB | 545.00 KiB/s, done.
Updating files: 100% (5/5), done.
Already up to date.
movies_with_image_embeddings.pickle


## Listing 15.14

In [3]:
def normalize_embedding(embedding):
    return numpy.divide(embedding,
      numpy.linalg.norm(embedding,axis=0)).tolist()

def read(cache_name):
    cache_file_name = f"../../data/tmdb/{cache_name}.pickle"
    with open(cache_file_name, "rb") as fd:
        return pickle.load(fd)

def generate_tmdb_with_embeddings_index():
    movies = read("movies_with_image_embeddings")
    embeddings = movies["image_embeddings"]
    normalized_embeddings = [ normalize_embedding(embedding) for embedding in embeddings ]
    collection = engine.create_collection("tmdb_with_embeddings")
    movies_dataframe = spark.createDataFrame(
        zip(movies["movie_ids"], movies["titles"], 
            movies["image_ids"], normalized_embeddings),
        schema=["movie_id", "title", "image_id", "image_embedding"])
    collection.write(movies_dataframe)

In [4]:
generate_tmdb_with_embeddings_index()

Wiping "tmdb_with_embeddings" collection
Creating "tmdb_with_embeddings" collection
Status: Success
Successfully written 7549 documents


## Listing 15.15

In [5]:
def load_image(full_path, log=False):   
    try:
        if full_path.startswith("http"):
            response = requests.get(full_path)
            image = Image.open(BytesIO(response.content))
        else:
            image = Image.open(full_path)
        if log: print("File Found")
        return image
    except:
        if log: print(f"No Image Available {full_path}")
        return []      

def movie_search(query_embedding, limit=8):
    collection = engine.get_collection("tmdb_with_embeddings")
    request = {
        "query_vector": query_embedding,
        "query_field": "image_embedding",
        "limit": limit,
        "quantization_size": "FLOAT32"}
    return collection.vector_search(**request)
    
def normalize_embedding(embedding):
    return numpy.divide(embedding,
      numpy.linalg.norm(embedding,axis=0)).tolist()

def encode_text(text):
    text = clip.tokenize([text]).to(device)    
    text_features = model.encode_text(text)
    embedding = text_features.tolist()[0] 
    normalized_embedding = normalize_embedding(embedding)
    return embedding
    
def encode_image(image_file):
    image = load_image(image_file)
    inputs = preprocess(image).unsqueeze(0).to(device)
    embedding = model.encode_image(inputs).tolist()[0]
    normalized_embedding = normalize_embedding(embedding)
    return embedding

def encode_text_and_image(text_query, image_file):    
    text_embedding = encode_text(text_query)
    image_embedding = encode_image(image_file)  
    return numpy.average((normalize_embedding(
        [text_embedding, image_embedding])), axis=0).tolist()

## Listing 15.16

In [6]:
def get_html(movies_documents):
    css = """
      <style type="text/css">
        .results { 
          margin-top: 15px; 
          display: flex; 
          flex-wrap: wrap; 
          justify-content: space-evenly; }
        .results .result { height: 250px; margin-bottom: 5px; }
      </style>"""
    
    results_html = ""
    for movie in movies_documents:
        image_file = f"http://image.tmdb.org/t/p/w780/{movie['image_id']}.jpg"
        movie_link = f"https://www.themoviedb.org/movie/{movie['movie_id']}"
        img_html = f"<img title='{movie['title']}' class='result' src='{image_file}'>"
        results_html += f"<a href='{movie_link}' target='_blank'>{img_html}</a>"
    return f"{css}<div class='results'>{results_html}</div>"
   
def display_results(search_results):    
    output = widgets.Output()
    with output:
        display(HTML(get_html(search_results["docs"]))) 
    display(widgets.HBox(layout=widgets.Layout(justify_content="center")), output)   

def search_and_display(text_query="", image_query=None):
    if image_query:
        if text_query:
            query_embedding = encode_text_and_image(text_query, image_query)
        else:
            query_embedding = encode_image(image_query)
    else:
        query_embedding = encode_text(text_query)
    display_results(movie_search(query_embedding))

# Figure 15.5

In [7]:
search_and_display(text_query="singing in the rain")

HBox(layout=Layout(justify_content='center'))

Output()

In [8]:
search_and_display(text_query="superhero flying")

HBox(layout=Layout(justify_content='center'))

Output()

# Figure 15.6

In [9]:
search_and_display(text_query="superheroes flying")

HBox(layout=Layout(justify_content='center'))

Output()

# Figure 15.7

In [10]:
search_and_display(image_query="delorean-query.jpg")

HBox(layout=Layout(justify_content='center'))

Output()

# Figure 15.8

In [11]:
search_and_display(text_query="superhero", image_query="delorean-query.jpg")

HBox(layout=Layout(justify_content='center'))

Output()