In this tutorial, we will prepare Milvus database for storing and searching nodes and edges of a graph.

In particular, we are using PrimeKG multimodal data from the BioBridge project.

In [1]:
# Load necessary libraries
import glob
import os
import time

import cudf
import cupy as cp
import hydra
from pymilvus import (
    Collection,
    CollectionSchema,
    DataType,
    FieldSchema,
    connections,
    db,
    utility,
)
from tqdm import tqdm

### Loading BioBridge-PrimeKG Multimodal Data

First, we need to get the path to the directory containing the parquet files of nodes and edges.

For nodes and edges, we have a separate folder that contains its enrichment and embeddings.

In [2]:
# Load hydra configuration
with hydra.initialize(
    version_base=None,
    config_path="../../../aiagents4pharma/talk2knowledgegraphs/configs",
):
    cfg = hydra.compose(
        config_name="config", overrides=["tools/multimodal_subgraph_extraction=default"]
    )
    cfg = cfg.tools.multimodal_subgraph_extraction
cfg

{'_target_': 'talk2knowledgegraphs.tools.multimodal_subgraph_extraction', 'ollama_embeddings': ['nomic-embed-text'], 'temperature': 0.1, 'streaming': False, 'topk': 5, 'topk_e': 5, 'cost_e': 0.5, 'c_const': 0.01, 'root': -1, 'num_clusters': 1, 'pruning': 'gw', 'verbosity_level': 0, 'node_id_column': 'node_id', 'node_attr_column': 'node_attr', 'edge_src_column': 'edge_src', 'edge_attr_column': 'edge_attr', 'edge_dst_column': 'edge_dst', 'node_colors_dict': {'gene/protein': '#6a79f7', 'molecular_function': '#82cafc', 'cellular_component': '#3f9b0b', 'biological_process': '#c5c9c7', 'drug': '#c4a661', 'disease': '#80013f'}, 'biobridge': {'source': '/mnt/blockstorage/biobridge_multimodal/', 'node_type': ['gene/protein', 'molecular_function', 'cellular_component', 'biological_process', 'drug', 'disease']}}

In [3]:
# You can set the source directory for biobridge data here
cfg.biobridge.source = "/mnt/blockstorage/biobridge_multimodal"

In [4]:
%%time

# Loop over nodes and edges
graph_dict = {}
for element in ["nodes", "edges"]:
    # Make an empty dictionary for each folder
    graph_dict[element] = {}
    for stage in ["enrichment", "embedding"]:
        print(element, stage)
        # Create the file pattern for the current subfolder
        file_list = glob.glob(
            os.path.join(cfg.biobridge.source, element, stage, "*.parquet.gzip")
        )
        print(file_list)
        # Read and concatenate all dataframes in the folder
        # Except the edges embedding, which is too large to read in one go
        # We are using a chunk size to read the edges embedding in smaller parts instead
        if element == "edges" and stage == "embedding":
            # For edges embedding, only read two columns: triplet_index and edge_emb
            # graph_dict[element][stage] = cudf.concat([cudf.read_parquet(f, columns=["triplet_index", "edge_emb"]) for f in file_list[:2]], ignore_index=True)
            # Loop by chunks
            # file_list = file_list[:2]
            chunk_size = 5
            graph_dict[element][stage] = []
            for i in range(0, len(file_list), chunk_size):
                chunk_files = file_list[i : i + chunk_size]
                chunk_df = cudf.concat(
                    [
                        cudf.read_parquet(f, columns=["triplet_index", "edge_emb"])
                        for f in chunk_files
                    ],
                    ignore_index=True,
                )
                graph_dict[element][stage].append(chunk_df)
        else:
            # For nodes and edges enrichment, read and concatenate all dataframes in the folder
            # This includes the nodes embedding, which is small enough to read in one go
            graph_dict[element][stage] = cudf.concat(
                [cudf.read_parquet(f) for f in file_list], ignore_index=True
            )

