In [1]:
import pickle
import sys
sys.path.append("../..")

from aips.spark.dataframe import from_csv
from pyspark.conf import SparkConf
from pyspark.sql import SparkSession
import os
from itertools import groupby
import aips.data_loaders.movies as movies
import requests
import shutil
from tqdm import tqdm
import tarfile
import clip
import torch

remote_image_path = "http://image.tmdb.org/t/p/w780/"
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device="cpu")

conf = SparkConf()
conf.set("spark.driver.memory", "8g")
conf.set("spark.executor.memory", "8g")
conf.set("spark.dynamicAllocation.enabled", "true")
conf.set("spark.dynamicAllocation.executorMemoryOverhead", "8g")
spark = SparkSession.builder.appName("AIPS").config(conf=conf).getOrCreate()

100%|███████████████████████████████████████| 338M/338M [00:16<00:00, 20.9MiB/s]


In [20]:
def dump(data, cache_name):
    cache_file_name = f"../../data/tmdb/{cache_name}.pickle"
    os.makedirs(os.path.dirname(cache_file_name), exist_ok=True)
    with open(cache_file_name, "wb") as fd:
        pickle.dump(data, fd)

def dump_dataframe(movies_dataframe, cache_name):
    movies = movies_dataframe.rdd.map(lambda row: row.asDict()).collect()
    dump(movies, cache_name=cache_name)

def read(cache_name):
    cache_file_name = f"../../data/tmdb/{cache_name}.pickle"
    with open(cache_file_name, "rb") as fd:
        return pickle.load(fd)

def compress_file(file_name, root_path="../../data/tmdb/"):    
    with tarfile.open(f"{root_path}{file_name}.tgz", "w:gz") as tar:
        tar.add(f"{root_path}{file_name}.pickle", arcname=f"{file_name}.pickle")

def load_image(file_name, load_remote, log=False):
    full_local_path = f"../../data/tmdb/large_movie_images/{file_name}.jpg"    
    try:
        exists = os.path.exists(full_local_path)
        if log: print(f"{full_local_path} exists: {exists}")
        if not exists and load_remote:
            remote_image_url = f"{remote_image_path}{file_name}.jpg"
            response = requests.get(remote_image_url, stream=True)
            with open(full_local_path, 'wb') as out_file:
                shutil.copyfileobj(response.raw, out_file)
            del response
            if log: print(f"Downloaded: {full_local_path}")
        image = Image.open(full_local_path)
        if log: print("File Found")
        return image
    except:
        if log: print(f"No Image Available {full_local_path}")
        return []

In [3]:
def download_movie_images(movie_file="tmdb_movies"):
    movie_data = read(movie_file)
    for movie in tqdm(movie_data, total=len(movie_data)):
        if movie["movie_image_ids"]:
            for image_id in movie["movie_image_ids"].split(","):                
                load_image(image_id, True)
                
def compute_image_embedding(image_id, log=False):
    try:        
        image = load_image(image_id, True, log=log)
        inputs = preprocess(image).unsqueeze(0).to("cpu")
        return model.encode_image(inputs).tolist()[0]
    except Exception as e:
        if log: print(f"Image processing exception: {e}")
        return []

In [4]:
def generate_movie_data_with_image_ids(cache_name="tmdb_movies", log=False):
    "Merges movie image ids with base movie data, generating a movie "
    title_movie_map_file = "../../data/tmdb/movie_data.csv"
    dataframe = from_csv(title_movie_map_file)
    movie_image_ids = {}
    for k, g in groupby([row.asDict() for row in dataframe.collect()],
                        lambda m: m["tooltip"].lower()):
        ids = [m["path"].split("/")[-1][:-4] for m in g]
        movie_image_ids[k] = ids
    
    movie_dataframe = movies.load_dataframe("../../data/tmdb.json", movie_image_ids)
    dump_dataframe(movie_dataframe, cache_name)
    compress_file(cache_name)

def generate_image_embeddings_data(cache_name="movies_with_image_embeddings", log=False):
    "Calculates and caches embeddings for all large movie images"
    movie_data = read("tmdb_movies")
    movie_ids, titles, image_ids, embeddings = [], [], [], []
    for movie in tqdm(movie_data, total=len(movie_data)):
        if movie["movie_image_ids"]:
            for image_id in movie["movie_image_ids"].split(","):
                embedding = compute_image_embedding(image_id, log=log)
                print(embedding)
                if embedding:
                    movie_ids.append(movie["id"])
                    titles.append(movie["title"])
                    image_ids.append(image_id)
                    embeddings.append(embedding)

    embeddings_data = { "movie_ids": movie_ids, "titles": titles, "image_ids": image_ids, "image_embeddings": embeddings }
    dump(embeddings_data, cache_name)
    compress_file(cache_name)
    return embeddings_data

In [5]:
#embeddings = generate_image_embeddings_data()