In [None]:
%pip install neo4j pymilvus numpy scipy langchain langchain-core langchain-openai tqdm

In [None]:
from neo4j import GraphDatabase
import json
import numpy as np
from collections import defaultdict
from scipy.sparse import csr_matrix
from pymilvus import MilvusClient
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
from langchain_core.output_parsers import StrOutputParser, JsonOutputParser
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from tqdm import tqdm
import os

In [None]:
# === Neo4j Connection Setup ===
uri = "bolt://localhost:7687"
driver = GraphDatabase.driver(uri, auth=("neo4j", "123456789"))

In [None]:
# === OpenAI Setup ===
# Replace with your API key or use environment variable
os.environ["OPENAI_API_KEY"] = "your-api-key-here"

llm = ChatOpenAI(
    model="gpt-4o",
    temperature=0,
)
embedding_model = OpenAIEmbeddings(model="text-embedding-3-small")

In [None]:
# === Milvus Setup ===
milvus_client = MilvusClient(uri="./milvus.db")

In [None]:
# === Function to delete all existing data (optional cleanup) ===
def delete_all_data(tx):
    tx.run("MATCH (n) DETACH DELETE n")
    print("🗑️ Cleared existing data.")

# === Function to insert an entity ===
def insert_entity(tx, name, entity_type, singular, description):
    query = (
        "MERGE (e:Entity {name: $name}) "
        "SET e.entity_type = $entity_type, e.singular = $singular, e.description = $description"
    )
    tx.run(query, name=name, entity_type=entity_type, singular=singular, description=description)
    print(f"✅ Inserted Entity: {name}")

# === Function to insert a relationship ===
def insert_relationship(tx, entity_1, entity_2, rel_type, description):
    query = (
        "MATCH (a:Entity {name: $entity_1}), (b:Entity {name: $entity_2}) "
        "MERGE (a)-[r:RELATION {type: $rel_type, description: $description}]->(b)"
    )
    tx.run(query, entity_1=entity_1, entity_2=entity_2, rel_type=rel_type, description=description)
    print(f"🔗 Created Relationship: ({entity_1})-[:{rel_type}]->({entity_2})")

In [None]:
# === Load JSON data ===
with open("/Users/shiva/Documents/Regal/data/enities_realtions.json", "r") as f:
    kg_data = json.load(f)

# === Insert all entities and relationships into Neo4j ===
with driver.session() as session:
    # Optional: Wipe database before inserting new data
    session.write_transaction(delete_all_data)

    # Insert entities
    for entity in kg_data["entities"]:
        session.write_transaction(
            insert_entity,
            entity["name"],
            entity["entity_type"],
            entity["singular"],
            entity["description"]
        )

    # Insert relationships
    for rel in kg_data["relationships"]:
        session.write_transaction(
            insert_relationship,
            rel["entity_1"],
            rel["entity_2"],
            rel["relationship_type"],
            rel["description"]
        )

print("✅ All entities and relationships inserted into Neo4j.")

In [None]:
# === Extract triplets for Graph RAG ===
# This function extracts triplets from the new data format
def extract_triplets_from_new_format(kg_data):
    triplets = []
    triplets_with_passages = []
    
    # Create a mapping of entity names to their descriptions
    entity_descriptions = {}
    for entity in kg_data["entities"]:
        entity_descriptions[entity["name"]] = entity["description"]
    
    # Create triplets from relationships
    for rel in kg_data["relationships"]:
        entity_1 = rel["entity_1"]
        entity_2 = rel["entity_2"]
        relationship_type = rel["relationship_type"]
        
        # Create a triplet in the format expected by Graph RAG
        triplet = [entity_1, relationship_type, entity_2]
        triplets.append(triplet)
        
        # Create a passage from the relationship and entity descriptions
        passage = f"{entity_1} {relationship_type} {entity_2}. "
        if entity_1 in entity_descriptions:
            passage += f"Description of {entity_1}: {entity_descriptions[entity_1]}. "
        if entity_2 in entity_descriptions:
            passage += f"Description of {entity_2}: {entity_descriptions[entity_2]}. "
        if "description" in rel and rel["description"]:
            passage += f"Relationship details: {rel['description']}"
        
        # Store the passage with the triplet
        triplet_with_passage = {
            "passage": passage,
            "triplets": [triplet]
        }
        triplets_with_passages.append(triplet_with_passage)
    
    return triplets, triplets_with_passages

# Extract triplets and passages from the new data format
triplets, triplets_with_passages = extract_triplets_from_new_format(kg_data)

In [None]:
# === Process triplets for Graph RAG ===
entityid_2_relationids = defaultdict(list)
relationid_2_passageids = defaultdict(list)

entities = []
relations = []
passages = []

