In [1]:
%cd ..

/Users/westford14/Desktop/projects/watchlist/watchlist-recommender


In [77]:
import json
import logging
import os
from ast import literal_eval
from collections import Counter, defaultdict
from typing import Any, List, Optional, Tuple

import faiss
import numpy as np
import pandas as pd
import torch
from sentence_transformers import SentenceTransformer
from sklearn.cluster import KMeans
from tqdm import tqdm
from umap import UMAP

In [3]:
torch.set_num_threads(1)

In [78]:
class SimilarityMeasure:
    def __init__(
        self,
        embed_model_name: str = "all-MiniLM-L6-v2",
        embed_device: str = "cpu",
        embed_batch_size: int = 64,
        embed_max_seq_length: int = 512,
        umap_components: int = 20,
        umap_metric: str = "euclidean",
    ) -> None:
        self.embed_model_name = embed_model_name
        self.embed_device = embed_device
        self.embed_batch_size = embed_batch_size
        self.embed_max_seq_length = embed_max_seq_length

        self.umap_components = umap_components
        self.umap_metric = umap_metric

        self.embeddings = None
        self.faiss_index = None
        self.cluster_labels = None
        self.texts = None
        self.projections = None
        self.umap_mapper = None

        self.embed_model = SentenceTransformer(
            self.embed_model_name, device=self.embed_device
        )
        self.embed_model.max_seq_length = self.embed_max_seq_length

        self.index_to_id = None

    def fit(
        self,
        texts: List[str],
        ids: List[int],
        embeddings: Optional[Any] = None
    ) -> Tuple[List[Any], List[int]]:
        self.texts = texts

        if embeddings is None:
            logging.info("embedding texts...")
            self.embeddings = self.embed(texts)
        else:
            logging.info("using precomputed embeddings...")
            self.embeddings = embeddings

        logging.info("building faiss index...")
        self.faiss_index = self.build_faiss_index(self.embeddings)
        self.index_to_id = dict(zip(list(range(self.faiss_index.ntotal)), ids))
        logging.info("projecting with umap...")
        self.projections, self.umap_mapper = self.project(self.embeddings)
        return self.embeddings

    def infer(self, texts: List[str], top_k: int = 1) -> Tuple[List[int], List[Any]]:
        embeddings = self.embed(texts)
        _, neighbours = self.faiss_index.search(embeddings, top_k)

        return neighbours, embeddings

    def embed(self, texts: List[str]) -> List[Any]:
        embeddings = self.embed_model.encode(
            texts,
            batch_size=self.embed_batch_size,
            show_progress_bar=True,
            convert_to_numpy=True,
            normalize_embeddings=True,
        )

        return embeddings

    def project(self, embeddings: List[Any]) -> Tuple[List[Any], UMAP]:
        mapper = UMAP(n_components=self.umap_components, metric=self.umap_metric).fit(
            embeddings
        )
        return mapper.embedding_, mapper

    def build_faiss_index(self, embeddings: List[Any]) -> List[int]:
        index = faiss.IndexFlatL2(embeddings.shape[1])
        index.add(embeddings)
        return index

    def save(self, folder: str) -> None:
        if not os.path.exists(folder):
            os.makedirs(folder)

        with open(f"{folder}/embeddings.npy", "wb") as f:
            np.save(f, self.embeddings)

        with open(f"{folder}/index_to_label.json", "w") as f:
            json.dump(self.index_to_label, f)

        faiss.write_index(self.faiss_index, f"{folder}/faiss.index")

    def load(self, folder: str) -> None:
        if not os.path.exists(folder):
            raise ValueError(f"The folder '{folder}' does not exsit.")

        with open(f"{folder}/embeddings.npy", "rb") as f:
            self.embeddings = np.load(f)

        with open(f"{folder}/index_to_label.json", "r") as f:
            self.index_to_label = json.load(f)

        self.faiss_index = faiss.read_index(f"{folder}/faiss.index")

### Cleaning

In [65]:
movies = pd.read_csv("data/movies_metadata.csv")

  movies = pd.read_csv("data/movies_metadata.csv")


In [66]:
movies = movies[["id", "original_title", "overview", "genres"]]
movies = movies.rename(columns={
    "id": "movie_id",
    "original_title": "title"
})

In [67]:
genres = [
    'Animation',
    'Comedy',
    'Family',
    'Adventure',
    'Fantasy',
    'Romance',
    'Drama',
    'Action',
    'Crime',
    'Thriller',
    'Horror',
    'History',
    'Science Fiction',
    'Mystery',
    'War',
    'Foreign',
    'Music',
    'Documentary',
    'Western'
]

