In [None]:
# requirements
# !pip3 install -U protobuf
# !pip3 install -U grpcio-tools
!pip3 install -U pymilvus

In [None]:
import torch
import re
from urllib.parse import urlparse, parse_qs

In [None]:
from pymilvus import MilvusClient, DataType
from pymilvus.client import types


In [None]:
url = 'milvus://localhost:19530?dropifexists#mytestcollection'

In [None]:
o = urlparse(url, allow_fragments=True)
qargs = parse_qs(o.query, keep_blank_values=True)
print(o)
print(qargs)


# parse dropifexists param
milvusuri = f'http://{o.netloc}'
drop_if_exists = re.search('(^$)|(^1$)|(^t[rue]{,3}$)|(^y[es]{,2}$)', qargs['dropifexists'][0].lower()) is not None
print(milvusuri)
print(drop_if_exists)


In [None]:
client = MilvusClient(uri=milvusuri)

In [None]:
def create_collection():
    print(f"Creating collection '{o.fragment}'.")
    schema = client.create_schema(
        auto_id=False,
        enable_dynamic_fields=True,
        description='my-demo',
    )
    schema.auto_id
    schema.add_field(field_name='id', datatype=DataType.INT64, is_primary=True)
    schema.add_field(field_name='embedding', datatype=DataType.FLOAT_VECTOR, dim=5)
    index_params = client.prepare_index_params()
    index_params.add_index(
        index_name='ix_id',
        field_name='id',
        index_type='STL_SORT'
    )
    index_params.add_index(
        index_name='ix_embedding',
        field_name='embedding', 
        index_type='FLAT', # IVF_FLAT, FLAT (HNSW most accurate, needs lots of memory)
        metric_type='IP', # Inner Product (DOT product)
        # params={ "nlist": 128 } # nlist for IVF_FLAT: rule-of-thumb: 4 × sqrt(n), where n is the total number of entities in a segment
    )
    # NOTE: omit index if used as vector store
    # Index can also be created afterwards with: client.create_index(collection_name=o.fragment, index_params=index_params)
    client.create_collection(
        collection_name=o.fragment, 
        schema=schema,
        index_params=index_params,
        consistency_level='Strong'
    )


if client.has_collection(o.fragment):
    print(f"Collection '{o.fragment}' exists.")
    if drop_if_exists:
        print('dropping')
        client.release_collection(o.fragment)
        client.drop_index(collection_name=o.fragment, index_name='ix_id')
        client.drop_index(collection_name=o.fragment, index_name='ix_embedding')
        client.drop_collection(o.fragment)
        #
        print('re-creating')
        create_collection()

else:
    create_collection()


In [None]:
# client.load_collection(o.fragment)
res = client.get_load_state(collection_name=o.fragment)
print(res)
print(type(res['state']))
print(res['state'])
if not res['state'] is types.LoadState.Loaded:
    client.load_collection(o.fragment)


In [None]:
a = torch.rand(int(1e4), 5)
print(a.size())

In [None]:
## add data as dict
items = [ { 'id': i, 'embedding': a[i].tolist() } for i in range(int(1e4)) ]
response = client.upsert( # response = client.upsert(collection_name, data)
    collection_name=o.fragment,
    data=items,
)
print(response)


In [None]:
# delete some ids
res = client.delete(
    collection_name=o.fragment,
    filter="id in [4,5,6]"
)
print(res)

In [None]:
# retrieve vector with ID
query_ids = [2,4,6,8, 1746812548123]
res = client.get(
    collection_name=o.fragment,
    ids=query_ids,
    output_fields=[ 'id', 'embedding' ]
)
print(type(res))
print(res)
print(res.extra)
print(len(res))

In [None]:
b = torch.tensor([e['embedding'] for e in res], dtype=torch.float)
print(b.size())


In [None]:
# default query
res = client.query(
    collection_name=o.fragment,
    filter=f"id in {str(query_ids)}", 
    output_fields=[ 'id', 'embedding' ]
)
print(type(res))
print(res)
print(res.extra)
print(len(res))

In [None]:
# count
res = client.query(
    collection_name=o.fragment,
    filter="", 
    output_fields = [ 'count(*)' ],
)
print(res)

In [None]:
client.get_collection_stats(collection_name=o.fragment)

In [None]:
client.describe_collection(collection_name=o.fragment)

In [None]:
client.close()