In [13]:

import sys
import warnings
sys.path.append("..")

from aips import *
import numpy
#User Interface for Search - not in book
from IPython.display import display, Markdown, HTML, clear_output
import ipywidgets as widgets
from PIL import Image
import imageio as iio
import pickle
import requests

#Multimodal Search Implementation
from transformers import CLIPProcessor, CLIPTextModel, CLIPModel
from io import BytesIO

#load CLIP model and processor
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

warnings.filterwarnings("ignore") #Some operations warn inside a loop, we'll only need to see the first warning

engine = get_engine()

In [10]:
![ ! -d 'tmdb' ] && git clone --depth 1 https://github.com/ai-powered-search/tmdb.git
! cd tmdb && git pull
! cd tmdb && mkdir -p '../../data/tmdb/' && tar -xvf movie_image_embeddings.tgz -C '../../data/tmdb/'

Already up to date.
movie_image_embeddings.pickle


In [12]:
def read(cache_name="tmdb_movies"):
    cache_file_name = f"../data/tmdb/{cache_name}.pickle"
    with open(cache_file_name, "rb") as fd:
        return pickle.load(fd)
    
def generate_tmdb_with_embeddings_index():
    embeddings_data = read("movie_image_embeddings")
    collection = engine.create_collection("tmdb_with_embeddings")
    movies = [v for k,v in embeddings_data.items()]
    collection.add_documents(movies)

generate_tmdb_with_embeddings_index()

Wiping "tmdb_with_embeddings" collection
Creating "tmdb_with_embeddings" collection
Status: Success

Adding Documents to 'tmdb_with_embeddings' collection


## Listing 15.14

In [14]:
def load_image(full_path, log=False):   
    try:
        if full_path.startswith("http"):
            response = requests.get(full_path)
            image = Image.open(BytesIO(response.content))
        else:
            image = iio.imread(full_path)
        if log: print("File Found")
        return image
    except:
        if log: print(f"No Image Available {full_path}")
        return []
    
def movie_search(query_embedding, limit=8, log=False):
    collection = engine.get_collection("tmdb_with_embeddings")
    request = {
        "query_vector": query_embedding,
        "query_field": "image_embeddings",
        "limit": limit,
        "quantization_size": "FLOAT32"}
    if log: request["log"] = True 
    response = collection.vector_search(**request)
    if log: print(f"Vector search results {len(response['docs'])}")
    return response
    
def get_html(movies_documents):
    css = """
      <style type="text/css">
        .results { 
          margin-top: 15px; 
          display: flex; 
          flex-wrap: wrap; 
          justify-content: space-evenly; }
        .results .result { height: 250px; margin-bottom: 5px; }
      </style>"""
    
    results_html = ''
    for movie in movies_documents:
        image_file = f"http://image.tmdb.org/t/p/w780/{movie['image_id']}.jpg"
        movie_link = f"https://www.themoviedb.org/movie/{movie['movie_id']}"
        result_html = f"<img title='{movie['title']}' class='result' src='{image_file}'>"
        target_url = movie["title"]
        if len(target_url) > 0:
            result_html = f"<a href='{movie_link}' target='_blank'>" + result_html + "</a>"
        results_html += result_html
    return f"{css}<div class='results'>{results_html}</div>"

def normalize_embedding(embedding):
    return numpy.divide(embedding,
                        numpy.linalg.norm(embedding,axis=0))
                      
def compute_text_embedding(text):
    inputs = processor(text=[text], return_tensors="pt", padding=True)
    embedding = model.get_text_features(**inputs).tolist()[0]
    normalized = normalize_embedding(embedding).tolist()
    return normalized

def compute_image_embedding(image_file, remote_url=""):
    image = load_image(image_file, remote_url)
    if len(image):
        try:
            inputs = processor(images=[image], return_tensors="pt", padding=True)
            embedding = model.get_image_features(**inputs).tolist()[0]
            normalized = normalize_embedding(embedding).tolist()
            return normalized
        except:
            print("Exception in image processing")
            return []
    else:
        return []

def text_to_image_search(query):
    query_embedding = compute_text_embedding(query)
    return movie_search(query_embedding)

def image_to_image_search(image_file):
    image_embedding = compute_image_embedding(image_file)
    return movie_search(image_embedding)

def text_and_image_to_image_search(text_query, image_file):
    normalized_text_query_embedding = compute_text_embedding(text_query)
    normalized_image_embedding = compute_image_embedding(image_file)
    pooled_embedding = numpy.average(
        [normalized_text_query_embedding,
         normalized_image_embedding], axis=0).tolist()
    return movie_search(pooled_embedding)
    
def display_search_results(text_query="", image_query=None):
    if image_query:
        if text_query:
            image_results = text_and_image_to_image_search(text_query, image_query)
        else:
            image_results = image_to_image_search(image_query)
    else:
        image_results = text_to_image_search(text_query)
    
    display(HTML(get_html(image_results["docs"]))) 
    #output = widgets.Output()
    #with output:    
    #    display(widgets.HBox(layout=widgets.Layout(justify_content="center")),
    #            output)

In [15]:
#Figure 15.5
display_search_results(text_query="singing in the rain")

In [4]:
display_search_results(text_query="superheroes flying")

In [5]:
#Figure 15.6
display_search_results(text_query="superhero flying")

In [6]:
#Figure 15.7
display_search_results(image_query="../data/tmdb/delorean-query.jpg")

In [7]:
#Figure 15.8
display_search_results(text_query="superhero", image_query="../data/tmdb/delorean-query.jpg")

## Listing 15.15


## Listing 15.16