In [68]:
def cleaner(x, genres=genres):
    ret = []
    for y in x:
        if y in genres:
            ret.append(y)
    return ret

In [69]:
movies["genres"] = movies["genres"].apply(lambda x: [y["name"] for y in literal_eval(x)])
movies["genres"] = movies["genres"].apply(cleaner)
movies["text"] = movies["title"] + " " + movies["overview"] + " " + movies["genres"].apply(lambda x: " ".join(x))

In [70]:
movies["movie_id"] = pd.to_numeric(movies["movie_id"], errors="coerce")
movies = movies.dropna()

In [71]:
train_data = movies[["movie_id", "text"]]
train_data = train_data.dropna()

In [72]:
movie_dict = movies.set_index("movie_id")["title"].to_dict()

In [73]:
train_data.head()

Unnamed: 0,movie_id,text
0,862.0,"Toy Story Led by Woody, Andy's toys live happi..."
1,8844.0,Jumanji When siblings Judy and Peter discover ...
2,15602.0,Grumpier Old Men A family wedding reignites th...
3,31357.0,"Waiting to Exhale Cheated on, mistreated and s..."
4,11862.0,Father of the Bride Part II Just when George B...


In [79]:
len(train_data)

44509

In [80]:
sim_measure = SimilarityMeasure(umap_components=20)

In [81]:
sim_measure.fit(train_data["text"].values.tolist(), ids=train_data["movie_id"].values.tolist())

Batches: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 696/696 [06:42<00:00,  1.73it/s]


array([[ 0.04767714, -0.01259631,  0.07834646, ...,  0.03473042,
         0.08336627,  0.05029294],
       [ 0.04895022,  0.09311377, -0.01963529, ..., -0.00398075,
        -0.09352873, -0.02675354],
       [-0.06341458, -0.00527815, -0.01627528, ...,  0.01918863,
        -0.01144707, -0.00563937],
       ...,
       [-0.02502596, -0.02891008, -0.05267037, ..., -0.05833787,
        -0.01803897, -0.04464827],
       [-0.02990002,  0.06295763, -0.08454257, ..., -0.04886457,
        -0.00978885, -0.03164884],
       [-0.04454112,  0.0115746 , -0.09546685, ..., -0.06198069,
        -0.01385289, -0.10803522]], shape=(44509, 384), dtype=float32)

In [82]:
movies.head()

Unnamed: 0,movie_id,title,overview,genres,text
0,862.0,Toy Story,"Led by Woody, Andy's toys live happily in his ...","[Animation, Comedy, Family]","Toy Story Led by Woody, Andy's toys live happi..."
1,8844.0,Jumanji,When siblings Judy and Peter discover an encha...,"[Adventure, Fantasy, Family]",Jumanji When siblings Judy and Peter discover ...
2,15602.0,Grumpier Old Men,A family wedding reignites the ancient feud be...,"[Romance, Comedy]",Grumpier Old Men A family wedding reignites th...
3,31357.0,Waiting to Exhale,"Cheated on, mistreated and stepped on, the wom...","[Comedy, Drama, Romance]","Waiting to Exhale Cheated on, mistreated and s..."
4,11862.0,Father of the Bride Part II,Just when George Banks has recovered from his ...,[Comedy],Father of the Bride Part II Just when George B...


In [85]:
preds, _ = sim_measure.infer([train_data[train_data["movie_id"].isin([8844])]["text"].iloc[0]], top_k=10)

Batches: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 41.03it/s]


In [86]:
for i, pred in enumerate(preds[0]):
    movie_id = train_data.iloc[pred]["movie_id"]
    title = movies[movies["movie_id"].isin([movie_id])]["title"]
    print(f"Top {i + 1}: {title}")

Top 1: 1    Jumanji
Name: title, dtype: object
Top 2: 892    The Wizard of Oz
Name: title, dtype: object
Top 3: 40472    Over the Garden Wall
Name: title, dtype: object
Top 4: 12416    The Spiderwick Chronicles
Name: title, dtype: object
Top 5: 7011    Peter Pan
Name: title, dtype: object
Top 6: 43887    George of the Jungle 2
Name: title, dtype: object
Top 7: 1978    Peter Pan
Name: title, dtype: object
Top 8: 18209    The Gruffalo
Name: title, dtype: object
Top 9: 2682    Big
Name: title, dtype: object
Top 10: 14859    Tom Thumb
Name: title, dtype: object
