In [1]:
import numpy as np
from tqdm import tqdm

from pymilvus import (
    connections,
    utility,
    FieldSchema,
    CollectionSchema,
    DataType,
    Collection,
)



In [2]:
vecs = np.load("data/als_vecs.npy")
n_vecs, dim = vecs.shape

In [3]:
connections.connect("default", host="localhost", port="19530")

In [4]:
if utility.has_collection("tracks"):
    utility.drop_collection("tracks")

In [5]:
fields = [
    FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=False),
    FieldSchema(name="embeddings", dtype=DataType.FLOAT_VECTOR, dim=dim)
]
schema = CollectionSchema(fields, "")
tracks = Collection("tracks", schema)

In [6]:
tracks

<Collection>:
-------------
<name>: tracks
<description>: 
<schema>: {'auto_id': False, 'description': '', 'fields': [{'name': 'id', 'description': '', 'type': <DataType.INT64: 5>, 'is_primary': True, 'auto_id': False}, {'name': 'embeddings', 'description': '', 'type': <DataType.FLOAT_VECTOR: 101>, 'params': {'dim': 128}}]}

In [7]:
batch_size = 1000

for i in tqdm(range(0, n_vecs, batch_size)):
    entities = [list(range(i, min(i + batch_size, n_vecs))), vecs[i:min(i + batch_size, n_vecs)].tolist()]
    insert_result = tracks.insert(entities)
    
tracks.flush()  

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 170/170 [00:03<00:00, 47.50it/s]


In [8]:
utility.has_collection("tracks")

True

In [9]:
n_vecs, tracks.num_entities

(169542, 169542)

In [10]:
index = {
    "index_type": "IVF_FLAT",
    "metric_type": "L2",
    "params": {"nlist": 128},
}
tracks.create_index("embeddings", index)
tracks.load()

In [11]:
def similar(id):
    vectors_to_search = [vecs[id].tolist()]
    search_params = {
        "metric_type": "L2",
        "params": {"nprobe": 10},
    }
    resp = tracks.search(vectors_to_search, "embeddings", search_params, limit=10, output_fields=["id", "embeddings"])
    return [r.entity.get('id') for r in resp[0]]

In [12]:
similar(0)

[0, 34574, 6761, 14766, 74274, 46600, 19795, 3331, 30335, 36532]

In [13]:
similar(1)

[1, 21776, 18594, 9976, 17919, 8285, 12400, 6736, 18560, 13668]

In [14]:
similar(2)

[2, 1189, 1488, 20638, 10329, 51061, 16045, 18576, 1116, 52169]

In [18]:
res = tracks.query(
  expr = "id in [0]",
  offset = 0,
  limit = 1, 
  output_fields = ["id", "embeddings"],
)
res[0]['embeddings']

[-0.017452085,
 0.021967735,
 0.022460308,
 -0.021715969,
 0.027727226,
 0.0207262,
 0.016178701,
 0.010607587,
 0.03436831,
 0.015008648,
 -0.01537897,
 0.030006101,
 0.03922965,
 0.0012650349,
 0.0010564083,
 -0.032457393,
 0.036207158,
 0.0019614426,
 -0.010713417,
 -0.03149738,
 -0.039326146,
 0.0016342064,
 -0.022431495,
 0.012658413,
 0.044591963,
 0.0015363134,
 0.028252825,
 -0.0070731505,
 -0.008412766,
 -0.008343729,
 -0.0021883226,
 0.07450654,
 0.029140215,
 0.0082035065,
 0.055540353,
 -0.03285737,
 0.027390094,
 -0.020342346,
 0.02483456,
 0.014269789,
 0.02332753,
 0.02702579,
 0.03391621,
 -0.0028902378,
 -0.0071795243,
 0.027912192,
 0.003600293,
 0.0495467,
 -0.033584736,
 0.02751022,
 -0.011300132,
 -0.0022968533,
 -0.0010127255,
 0.001966205,
 -0.010524735,
 0.019233627,
 0.027640305,
 -0.00066645717,
 0.031088293,
 -0.02983251,
 0.037521902,
 0.07228329,
 0.0012775419,
 0.021634575,
 0.016209813,
 0.040289737,
 0.018820524,
 0.008956291,
 -0.0040700827,
 -0.0196164