In [92]:
from pymilvus import model
from pymilvus import MilvusClient, DataType
import json
import math

In [93]:
client = MilvusClient("./milvus_demo.db")

client.drop_collection(collection_name="my_sparse_collection")

schema = client.create_schema(
    auto_id=True,
    enable_dynamic_fields=True,
)

schema.add_field(field_name="pk", datatype=DataType.VARCHAR, is_primary=True, max_length=100)
schema.add_field(field_name="id", datatype=DataType.VARCHAR, is_primary=False, max_length=100)
schema.add_field(field_name="text", datatype=DataType.VARCHAR, is_primary=False, max_length=10000)
schema.add_field(field_name="embeddings", datatype=DataType.SPARSE_FLOAT_VECTOR)

{'auto_id': True, 'description': '', 'fields': [{'name': 'pk', 'description': '', 'type': <DataType.VARCHAR: 21>, 'params': {'max_length': 100}, 'is_primary': True, 'auto_id': False}, {'name': 'id', 'description': '', 'type': <DataType.VARCHAR: 21>, 'params': {'max_length': 100}}, {'name': 'text', 'description': '', 'type': <DataType.VARCHAR: 21>, 'params': {'max_length': 10000}}, {'name': 'embeddings', 'description': '', 'type': <DataType.SPARSE_FLOAT_VECTOR: 104>}], 'enable_dynamic_field': False}

In [94]:
index_params = client.prepare_index_params()

index_params.add_index(field_name="embeddings",
                       index_name="sparse_inverted_index",
                       index_type="SPARSE_INVERTED_INDEX",
                       metric_type="IP",
                       params={"drop_ratio_build": 0.2}
                       )

In [95]:
client.create_collection(
    collection_name="my_sparse_collection",
    schema=schema,
    index_params=index_params
)

In [96]:
embeddings_model = model.sparse.SpladeEmbeddingFunction(
    model_name="ibm-granite/granite-embedding-30m-sparse",
    device="cuda",
    batch_size=2,
    k_tokens_query=50,
    k_tokens_document=192
)

In [97]:
from docuverse.utils.embeddings.sparse_embedding_function import SparseEmbeddingFunction
embeddings_model1 = SparseEmbeddingFunction("ibm-granite/granite-embedding-30m-sparse",
                                            batch_size=1,
                                            doc_max_tokens=192,
                                            query_max_tokens=50,
                                            process_name="ingestion")
embeddings_model = embeddings_model1

=== done initializing model


In [98]:
# Prepare documents to be ingested
docs = [
    "Artificial intelligence was founded as an academic discipline in 1956.",
    "Alan Turing was the first person to conduct substantial research in AI.",
    "Born in Maida Vale, London, Turing was raised in southern England.",
]

# SpladeEmbeddingFunction.encode_documents returns sparse matrix or sparse array depending
# on the milvus-model version. reshape(1,-1) ensures the format is correct for ingestion.
doc_vector = [{"embeddings": doc_emb.reshape(1,-1),
               "text": doc_text,
               "id": f"item_{i}"}
              for i, (doc_emb, doc_text) in enumerate(zip(embeddings_model.encode_documents(docs), docs))]


client.insert(
    collection_name="my_sparse_collection",
    data=doc_vector
)

Processed candidates: 100%|██████████| 3/3 [00:00<00:00, 186.67it/s]


{'insert_count': 3, 'ids': ['460565663105155084', '460565663105155085', '460565663105155086'], 'cost': 0}

In [99]:
# Prepare search parameters
search_params = {
    "params": {"drop_ratio_search": 0.2},  # Additional optional search parameters
}

# Prepare the query vector

queries = [
    "When was artificial intelligence founded",
    "Where was Turing born?",
    "Who was the first person to work in AI?"
]
answers = [
    'item0',
    'item2',
    'item1'
]
query_vector = embeddings_model.encode_documents(queries)

Processed candidates: 100%|██████████| 3/3 [00:00<00:00, 167.87it/s]


In [100]:
res = client.search(
    collection_name="my_sparse_collection",
    data=query_vector,
    limit=2, #top k documents to return
    output_fields=["id", "text", "embeddings"],
    search_params=search_params,
)

for r in res:
    print(r)

