In this tutorial, we will prepare Milvus database for storing and searching nodes and edges of a graph.

In particular, we are using PrimeKG multimodal data from the BioBridge project.

In [3]:
# Load necessary libraries
import glob
import os
import pickle
import time

import cudf
import cupy as cp
import numpy as np
from langchain_openai import OpenAIEmbeddings
from pymilvus import (
    Collection,
    CollectionSchema,
    DataType,
    FieldSchema,
    connections,
    db,
    utility,
)
from tqdm import tqdm

### Setup OpenAI API Key for Re-Embedding

In [None]:
os.environ["OPENAI_API_KEY"] = "your_openai_api_key_here"

emb_model = OpenAIEmbeddings(
    model="text-embedding-ada-002", openai_api_key=os.environ["OPENAI_API_KEY"]
)

### Loading IBD BioBridge-PrimeKG Multimodal Data

First, we need to get the path to the directory containing the parquet files of nodes and edges.

For nodes and edges, we have a separate folder that contains its enrichment and embeddings.

In [6]:
# Load pickle of the graph data
with open(
    "../../../aiagents4pharma/talk2knowledgegraphs/tests/files/biobridge_multimodal_pyg_graph.pkl",
    "rb",
) as f:
    graph = pickle.load(f)

  from .autonotebook import tqdm as notebook_tqdm


In [7]:
def normalize_matrix(m, axis=1):
    """
    Normalize each row of a 2D matrix using CuPy.

    Parameters:
    m (cupy.ndarray): 2D matrix to normalize.

    Returns:
    cupy.ndarray: Normalized matrix.
    """
    norms = cp.linalg.norm(m, axis=axis, keepdims=True)
    return m / norms


def normalize_vector(v):
    """
    Normalize a vector using CuPy.

    Parameters:
    v (cupy.ndarray): Vector to normalize.

    Returns:
    cupy.ndarray: Normalized vector.
    """
    v = cp.asarray(v)
    norm = cp.linalg.norm(v)
    return v / norm

#### Nodes Preprocessing (including re-embedding)

In [8]:
# Convert the list of embeddings to a 2D CuPy array (N x D)
graph_desc_x_cp = cp.asarray(graph["desc_x"].tolist())

# Normalize all rows (vectors) using broadcasting
graph_desc_x_normalized = normalize_matrix(graph_desc_x_cp, axis=1)
graph_x_normalized = [normalize_vector(v).tolist() for v in graph["x"]]

# Convert the graph nodes to a cudf DataFrame
nodes_df = cudf.DataFrame(
    {
        "node_id": graph["node_id"],
        "node_name": graph["node_name"],
        "node_type": graph["node_type"],
        "desc": graph["desc"],
        "desc_emb": graph_desc_x_normalized.tolist(),
        "feat": graph["enriched_node"],
        "feat_emb": graph_x_normalized,
    }
)
nodes_df.reset_index(inplace=True)
nodes_df.rename(columns={"index": "node_index"}, inplace=True)
nodes_df.head(3)

Unnamed: 0,node_index,node_id,node_name,node_type,desc,desc_emb,feat,feat_emb
0,0,SMAD3_(144),SMAD3,gene/protein,SMAD3 belongs to gene/protein node. SMAD3 is S...,"[0.02974936784975063, 0.05350021171537046, -0....",MSSILPFTPPIVKRLLGWKKGEQNGQEEKWCEKAVKSLVKKLKKTG...,"[-0.0010794274069904548, -0.0028632148270051, ..."
1,1,IL10RB_(179),IL10RB,gene/protein,IL10RB belongs to gene/protein node. IL10RB is...,"[0.02842173040130417, 0.01986006372730412, -0....",MAWSLGSWLGGCLLVSALGMVPPPENVRMNSVNFKNILQWESPAFA...,"[-0.007157766077247574, 0.006195289622587354, ..."
2,2,GNA12_(192),GNA12,gene/protein,GNA12 belongs to gene/protein node. GNA12 is G...,"[0.003668847841835145, 0.051380571197126614, -...",MSGVVRTLSRCLLPAEAGGARERRAGSGARDAEREARRRSRDIDAL...,"[-0.001562959383761, -0.01338132129666802, -0...."


#### Optional: Nodes Re-embedding usin OpenAI

You may skip this section if you want to stick with 'nomic-embed-text' embedding model as default one.

In [9]:
# Checking embedding dimensions before proceeding
emb_dim = len(nodes_df["desc_emb"].iloc[0])
print(f"Embedding dimension: {emb_dim}")

Embedding dimension: 768


In [11]:
# For textual data, we will re-embed the descriptions using OpenAI embeddings
mini_batch_size = 100
desc_embeddings = []
for i in tqdm(
    range(0, nodes_df.shape[0], mini_batch_size), desc="Re-embedding descriptions"
):
    batch = nodes_df["desc"].to_pandas().tolist()[i : i + mini_batch_size]
    embeddings = emb_model.embed_documents(batch)
    desc_embeddings.extend(embeddings)
nodes_df["desc_emb"] = desc_embeddings

Re-embedding descriptions: 100%|██████████| 30/30 [00:27<00:00,  1.07it/s]


In [12]:
# Checking embeddings dimensions after re-embedding
emb_dim = len(nodes_df["desc_emb"].iloc[0])
print(f"Embedding dimension after re-embedding: {emb_dim}")

Embedding dimension after re-embedding: 1536


In [15]:
# Get the text-based nodes for re-embedding
text_based_df = nodes_df[
    nodes_df.node_type.isin(
        ["disease", "biological_process", "cellular_component", "molecular_function"]
    )
]

for nt, text_based_df_ in text_based_df.groupby("node_type"):
    print(f"Re-embedding {nt} nodes")
    # Checking embedding dimensions before proceeding
    emb_dim = len(text_based_df_["feat_emb"].iloc[0])
    print(f"Embedding dimension: {emb_dim}")
    print("---")

Re-embedding biological_process nodes
Embedding dimension: 768
---
Re-embedding cellular_component nodes
Embedding dimension: 768
---
Re-embedding disease nodes
Embedding dimension: 768
---
Re-embedding molecular_function nodes
Embedding dimension: 768
---


In [17]:
# Update textual pre-loaded embeddings with OpenAI embeddings
# Since the records of nodes has large amount of data, we will split them into mini-batches
mini_batch_size = 100
text_node_indexes = []
text_node_embeddings = []
for i in tqdm(
    range(0, text_based_df.shape[0], mini_batch_size), desc="Re-embedding text nodes"
):
    outputs = emb_model.embed_documents(
        text_based_df.to_pandas().feat.values.tolist()[i : i + mini_batch_size]
    )
    text_node_indexes.extend(
        text_based_df.to_pandas().node_index.values.tolist()[i : i + mini_batch_size]
    )
    text_node_embeddings.extend(outputs)
dic_text_embeddings = dict(zip(text_node_indexes, text_node_embeddings, strict=False))

Re-embedding text nodes: 100%|██████████| 22/22 [00:58<00:00,  2.65s/it]


