In [1]:
import time

import numpy as np
from pymilvus import (
  connections,
  utility,
  FieldSchema, CollectionSchema, DataType,
  Collection,
)

fmt = "\n=== {:30} ===\n"
search_latency_fmt = "search latency = {:.4f}s"
dim = 512

In [2]:
print(fmt.format("start connecting to Milvus"))
connections.connect("default", host="localhost", port="19530")

print(fmt.format("Drop collection `cifar_10`"))
utility.drop_collection("cifar_10")

has = utility.has_collection("cifar_10")
print(f"Does collection cifar_10 exist in Milvus: {has}")


=== start connecting to Milvus     ===


=== Drop collection `cifar_10`     ===

Does collection cifar_10 exist in Milvus: False


In [3]:
fields = [
  FieldSchema(name="pk", dtype=DataType.INT64, is_primary=True, auto_id=True),
  FieldSchema(name="index", dtype=DataType.INT64),
  FieldSchema(name="label", dtype=DataType.INT8),
  FieldSchema(name="embedded", dtype=DataType.FLOAT_VECTOR, dim=dim)
]

schema = CollectionSchema(fields, "cifar-10 colleation schema for demo")

print(fmt.format("Create collection `cifar-10`"))
cifar_10 = Collection("cifar_10", schema, consistency_level="Strong")


=== Create collection `cifar-10`   ===



In [4]:
import json

print(fmt.format("Start inserting entities"))
# Read JSON file and extract vectors and IDs
with open('output/output-limited-100.json', 'r') as file:
  data = json.load(file)
  rows = data['rows']
  ids = [row['index'] for row in rows]
  labels = [row['label'] for row in rows]
  embedded = [row['embedded'] for row in rows]

entities = [ids, labels, embedded]
insert_result = cifar_10.insert(entities)
print(f"Insert result: {insert_result}")
cifar_10.flush()
print(f"Number of entities in Milvus: {cifar_10.num_entities}")


=== Start inserting entities       ===

Insert result: (insert count: 1000, delete count: 0, upsert count: 0, timestamp: 446840929364475907, success count: 1000, err count: 0)
Number of entities in Milvus: 1000


In [5]:
print(fmt.format("Start Creating index IVF_FLAT"))
index = {
  "index_type": "IVF_FLAT",
  "metric_type": "COSINE",
  "params": {"nlist": 1024},
}

cifar_10.create_index("embedded", index)


=== Start Creating index IVF_FLAT  ===



Status(code=0, message=)

In [6]:
print(fmt.format("Start loading"))
cifar_10.load()


=== Start loading                  ===



In [7]:
import torch
import open_clip

# Load CLIP model
model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32-quickgelu', pretrained='metaclip_400m')

# Function to generate image embeddings
def generate_text_embedding(text):
  text = open_clip.tokenize(text)
  with torch.no_grad():
    text_features = model.encode_text(text)
     text_features /= text_features.norm(dim=-1, keepdim=True)
  return text_features.flatten().numpy().tolist()

  from .autonotebook import tqdm as notebook_tqdm


In [8]:
# -----------------------------------------------------------------------------
# search based on vector similarity
search_params = {
  "metric_type": "COSINE",
  "params": {"nprobe": 32},
}

vectors_to_search


=== Start searching based on vector similarity ===



[[0.0004834487335756421,
  -0.006879542022943497,
  -0.01641673967242241,
  -0.01990366354584694,
  0.009547844529151917,
  0.011988690122961998,
  -0.006266705691814423,
  -0.020534906536340714,
  0.01713847741484642,
  0.003880884498357773,
  0.0015259037027135491,
  -0.014198416844010353,
  -0.01584639772772789,
  0.015597953461110592,
  -0.019504329189658165,
  0.0017762879142537713,
  -0.01606050319969654,
  0.011114765889942646,
  0.010196241550147533,
  0.018650181591510773,
  -0.03348131105303764,
  0.024296915158629417,
  -0.01291564665734768,
  -0.008845828473567963,
  -0.012448788620531559,
  -0.01745481789112091,
  -0.013791827484965324,
  -0.028261952102184296,
  -0.005722592119127512,
  -0.0023746322840452194,
  0.003261336823925376,
  -0.010689937509596348,
  -0.02255057916045189,
  -0.014502893202006817,
  -0.013779481872916222,
  -0.0027637931052595377,
  0.010926499031484127,
  0.02790708653628826,
  0.04291084036231041,
  0.013337207026779652,
  -0.008994746953248978

In [16]:
# Create a dictionary to map numbers to labels
number_to_label = {
  0: "airplane",
  1: "automobile",
  2: "bird",
  3: "cat",
  4: "deer",
  5: "dog",
  6: "frog",
  7: "horse",
  8: "ship",
  9: "truck"
}
# Function to retrieve label string based on number
def get_label(number):
  return number_to_label.get(number, "Not in label")

In [18]:
print(fmt.format("Start searching based on vector similarity"))
vectors_to_search = [generate_text_embedding("bird")]
a_label = []
top_k = 20
start_time = time.time()
result = cifar_10.search(data=vectors_to_search, anns_field="embedded", param=search_params, limit=top_k, output_fields=["label"])
end_time = time.time()

for hits in result:
  for hit in hits:
      print(f"hit: {hit}, Label: {get_label(hit.entity.get('label'))}")

print(search_latency_fmt.format(end_time - start_time))


=== Start searching based on vector similarity ===

hit: id: 446839055364414587, distance: 0.3052757978439331, entity: {'label': 2}, Label: bird
hit: id: 446839055364415001, distance: 0.30502647161483765, entity: {'label': 2}, Label: bird
hit: id: 446839055364414889, distance: 0.3032063841819763, entity: {'label': 2}, Label: bird
hit: id: 446839055364415039, distance: 0.3015752136707306, entity: {'label': 2}, Label: bird
hit: id: 446839055364415018, distance: 0.30047014355659485, entity: {'label': 2}, Label: bird
hit: id: 446839055364415351, distance: 0.30019211769104004, entity: {'label': 2}, Label: bird
hit: id: 446839055364415233, distance: 0.2999524176120758, entity: {'label': 2}, Label: bird
hit: id: 446839055364414609, distance: 0.2994062602519989, entity: {'label': 2}, Label: bird
hit: id: 446839055364414881, distance: 0.29939037561416626, entity: {'label': 2}, Label: bird
hit: id: 446839055364414879, distance: 0.29905086755752563, entity: {'label': 2}, Label: bird
hit: id: 446