# Process the triplets with passages
for passage_id, dataset_info in enumerate(triplets_with_passages):
    passage, triplet_list = dataset_info["passage"], dataset_info["triplets"]
    passages.append(passage)
    
    for triplet in triplet_list:
        if triplet[0] not in entities:
            entities.append(triplet[0])
        if triplet[2] not in entities:
            entities.append(triplet[2])
        
        relation = " ".join(triplet)
        if relation not in relations:
            relations.append(relation)
            entityid_2_relationids[entities.index(triplet[0])].append(len(relations) - 1)
            entityid_2_relationids[entities.index(triplet[2])].append(len(relations) - 1)
        
        relationid_2_passageids[relations.index(relation)].append(passage_id)

In [None]:
# === Create Milvus collections ===
embedding_dim = len(embedding_model.embed_query("foo"))

def create_milvus_collection(collection_name: str):
    if milvus_client.has_collection(collection_name=collection_name):
        milvus_client.drop_collection(collection_name=collection_name)
    milvus_client.create_collection(
        collection_name=collection_name,
        dimension=embedding_dim,
        consistency_level="Strong",
    )

entity_col_name = "entity_collection"
relation_col_name = "relation_collection"
passage_col_name = "passage_collection"
create_milvus_collection(entity_col_name)
create_milvus_collection(relation_col_name)
create_milvus_collection(passage_col_name)

In [None]:
# === Insert data into Milvus ===
def milvus_insert(
    collection_name: str,
    text_list: list[str],
):
    batch_size = 512
    for row_id in tqdm(range(0, len(text_list), batch_size), desc="Inserting"):
        batch_texts = text_list[row_id : row_id + batch_size]
        batch_embeddings = embedding_model.embed_documents(batch_texts)

        batch_ids = [row_id + j for j in range(len(batch_texts))]
        batch_data = [
            {
                "id": id_,
                "text": text,
                "vector": vector,
            }
            for id_, text, vector in zip(batch_ids, batch_texts, batch_embeddings)
        ]
        milvus_client.insert(
            collection_name=collection_name,
            data=batch_data,
        )

# Insert relations, entities, and passages into Milvus
milvus_insert(
    collection_name=relation_col_name,
    text_list=relations,
)

milvus_insert(
    collection_name=entity_col_name,
    text_list=entities,
)

milvus_insert(
    collection_name=passage_col_name,
    text_list=passages,
)

In [None]:
# === Build adjacency matrices ===
entity_relation_adj = np.zeros((len(entities), len(relations)))
for entity_id, entity in enumerate(entities):
    entity_relation_adj[entity_id, entityid_2_relationids[entity_id]] = 1

entity_relation_adj = csr_matrix(entity_relation_adj)

entity_adj_1_degree = entity_relation_adj @ entity_relation_adj.T
relation_adj_1_degree = entity_relation_adj.T @ entity_relation_adj

target_degree = 1

entity_adj_target_degree = entity_adj_1_degree
for _ in range(target_degree - 1):
    entity_adj_target_degree = entity_adj_target_degree * entity_adj_1_degree
relation_adj_target_degree = relation_adj_1_degree
for _ in range(target_degree - 1):
    relation_adj_target_degree = relation_adj_target_degree * relation_adj_1_degree

entity_relation_adj_target_degree = entity_adj_target_degree @ entity_relation_adj

In [None]:
# === Rerank relations function ===
def rerank_relations(
    query: str, relation_candidate_texts: list[str], relation_candidate_ids: list[str]
) -> list[int]:
    query_prompt_one_shot_input = """I will provide you with a list of relationship descriptions. Your task is to select 3 relationships that may be useful to answer the given question. Please return a JSON object containing your thought process and a list of the selected relationships in order of their relevance.

Question:
When was the mother of the leader of the Third Crusade born?

Relationship descriptions:
[1] Eleanor was born in 1122.
[2] Eleanor married King Louis VII of France.
[3] Eleanor was the Duchess of Aquitaine.
[4] Eleanor participated in the Second Crusade.
[5] Eleanor had eight children.
[6] Eleanor was married to Henry II of England.
[7] Eleanor was the mother of Richard the Lionheart.
[8] Richard the Lionheart was the King of England.
[9] Henry II was the father of Richard the Lionheart.
[10] Henry II was the King of England.
[11] Richard the Lionheart led the Third Crusade.

"""
    query_prompt_one_shot_output = """{'thought_process': 'To answer the question about the birth of the mother of the leader of the Third Crusade, I first need to identify who led the Third Crusade and then determine who his mother was. After identifying his mother, I can look for the relationship that mentions her birth.', 'useful_relationships': ['[11] Richard the Lionheart led the Third Crusade', '[7] Eleanor was the mother of Richard the Lionheart', '[1] Eleanor was born in 1122']}"""

    query_prompt_template = """Question:
{question}

Relationship descriptions:
{relation_des_str}

"""
    relation_des_str = "\n".join(
        map(
            lambda item: f"[{item[0]}] {item[1]}",
            zip(relation_candidate_ids, relation_candidate_texts),
        )
    ).strip()
    rerank_prompts = ChatPromptTemplate.from_messages(
        [
            HumanMessage(query_prompt_one_shot_input),
            AIMessage(query_prompt_one_shot_output),
            HumanMessagePromptTemplate.from_template(query_prompt_template),
        ]
    )
    rerank_chain = (
        rerank_prompts
        | llm.bind(response_format={"type": "json_object"})
        | JsonOutputParser()
    )
    rerank_res = rerank_chain.invoke(
        {"question": query, "relation_des_str": relation_des_str}
    )
    rerank_relation_ids = []
    rerank_relation_lines = rerank_res["useful_relationships"]
    id_2_lines = {}
    for line in rerank_relation_lines:
        id_ = int(line[line.find("[") + 1 : line.find("]")])
        id_2_lines[id_] = line.strip()
        rerank_relation_ids.append(id_)
    return rerank_relation_ids

