In [65]:
import neo4j
from neo4j_graphrag_custom.kg_indexer import KGIndexer
from neo4j_graphrag_custom.kg_builder import GeminiLLM
from neo4j_graphrag.schema import get_schema
import os
import json
from dotenv import load_dotenv
from neo4j_graphrag.embeddings import SentenceTransformerEmbeddings
from neo4j_graphrag.retrievers import (
    VectorRetriever,
    VectorCypherRetriever,
    HybridRetriever,
    HybridCypherRetriever,
    Text2CypherRetriever
)
from pprint import pprint

# 0. Initial setup

Note: this notebook assumes the existence of an *indexed* knowledge graph (with the full text and the embeddings indexed) in the Neo4j database that is called. 

In [87]:
# Load configuration and setup

script_dir = os.getcwd()

# script_dir = os.path.dirname(os.path.abspath(__file__))  # Uncomment if running as a script

# Load environment variables from a .env file
dotenv_path = os.path.join(script_dir, '.env')
load_dotenv(dotenv_path, override=True)

# Open configuration file from JSON format
config_path = os.path.join(script_dir, 'kg_building_config.json')  # Configuration file of the knowledge graph builder
with open(config_path, 'r') as kg_build_config_file:
    build_config = json.load(kg_build_config_file)
config_path = os.path.join(script_dir, 'kg_retrieval_config.json')  # Configuration file of the knowledge graph retriever
with open(config_path, 'r') as kg_retr_config_file:
    retr_config = json.load(kg_retr_config_file)

# Neo4j connection
neo4j_uri = os.getenv('NEO4J_URI')
neo4j_username = os.getenv('NEO4J_USERNAME')
neo4j_password = os.getenv('NEO4J_PASSWORD')
gemini_api_key = os.getenv('GEMINI_API_KEY')

driver = neo4j.GraphDatabase.driver(neo4j_uri, auth=(neo4j_username, neo4j_password))

In [4]:
# Create embedder
embedder = SentenceTransformerEmbeddings(model=build_config['embedder_config']['model_name'])

# Get the index name for the text embeddings index
indexer = KGIndexer(driver=driver)
existing_indexes = indexer.list_all_indexes()
embeddings_index_name = [index['name'] for index in existing_indexes if index['type'] == 'VECTOR'][0]
fulltext_index_name = [index['name'] for index in existing_indexes if index['type'] == 'FULLTEXT'][0]

Found 5 indexes in the database:

1. {'id': 2, 'name': '__entity__id', 'state': 'ONLINE', 'populationPercent': 100.0, 'type': 'RANGE', 'entityType': 'NODE', 'labelsOrTypes': ['__KGBuilder__'], 'properties': ['id'], 'indexProvider': 'range-1.0', 'owningConstraint': None, 'lastRead': neo4j.time.DateTime(2025, 6, 5, 14, 52, 4, 898000000, tzinfo=<UTC>), 'readCount': 3508}

2. {'id': 3, 'name': 'embeddings_index', 'state': 'ONLINE', 'populationPercent': 100.0, 'type': 'VECTOR', 'entityType': 'NODE', 'labelsOrTypes': ['Chunk'], 'properties': ['embedding'], 'indexProvider': 'vector-2.0', 'owningConstraint': None, 'lastRead': None, 'readCount': 0}

3. {'id': 4, 'name': 'fulltext_index', 'state': 'ONLINE', 'populationPercent': 100.0, 'type': 'FULLTEXT', 'entityType': 'NODE', 'labelsOrTypes': ['Chunk'], 'properties': ['text'], 'indexProvider': 'fulltext-1.0', 'owningConstraint': None, 'lastRead': None, 'readCount': 0}

