In [None]:
import pandas as pd

# download from https://www.kaggle.com/datasets/rounakbanik/the-movies-dataset/data

movies = pd.read_csv('dataset/movies_metadata.csv', low_memory=False)
movies.shape

In [None]:
movies.columns

In [None]:
from math import isnan
from pprint import pprint

trimmed_movies = movies[["title", "overview", "release_date", "genres"]]
trimmed_movies.head(4)

unclean_movie_dict = trimmed_movies.to_dict('records')
print('{} movies'.format(len(unclean_movie_dict)))

movies_dict = []

for movie in unclean_movie_dict:
    if movie["overview"] == movie["overview"] and movie["release_date"] == movie["release_date"] and movie["genres"] == movie["genres"]:
        movies_dict.append(movie)

print('{} movies'.format(len(movies_dict)))

In [None]:
from pymilvus import *

milvus_uri = "http://192.168.1.106:19530"

connections.connect(uri=milvus_uri)

print("Connected!")

In [None]:
from pymilvus import FieldSchema, DataType, CollectionSchema, utility, Collection

COLLECTION_NAME = 'film_vectors'
PARTITION_NAME = 'movie'

id = FieldSchema(name='title', dtype=DataType.VARCHAR, max_length=500, is_primary=True)
field = FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, dim=384)


schema = CollectionSchema(fields=[id, field], description="movie recommender: film vectors", enable_dynamic_field=True)

if utility.has_collection(COLLECTION_NAME):
    collection = Collection(COLLECTION_NAME)
    collection.drop()

collection = Collection(COLLECTION_NAME, schema)
print("Collection created.")

index_params = {
    "index_type": "IVF_FLAT",
    "metric_type": "L2",
    "params": {"nlist":128}
}

collection.create_index(field_name="embedding", index_params=index_params)
collection.load()

collection.flush()

print("Collection indexed!")

In [None]:
from sentence_transformers import SentenceTransformer
import ast

def build_genres(data):
    genres = data['genres']
    genre_list = ""
    entries= ast.literal_eval(genres)
    genres = ""
    for entry in entries:
        genre_list = genre_list + entry["name"] + ", "
    genres += genre_list
    genres = "".join(genres.rsplit(",", 1))
    return genres

transformer = SentenceTransformer('all-MiniLM-L6-v2')

def embed_movie(data):
    embed = "{} Released on {}. Genres are {}.".format(data["overview"], data["release_date"], build_genres(data))
    embeddings = transformer.encode(embed)
    return embeddings

In [None]:
j = 0
batch = []

for movie_dict in movies_dict:
    if pd.isnull(movie_dict['title']):
        continue

    try:
        movie_dict["embedding"] = embed_movie(movie_dict)
        batch.append(movie_dict)
        j += 1
        if j % 50 == 0:
            print("Embedded {} records".format(j))
            collection.insert(batch)
            print("Batch insert completed")
            batch = []
    except Exception as e:
        print("Error inserting record {}, {}".format(e))
        pprint(batch)
        break

if len(batch) > 0:
    collection.insert(batch)


print("Final batch completed")
print("Finished with {} embeddings".format(j))

In [None]:
collection.load()

topK = 4
SEARCH_PARAM = {
    "metric_type":"L2",
    "params":{"nprobe":20}
}

def embed_search(search_string):
    search_embeddings = transformer.encode(search_string)
    return search_embeddings

def search_for_movies(search_string):
    user_vector = embed_search(search_string)
    return collection.search([user_vector], "embedding", param=SEARCH_PARAM, limit=topK, expr=None, output_fields=['title', 'overview'])

In [None]:
from pprint import pprint


search_string = "A comedy from the 1990s set in a hospital. The main characters are in their 20s and are trying to stop a vampire."
results = search_for_movies(search_string)

for hits in results:
    for hit in hits:
        print(hit.entity.get('title'))
        print(hit.entity.get('overview'))
        print("-------------------------------")