In [None]:
# === Example query function ===
def query_graph_rag(query, query_ner_list=None):
    if query_ner_list is None:
        # Extract named entities from the query using a simple approach
        # In a real implementation, you might use a more sophisticated NER model
        query_ner_list = [word for word in query.split() if word[0].isupper()]
    
    query_ner_embeddings = [
        embedding_model.embed_query(query_ner) for query_ner in query_ner_list
    ]
    
    top_k = 3
    
    entity_search_res = milvus_client.search(
        collection_name=entity_col_name,
        data=query_ner_embeddings,
        limit=top_k,
        output_fields=["id"],
    )
    
    query_embedding = embedding_model.embed_query(query)
    
    relation_search_res = milvus_client.search(
        collection_name=relation_col_name,
        data=[query_embedding],
        limit=top_k,
        output_fields=["id"],
    )[0]
    
    expanded_relations_from_relation = set()
    expanded_relations_from_entity = set()
    
    filtered_hit_relation_ids = [
        relation_res["entity"]["id"]
        for relation_res in relation_search_res
    ]
    for hit_relation_id in filtered_hit_relation_ids:
        expanded_relations_from_relation.update(
            relation_adj_target_degree[hit_relation_id].nonzero()[1].tolist()
        )
    
    filtered_hit_entity_ids = [
        one_entity_res["entity"]["id"]
        for one_entity_search_res in entity_search_res
        for one_entity_res in one_entity_search_res
    ]
    
    for filtered_hit_entity_id in filtered_hit_entity_ids:
        expanded_relations_from_entity.update(
            entity_relation_adj_target_degree[filtered_hit_entity_id].nonzero()[1].tolist()
        )
    
    relation_candidate_ids = list(
        expanded_relations_from_relation | expanded_relations_from_entity
    )
    
    relation_candidate_texts = [
        relations[relation_id] for relation_id in relation_candidate_ids
    ]
    
    # Rerank relations
    rerank_relation_ids = rerank_relations(
        query,
        relation_candidate_texts=relation_candidate_texts,
        relation_candidate_ids=relation_candidate_ids,
    )
    
    final_top_k = 2
    
    final_passages = []
    final_passage_ids = []
    for relation_id in rerank_relation_ids:
        for passage_id in relationid_2_passageids[relation_id]:
            if passage_id not in final_passage_ids:
                final_passage_ids.append(passage_id)
                final_passages.append(passages[passage_id])
    passages_from_our_method = final_passages[:final_top_k]
    
    # Compare with naive RAG
    naive_passage_res = milvus_client.search(
        collection_name=passage_col_name,
        data=[query_embedding],
        limit=final_top_k,
        output_fields=["text"],
    )[0]
    passages_from_naive_rag = [res["entity"]["text"] for res in naive_passage_res]
    
    # Generate answers
    prompt = ChatPromptTemplate.from_messages(
        [
            (
                "human",
                """Use the following pieces of retrieved context to answer the question. If there is not enough information in the retrieved context to answer the question, just say that you don't know.
    
    Question: {question}
    Context: {context}
    
    Answer:""",
            )
        ]
    )
    
    rag_chain = prompt | llm | StrOutputParser()
    
    answer_from_naive_rag = rag_chain.invoke(
        {"question": query, "context": "\n".join(passages_from_naive_rag)}
    )
    answer_from_our_method = rag_chain.invoke(
        {"question": query, "context": "\n".join(passages_from_our_method)}
    )
    
    return {
        "passages_from_naive_rag": passages_from_naive_rag,
        "passages_from_our_method": passages_from_our_method,
        "answer_from_naive_rag": answer_from_naive_rag,
        "answer_from_our_method": answer_from_our_method
    }

In [None]:
# === Test the Graph RAG system ===
test_query = "What is the relationship between Reserve Bank and Foreign Exchange Management Act?"
result = query_graph_rag(test_query, ["Reserve Bank", "Foreign Exchange Management Act"])

print(f"Passages retrieved from naive RAG: \n{result['passages_from_naive_rag']}\n\n"
      f"Passages retrieved from our method: \n{result['passages_from_our_method']}\n\n")

print(f"Answer from naive RAG: {result['answer_from_naive_rag']}\n\n"
      f"Answer from our method: {result['answer_from_our_method']}")

In [None]:
# Close Neo4j connection
driver.close()