In [None]:
## * install pymilvus sentence-transformers library 
!pip install --no-cache-dir python-dotenv pymilvus sentence-transformers 

In [None]:
from dotenv import load_dotenv
import os
load_dotenv()

## * Loading environment variables
MILVUS_PORT =os.getenv("MILVUS_PORT")
MILVUS_HOST =os.getenv("MILVUS_HOST")

In [4]:
## * bash standalone_embed.sh start
## * start Milvus standalone
from pymilvus import (
    connections,utility,FieldSchema,CollectionSchema,DataType,Collection
)
connections.connect(host=MILVUS_HOST,port=MILVUS_PORT)

In [19]:
COLLECTION_NAME = "movies_db"
DIMENSION = 384 ## sentence-transformers generate 384 embedding dimension.

## * Generally more dimensions means more information can be stored. and more accurate result
## * adding field to collection
fields = [
    FieldSchema(name='id', dtype=DataType.INT64, is_primary=True, auto_id=True), ## id is auto=increment
    FieldSchema(name='title', dtype=DataType.VARCHAR, max_length=200),
    FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, dim=DIMENSION)
]
schema = CollectionSchema(fields=fields)
collection = Collection(name=COLLECTION_NAME, schema=schema)


In [20]:
## * create an IVF_FLAT index for collection.
index_params = {
    'metric_type':'L2',
    'index_type':"IVF_FLAT",
    'params':{'nlist': 1536}
}
collection.create_index(field_name="embedding", index_params=index_params)
collection.load()

In [21]:
## * inserting data
import csv
from sentence_transformers import SentenceTransformer

transformer = SentenceTransformer('all-MiniLM-L6-v2')

def csv_load(file):
    with open(file, newline='') as f:
        reader = csv.reader(f, delimiter=',')
        for row in reader:
            if '' in (row[1], row[7]):
                continue
            yield (row[1], row[7])


def embed_insert(data):
    embeds = transformer.encode(data[1]) 
    ins = [
            data[0],
            [x for x in embeds]
    ]
    collection.insert(ins)


In [None]:
## * processing embeddings in batch because it takes time to insert them.
import time

BATCH_SIZE = 128

data_batch = [[],[]]

count = 0

for title, plot in csv_load('plots.csv'):
    data_batch[0].append(title)
    data_batch[1].append(plot)
    if len(data_batch[0]) % BATCH_SIZE == 0:
        embed_insert(data_batch)
        data_batch = [[],[]]
        print(f"\ninserted... {count} movies")
    count += 1

if len(data_batch[0]) != 0:
    embed_insert(data_batch)

collection.flush()

In [27]:
TOP_K = 3
search_terms = ['A movie about cars', 'A movie about monsters']

def embed_search(data):
    embeds = transformer.encode(data) 
    return [x for x in embeds]

search_data = embed_search(search_terms)

start = time.time()
res = collection.search(
    data=search_data, 
    anns_field="embedding",
    param={},
    limit = TOP_K, 
    output_fields=['title']
)
end = time.time()

for hits_i, hits in enumerate(res):
    print('Title:', search_terms[hits_i])
    print('Search Time:', end-start)
    print('Results:')
    for hit in hits:
        print( hit.entity.get('title'), '----', hit.distance)
    print()


Title: A movie about cars
Search Time: 0.023342132568359375
Results:
Red Line 7000 ---- 0.9104408025741577
Tomboy ---- 0.925471305847168
Quick Millions ---- 0.9283517003059387

Title: A movie about monsters
Search Time: 0.023342132568359375
Results:
The Monster Squad ---- 0.9770852327346802
The Butcher Boy ---- 0.9966024160385132
Maa Kasam ---- 1.0104992389678955