[{'id': '460565663105155084', 'distance': 12.359456062316895, 'entity': {'embeddings': {0: 0.0, 1: 0.0, 2: 0.0, 3: 0.0, 4: 0.0, 5: 0.0, 6: 0.0, 7: 0.0, 8: 0.0, 9: 0.0, 10: 0.0, 11: 0.0, 12: 0.0, 13: 0.0, 14: 0.0, 15: 0.0, 16: 0.0, 17: 0.0, 18: 0.0, 19: 0.0, 20: 0.0, 21: 0.0, 22: 0.0, 23: 0.0, 24: 0.0, 25: 0.0, 26: 0.0, 27: 0.0, 28: 0.0, 29: 0.0, 30: 0.0, 31: 0.0, 32: 0.0, 33: 0.0, 34: 0.0, 35: 0.0, 36: 0.0, 37: 0.0, 38: 0.0, 39: 0.0, 40: 0.0, 41: 0.0, 42: 0.0, 43: 0.0, 44: 0.0, 45: 0.0, 46: 0.0, 47: 0.0, 48: 0.0, 49: 0.0, 50: 0.0, 51: 0.0, 52: 0.0, 53: 0.0, 54: 0.0, 55: 0.0, 56: 0.0, 57: 0.0, 58: 0.0, 59: 0.0, 60: 0.0, 61: 0.0, 62: 0.0, 63: 0.0, 64: 0.0, 65: 0.0, 66: 0.0, 67: 0.0, 68: 0.0, 69: 0.0, 70: 0.0, 71: 0.0, 72: 0.0, 73: 0.0, 74: 0.0, 75: 0.0, 76: 0.0, 77: 0.0, 78: 0.0, 79: 0.0, 80: 0.0, 81: 0.0, 82: 0.0, 83: 0.0, 84: 0.0, 85: 0.0, 86: 0.0, 87: 0.0, 88: 0.0, 89: 0.0, 90: 0.0, 91: 0.0, 92: 0.0, 93: 0.0, 94: 0.0, 95: 0.0, 96: 0.0, 97: 0.0, 98: 0.0, 99: 0.0, 100: 0.0, 101: 0.0, 10

In [101]:
ee=res[0][0].data['entity']['embeddings']
def get_vector(e):
    if hasattr(e, 'data'):
        e = e.data['entity']['embeddings']
    aa = sorted([(k, v) for k, v in ee.items() if math.fabs(v)>0.001], key=lambda x: x[1], reverse=True)
    return embeddings_model.model.convert_token_ids_to_tokens([aa])

In [102]:
get_vector(res[0][0].data['entity']['embeddings'])

[[('ĠAI', 1.6640625),
  ('Ġintelligence', 1.4921875),
  ('Ġartificial', 1.25),
  ('Ġdiscipline', 1.21875),
  ('Ġfounded', 1.0546875),
  ('Ġ1956', 1.03125),
  ('Ġinvention', 0.9765625),
  ('56', 0.71484375),
  ('Ġlearning', 0.69140625),
  ('Ġscientific', 0.68359375),
  ('Ġcomputer', 0.66015625),
  ('Ġacademic', 0.6171875),
  ('Ġuniversity', 0.578125),
  ('Ġrobot', 0.5703125),
  ('Ġestablishment', 0.55078125),
  ('Ġphilosophy', 0.54296875),
  ('A', 0.494140625),
  ('Ġbrain', 0.486328125),
  ('Ġmachine', 0.4453125),
  ('1960', 0.4453125),
  ('1950', 0.431640625),
  ('Ġalgorithm', 0.416015625),
  ('Ġscience', 0.384765625),
  ('Ġregression', 0.37890625),
  ('ĠDiscipline', 0.330078125),
  ('Ġcomput', 0.330078125),
  ('Ġinstitute', 0.306640625),
  ('Ġtechnology', 0.27734375),
  ('Ġautomatic', 0.265625),
  ('Ġphilosopher', 0.22265625),
  ('Ġclassification', 0.2041015625),
  ('ĠEvolution', 0.1845703125),
  ('Ġpublication', 0.11083984375),
  ('ĠIndia', 0.08251953125),
  ('history', 0.06787109375

In [103]:
query_vector[0].indices
# get_vector(doc_vector[0]['embeddings'])

array([    0,     1,     2,     3,     4,     5,     6,     7,     8,
           9,    10,    11,    12,    13,    14,    15,    16,    17,
          18,    19,    20,    21,    22,    23,    24,    25,    26,
          27,    28,    29,    30,    31,    32,    33,    34,    35,
          36,    37,    38,    39,    40,    41,    42,    43,    44,
          45,    46,    47,    48,    49,    50,    51,    52,    53,
          54,    55,    56,    57,    58,    59,    60,    61,    62,
          63,    64,    65,    66,    67,    68,    69,    70,    71,
          72,    73,    74,    75,    76,    77,    78,    79,    80,
          81,    82,    83,    84,    85,    86,    87,    88,    89,
          90,    91,    92,    93,    94,    95,    96,    97,    98,
          99,   100,   101,   102,   103,   104,   105,   106,   107,
         108,   109,   110,   111,   112,   113,   114,   115,   116,
         117,   118,   119,   120,   121,   122,   123,   124,   125,
         126,   127,

In [113]:
{f"{i}": float(query_vector[0].data[j]) for j, i in enumerate(query_vector[0].indices) if query_vector[0].data[j]>0}

{'250': 0.39453125,
 '806': 0.4765625,
 '2226': 0.1787109375,
 '2239': 0.69140625,
 '2316': 1.546875,
 '2866': 0.171875,
 '2900': 0.369140625,
 '3034': 0.73046875,
 '3563': 0.40625,
 '3742': 0.39453125,
 '4687': 1.65625,
 '4790': 1.2578125,
 '5423': 0.2353515625,
 '6441': 0.41015625,
 '7147': 0.61328125,
 '7350': 1.4140625,
 '8408': 0.2041015625,
 '9916': 0.61328125,
 '10561': 0.345703125,
 '11767': 0.03076171875,
 '14578': 0.1318359375,
 '16807': 0.138671875,
 '17194': 0.3359375,
 '20257': 0.0966796875,
 '26101': 1.0,
 '28034': 0.30078125,
 '29991': 0.08984375,
 '31024': 0.40625,
 '37283': 0.357421875,
 '39974': 0.1455078125}