In [19]:
# Replace the embeddings of the nodes with the updated embeddings for text-based nodes
nodes_df["feat_emb"] = nodes_df.to_pandas().apply(
    lambda x: dic_text_embeddings[x["node_index"]]
    if x["node_index"] in dic_text_embeddings
    else x["feat_emb"],
    axis=1,
)

In [20]:
# Get the text-based nodes for re-embedding
text_based_df = nodes_df[
    nodes_df.node_type.isin(
        ["disease", "biological_process", "cellular_component", "molecular_function"]
    )
]

for nt, text_based_df_ in text_based_df.groupby("node_type"):
    print(f"Re-embedding {nt} nodes")
    # Checking embedding dimensions before proceeding
    emb_dim = len(text_based_df_["feat_emb"].iloc[0])
    print(f"Embedding dimension: {emb_dim}")
    print("---")

Re-embedding biological_process nodes
Embedding dimension: 1536
---
Re-embedding cellular_component nodes
Embedding dimension: 1536
---
Re-embedding disease nodes
Embedding dimension: 1536
---
Re-embedding molecular_function nodes
Embedding dimension: 1536
---


#### Edges Preprocessing (including re-embedding)

In [21]:
# Convert the list of edge embeddings to a 2D CuPy array (M x D)
graph_edge_attr_cp = cp.asarray(graph["edge_attr"].tolist())

# Normalize all rows (vectors) using broadcasting
graph_edge_attr_normalized = normalize_matrix(graph_edge_attr_cp, axis=1)

# Convert the graph edges to a cudf DataFrame
edges_df = cudf.DataFrame(
    {
        "triplet_index": graph["triplet_index"],
        "head_id": graph["head_id"],
        "head_name": graph["head_name"],
        "tail_id": graph["tail_id"],
        "tail_name": graph["tail_name"],
        "display_relation": graph["display_relation"],
        "edge_type": graph["edge_type"],
        "edge_type_str": ["|".join(e) for e in graph["edge_type"]],
        "feat": graph["enriched_edge"],
        "edge_emb": graph_edge_attr_normalized.tolist(),
    }
)
edges_df = edges_df.merge(
    nodes_df[["node_index", "node_id"]],
    left_on="head_id",
    right_on="node_id",
    how="left",
)
edges_df.rename(columns={"node_index": "head_index"}, inplace=True)
edges_df.drop(columns=["node_id"], inplace=True)
edges_df = edges_df.merge(
    nodes_df[["node_index", "node_id"]],
    left_on="tail_id",
    right_on="node_id",
    how="left",
)
edges_df.rename(columns={"node_index": "tail_index"}, inplace=True)
edges_df.drop(columns=["node_id"], inplace=True)
edges_df.head(3)

