# Milvus db set up

In [1]:
from pymilvus import DataType, MilvusClient, connections, Collection
import pandas as pd

In [2]:
posts_df = pd.read_parquet('posts.parquet.gzip')

In [3]:
posts_data = posts_df[['id', 'bert_descr_vector', 'tfidf_descr_vector', 'image_vector', 'tags_vector']]
posts_data = posts_data.to_dict(orient='records')

for post in posts_data:
    for vector_name in ['bert_descr_vector', 'tfidf_descr_vector', 'image_vector', 'tags_vector']:
        post[vector_name] = post[vector_name].tolist()

In [4]:
connections.connect(
  alias="default", 
  host='localhost',
  port='19530'
)

In [5]:
m_client = MilvusClient("http://localhost:19530")

In [6]:
if m_client.has_collection('post'):
    m_client.drop_collection('post')

if m_client.has_collection('user'):
    m_client.drop_collection('user')


post_schema = MilvusClient.create_schema(
    auto_id=False,
    enable_dynamic_field=False,
)

user_schema = MilvusClient.create_schema(
    auto_id=False,
    enable_dynamic_field=False,
)

post_schema.add_field(field_name="id", datatype=DataType.INT64, is_primary=True)

user_schema.add_field(field_name="id", datatype=DataType.INT64, is_primary=True)
user_schema.add_field(field_name="weight_count", datatype=DataType.INT64)

for vector_name in ['bert_descr_vector', 'tfidf_descr_vector', 'image_vector', 'tags_vector']:
    post_schema.add_field(field_name=vector_name, datatype=DataType.FLOAT_VECTOR, dim=len(posts_data[0][vector_name]))
    user_schema.add_field(field_name=vector_name, datatype=DataType.FLOAT_VECTOR, dim=len(posts_data[0][vector_name]))
    


user_index_params = m_client.prepare_index_params()
post_index_params = m_client.prepare_index_params()

for index_params in [user_index_params, post_index_params]:
    
    for vector_name in ['bert_descr_vector', 'tfidf_descr_vector', 'image_vector', 'tags_vector']:
        
        index_params.add_index(
            field_name="id",
            index_type="STL_SORT"
        )
        
        index_params.add_index(
            field_name=vector_name, 
            index_type="IVF_FLAT",
            metric_type="COSINE",
            params={ "nlist": 128 }
        )


m_client.create_collection(
    collection_name="post",
    schema=post_schema,
    index_params=post_index_params
)


m_client.create_collection(
    collection_name="user",
    schema=user_schema,
    index_params=user_index_params
)


In [7]:
m_client.get_load_state(
    collection_name="post"
)

{'state': <LoadState: Loaded>}

In [8]:
m_client.get_load_state(
    collection_name="user"
)

{'state': <LoadState: Loaded>}

In [9]:
m_client.upsert(collection_name='post', data = posts_data)

{'upsert_count': 2825, 'cost': 0}

In [12]:
res = m_client.get(collection_name='post', ids=[0])

In [13]:
res

data: ["{'id': 0, 'bert_descr_vector': [-0.0035458505, -0.10077144, 0.41746825, 0.1076589, -0.15943623, -0.054564666, 0.0778803, 0.6498423, 0.011850202, -0.3243342, 0.28670463, -0.12073275, -0.06845608, 0.457403, -0.45648322, 0.09012329, -0.24971096, -0.15717524, -0.025247041, -0.07818464, 0.25989014, 0.02354228, -0.225789, 0.2652919, 0.119403444, -0.042693846, -0.028104033, 0.36277932, -0.09475521, 0.11448097, 0.33854038, -0.18189439, -0.06162821, -0.24900906, -0.18813363, -0.08581881, -0.22480376, -0.21522434, -0.040166207, -0.010925844, -0.5524461, -0.098143294, 0.23922765, 0.064411774, -0.023454921, 0.05113361, 0.21469028, -0.24828915, -0.4144125, -0.17601463, -0.3067145, 0.25860208, -0.118359044, -0.077738404, 0.00095992384, 0.5124423, -0.00503892, -0.27376622, -0.04734125, 0.17271571, 0.16690758, 0.34316328, 0.16758034, -0.32303983, 0.15322109, 0.16731152, -0.02238362, 0.31504568, -0.5500395, 0.21407409, -0.36378366, -0.19799359, 0.0030041337, -0.08712377, 0.1357536, -0.052167904