4. {'id': 0, 'name': 'index_343aff4e', 'state': 'ONLINE', 'populationPercent'

In [6]:
sample_query_text = "Which have been the most pressing security-related issues in Sudan in the last year? What are the future prospects for the country considering the current situation?"

# 1. Vector retriever

Similarity search using vector embeddings.

In [7]:
# Create vector retriever
v_retriever = VectorRetriever(
    driver=driver,
    index_name=embeddings_index_name,  # Name of the vector index that will be used for similarity search with the embedded query text
    embedder=embedder,  # Embedder to use for embedding the query text when doing a vector search
    return_properties=['text']  # Properties to return from the vector search results, apart from the similarity scores (cosine similarity scores by default)
)

Now, let's check which information is retrieved with the sample query text and this vector retriever. This will NOT be the final output of GraphRAG, but an intermediate step where all of the relevant information is compiled according to the characteristics of the retriever. In this case, the vector retriever will:
1. Embed the query text with the embedder. At this point, it could be interesting to consider **hypothetical document embeddings** (see Arnault's slides of Advanced NLP, session 9).
2. Compute the cosine similarity of the embedded query text with the embeddings of the text.
3. Return the cosine similarity of the closest vectors, together with their text (if the property "text" is returned).

Now, let's see how Neo4j is getting the results with the lower-level (but more robust than `.search()`) `.get_search_results()` method.

In [25]:
v_results = v_retriever.get_search_results(
    query_vector=None,  # The query vector is None because we will use the embedder to embed the query text
    query_text=sample_query_text,  # The query text to embed and search for
    top_k=5  # Number of results to return
)

print("Raw search results:\n", v_results)

for i in v_results.records: print("=" * 50 + "\n" + json.dumps(i.data(), indent=4))

Raw search results:
{
    "node": {
    },
    "nodeLabels": [
        "__KGBuilder__",
        "Chunk"
    ],
    "elementId": "4:6c5b3cb6-25d4-4346-af2c-705e7774defb:264",
    "id": "4:6c5b3cb6-25d4-4346-af2c-705e7774defb:264",
    "score": 0.8028373718261719
}
{
    "node": {
        "text": "Sudanese civil defense says all fires at major oil depots in government-controlled Port Sudan are now \"completely\" under control following numerous RSF attacks on petroleum reserves"
    },
    "nodeLabels": [
        "__KGBuilder__",
        "Chunk"
    ],
    "elementId": "4:6c5b3cb6-25d4-4346-af2c-705e7774defb:296",
    "id": "4:6c5b3cb6-25d4-4346-af2c-705e7774defb:296",
    "score": 0.78676438331604
}
{
    "node": {
        "text": "Sudanese media reports new drone strikes on unspecified areas in Port Sudan"
    },
    "nodeLabels": [
        "__KGBuilder__",
        "Chunk"
    ],
    "elementId": "4:6c5b3cb6-25d4-4346-af2c-705e7774defb:364",
    "id": "4:6c5b3cb6-25d4-4346-af2c-705e777

# 2. VectorCypherRetriever

Combines vector search with retrieval queries in Cypher, Neo4j’s Graph Query language, to traverse the graph and incorporate additional nodes and relationships. Below we create a retriever to obtain Chunk nodes via vector search, then traversing out on entities up to 3 hops out (query taken from [this article](https://neo4j.com/blog/news/graphrag-python-package/)). 

In [44]:
retrieval_query = """
//1) Go out 2-3 hops in the entity graph and get relationships
WITH node AS chunk
MATCH (chunk)<-[:FROM_CHUNK]-()-[relList:!FROM_CHUNK]-{1,2}()
UNWIND relList AS rel

//2) collect relationships and text chunks
WITH collect(DISTINCT chunk) AS chunks, 
  collect(DISTINCT rel) AS rels

//3) format and return context
RETURN '=== text ===\n' + apoc.text.join([c in chunks | c.text], '\n---\n') + '\n\n=== kg_rels ===\n' +
  apoc.text.join([r in rels | startNode(r).name + ' - ' + type(r) + '(' + coalesce(r.details, '') + ')' +  ' -> ' + endNode(r).name ], '\n---\n') AS info
"""

In [42]:
print(retr_config['VectorCypherRetriever_config']['retrieval_query'])

//1) Go out 2-3 hops in the entity graph and get relationships
WITH node AS chunk
MATCH (chunk)<-[:FROM_CHUNK]-()-[relList:!FROM_CHUNK]-{1,2}()
UNWIND relList AS rel

//2) collect relationships and text chunks
WITH collect(DISTINCT chunk) AS chunks,
 collect(DISTINCT rel) AS rels

//3) format and return context
RETURN '=== text ===\n' + apoc.text.join([c in chunks | c.text], '\n---\n') + '\n\n=== kg_rels ===\n' +
 apoc.text.join([r in rels | startNode(r).name + ' - ' + type(r) + '(' + coalesce(r.details, '') + ')' +  ' -> ' + endNode(r).name ], '\n---\n') AS info


In [45]:
# Create vector retriever
vc_retriever = VectorCypherRetriever(
    driver=driver,
    index_name=embeddings_index_name,  # Name of the vector index that will be used for similarity search with the embedded query text
    retrieval_query=retrieval_query, # Cypher query to retrieve the context surrounding the embeddings that are found for the results
    embedder=embedder  # Embedder to use for embedding the query text when doing a vector search
)

In [53]:
vc_results = vc_retriever.get_search_results(
    query_vector=None,  # The query vector is None because we will use the embedder to embed the query text
    query_text=sample_query_text,  # The query text to embed and search for
    top_k=5  # Number of results to return
)

# print output
kg_rel_pos = vc_results.records[0]['info'].find('=== kg_rels ===\n')
print("# Text Chunk Context:\n")
print(vc_results.records[0]['info'][:kg_rel_pos])
print("\n# KG Context From Relationships:\n")
print(vc_results.records[0]['info'][kg_rel_pos:])

# Text Chunk Context:

=== text ===
---
Sudanese civil defense says all fires at major oil depots in government-controlled Port Sudan are now "completely" under control following numerous RSF attacks on petroleum reserves
---
Sudanese media reports new drone strikes on unspecified areas in Port Sudan
---
Sudanese media report new drone attack on Port Sudan with no specific location given; air defense reported at work
---
Sudanese military, citing official, reports drone attack on Port Sudan targeted civilian facilities including air base and a cargo warehouse; unverified reports claim power outages in parts of city [corrects location struck]



# KG Context From Relationships:

=== kg_rels ===
Emirati - COOPERATED_WITH() -> RSF
---
UAE ships - COOPERATED_WITH() -> RSF
---
UAE ships - IS_FROM() -> Red Sea
---
UAE ships - IS_FROM() -> UAE
---
RSF - IS_FROM() -> Nyala
---
Atbara airport - IS_WITHIN() -> Nyala
---
Nyala - IS_WITHIN() -> Sudan
---
Nyala - IS_WITHIN() -> Darfur
---
airstrike

i.e., the `VectorCypherRetriever` extracts information both from the similarity scores of the query embeddings with the text embeddings as well as with the graph properties.

# 3. HybridRetriever

Combines vector and full-text search.

In [54]:
hy_retriever = HybridRetriever(
    driver=driver,
    vector_index_name=embeddings_index_name,  # Name of the vector index that will be used for similarity search with the embedded query text
    fulltext_index_name=fulltext_index_name,  # Name of the fulltext index that will be used for text search
    embedder=embedder,  # Embedder to use for embedding the query text when doing a vector search
    return_properties=['text']  # Properties to return from the vector search results, apart from the similarity scores (cosine similarity scores by default)
)

In [56]:
hy_results = hy_retriever.get_search_results(
    query_vector=None,  # The query vector is None because we will use the embedder to embed the query text
    query_text=sample_query_text,  # The query text to embed and search for
    top_k=5,  # Number of results to return
    ranker='linear',  # Ranker to use for ranking the results, 'linear' is a simple linear combination of the vector and text scores, "naive" is default value and just combines the scores without weighting them 
    alpha=0.5  # Weighting factor for the vector score in the linear combination, 0.5 means equal weighting for vector and text scores
)

print("Raw search results:\n", hy_results)

for i in hy_results.records: print("=" * 50 + "\n" + json.dumps(i.data(), indent=4))

Raw search results:
{
    "node": {
    },
    "nodeLabels": [
        "__KGBuilder__",
        "Chunk"
    ],
    "elementId": "4:6c5b3cb6-25d4-4346-af2c-705e7774defb:264",
    "id": "4:6c5b3cb6-25d4-4346-af2c-705e7774defb:264",
    "score": 0.5
}
{
    "node": {
        "text": "Editor's note: We are aware of images circulating on social media of a large fire burning at multiple oil depots in Port Sudan, Sudan. The attack has not yet been reported on by Sudanese media, and the cause of the fires is still unclear. The fires come just one day after a series of RSF paramilitary drone strikes in the city, which were the first of the country's two year civil war. We are watching our sources for more information. - Owen"
    },
    "nodeLabels": [
        "__KGBuilder__",
        "Chunk"
    ],
    "elementId": "4:6c5b3cb6-25d4-4346-af2c-705e7774defb:434",
    "id": "4:6c5b3cb6-25d4-4346-af2c-705e7774defb:434",
    "score": 0.5
}
{
    "node": {
        "text": "Sudanese civil defense says

# 4. HybridCypherRetriever

Combines vector and full-text search with Cypher retrieval queries for additional graph traversal. 

In [57]:
# Create vector retriever
hyc_retriever = HybridCypherRetriever(
    driver=driver,
    vector_index_name=embeddings_index_name,  # Name of the vector index that will be used for similarity search with the embedded query text
    fulltext_index_name=fulltext_index_name,  # Name of the fulltext index that will be used for text search
    retrieval_query=retrieval_query, # Cypher query to retrieve the context surrounding the embeddings that are found for the results
    embedder=embedder  # Embedder to use for embedding the query text when doing a vector search
)

In [58]:
hyc_results = hyc_retriever.get_search_results(
    query_vector=None,  # The query vector is None because we will use the embedder to embed the query text
    query_text=sample_query_text,  # The query text to embed and search for
    top_k=5,  # Number of results to return
    ranker="linear",  # Ranker to use for ranking the results, 'linear' is a simple linear combination of the vector and text scores, "naive" is default value and just combines the scores without weighting them
    alpha=0.5  # Weighting factor for the vector score in the linear combination, 0.5 means equal weighting for vector and text scores
)

# print output
kg_rel_pos = hyc_results.records[0]['info'].find('=== kg_rels ===\n')
print("# Text Chunk Context:\n")
print(hyc_results.records[0]['info'][:kg_rel_pos])
print("\n# KG Context From Relationships:\n")
print(hyc_results.records[0]['info'][kg_rel_pos:])

# Text Chunk Context:

=== text ===
---
Editor's note: We are aware of images circulating on social media of a large fire burning at multiple oil depots in Port Sudan, Sudan. The attack has not yet been reported on by Sudanese media, and the cause of the fires is still unclear. The fires come just one day after a series of RSF paramilitary drone strikes in the city, which were the first of the country's two year civil war. We are watching our sources for more information. - Owen
---
Sudanese civil defense says all fires at major oil depots in government-controlled Port Sudan are now "completely" under control following numerous RSF attacks on petroleum reserves
---
Sudanese media reports new drone strikes on unspecified areas in Port Sudan
---
Sudanese media report new drone attack on Port Sudan with no specific location given; air defense reported at work



# KG Context From Relationships:

=== kg_rels ===
Emirati - COOPERATED_WITH() -> RSF
---
UAE ships - COOPERATED_WITH() -> RSF
--

# 5. Text2CypherRetriever

Converts natural language queries into Cypher queries to run against Neo4j. Does NOT search in text or perform similarity measures.

In [64]:
llm = GeminiLLM(
    model_name=retr_config['Text2CypherRetriever_config']['llm']['model_name'],  # LLM model name to use for generating Cypher queries from the query text
    google_api_key=gemini_api_key,
    model_params= retr_config['Text2CypherRetriever_config']['llm']['model_params'],  # Model parameters for the LLM
)

In [73]:
schema = get_schema(  # Get the schema of the knowledge graph
    driver=driver,
    is_enhanced=True,  # Whether to use the enhanced schema with additional information (e.g., include examples)
    sanitize=False
)

print(schema)

Node properties:
- **Document**
  - `id`: STRING Example: "5a9f0108-8991-44ae-9dbb-263a1a40eff7"
  - `path`: STRING Example: "1"
  - `createdAt`: STRING Example: "2025-06-05T14:20:10.336471+00:00"
  - `source`: STRING Example: "https://www.unicef.org/press-releases/wfpunicef-hu"
  - `published_date`: STRING Example: "2025-06-03 15:28:28.271179+00:00"
- **Chunk**
  - `id`: STRING Example: "f830f947-5d6c-4a4b-94b5-a4154b66e0bf"
  - `index`: INTEGER Min: 0, Max: 0
  - `text`: STRING Example: "WFP and UNICEF now say five members were killed, s"
- **Actor**
  - `id`: STRING Example: "f830f947-5d6c-4a4b-94b5-a4154b66e0bf:0"
  - `name`: STRING Example: "WFP"
  - `chunk_index`: INTEGER Min: 0, Max: 0
  - `type`: STRING Example: "International Organization"
- **Event**
  - `id`: STRING Example: "f830f947-5d6c-4a4b-94b5-a4154b66e0bf:2"
  - `name`: STRING Example: "attack on aid convoy"
  - `chunk_index`: INTEGER Min: 0, Max: 0
  - `type`: STRING Example: "Attack"
  - `start_date`: STRING Availab

In [85]:
examples = [
    "USER INPUT: 'What events happened in Sudan?'\nQUERY: MATCH (e:Event)-[:HAPPENED_IN]->(c:Country) WHERE c.name = 'Sudan' RETURN e.name, e.type",
    "USER INPUT: 'Which actors participated in attacks?'\nQUERY: MATCH (a:Actor)-[:PARTICIPATED_IN]->(e:Event) WHERE e.type = 'Attack' RETURN a.name, a.type"
]

In [89]:
print(retr_config['Text2CypherRetriever_config']['examples_config'])

{'include_examples': True, 'examples': ["USER INPUT: 'What events happened in Sudan?'\nQUERY: MATCH (e:Event)-[:HAPPENED_IN]->(c:Country) WHERE c.name = 'Sudan' RETURN e.name, e.type", "USER INPUT: 'Which actors participated in attacks?'\nQUERY: MATCH (a:Actor)-[:PARTICIPATED_IN]->(e:Event) WHERE e.type = 'Attack' RETURN a.name, a.type"]}


In [90]:
t2c_retriever = Text2CypherRetriever(
    driver=driver,
    llm=llm,  # LLM to use for generating Cypher queries from the query text
    neo4j_schema=schema,  # Schema of the knowledge graph to use for generating Cypher queries
    examples=examples,  # Examples to use for generating Cypher queries
    custom_prompt=None  # Use the auto-generated prompt by th LLM
)

In [94]:
t2c_results = t2c_retriever.get_search_results(
    query_text= sample_query_text,  # The natural language query used to generate the Cypher query
)

# Print results in a more readable format
print("=" * 80)
print("TEXT-TO-CYPHER RETRIEVAL RESULTS")
print("=" * 80)
print(f"Query: {sample_query_text}")
print(f"Generated Cypher: {t2c_results.metadata['cypher']}")
print("-" * 80)
print(f"Found {len(t2c_results.records)} events in Sudan:")
print("-" * 80)

# Group events by type for better organization
events_by_type = {}
for record in t2c_results.records:
    event_type = record['e.type']
    event_name = record['e.name']
    
    if event_type not in events_by_type:
        events_by_type[event_type] = []
    events_by_type[event_type].append(event_name)

# Display events grouped by type
for event_type, events in sorted(events_by_type.items()):
    print(f"\n📍 {event_type.upper()} ({len(events)} events):")
    for i, event in enumerate(events, 1):
        print(f"   {i}. {event}")

print("\n" + "=" * 80)

TEXT-TO-CYPHER RETRIEVAL RESULTS
Query: Which have been the most pressing security-related issues in Sudan in the last year? What are the future prospects for the country considering the current situation?
Generated Cypher: MATCH (e:Event)-[:HAPPENED_IN]->(c:Country) WHERE c.name = 'Sudan' RETURN e.name, e.type
--------------------------------------------------------------------------------
Found 15 events in Sudan:
--------------------------------------------------------------------------------

📍 ATTACK (5 events):
   1. attack on aid convoy
   2. strike on a prison
   3. large fire
   4. drones attack
   5. Hospital Bombing

📍 CONFLICT (2 events):
   1. civil war
   2. clashes

📍 DEATHS (1 events):
   1. 12 deaths

📍 DISEASE CASES (1 events):
   1. 727 cholera cases

📍 GOVERNMENT ACTION (2 events):
   1. Government Dissolution
   2. Undertaking Responsibilities

📍 NATURAL DISASTER IMPACT (1 events):
   1. Impact of Heavy Rain

📍 ASSESSMENT (1 events):
   1. assessment of refugee con

# 6. Closing the driver connection

In [95]:
driver.close()