Unnamed: 0,triplet_index,head_id,head_name,tail_id,tail_name,display_relation,edge_type,edge_type_str,feat,edge_emb,head_index,tail_index
0,8602,cytokine-mediated signaling pathway_(47242),cytokine-mediated signaling pathway,IL10RB_(179),IL10RB,interacts with,"[biological_process, interacts with, gene/prot...",biological_process|interacts with|gene/protein,cytokine-mediated signaling pathway (biologica...,"[0.016838406414606846, 0.019238545922865967, -...",1455,1
1,8603,cytokine-mediated signaling pathway_(47242),cytokine-mediated signaling pathway,IL12B_(6168),IL12B,interacts with,"[biological_process, interacts with, gene/prot...",biological_process|interacts with|gene/protein,cytokine-mediated signaling pathway (biologica...,"[0.018197947379867397, 0.03141968316046658, -0...",1455,59
2,8604,cytokine-mediated signaling pathway_(47242),cytokine-mediated signaling pathway,IRF5_(3646),IRF5,interacts with,"[biological_process, interacts with, gene/prot...",biological_process|interacts with|gene/protein,cytokine-mediated signaling pathway (biologica...,"[0.018029207941198132, 0.019414354880667273, -...",1455,46


#### Optional: Edges Re-embedding using OpenAI

In [22]:
# Checking embeddings dimensions after re-embedding
emb_dim = len(edges_df["edge_emb"].iloc[0])
print(f"Embedding dimension after re-embedding: {emb_dim}")

Embedding dimension after re-embedding: 768


In [23]:
# For textual data, we will re-embed the descriptions using OpenAI embeddings
mini_batch_size = 100
edge_embeddings = []
for i in tqdm(range(0, edges_df.shape[0], mini_batch_size), desc="Re-embedding edges"):
    batch = edges_df["feat"].to_pandas().tolist()[i : i + mini_batch_size]
    embeddings = emb_model.embed_documents(batch)
    edge_embeddings.extend(embeddings)
edges_df["edge_emb"] = edge_embeddings

Re-embedding edges: 100%|██████████| 113/113 [01:37<00:00,  1.16it/s]


In [24]:
# Checking embeddings dimensions after re-embedding
emb_dim = len(edges_df["edge_emb"].iloc[0])
print(f"Embedding dimension after re-embedding: {emb_dim}")

Embedding dimension after re-embedding: 1536


### Storing dataframes

In [25]:
# Store the DataFrame into compressed parquet files
storage_path = (
    "../../../aiagents4pharma/talk2knowledgegraphs/tests/files/biobridge_multimodal/"
)
os.makedirs(storage_path, exist_ok=True)

# Nodes enrichment & embeddings
nodes_enrichment = (
    nodes_df[
        [
            "node_index",
            "node_id",
            "node_name",
            "node_type",
            "desc",
            "feat",
        ]
    ]
    .to_pandas()
    .copy()
)
os.makedirs(os.path.join(storage_path, "nodes", "enrichment"), exist_ok=True)
for nt, nodes_df_ in nodes_enrichment.groupby("node_type"):
    print(nt, nodes_df_.shape)
    nodes_df_.to_parquet(
        os.path.join(
            storage_path, "nodes", "enrichment", f"{nt.replace('/', '_')}.parquet.gzip"
        ),
        compression="gzip",
        index=False,
    )
print("Nodes enrichment saved.")
print("---")

nodes_embeddings = (
    nodes_df[["node_index", "node_id", "node_type", "desc_emb", "feat_emb"]]
    .to_pandas()
    .copy()
)
os.makedirs(os.path.join(storage_path, "nodes", "embedding"), exist_ok=True)
for nt, nodes_df_ in nodes_embeddings.groupby("node_type"):
    print(nt, nodes_df_.shape)
    nodes_df_[["node_index", "node_id", "desc_emb", "feat_emb"]].to_parquet(
        os.path.join(
            storage_path, "nodes", "embedding", f"{nt.replace('/', '_')}.parquet.gzip"
        ),
        compression="gzip",
        index=False,
    )
print("Nodes embeddings saved.")

biological_process (1615, 6)
cellular_component (202, 6)
disease (7, 6)
drug (748, 6)
gene/protein (102, 6)
molecular_function (317, 6)
Nodes enrichment saved.
---
biological_process (1615, 5)
cellular_component (202, 5)
disease (7, 5)
drug (748, 5)
gene/protein (102, 5)
molecular_function (317, 5)
Nodes embeddings saved.


In [26]:
# Edges enrichment & embeddings
edges_enrichment = (
    edges_df[
        [
            "triplet_index",
            "head_id",
            "head_index",
            "tail_id",
            "tail_index",
            "edge_type",
            "edge_type_str",
            "display_relation",
            "feat",
        ]
    ]
    .to_pandas()
    .copy()
)
os.makedirs(os.path.join(storage_path, "edges", "enrichment"), exist_ok=True)
edges_enrichment.to_parquet(
    os.path.join(storage_path, "edges", "enrichment", "edges.parquet.gzip"),
    compression="gzip",
    index=False,
)
print("Edges enrichment saved.")
print("---")

edges_embeddings = (
    edges_df[["triplet_index", "head_index", "tail_index", "edge_emb"]]
    .to_pandas()
    .copy()
)
os.makedirs(os.path.join(storage_path, "edges", "embedding"), exist_ok=True)
chunk_size = 1000
for i in range(0, edges_embeddings.shape[0], chunk_size):
    et = f"edges_{i // chunk_size}"
    edges_embeddings_chunk = edges_embeddings.iloc[i : i + chunk_size]
    # Save each chunk to a separate parquet file
    edges_embeddings_chunk.to_parquet(
        os.path.join(storage_path, "edges", "embedding", f"{et}.parquet.gzip"),
        compression="gzip",
        index=False,
    )
print("Edges embeddings saved.")

Edges enrichment saved.
---
Edges embeddings saved.


### Loading dataframes

In [27]:
# Set storage path for the dataframes
storage_path = (
    "../../../aiagents4pharma/talk2knowledgegraphs/tests/files/biobridge_multimodal/"
)

In [28]:
# Loop over nodes and edges
graph_dict = {}
for element in ["nodes", "edges"]:
    # Make an empty dictionary for each folder
    graph_dict[element] = {}
    for stage in ["enrichment", "embedding"]:
        print(element, stage)
        # Create the file pattern for the current subfolder
        file_list = glob.glob(
            os.path.join(storage_path, element, stage, "*.parquet.gzip")
        )
        print(file_list)
        # Read and concatenate all dataframes in the folder
        # Except the edges embedding, which is too large to read in one go
        # We are using a chunk size to read the edges embedding in smaller parts instead
        if element == "edges" and stage == "embedding":
            # For edges embedding, only read two columns: triplet_index and edge_emb
            # graph_dict[element][stage] = cudf.concat([cudf.read_parquet(f, columns=["triplet_index", "edge_emb"]) for f in file_list[:2]], ignore_index=True)
            # Loop by chunks
            # file_list = file_list[:2]
            chunk_size = 5
            graph_dict[element][stage] = []
            for i in range(0, len(file_list), chunk_size):
                chunk_files = file_list[i : i + chunk_size]
                chunk_df = cudf.concat(
                    [
                        cudf.read_parquet(f, columns=["triplet_index", "edge_emb"])
                        for f in chunk_files
                    ],
                    ignore_index=True,
                )
                graph_dict[element][stage].append(chunk_df)
        else:
            # For nodes and edges enrichment, read and concatenate all dataframes in the folder
            # This includes the nodes embedding, which is small enough to read in one go
            graph_dict[element][stage] = cudf.concat(
                [cudf.read_parquet(f) for f in file_list], ignore_index=True
            )

nodes enrichment
['../../../aiagents4pharma/talk2knowledgegraphs/tests/files/biobridge_multimodal/nodes/enrichment/gene_protein.parquet.gzip', '../../../aiagents4pharma/talk2knowledgegraphs/tests/files/biobridge_multimodal/nodes/enrichment/disease.parquet.gzip', '../../../aiagents4pharma/talk2knowledgegraphs/tests/files/biobridge_multimodal/nodes/enrichment/biological_process.parquet.gzip', '../../../aiagents4pharma/talk2knowledgegraphs/tests/files/biobridge_multimodal/nodes/enrichment/cellular_component.parquet.gzip', '../../../aiagents4pharma/talk2knowledgegraphs/tests/files/biobridge_multimodal/nodes/enrichment/molecular_function.parquet.gzip', '../../../aiagents4pharma/talk2knowledgegraphs/tests/files/biobridge_multimodal/nodes/enrichment/drug.parquet.gzip']
nodes embedding
['../../../aiagents4pharma/talk2knowledgegraphs/tests/files/biobridge_multimodal/nodes/embedding/gene_protein.parquet.gzip', '../../../aiagents4pharma/talk2knowledgegraphs/tests/files/biobridge_multimodal/nodes/

In [29]:
# Get nodes enrichment and embedding dataframes
nodes_enrichment_df = graph_dict["nodes"]["enrichment"]
nodes_embedding_df = graph_dict["nodes"]["embedding"]

# Get edges enrichment and embedding dataframes
edges_enrichment_df = graph_dict["edges"]["enrichment"]
edges_embedding_df = graph_dict["edges"][
    "embedding"
]  # !!consisted of a list of dataframes!!

In [45]:
# Merge nodes enrichment and embedding dataframes
merged_nodes_df = nodes_enrichment_df.merge(
    nodes_embedding_df[["node_id", "desc_emb", "feat_emb"]], on="node_id", how="left"
)
# del nodes_enrichment_df, nodes_embedding_df  # Free memory

In [30]:
# Check dataframe
nodes_enrichment_df.head(3)

Unnamed: 0,node_index,node_id,node_name,node_type,desc,feat
0,0,SMAD3_(144),SMAD3,gene/protein,SMAD3 belongs to gene/protein node. SMAD3 is S...,MSSILPFTPPIVKRLLGWKKGEQNGQEEKWCEKAVKSLVKKLKKTG...
1,1,IL10RB_(179),IL10RB,gene/protein,IL10RB belongs to gene/protein node. IL10RB is...,MAWSLGSWLGGCLLVSALGMVPPPENVRMNSVNFKNILQWESPAFA...
2,2,GNA12_(192),GNA12,gene/protein,GNA12 belongs to gene/protein node. GNA12 is G...,MSGVVRTLSRCLLPAEAGGARERRAGSGARDAEREARRRSRDIDAL...


In [31]:
# Check dataframe
nodes_embedding_df.head(3)

Unnamed: 0,node_index,node_id,desc_emb,feat_emb
0,0,SMAD3_(144),"[-0.03699171170592308, -0.005479035433381796, ...","[-0.0010794274069904548, -0.0028632148270051, ..."
1,1,IL10RB_(179),"[-0.02927332930266857, -0.0068625640124082565,...","[-0.007157766077247574, 0.006195289622587354, ..."
2,2,GNA12_(192),"[-0.02188265137374401, -0.01718498021364212, -...","[-0.001562959383761, -0.01338132129666802, -0...."


In [32]:
# Check dataframe
edges_enrichment_df.head(3)

Unnamed: 0,triplet_index,head_id,head_index,tail_id,tail_index,edge_type,edge_type_str,display_relation,feat
0,8602,cytokine-mediated signaling pathway_(47242),1455,IL10RB_(179),1,"[biological_process, interacts with, gene/prot...",biological_process|interacts with|gene/protein,interacts with,cytokine-mediated signaling pathway (biologica...
1,8603,cytokine-mediated signaling pathway_(47242),1455,IL12B_(6168),59,"[biological_process, interacts with, gene/prot...",biological_process|interacts with|gene/protein,interacts with,cytokine-mediated signaling pathway (biologica...
2,8604,cytokine-mediated signaling pathway_(47242),1455,IRF5_(3646),46,"[biological_process, interacts with, gene/prot...",biological_process|interacts with|gene/protein,interacts with,cytokine-mediated signaling pathway (biologica...


In [33]:
# Check dataframes
len(edges_embedding_df)

3

In [34]:
# Check the first chunk of edges embedding
edges_embedding_df[0].head(3)

Unnamed: 0,triplet_index,edge_emb
0,7408,"[-0.030603667721152306, -0.020534928888082504,..."
1,7071,"[-0.028667431324720383, -0.004508184734731913,..."
2,7072,"[-0.04058527573943138, -0.010810465551912785, ..."


### Setup Milvus Database

In [91]:
# Configuration for Milvus
milvus_host = "localhost"
milvus_port = "19530"
milvus_uri = "http://localhost:19530"
milvus_token = "root:Milvus"
milvus_user = "root"
milvus_password = "Milvus"
milvus_database = "t2kg_primekg"

In [None]:
# Connect to Milvus
connections.connect(
    alias="default",
    host=milvus_host,
    port=milvus_port,
    user=milvus_user,
    password=milvus_password,
)

In [42]:
# Check if the database exists, create if it doesn't
if milvus_database not in db.list_database():
    db.create_database(milvus_database)

# Switch to the desired database
db.using_database(milvus_database)

In [43]:
# List all collections
for coll in utility.list_collections():
    print(f"Collection: {coll}")

    # Load the collection to get stats
    collection = Collection(name=coll)
    print(collection.num_entities)

    # Drop the collection if it exists
    if utility.has_collection(coll):
        print(f"Dropping collection: {coll}")
        utility.drop_collection(coll)

In [44]:
# A helper function to chunk the data into smaller parts
# Utility: chunk generator
def chunked(data_list, chunk_size):
    for i in range(0, len(data_list), chunk_size):
        yield data_list[i : i + chunk_size]

#### Building Node Collection (Description Embedding)

In [46]:
%%time

# Configuration for Milvus collection
node_coll_name = f"{milvus_database}_nodes"

# Define schema for the collection
# Leave out the feat and feat_emb fields for now
desc_emb_dim = len(merged_nodes_df.iloc[0]["desc_emb"].to_arrow().to_pylist()[0])
node_fields = [
    FieldSchema(name="node_index", dtype=DataType.INT64, is_primary=True),
    FieldSchema(name="node_id", dtype=DataType.VARCHAR, max_length=1024),
    FieldSchema(
        name="node_name",
        dtype=DataType.VARCHAR,
        max_length=1024,
        enable_analyzer=True,
        enable_match=True,
    ),
    FieldSchema(
        name="node_type",
        dtype=DataType.VARCHAR,
        max_length=1024,
        enable_analyzer=True,
        enable_match=True,
    ),
    FieldSchema(
        name="desc",
        dtype=DataType.VARCHAR,
        max_length=40960,
        enable_analyzer=True,
        enable_match=True,
    ),
    FieldSchema(name="desc_emb", dtype=DataType.FLOAT_VECTOR, dim=desc_emb_dim),
]
schema = CollectionSchema(
    fields=node_fields, description=f"Schema for collection {node_coll_name}"
)

# Create collection if it doesn't exist
if not utility.has_collection(node_coll_name):
    collection = Collection(name=node_coll_name, schema=schema)
else:
    collection = Collection(name=node_coll_name)

# Create indexes
collection.create_index(
    field_name="node_index",
    index_params={"index_type": "STL_SORT"},  # STL_SORT
    index_name="node_index_index",
)
# Create index for node_name, node_type, desc fields (inverted)
collection.create_index(
    field_name="node_name",
    index_params={"index_type": "INVERTED"},
    index_name="node_name_index",
)
collection.create_index(
    field_name="node_type",
    index_params={"index_type": "INVERTED"},
    index_name="node_type_index",
)
collection.create_index(
    field_name="desc", index_params={"index_type": "INVERTED"}, index_name="desc_index"
)
collection.create_index(
    field_name="desc_emb",
    index_params={"index_type": "GPU_CAGRA", "metric_type": "IP"},  # AUTOINDEX
    index_name="desc_emb_index",
)

# Prepare data for insertion
# Normalize the embeddings
graph_desc_emb_cp = (
    cp.asarray(merged_nodes_df["desc_emb"].list.leaves)
    .astype(cp.float32)
    .reshape(merged_nodes_df.shape[0], -1)
)
graph_desc_emb_norm = normalize_matrix(graph_desc_emb_cp, axis=1)
data = [
    merged_nodes_df["node_index"].to_arrow().to_pylist(),
    merged_nodes_df["node_id"].to_arrow().to_pylist(),
    merged_nodes_df["node_name"].to_arrow().to_pylist(),
    merged_nodes_df["node_type"].to_arrow().to_pylist(),
    merged_nodes_df["desc"].to_arrow().to_pylist(),
    graph_desc_emb_norm.tolist(),  # Use normalized embeddings
]

# Insert data in batches
batch_size = 500
total = len(data[0])
for i in tqdm(range(0, total, batch_size)):
    batch = [col[i : i + batch_size] for col in data]
    collection.insert(batch)

# Flush to persist data
collection.flush()

# Get collection stats
print(collection.num_entities)

100%|██████████| 6/6 [00:04<00:00,  1.46it/s]


2991
CPU times: user 1.3 s, sys: 137 ms, total: 1.44 s
Wall time: 8.61 s


In [47]:
# List all collections
for coll in utility.list_collections():
    print(f"Collection: {coll}")

    # Load the collection to get stats
    collection = Collection(name=coll)
    print(collection.num_entities)

Collection: t2kg_primekg_nodes
2991


In [48]:
%%time

# Assume node_coll_name is defined and collection exists
collection = Collection(node_coll_name)

# Load the collection into memory before query
collection.load()

# Query by expr on node_index
expr = "node_index == 0"
output_fields = ["node_index", "node_id", "node_name", "node_type", "desc", "desc_emb"]

results = collection.query(expr, output_fields=output_fields)

print(results)

data: ['{\'node_index\': 0, \'node_id\': \'SMAD3_(144)\', \'node_name\': \'SMAD3\', \'node_type\': \'gene/protein\', \'desc\': "SMAD3 belongs to gene/protein node. SMAD3 is SMAD family member 3. The SMAD family of proteins are a group of intracellular signal transducer proteins similar to the gene products of the Drosophila gene \'mothers against decapentaplegic\' (Mad) and the C. elegans gene Sma. The SMAD3 protein functions in the transforming growth factor-beta signaling pathway, and transmits signals from the cell surface to the nucleus, regulating gene activity and cell proliferation. This protein forms a complex with other SMAD proteins and binds DNA, functioning both as a transcription factor and tumor suppressor. Mutations in this gene are associated with aneurysms-osteoarthritis syndrome and Loeys-Dietz Syndrome 3. [provided by RefSeq, May 2022].", \'desc_emb\': [-0.036991715, -0.005479036, -0.03023007, -0.012918158, -0.02741491, 0.025599528, -0.024547132, 0.011050156, -0.0282

In [49]:
%%time

# Assume node_coll_name is defined and collection exists
collection = Collection(node_coll_name)

# Load the collection into memory before query
collection.load()

# Query by expr on node_index
expr = "node_index in [0, 1]"
output_fields = ["node_index", "node_id", "node_name", "node_type", "desc", "desc_emb"]

results = collection.query(expr, output_fields=output_fields)

print(results)

data: ['{\'node_name\': \'SMAD3\', \'node_type\': \'gene/protein\', \'desc\': "SMAD3 belongs to gene/protein node. SMAD3 is SMAD family member 3. The SMAD family of proteins are a group of intracellular signal transducer proteins similar to the gene products of the Drosophila gene \'mothers against decapentaplegic\' (Mad) and the C. elegans gene Sma. The SMAD3 protein functions in the transforming growth factor-beta signaling pathway, and transmits signals from the cell surface to the nucleus, regulating gene activity and cell proliferation. This protein forms a complex with other SMAD proteins and binds DNA, functioning both as a transcription factor and tumor suppressor. Mutations in this gene are associated with aneurysms-osteoarthritis syndrome and Loeys-Dietz Syndrome 3. [provided by RefSeq, May 2022].", \'desc_emb\': [-0.036991715, -0.005479036, -0.03023007, -0.012918158, -0.02741491, 0.025599528, -0.024547132, 0.011050156, -0.028256828, -0.0019304885, 0.034308106, 0.0019403548, 

In [50]:
%%time

# Assume node_coll_name is defined and collection exists
collection = Collection(node_coll_name)

# Load the collection into memory before query
collection.load()

# Vector similarity search in Milvus
vector_to_search = nodes_df["desc_emb"].iloc[0]
search_params = {"metric_type": "IP"}
results = collection.search(
    data=[vector_to_search],
    anns_field="desc_emb",
    param=search_params,
    limit=10,
    output_fields=["node_id", "node_name"],
)

results

CPU times: user 5.9 ms, sys: 2.02 ms, total: 7.91 ms
Wall time: 11.6 ms


data: [[{'node_index': 0, 'distance': 0.9999999403953552, 'entity': {'node_id': 'SMAD3_(144)', 'node_name': 'SMAD3'}}, {'node_index': 2182, 'distance': 0.8984372019767761, 'entity': {'node_id': 'SMAD protein signal transduction_(101792)', 'node_name': 'SMAD protein signal transduction'}}, {'node_index': 2023, 'distance': 0.8962712287902832, 'entity': {'node_id': 'SMAD protein complex_(55601)', 'node_name': 'SMAD protein complex'}}, {'node_index': 2913, 'distance': 0.8886438012123108, 'entity': {'node_id': 'heteromeric SMAD protein complex_(124869)', 'node_name': 'heteromeric SMAD protein complex'}}, {'node_index': 1009, 'distance': 0.8831610083580017, 'entity': {'node_id': 'regulation of SMAD protein signal transduction_(41433)', 'node_name': 'regulation of SMAD protein signal transduction'}}, {'node_index': 2152, 'distance': 0.8783092498779297, 'entity': {'node_id': 'positive regulation of SMAD protein signal transduction_(101088)', 'node_name': 'positive regulation of SMAD protein si

In [52]:
nodes_df.loc[0]

Unnamed: 0,node_index,node_id,node_name,node_type,desc,desc_emb,feat,feat_emb
0,0,SMAD3_(144),SMAD3,gene/protein,SMAD3 belongs to gene/protein node. SMAD3 is S...,"[-0.03699171170592308, -0.005479035433381796, ...",MSSILPFTPPIVKRLLGWKKGEQNGQEEKWCEKAVKSLVKKLKKTG...,"[-0.0010794274069904548, -0.0028632148270051, ..."


#### Building Node Collection (Node Type-specific Embedding)

Note that nodes information of the PrimeKG data is different for each node type, 
we are going to build a separate collection for each node type.

We will use the node type as the collection name.

In [53]:
%%time

# Loop over group enrichment nodes by node_type
for node_type, nodes_df in tqdm(merged_nodes_df.groupby("node_type")):
    print(f"Processing node type: {node_type}")

    # Milvus collection name for this node_type
    node_coll_name = f"{milvus_database}_nodes_{node_type.replace('/', '_')}"

    # Define collection schema
    desc_emb_dim = len(nodes_df.iloc[0]["desc_emb"].to_arrow().to_pylist()[0])
    feat_emb_dim = len(nodes_df.iloc[0]["feat_emb"].to_arrow().to_pylist()[0])
    node_fields = [
        FieldSchema(
            name="node_index", dtype=DataType.INT64, is_primary=True, auto_id=False
        ),
        FieldSchema(name="node_id", dtype=DataType.VARCHAR, max_length=1024),
        FieldSchema(
            name="node_name",
            dtype=DataType.VARCHAR,
            max_length=1024,
            enable_analyzer=True,
            enable_match=True,
        ),
        FieldSchema(
            name="node_type",
            dtype=DataType.VARCHAR,
            max_length=1024,
            enable_analyzer=True,
            enable_match=True,
        ),
        FieldSchema(
            name="desc",
            dtype=DataType.VARCHAR,
            max_length=40960,
            enable_analyzer=True,
            enable_match=True,
        ),
        FieldSchema(name="desc_emb", dtype=DataType.FLOAT_VECTOR, dim=desc_emb_dim),
        FieldSchema(
            name="feat",
            dtype=DataType.VARCHAR,
            max_length=40960,
            enable_analyzer=True,
            enable_match=True,
        ),
        FieldSchema(name="feat_emb", dtype=DataType.FLOAT_VECTOR, dim=feat_emb_dim),
    ]
    schema = CollectionSchema(
        fields=node_fields, description=f"schema for collection {node_coll_name}"
    )

    # Create collection if not exists
    if not utility.has_collection(node_coll_name):
        collection = Collection(name=node_coll_name, schema=schema)
    else:
        collection = Collection(name=node_coll_name)

    # Create index for node_index field (scalar)
    collection.create_index(
        field_name="node_index",
        index_params={"index_type": "STL_SORT"},
        index_name="node_index_index",
    )
    # Create index for node_name, node_type, desc fields (inverted)
    collection.create_index(
        field_name="node_name",
        index_params={"index_type": "INVERTED"},
        index_name="node_name_index",
    )
    collection.create_index(
        field_name="node_type",
        index_params={"index_type": "INVERTED"},
        index_name="node_type_index",
    )
    collection.create_index(
        field_name="desc",
        index_params={"index_type": "INVERTED"},
        index_name="desc_index",
    )
    collection.create_index(
        field_name="desc_emb",
        index_params={"index_type": "GPU_CAGRA", "metric_type": "IP"},  # AUTOINDEX
        index_name="desc_emb_index",
    )
    # Create index for feat_emb (vector)
    collection.create_index(
        field_name="feat_emb",
        index_params={"index_type": "GPU_CAGRA", "metric_type": "IP"},  # AUTOINDEX
        index_name="feat_emb_index",
    )

    # Prepare data for insertion
    # Normalize the embeddings
    graph_desc_emb_cp = (
        cp.asarray(nodes_df["desc_emb"].list.leaves)
        .astype(cp.float32)
        .reshape(nodes_df.shape[0], -1)
    )
    graph_desc_emb_norm = normalize_matrix(graph_desc_emb_cp, axis=1)
    graph_feat_emb_cp = (
        cp.asarray(nodes_df["feat_emb"].list.leaves)
        .astype(cp.float32)
        .reshape(nodes_df.shape[0], -1)
    )
    graph_feat_emb_norm = normalize_matrix(graph_feat_emb_cp, axis=1)
    # Columns must be lists of values in order matching schema fields
    data = [
        nodes_df["node_index"].to_arrow().to_pylist(),
        nodes_df["node_id"].to_arrow().to_pylist(),
        nodes_df["node_name"].to_arrow().to_pylist(),
        nodes_df["node_type"].to_arrow().to_pylist(),
        nodes_df["desc"].to_arrow().to_pylist(),
        graph_desc_emb_norm.tolist(),  # Use normalized embeddings
        nodes_df["feat"].to_arrow().to_pylist(),
        graph_feat_emb_norm.tolist(),  # Use normalized embeddings
    ]

    # Batch insert data in chunks
    batch_size = 500
    total_rows = len(data[0])
    for i in tqdm(range(0, total_rows, batch_size)):
        batch = [col[i : i + batch_size] for col in data]
        collection.insert(batch)

    # Flush the collection to ensure data is persisted
    collection.flush()

    # Print collection stats (number of entities and segment info)
    stats = collection.num_entities
    print(f"Collection {node_coll_name} stats:")
    print(stats)

  0%|          | 0/6 [00:00<?, ?it/s]

Processing node type: biological_process


100%|██████████| 4/4 [00:02<00:00,  1.67it/s]
 17%|█▋        | 1/6 [00:06<00:32,  6.52s/it]

Collection t2kg_primekg_nodes_biological_process stats:
1615
Processing node type: cellular_component


100%|██████████| 1/1 [00:00<00:00,  1.32it/s]
 33%|███▎      | 2/6 [00:11<00:22,  5.52s/it]

Collection t2kg_primekg_nodes_cellular_component stats:
202
Processing node type: disease


100%|██████████| 1/1 [00:00<00:00,  2.79it/s]
 50%|█████     | 3/6 [00:15<00:14,  5.00s/it]

Collection t2kg_primekg_nodes_disease stats:
7
Processing node type: drug


100%|██████████| 2/2 [00:01<00:00,  1.29it/s]
 67%|██████▋   | 4/6 [00:21<00:10,  5.49s/it]

Collection t2kg_primekg_nodes_drug stats:
748
Processing node type: gene/protein


100%|██████████| 1/1 [00:00<00:00,  2.78it/s]
 83%|████████▎ | 5/6 [00:26<00:05,  5.10s/it]

Collection t2kg_primekg_nodes_gene_protein stats:
102
Processing node type: molecular_function


100%|██████████| 1/1 [00:00<00:00,  1.27it/s]
100%|██████████| 6/6 [00:31<00:00,  5.20s/it]

Collection t2kg_primekg_nodes_molecular_function stats:
317
CPU times: user 1.77 s, sys: 238 ms, total: 2 s
Wall time: 31.2 s





In [54]:
# List all collections
for coll in utility.list_collections():
    print(f"Collection: {coll}")

    # Load the collection to get stats
    collection = Collection(name=coll)
    print(collection.num_entities)

Collection: t2kg_primekg_nodes
2991
Collection: t2kg_primekg_nodes_biological_process
1615
Collection: t2kg_primekg_nodes_cellular_component
202
Collection: t2kg_primekg_nodes_molecular_function
317
Collection: t2kg_primekg_nodes_disease
7
Collection: t2kg_primekg_nodes_gene_protein
102
Collection: t2kg_primekg_nodes_drug
748


In [62]:
merged_nodes_df[merged_nodes_df.node_type == "gene/protein"]

Unnamed: 0,node_index,node_id,node_name,node_type,desc,feat,desc_emb,feat_emb
48,16,IL1R2_(1654),IL1R2,gene/protein,IL1R2 belongs to gene/protein node. IL1R2 is i...,MLRLYVLVMGVSAFTLQPAAHTGAARSCRFRGRHYKREFRLEGEPV...,"[-0.02899833954870701, 0.0021472955122590065, ...","[-0.0014144123640997008, -0.001413213325978187..."
49,17,HERC2_(1777),HERC2,gene/protein,HERC2 belongs to gene/protein node. HERC2 is H...,MPSESFCLAAQARLDSKWLKTDIQLAFTRDGLCGLWNEMVKDGEIV...,"[-0.014023775234818459, 0.009176542051136494, ...","[-0.007428297050351188, -0.010088199667370488,..."
50,18,FCGR2A_(1990),FCGR2A,gene/protein,FCGR2A belongs to gene/protein node. FCGR2A is...,MTMETQMSQNVCPRNLWLLQPLTVLLLLASADSQAAAPPKAVLKLE...,"[-0.030694980174303055, 0.008019879460334778, ...","[0.005696470124890766, -0.006343481862099802, ..."
51,19,CXCR1_(2012),CXCR1,gene/protein,CXCR1 belongs to gene/protein node. CXCR1 is C...,MSNITDPQMWDFDDLNFTGMPPADEDYSPCMLETETLNKYVVIIAY...,"[-0.027637897059321404, -0.0004926534020341933...","[0.010659271964210543, -0.0016738061039453167,..."
52,20,FN1_(2057),FN1,gene/protein,FN1 belongs to gene/protein node. FN1 is fibro...,MLRGPGPGLLLLAVQCLGTAVPSTGASKSKRQAQQMVQPQSPVAVS...,"[-0.03269859775900841, 0.0019264572765678167, ...","[-0.0009853045367468072, -0.004182791243864574..."
...,...,...,...,...,...,...,...,...
1168,848,IL17REL_(34781),IL17REL,gene/protein,IL17REL belongs to gene/protein node. IL17REL ...,MSRSVLEALTSSTAMQCVPSDGCAMLLRVRASITLHERLRGLEACA...,"[-0.02713669277727604, -0.011936459690332413, ...","[0.0003768516783435937, -0.005117051265290013,..."
1169,849,TAGAP_(34814),TAGAP,gene/protein,TAGAP belongs to gene/protein node. TAGAP is T...,MKLRSSHNASKTLNANNMETLIECQSEGDIKEHPLLASCESEDSIC...,"[-0.03348768502473831, -0.009191920049488544, ...","[-0.0010800918479102228, -0.003681655277429923..."
1170,850,DENND1B_(34887),DENND1B,gene/protein,DENND1B belongs to gene/protein node. DENND1B ...,MDCRTKANPDRTFDLVLKVKCHASENEDPVVLWKFPEDFGDQEILQ...,"[-0.02666034922003746, 0.005825655534863472, -...","[-0.000950883844682143, -0.0007957585650949377..."
1171,851,IL21_(34967),IL21,gene/protein,IL21 belongs to gene/protein node. IL21 is int...,MRSSPGNMERIVICLMVIFLGTLVHKSSSQGQDRHMIRMRQLIDIV...,"[-0.02861925959587097, 0.005471138749271631, -...","[0.002036418016520262, -0.001478201039299714, ..."


In [76]:
%%time

# Assume node_coll_name is defined and collection exists
collection = Collection("t2kg_primekg_nodes_gene_protein")

# Load the collection into memory before query
collection.load()

# Vector similarity search in Milvus
vector_to_search = normalize_vector(
    merged_nodes_df[merged_nodes_df.node_type == "gene/protein"]["feat_emb"].iloc[0]
).tolist()
search_params = {"metric_type": "IP"}
results = collection.search(
    data=[vector_to_search],
    anns_field="feat_emb",
    param=search_params,
    limit=10,
    output_fields=["node_id", "node_name"],
)

results

CPU times: user 8.64 ms, sys: 53 μs, total: 8.69 ms
Wall time: 20.8 ms


data: [[{'node_index': 16, 'distance': 1.000000238418579, 'entity': {'node_id': 'IL1R2_(1654)', 'node_name': 'IL1R2'}}, {'node_index': 82, 'distance': 0.9777935147285461, 'entity': {'node_id': 'IL18RAP_(11588)', 'node_name': 'IL18RAP'}}, {'node_index': 59, 'distance': 0.9770528078079224, 'entity': {'node_id': 'IL12B_(6168)', 'node_name': 'IL12B'}}, {'node_index': 4, 'distance': 0.9742822051048279, 'entity': {'node_id': 'VCAM1_(417)', 'node_name': 'VCAM1'}}, {'node_index': 64, 'distance': 0.9739426970481873, 'entity': {'node_id': 'IL2RA_(7059)', 'node_name': 'IL2RA'}}, {'node_index': 51, 'distance': 0.9738389253616333, 'entity': {'node_id': 'ICAM1_(4968)', 'node_name': 'ICAM1'}}, {'node_index': 75, 'distance': 0.9732135534286499, 'entity': {'node_id': 'TLR9_(10113)', 'node_name': 'TLR9'}}, {'node_index': 18, 'distance': 0.9724373817443848, 'entity': {'node_id': 'FCGR2A_(1990)', 'node_name': 'FCGR2A'}}, {'node_index': 73, 'distance': 0.9714288711547852, 'entity': {'node_id': 'ICOSLG_(945

In [80]:
# Check the ground truth for the search
merged_nodes_df[merged_nodes_df.node_type == "gene/protein"].iloc[0]

Unnamed: 0,node_index,node_id,node_name,node_type,desc,feat,desc_emb,feat_emb
48,16,IL1R2_(1654),IL1R2,gene/protein,IL1R2 belongs to gene/protein node. IL1R2 is i...,MLRLYVLVMGVSAFTLQPAAHTGAARSCRFRGRHYKREFRLEGEPV...,"[-0.02899833954870701, 0.0021472955122590065, ...","[-0.0014144123640997008, -0.001413213325978187..."


In [81]:
# Get node indices from the results
[n["node_index"] for n in results[0]]

[16, 82, 59, 4, 64, 51, 75, 18, 73, 31]

In [82]:
# Get the cosine similarity scores
[n["distance"] for n in results[0]]

[1.000000238418579,
 0.9777935147285461,
 0.9770528078079224,
 0.9742822051048279,
 0.9739426970481873,
 0.9738389253616333,
 0.9732135534286499,
 0.9724373817443848,
 0.9714288711547852,
 0.9704809188842773]

#### Building Edge Collection

Subsquently, we are also building the edges collection in Milvus.

Note that the edges information of PrimeKG has massive records, so once again we are chunking the data to avoid memory issues.

In [83]:
%%time

# Define collection name
edge_coll_name = f"{milvus_database}_edges"

# Define schema
edge_fields = [
    FieldSchema(
        name="triplet_index", dtype=DataType.INT64, is_primary=True, auto_id=False
    ),
    FieldSchema(name="head_id", dtype=DataType.VARCHAR, max_length=1024),
    FieldSchema(name="head_index", dtype=DataType.INT64),
    FieldSchema(name="tail_id", dtype=DataType.VARCHAR, max_length=1024),
    FieldSchema(name="tail_index", dtype=DataType.INT64),
    FieldSchema(name="edge_type", dtype=DataType.VARCHAR, max_length=1024),
    FieldSchema(name="display_relation", dtype=DataType.VARCHAR, max_length=1024),
    FieldSchema(name="feat", dtype=DataType.VARCHAR, max_length=40960),
    FieldSchema(name="feat_emb", dtype=DataType.FLOAT_VECTOR, dim=1536),
]
edge_schema = CollectionSchema(
    fields=edge_fields, description="Schema for edges collection"
)

# Create collection if not exists
if not utility.has_collection(edge_coll_name):
    collection = Collection(name=edge_coll_name, schema=edge_schema)
else:
    collection = Collection(name=edge_coll_name)

# Create indexes
collection.create_index(
    field_name="triplet_index",
    index_params={"index_type": "STL_SORT"},
    index_name="triplet_index_index",
)
collection.create_index(
    field_name="head_index",
    index_params={"index_type": "STL_SORT"},
    index_name="head_index_index",
)
collection.create_index(
    field_name="tail_index",
    index_params={"index_type": "STL_SORT"},
    index_name="tail_index_index",
)
collection.create_index(
    field_name="feat_emb",
    index_params={"index_type": "GPU_CAGRA", "metric_type": "IP"},
    index_name="feat_emb_index",
)  # AUTOINDEX

# Iterate over chunked edges embedding df
for edges_df in tqdm(edges_embedding_df):
    # Merge enrichment with embedding
    merged_edges_df = edges_enrichment_df.merge(
        edges_df[["triplet_index", "edge_emb"]], on="triplet_index", how="inner"
    )

    # Prepare data fields in column-wise format
    # Normalize the embeddings
    edges_edge_emb_cp = (
        cp.asarray(merged_edges_df["edge_emb"].list.leaves)
        .astype(cp.float32)
        .reshape(merged_edges_df.shape[0], -1)
    )
    edges_edge_emb_norm = normalize_matrix(edges_edge_emb_cp, axis=1)
    data = [
        merged_edges_df["triplet_index"].to_arrow().to_pylist(),
        merged_edges_df["head_id"].to_arrow().to_pylist(),
        merged_edges_df["head_index"].to_arrow().to_pylist(),
        merged_edges_df["tail_id"].to_arrow().to_pylist(),
        merged_edges_df["tail_index"].to_arrow().to_pylist(),
        merged_edges_df["edge_type_str"].to_arrow().to_pylist(),
        merged_edges_df["display_relation"].to_arrow().to_pylist(),
        merged_edges_df["feat"].to_arrow().to_pylist(),
        edges_edge_emb_norm.tolist(),  # Use normalized embeddings
    ]

    # Insert in chunks
    batch_size = 500
    for i in tqdm(range(0, len(data[0]), batch_size)):
        batch_data = [d[i : i + batch_size] for d in data]
        collection.insert(batch_data)

    # Flush to ensure persistence
    collection.flush()

    # Print collection stats
    print(collection.num_entities)

    time.sleep(5)  # Sleep to avoid overwhelming the server

100%|██████████| 10/10 [00:06<00:00,  1.52it/s]


5000


100%|██████████| 9/9 [00:05<00:00,  1.62it/s]


9272


100%|██████████| 4/4 [00:02<00:00,  1.50it/s]


11272


100%|██████████| 3/3 [00:34<00:00, 11.59s/it]

CPU times: user 1.99 s, sys: 402 ms, total: 2.39 s
Wall time: 37 s





In [84]:
# List all collections
for coll in utility.list_collections():
    print(f"Collection: {coll}")

    # Load the collection to get stats
    collection = Collection(name=coll)
    print(collection.num_entities)

Collection: t2kg_primekg_edges
11272
Collection: t2kg_primekg_nodes_drug
748
Collection: t2kg_primekg_nodes
2991
Collection: t2kg_primekg_nodes_biological_process
1615
Collection: t2kg_primekg_nodes_cellular_component
202
Collection: t2kg_primekg_nodes_molecular_function
317
Collection: t2kg_primekg_nodes_disease
7
Collection: t2kg_primekg_nodes_gene_protein
102


In [85]:
%%time

# Assume node_coll_name is defined and collection exists
collection = Collection("t2kg_primekg_edges")

# Load the collection into memory before query
collection.load()

# Query by expr on triplet_index
expr = "triplet_index == 0"
output_fields = ["triplet_index", "head_id", "tail_id", "edge_type", "feat", "feat_emb"]

results = collection.query(expr, output_fields=output_fields)
results

CPU times: user 10.6 ms, sys: 5.68 ms, total: 16.3 ms
Wall time: 1.52 s


data: ["{'tail_id': 'LTF_(3233)', 'edge_type': 'drug|carrier|gene/protein', 'feat': 'Rose bengal (drug) has a direct relationship of drug_protein:carrier with LTF (gene/protein).', 'feat_emb': [-0.011582981, 0.0018133956, 0.0050554257, -0.022483429, -0.024731772, 0.0004805331, -0.0038308818, 0.02338009, -0.018321319, -0.0016402531, -0.0034862699, 0.010552491, -0.0005691955, 0.015858848, -0.0011275172, -0.0024306863, 0.028318414, -0.00040399912, -0.0013274257, -0.0055606337, -0.005219368, 0.027247775, 0.004904867, -0.012713844, -0.011161417, 0.01891017, 0.03642851, -0.03214595, -0.015577804, -0.012165141, 0.009816426, -0.0039312546, -0.02168045, -0.008250616, 0.008712329, -0.033591315, 0.010519033, 0.008431287, 0.028773436, -0.0077286786, 0.024785304, 0.010826842, 0.0058216024, 0.0050018937, -0.0031985354, -0.01072647, 0.021881195, -0.017317595, -0.040898427, 0.0023252952, -0.015136166, 0.011636513, -0.016086359, -0.005972161, -0.007989648, 0.005707847, 0.003285525, 0.0035431476, 0.0048

In [86]:
# Check the ground truth for the search
(
    results[0]["triplet_index"],
    results[0]["head_id"],
    results[0]["tail_id"],
    results[0]["edge_type"],
    results[0]["feat"],
)

(0,
 'Rose bengal_(14118)',
 'LTF_(3233)',
 'drug|carrier|gene/protein',
 'Rose bengal (drug) has a direct relationship of drug_protein:carrier with LTF (gene/protein).')

In [87]:
%%time

# Assume node_coll_name is defined and collection exists
collection = Collection("t2kg_primekg_edges")

# Load the collection into memory before query
collection.load()

# Vector similarity search in Milvus
vector_to_search = np.array(
    results[0]["feat_emb"]
).tolist()  # merged_edges_df["edge_emb"].iloc[0]
search_params = {"metric_type": "IP"}
results = collection.search(
    data=[vector_to_search],
    anns_field="feat_emb",
    param=search_params,
    limit=10,
    output_fields=["head_id", "tail_id", "edge_type", "feat"],
)
results

CPU times: user 2.89 ms, sys: 0 ns, total: 2.89 ms
Wall time: 9.52 ms


data: [[{'triplet_index': 0, 'distance': 1.000000238418579, 'entity': {'edge_type': 'drug|carrier|gene/protein', 'feat': 'Rose bengal (drug) has a direct relationship of drug_protein:carrier with LTF (gene/protein).', 'head_id': 'Rose bengal_(14118)', 'tail_id': 'LTF_(3233)'}}, {'triplet_index': 5636, 'distance': 0.9769061207771301, 'entity': {'edge_type': 'gene/protein|carrier|drug', 'feat': 'LTF (gene/protein) has a direct relationship of drug_protein:carrier with Rose bengal (drug).', 'head_id': 'LTF_(3233)', 'tail_id': 'Rose bengal_(14118)'}}, {'triplet_index': 196, 'distance': 0.900477409362793, 'entity': {'edge_type': 'drug|target|gene/protein', 'feat': '3h-Indole-5,6-Diol (drug) has a direct relationship of drug_protein:target with LTF (gene/protein).', 'head_id': '3h-Indole-5,6-Diol_(18278)', 'tail_id': 'LTF_(3233)'}}, {'triplet_index': 199, 'distance': 0.898104727268219, 'entity': {'edge_type': 'drug|target|gene/protein', 'feat': 'alpha-D-Fucopyranose (drug) has a direct relat

In [88]:
# Get node indices from the results
[n["triplet_index"] for n in results[0]]

[0, 5636, 196, 199, 198, 197, 201, 5837, 202, 5832]

In [89]:
# Get the cosine similarity scores
[n["distance"] for n in results[0]]

[1.000000238418579,
 0.9769061207771301,
 0.900477409362793,
 0.898104727268219,
 0.8961387872695923,
 0.8955156803131104,
 0.8937671780586243,
 0.8916683197021484,
 0.8896037340164185,
 0.8886452913284302]