nodes enrichment
['/mnt/blockstorage/biobridge_multimodal/nodes/enrichment/drug.parquet.gzip', '/mnt/blockstorage/biobridge_multimodal/nodes/enrichment/cellular_component.parquet.gzip', '/mnt/blockstorage/biobridge_multimodal/nodes/enrichment/gene_protein.parquet.gzip', '/mnt/blockstorage/biobridge_multimodal/nodes/enrichment/disease.parquet.gzip', '/mnt/blockstorage/biobridge_multimodal/nodes/enrichment/biological_process.parquet.gzip', '/mnt/blockstorage/biobridge_multimodal/nodes/enrichment/molecular_function.parquet.gzip']
nodes embedding
['/mnt/blockstorage/biobridge_multimodal/nodes/embedding/drug.parquet.gzip', '/mnt/blockstorage/biobridge_multimodal/nodes/embedding/cellular_component.parquet.gzip', '/mnt/blockstorage/biobridge_multimodal/nodes/embedding/gene_protein.parquet.gzip', '/mnt/blockstorage/biobridge_multimodal/nodes/embedding/disease.parquet.gzip', '/mnt/blockstorage/biobridge_multimodal/nodes/embedding/biological_process.parquet.gzip', '/mnt/blockstorage/biobridge_mu

In [5]:
# Get nodes enrichment and embedding dataframes
nodes_enrichment_df = graph_dict["nodes"]["enrichment"]
nodes_embedding_df = graph_dict["nodes"]["embedding"]

# Get edges enrichment and embedding dataframes
edges_enrichment_df = graph_dict["edges"]["enrichment"]
edges_embedding_df = graph_dict["edges"][
    "embedding"
]  # !!consisted of a list of dataframes!!

In [6]:
# Merge nodes enrichment and embedding dataframes
merged_nodes_df = nodes_enrichment_df.merge(
    nodes_embedding_df[["node_id", "desc_emb", "feat_emb"]], on="node_id", how="left"
)
# del nodes_enrichment_df, nodes_embedding_df  # Free memory

### Normalization of the embeddings

Helper to normalize the vector/matrix.

In [2]:
def normalize_matrix(m, axis=1):
    """
    Normalize each row of a 2D matrix using CuPy.

    Parameters:
    m (cupy.ndarray): 2D matrix to normalize.

    Returns:
    cupy.ndarray: Normalized matrix.
    """
    norms = cp.linalg.norm(m, axis=axis, keepdims=True)
    return m / norms


def normalize_vector(v):
    """
    Normalize a vector using CuPy.

    Parameters:
    v (cupy.ndarray): Vector to normalize.

    Returns:
    cupy.ndarray: Normalized vector.
    """
    v = cp.asarray(v)
    norm = cp.linalg.norm(v)
    return v / norm

### Setup Milvus Database

In [3]:
# Configuration for Milvus
milvus_host = "localhost"
milvus_port = "19530"
milvus_uri = "http://localhost:19530"
milvus_token = "root:Milvus"
milvus_user = "root"
milvus_password = "Milvus"
milvus_database = "t2kg_primekg"

In [4]:
# Connect to Milvus
connections.connect(
    alias="default",
    host=milvus_host,
    port=milvus_port,
    user=milvus_user,
    password=milvus_password,
)

In [5]:
# Check if the database exists, create if it doesn't
if milvus_database not in db.list_database():
    db.create_database(milvus_database)

# Switch to the desired database
db.using_database(milvus_database)

In [11]:
# List all collections
for coll in utility.list_collections():
    print(f"Collection: {coll}")

    # Load the collection to get stats
    collection = Collection(name=coll)
    print(collection.num_entities)

In [12]:
# A helper function to chunk the data into smaller parts
# Utility: chunk generator
def chunked(data_list, chunk_size):
    for i in range(0, len(data_list), chunk_size):
        yield data_list[i : i + chunk_size]

#### Building Node Collection (Description Embedding)

In [13]:
%%time

# Configuration for Milvus collection
node_coll_name = f"{milvus_database}_nodes"

# Define schema for the collection
# Leave out the feat and feat_emb fields for now
desc_emb_dim = len(merged_nodes_df.iloc[0]["desc_emb"].to_arrow().to_pylist()[0])
node_fields = [
    FieldSchema(name="node_index", dtype=DataType.INT64, is_primary=True),
    FieldSchema(name="node_id", dtype=DataType.VARCHAR, max_length=1024),
    FieldSchema(
        name="node_name",
        dtype=DataType.VARCHAR,
        max_length=1024,
        enable_analyzer=True,
        enable_match=True,
    ),
    FieldSchema(
        name="node_type",
        dtype=DataType.VARCHAR,
        max_length=1024,
        enable_analyzer=True,
        enable_match=True,
    ),
    FieldSchema(
        name="desc",
        dtype=DataType.VARCHAR,
        max_length=40960,
        enable_analyzer=True,
        enable_match=True,
    ),
    FieldSchema(name="desc_emb", dtype=DataType.FLOAT_VECTOR, dim=desc_emb_dim),
]
schema = CollectionSchema(
    fields=node_fields, description=f"Schema for collection {node_coll_name}"
)

# Create collection if it doesn't exist
if not utility.has_collection(node_coll_name):
    collection = Collection(name=node_coll_name, schema=schema)
else:
    collection = Collection(name=node_coll_name)

# Create indexes
collection.create_index(
    field_name="node_index",
    index_params={"index_type": "STL_SORT"},  # STL_SORT
    index_name="node_index_index",
)
# Create index for node_name, node_type, desc fields (inverted)
collection.create_index(
    field_name="node_name",
    index_params={"index_type": "INVERTED"},
    index_name="node_name_index",
)
collection.create_index(
    field_name="node_type",
    index_params={"index_type": "INVERTED"},
    index_name="node_type_index",
)
collection.create_index(
    field_name="desc", index_params={"index_type": "INVERTED"}, index_name="desc_index"
)
collection.create_index(
    field_name="desc_emb",
    index_params={"index_type": "GPU_CAGRA", "metric_type": "IP"},  # AUTOINDEX
    index_name="desc_emb_index",
)

# Prepare data for insertion
# Normalize the embeddings
graph_desc_emb_cp = (
    cp.asarray(merged_nodes_df["desc_emb"].list.leaves)
    .astype(cp.float32)
    .reshape(merged_nodes_df.shape[0], -1)
)
graph_desc_emb_norm = normalize_matrix(graph_desc_emb_cp, axis=1)
data = [
    merged_nodes_df["node_index"].to_arrow().to_pylist(),
    merged_nodes_df["node_id"].to_arrow().to_pylist(),
    merged_nodes_df["node_name"].to_arrow().to_pylist(),
    merged_nodes_df["node_type"].to_arrow().to_pylist(),
    merged_nodes_df["desc"].to_arrow().to_pylist(),
    graph_desc_emb_norm.tolist(),  # Use normalized embeddings
]

# Insert data in batches
batch_size = 500
total = len(data[0])
for i in tqdm(range(0, total, batch_size)):
    batch = [col[i : i + batch_size] for col in data]
    collection.insert(batch)

# Flush to persist data
collection.flush()

# Get collection stats
print(collection.num_entities)

100%|██████████| 170/170 [01:52<00:00,  1.51it/s]


84981
CPU times: user 28.2 s, sys: 3.68 s, total: 31.9 s
Wall time: 2min 3s


In [14]:
# List all collections
for coll in utility.list_collections():
    print(f"Collection: {coll}")

    # Load the collection to get stats
    collection = Collection(name=coll)
    print(collection.num_entities)

Collection: t2kg_primekg_nodes
84981


In [9]:
%%time

# Assume node_coll_name is defined and collection exists
collection = Collection("t2kg_primekg_nodes")

# Load the collection into memory before query
collection.load()

# Query by expr on node_index
expr = "node_index == 13814"
output_fields = ["node_index", "node_id", "node_name", "node_type", "desc", "desc_emb"]

results = collection.query(expr, output_fields=output_fields)

print(results)

data: ["{'node_name': 'Copper', 'node_type': 'drug', 'desc': 'Copper belongs to drug node. Copper is a transition metal and a trace element in the body. It is important to the function of many enzymes including cytochrome c oxidase, monoamine oxidase and superoxide dismutase. Copper is commonly used in contraceptive intrauterine devices (IUD). Copper is absorbed from the gut via high affinity copper uptake protein and likely through low affinity copper uptake protein and natural resistance-associated macrophage protein-2. It is believed that copper is reduced to the Cu1+ form prior to transport. Once inside the enterocyte, it is bound to copper transport protein ATOX1 which shuttles the ion to copper transporting ATPase-1 on the golgi membrane which take up copper into the golgi apparatus. Once copper has been secreted by enterocytes into the systemic circulation it remain largely bound by ceruloplasmin (65-90%), albumin (18%), and alpha 2-macroglobulin (12%).  Copper is nearly entirel

In [10]:
%%time

# Assume node_coll_name is defined and collection exists
collection = Collection("t2kg_primekg_nodes")

# Load the collection into memory before query
collection.load()

# Query by expr on node_index
expr = "node_index in [13814, 13815]"
output_fields = ["node_index", "node_id", "node_name", "node_type", "desc", "desc_emb"]

results = collection.query(expr, output_fields=output_fields)

print(results)

data: ["{'node_type': 'drug', 'desc': 'Copper belongs to drug node. Copper is a transition metal and a trace element in the body. It is important to the function of many enzymes including cytochrome c oxidase, monoamine oxidase and superoxide dismutase. Copper is commonly used in contraceptive intrauterine devices (IUD). Copper is absorbed from the gut via high affinity copper uptake protein and likely through low affinity copper uptake protein and natural resistance-associated macrophage protein-2. It is believed that copper is reduced to the Cu1+ form prior to transport. Once inside the enterocyte, it is bound to copper transport protein ATOX1 which shuttles the ion to copper transporting ATPase-1 on the golgi membrane which take up copper into the golgi apparatus. Once copper has been secreted by enterocytes into the systemic circulation it remain largely bound by ceruloplasmin (65-90%), albumin (18%), and alpha 2-macroglobulin (12%).  Copper is nearly entirely bound by ceruloplasmi

In [15]:
%%time

vector_to_search = merged_nodes_df["desc_emb"].loc[2]


# Assume node_coll_name is defined and collection exists
collection = Collection("t2kg_primekg_nodes")

# Load the collection into memory before query
collection.load()

# Vector similarity search in Milvus
vector_to_search = normalize_vector(merged_nodes_df["desc_emb"].iloc[2]).tolist()
search_params = {"metric_type": "IP"}
results = collection.search(
    data=[vector_to_search],
    anns_field="desc_emb",
    param=search_params,
    limit=10,
    output_fields=["node_id", "node_name"],
)

results

CPU times: user 140 ms, sys: 39.1 ms, total: 179 ms
Wall time: 521 ms


data: [[{'node_index': 11504, 'distance': 1.0000001192092896, 'entity': {'node_name': 'SSTR2', 'node_id': 'SSTR2_(11630)'}}, {'node_index': 28075, 'distance': 0.9640881419181824, 'entity': {'node_name': 'SSTR4', 'node_id': 'SSTR4_(34965)'}}, {'node_index': 13207, 'distance': 0.9590871334075928, 'entity': {'node_name': 'SSTR1', 'node_id': 'SSTR1_(13376)'}}, {'node_index': 10873, 'distance': 0.9556770920753479, 'entity': {'node_name': 'SSTR3', 'node_id': 'SSTR3_(10983)'}}, {'node_index': 11786, 'distance': 0.9445285797119141, 'entity': {'node_name': 'SSTR5', 'node_id': 'SSTR5_(11920)'}}, {'node_index': 12024, 'distance': 0.9246837496757507, 'entity': {'node_name': 'SST', 'node_id': 'SST_(12165)'}}, {'node_index': 15309, 'distance': 0.9021860361099243, 'entity': {'node_name': 'Somatostatin', 'node_id': 'Somatostatin_(15618)'}}, {'node_index': 10048, 'distance': 0.8932695388793945, 'entity': {'node_name': 'SCTR', 'node_id': 'SCTR_(10137)'}}, {'node_index': 78455, 'distance': 0.891120433807

In [16]:
merged_nodes_df.loc[2]

Unnamed: 0,node_index,primekg_node_index,node_id,node_name,node_type,desc,feat,desc_emb,feat_emb
2,11504,11630,SSTR2_(11630),SSTR2,gene/protein,SSTR2 belongs to gene/protein node. SSTR2 is s...,MDMADEPLNGSHTWLSIPFDLNGSVVSTNTSNQTEPYYDLTSNAVL...,"[-0.036013797, 0.002569047, -0.0049933116, -0....","[0.10192450135946274, -0.04799338057637215, 0...."


#### Building Node Collection (Node Type-specific Embedding)

Note that nodes information of the PrimeKG data is different for each node type, 
we are going to build a separate collection for each node type.

We will use the node type as the collection name.

In [15]:
%%time

# Loop over group enrichment nodes by node_type
for node_type, nodes_df in tqdm(merged_nodes_df.groupby("node_type")):
    print(f"Processing node type: {node_type}")

    # Milvus collection name for this node_type
    node_coll_name = f"{milvus_database}_nodes_{node_type.replace('/', '_')}"

    # Define collection schema
    desc_emb_dim = len(nodes_df.iloc[0]["desc_emb"].to_arrow().to_pylist()[0])
    feat_emb_dim = len(nodes_df.iloc[0]["feat_emb"].to_arrow().to_pylist()[0])
    node_fields = [
        FieldSchema(
            name="node_index", dtype=DataType.INT64, is_primary=True, auto_id=False
        ),
        FieldSchema(name="node_id", dtype=DataType.VARCHAR, max_length=1024),
        FieldSchema(
            name="node_name",
            dtype=DataType.VARCHAR,
            max_length=1024,
            enable_analyzer=True,
            enable_match=True,
        ),
        FieldSchema(
            name="node_type",
            dtype=DataType.VARCHAR,
            max_length=1024,
            enable_analyzer=True,
            enable_match=True,
        ),
        FieldSchema(
            name="desc",
            dtype=DataType.VARCHAR,
            max_length=40960,
            enable_analyzer=True,
            enable_match=True,
        ),
        FieldSchema(name="desc_emb", dtype=DataType.FLOAT_VECTOR, dim=desc_emb_dim),
        FieldSchema(
            name="feat",
            dtype=DataType.VARCHAR,
            max_length=40960,
            enable_analyzer=True,
            enable_match=True,
        ),
        FieldSchema(name="feat_emb", dtype=DataType.FLOAT_VECTOR, dim=feat_emb_dim),
    ]
    schema = CollectionSchema(
        fields=node_fields, description=f"schema for collection {node_coll_name}"
    )

    # Create collection if not exists
    if not utility.has_collection(node_coll_name):
        collection = Collection(name=node_coll_name, schema=schema)
    else:
        collection = Collection(name=node_coll_name)

    # Create index for node_index field (scalar)
    collection.create_index(
        field_name="node_index",
        index_params={"index_type": "STL_SORT"},
        index_name="node_index_index",
    )
    # Create index for node_name, node_type, desc fields (inverted)
    collection.create_index(
        field_name="node_name",
        index_params={"index_type": "INVERTED"},
        index_name="node_name_index",
    )
    collection.create_index(
        field_name="node_type",
        index_params={"index_type": "INVERTED"},
        index_name="node_type_index",
    )
    collection.create_index(
        field_name="desc",
        index_params={"index_type": "INVERTED"},
        index_name="desc_index",
    )
    collection.create_index(
        field_name="desc_emb",
        index_params={"index_type": "GPU_CAGRA", "metric_type": "IP"},  # AUTOINDEX
        index_name="desc_emb_index",
    )
    # Create index for feat_emb (vector)
    collection.create_index(
        field_name="feat_emb",
        index_params={"index_type": "GPU_CAGRA", "metric_type": "IP"},  # AUTOINDEX
        index_name="feat_emb_index",
    )

    # Prepare data for insertion
    # Normalize the embeddings
    graph_desc_emb_cp = (
        cp.asarray(nodes_df["desc_emb"].list.leaves)
        .astype(cp.float32)
        .reshape(nodes_df.shape[0], -1)
    )
    graph_desc_emb_norm = normalize_matrix(graph_desc_emb_cp, axis=1)
    graph_feat_emb_cp = (
        cp.asarray(nodes_df["feat_emb"].list.leaves)
        .astype(cp.float32)
        .reshape(nodes_df.shape[0], -1)
    )
    graph_feat_emb_norm = normalize_matrix(graph_feat_emb_cp, axis=1)
    # Columns must be lists of values in order matching schema fields
    data = [
        nodes_df["node_index"].to_arrow().to_pylist(),
        nodes_df["node_id"].to_arrow().to_pylist(),
        nodes_df["node_name"].to_arrow().to_pylist(),
        nodes_df["node_type"].to_arrow().to_pylist(),
        nodes_df["desc"].to_arrow().to_pylist(),
        graph_desc_emb_norm.tolist(),  # Use normalized embeddings
        nodes_df["feat"].to_arrow().to_pylist(),
        graph_feat_emb_norm.tolist(),  # Use normalized embeddings
    ]

    # Batch insert data in chunks
    batch_size = 500
    total_rows = len(data[0])
    for i in tqdm(range(0, total_rows, batch_size)):
        batch = [col[i : i + batch_size] for col in data]
        collection.insert(batch)

    # Flush the collection to ensure data is persisted
    collection.flush()

    # Print collection stats (number of entities and segment info)
    stats = collection.num_entities
    print(f"Collection {node_coll_name} stats:")
    print(stats)

  0%|          | 0/6 [00:00<?, ?it/s]

Processing node type: biological_process


100%|██████████| 55/55 [00:48<00:00,  1.12it/s]
 17%|█▋        | 1/6 [00:59<04:57, 59.56s/it]

Collection t2kg_primekg_nodes_biological_process stats:
27409
Processing node type: cellular_component


100%|██████████| 9/9 [00:07<00:00,  1.24it/s]
 33%|███▎      | 2/6 [01:14<02:12, 33.03s/it]

Collection t2kg_primekg_nodes_cellular_component stats:
4011
Processing node type: disease


100%|██████████| 35/35 [00:30<00:00,  1.15it/s]
 50%|█████     | 3/6 [01:51<01:45, 35.14s/it]

Collection t2kg_primekg_nodes_disease stats:
17054
Processing node type: drug


100%|██████████| 14/14 [00:11<00:00,  1.18it/s]
 67%|██████▋   | 4/6 [02:08<00:56, 28.09s/it]

Collection t2kg_primekg_nodes_drug stats:
6759
Processing node type: gene/protein


100%|██████████| 38/38 [00:41<00:00,  1.09s/it]
 83%|████████▎ | 5/6 [02:58<00:35, 35.68s/it]

Collection t2kg_primekg_nodes_gene_protein stats:
18797
Processing node type: molecular_function


100%|██████████| 22/22 [00:19<00:00,  1.14it/s]
100%|██████████| 6/6 [03:24<00:00, 34.10s/it]

Collection t2kg_primekg_nodes_molecular_function stats:
10951
CPU times: user 1min 5s, sys: 8.56 s, total: 1min 14s
Wall time: 3min 24s





In [16]:
# List all collections
for coll in utility.list_collections():
    print(f"Collection: {coll}")

    # Load the collection to get stats
    collection = Collection(name=coll)
    print(collection.num_entities)

Collection: t2kg_primekg_nodes_disease
17054
Collection: t2kg_primekg_nodes_drug
6759
Collection: t2kg_primekg_nodes_gene_protein
18797
Collection: t2kg_primekg_nodes_molecular_function
10951
Collection: t2kg_primekg_nodes
84981
Collection: t2kg_primekg_nodes_biological_process
27409
Collection: t2kg_primekg_nodes_cellular_component
4011


In [12]:
%%time

# Assume node_coll_name is defined and collection exists
collection = Collection("t2kg_primekg_nodes_gene_protein")

# Load the collection into memory before query
collection.load()

# Vector similarity search in Milvus
vector_to_search = normalize_vector(merged_nodes_df["feat_emb"].iloc[2]).tolist()
search_params = {"metric_type": "IP"}
results = collection.search(
    data=[vector_to_search],
    anns_field="feat_emb",
    param=search_params,
    limit=10,
    output_fields=["node_id", "node_name"],
)

results

CPU times: user 13.9 ms, sys: 366 μs, total: 14.2 ms
Wall time: 22.9 ms


data: [[{'node_index': 7072, 'distance': 0.9999998807907104, 'entity': {'node_id': 'SOX14_(7119)', 'node_name': 'SOX14'}}, {'node_index': 460, 'distance': 0.9957526922225952, 'entity': {'node_id': 'SOX2_(460)', 'node_name': 'SOX2'}}, {'node_index': 13546, 'distance': 0.9949035048484802, 'entity': {'node_id': 'SOX7_(13730)', 'node_name': 'SOX7'}}, {'node_index': 49952, 'distance': 0.9946167469024658, 'entity': {'node_id': 'SOX21_(57709)', 'node_name': 'SOX21'}}, {'node_index': 6972, 'distance': 0.9945659041404724, 'entity': {'node_id': 'SOX9_(7019)', 'node_name': 'SOX9'}}, {'node_index': 9392, 'distance': 0.9940434694290161, 'entity': {'node_id': 'LEF1_(9476)', 'node_name': 'LEF1'}}, {'node_index': 1678, 'distance': 0.9940416216850281, 'entity': {'node_id': 'SOX10_(1682)', 'node_name': 'SOX10'}}, {'node_index': 11064, 'distance': 0.9938754439353943, 'entity': {'node_id': 'TOX2_(11177)', 'node_name': 'TOX2'}}, {'node_index': 6659, 'distance': 0.993757426738739, 'entity': {'node_id': 'TOX

In [13]:
# Check the ground truth for the search
merged_nodes_df.loc[2]

Unnamed: 0,node_index,primekg_node_index,node_id,node_name,node_type,desc,feat,desc_emb,feat_emb
2,7072,7119,SOX14_(7119),SOX14,gene/protein,SOX14 belongs to gene/protein node. SOX14 is S...,MSKPSDHIKRPMNAFMVWSRGQRRKMAQENPKMHNSEISKRLGAEW...,"[-0.02060385, -0.020443613, -0.017091982, -0.0...","[0.049131330102682114, 0.014837502501904964, 0..."


In [14]:
# Get node indices from the results
[n["node_index"] for n in results[0]]

[7072, 460, 13546, 49952, 6972, 9392, 1678, 11064, 6659, 5090]

In [15]:
# Get the cosine similarity scores
[n["distance"] for n in results[0]]

[0.9999998807907104,
 0.9957526922225952,
 0.9949035048484802,
 0.9946167469024658,
 0.9945659041404724,
 0.9940434694290161,
 0.9940416216850281,
 0.9938754439353943,
 0.993757426738739,
 0.9937493205070496]

#### Building Edge Collection

Subsquently, we are also building the edges collection in Milvus.

Note that the edges information of PrimeKG has massive records, so once again we are chunking the data to avoid memory issues.

In [17]:
%%time

# Define collection name
edge_coll_name = f"{milvus_database}_edges"

# Define schema
edge_fields = [
    FieldSchema(
        name="triplet_index", dtype=DataType.INT64, is_primary=True, auto_id=False
    ),
    FieldSchema(name="head_id", dtype=DataType.VARCHAR, max_length=1024),
    FieldSchema(name="head_index", dtype=DataType.INT64),
    FieldSchema(name="tail_id", dtype=DataType.VARCHAR, max_length=1024),
    FieldSchema(name="tail_index", dtype=DataType.INT64),
    FieldSchema(name="edge_type", dtype=DataType.VARCHAR, max_length=1024),
    FieldSchema(name="display_relation", dtype=DataType.VARCHAR, max_length=1024),
    FieldSchema(name="feat", dtype=DataType.VARCHAR, max_length=40960),
    FieldSchema(name="feat_emb", dtype=DataType.FLOAT_VECTOR, dim=1536),
]
edge_schema = CollectionSchema(
    fields=edge_fields, description="Schema for edges collection"
)

# Create collection if not exists
if not utility.has_collection(edge_coll_name):
    collection = Collection(name=edge_coll_name, schema=edge_schema)
else:
    collection = Collection(name=edge_coll_name)

# Create indexes
collection.create_index(
    field_name="triplet_index",
    index_params={"index_type": "STL_SORT"},
    index_name="triplet_index_index",
)
collection.create_index(
    field_name="head_index",
    index_params={"index_type": "STL_SORT"},
    index_name="head_index_index",
)
collection.create_index(
    field_name="tail_index",
    index_params={"index_type": "STL_SORT"},
    index_name="tail_index_index",
)
collection.create_index(
    field_name="feat_emb",
    index_params={"index_type": "GPU_CAGRA", "metric_type": "IP"},
    index_name="feat_emb_index",
)  # AUTOINDEX

# Iterate over chunked edges embedding df
for edges_df in tqdm(edges_embedding_df):
    # Merge enrichment with embedding
    merged_edges_df = edges_enrichment_df.merge(
        edges_df[["triplet_index", "edge_emb"]], on="triplet_index", how="inner"
    )

    # Prepare data fields in column-wise format
    # Normalize the embeddings
    edges_edge_emb_cp = (
        cp.asarray(merged_edges_df["edge_emb"].list.leaves)
        .astype(cp.float32)
        .reshape(merged_edges_df.shape[0], -1)
    )
    edges_edge_emb_norm = normalize_matrix(edges_edge_emb_cp, axis=1)
    data = [
        merged_edges_df["triplet_index"].to_arrow().to_pylist(),
        merged_edges_df["head_id"].to_arrow().to_pylist(),
        merged_edges_df["head_index"].to_arrow().to_pylist(),
        merged_edges_df["tail_id"].to_arrow().to_pylist(),
        merged_edges_df["tail_index"].to_arrow().to_pylist(),
        merged_edges_df["edge_type_str"].to_arrow().to_pylist(),
        merged_edges_df["display_relation"].to_arrow().to_pylist(),
        merged_edges_df["feat"].to_arrow().to_pylist(),
        edges_edge_emb_norm.tolist(),  # Use normalized embeddings
    ]

    # Insert in chunks
    batch_size = 500
    for i in tqdm(range(0, len(data[0]), batch_size)):
        batch_data = [d[i : i + batch_size] for d in data]
        collection.insert(batch_data)

    # Flush to ensure persistence
    collection.flush()

    # Print collection stats
    print(collection.num_entities)

    time.sleep(5)  # Sleep to avoid overwhelming the server

100%|██████████| 500/500 [05:31<00:00,  1.51it/s]


250000


100%|██████████| 500/500 [05:34<00:00,  1.49it/s]


500000


100%|██████████| 500/500 [05:36<00:00,  1.49it/s]


750000


100%|██████████| 500/500 [05:34<00:00,  1.49it/s]


1000000


100%|██████████| 500/500 [05:33<00:00,  1.50it/s]


1250000


100%|██████████| 410/410 [04:34<00:00,  1.49it/s]


1454610


100%|██████████| 500/500 [05:36<00:00,  1.49it/s]


1704610


100%|██████████| 500/500 [05:34<00:00,  1.50it/s]


1954610


100%|██████████| 500/500 [05:33<00:00,  1.50it/s]


2204610


100%|██████████| 500/500 [05:34<00:00,  1.50it/s]


2454610


100%|██████████| 500/500 [05:36<00:00,  1.49it/s]]


2704610


100%|██████████| 500/500 [05:31<00:00,  1.51it/s]]


2954610


100%|██████████| 500/500 [05:36<00:00,  1.49it/s]]


3204610


100%|██████████| 500/500 [05:33<00:00,  1.50it/s]]


3454610


100%|██████████| 500/500 [05:34<00:00,  1.49it/s]]


3704610


100%|██████████| 400/400 [04:27<00:00,  1.50it/s]]


3904610


100%|██████████| 16/16 [1:35:30<00:00, 358.18s/it]

CPU times: user 23min 24s, sys: 2min 56s, total: 26min 20s
Wall time: 1h 35min 33s





In [6]:
# List all collections
for coll in utility.list_collections():
    print(f"Collection: {coll}")

    # Load the collection to get stats
    collection = Collection(name=coll)
    print(collection.num_entities)

Collection: t2kg_primekg_nodes_molecular_function
10951
Collection: t2kg_primekg_edges
3904610
Collection: t2kg_primekg_nodes
84981
Collection: t2kg_primekg_nodes_biological_process
27409
Collection: t2kg_primekg_nodes_cellular_component
4011
Collection: t2kg_primekg_nodes_disease
17054
Collection: t2kg_primekg_nodes_drug
6759
Collection: t2kg_primekg_nodes_gene_protein
18797


In [7]:
# Assume node_coll_name is defined and collection exists
collection = Collection("t2kg_primekg_edges")

# Load the collection into memory before query
collection.load()

In [8]:
%%time

# Assume node_coll_name is defined and collection exists
collection = Collection("t2kg_primekg_edges")

# Load the collection into memory before query
collection.load()

# Query by expr on triplet_index
expr = "triplet_index == 0"
output_fields = ["triplet_index", "head_id", "tail_id", "edge_type", "feat", "feat_emb"]

results = collection.query(expr, output_fields=output_fields)
results

CPU times: user 5 ms, sys: 0 ns, total: 5 ms
Wall time: 564 ms


data: ["{'feat': 'PHYHIP (gene/protein) has a direct relationship of protein_protein:ppi with KIF15 (gene/protein).', 'feat_emb': [-0.01934238, 0.0011752498, 0.004431808, -0.033904973, -0.01556057, 0.011148459, 0.0011645806, 0.028704984, 0.012691384, -0.005199988, 0.020970657, 0.01423431, 0.009421695, 0.0033714569, -0.0027230997, -0.0106954295, 0.015586833, -0.015100975, 0.01833127, 0.014812087, -0.008581293, 0.019014098, 0.005584078, -0.036977693, 0.010511592, 0.034535274, 0.010833308, -0.026236303, -0.025107013, -0.037502944, 0.027575694, -0.019513085, -0.001110414, 0.015416126, 0.0045368583, -0.0110171465, 0.009841897, -0.021863585, 0.016400972, -0.0021190608, -0.0024818124, -0.0077540223, -0.006516399, 0.0013426737, -0.0038474659, -0.0059189256, 0.0074257404, -0.022231262, -0.017438343, 0.008292405, -0.014102997, 0.002123985, -0.012579769, -0.020051468, 0.0068545295, -0.0032368612, 0.0050850892, 0.00063440506, -0.028127206, -0.004412111, 0.0011465251, 0.0029463316, -0.008568162, -0

In [9]:
# Check the ground truth for the search
(
    results[0]["triplet_index"],
    results[0]["head_id"],
    results[0]["tail_id"],
    results[0]["edge_type"],
    results[0]["feat"],
)

(0,
 'PHYHIP_(0)',
 'KIF15_(8889)',
 'gene/protein|ppi|gene/protein',
 'PHYHIP (gene/protein) has a direct relationship of protein_protein:ppi with KIF15 (gene/protein).')

In [10]:
import numpy as np

In [11]:
%%time

# Assume node_coll_name is defined and collection exists
collection = Collection("t2kg_primekg_edges")

# Load the collection into memory before query
collection.load()

# Vector similarity search in Milvus
vector_to_search = normalize_vector(
    np.array(results[0]["feat_emb"])
).tolist()  # merged_edges_df["edge_emb"].iloc[0]
search_params = {"metric_type": "IP"}
results = collection.search(
    data=[vector_to_search],
    anns_field="feat_emb",
    param=search_params,
    limit=10,
    output_fields=["head_id", "tail_id", "edge_type", "feat"],
)
results

CPU times: user 509 ms, sys: 154 ms, total: 664 ms
Wall time: 1.05 s


data: [[{'triplet_index': 0, 'distance': 1.000000238418579, 'entity': {'edge_type': 'gene/protein|ppi|gene/protein', 'feat': 'PHYHIP (gene/protein) has a direct relationship of protein_protein:ppi with KIF15 (gene/protein).', 'head_id': 'PHYHIP_(0)', 'tail_id': 'KIF15_(8889)'}}, {'triplet_index': 3069556, 'distance': 0.9815623760223389, 'entity': {'edge_type': 'gene/protein|ppi|gene/protein', 'feat': 'KIF15 (gene/protein) has a direct relationship of protein_protein:ppi with PHYHIP (gene/protein).', 'head_id': 'KIF15_(8889)', 'tail_id': 'PHYHIP_(0)'}}, {'triplet_index': 93582, 'distance': 0.9684444665908813, 'entity': {'edge_type': 'gene/protein|ppi|gene/protein', 'feat': 'PHYHIP (gene/protein) has a direct relationship of protein_protein:ppi with PRKD2 (gene/protein).', 'head_id': 'PHYHIP_(0)', 'tail_id': 'PRKD2_(9221)'}}, {'triplet_index': 305788, 'distance': 0.9657970666885376, 'entity': {'edge_type': 'gene/protein|ppi|gene/protein', 'feat': 'PHYHIP (gene/protein) has a direct relat

In [12]:
# Get node indices from the results
[n["triplet_index"] for n in results[0]]

[0, 3069556, 93582, 305788, 39303, 300682, 3154085, 3201926, 3334423, 110476]

In [13]:
# Get the cosine similarity scores
[n["distance"] for n in results[0]]

[1.000000238418579,
 0.9815623760223389,
 0.9684444665908813,
 0.9657970666885376,
 0.9635176658630371,
 0.9599261283874512,
 0.9597105979919434,
 0.9593513607978821,
 0.9588445425033569,
 0.9586011171340942]