<a href="https://colab.research.google.com/github/tomasonjo/blogs/blob/master/llm/llama_index_neo4j_custom_retriever.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Defining a Custom Property Graph Retriever

This guide shows you how to define a custom retriever against a property graph.

It is more involved than using our out-of-the-box graph retrievers, but allows you to have granular control over the retrieval process so that it's better tailored for your application.

We show you how to define an advanced retrieval flow by directly leveraging the property graph store. We'll execute both vector search and text-to-cypher retrieval, and then combine the results through a reranking module.

In [1]:
!pip install --quiet llama-index llama-index-graph-stores-neo4j llama-index-program-openai llama-index-llms-openai

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m15.4/15.4 MB[0m [31m41.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m49.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m75.6/75.6 kB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m141.9/141.9 kB[0m [31m5.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m325.5/325.5 kB[0m [31m14.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m22.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m203.0/203.0 kB[0m [31m7.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Installing backend dependencies ..

## Setup and Build the Property Graph

In [2]:
import nest_asyncio

nest_asyncio.apply()

In [3]:
from llama_index.graph_stores.neo4j import Neo4jPGStore

username="neo4j"
password="cradles-legends-assemblies"
url="bolt://18.212.1.253:7687"

graph_store = Neo4jPGStore(
    username=username,
    password=password,
    url=url,
)

In [4]:
import os

os.environ["OPENAI_API_KEY"] = "sk-"

#### Load News

In [5]:
import pandas as pd
from llama_index.core import Document

news = pd.read_csv("https://raw.githubusercontent.com/tomasonjo/blog-datasets/main/news_articles.csv")
documents = [Document(text=f"{row['title']}: {row['text']}") for i, row in news.iterrows()]
news.head()

Unnamed: 0,title,date,text
0,Chevron: Best Of Breed,2031-04-06T01:36:32.000000000+00:00,JHVEPhoto Like many companies in the O&G secto...
1,FirstEnergy (NYSE:FE) Posts Earnings Results,2030-04-29T06:55:28.000000000+00:00,FirstEnergy (NYSE:FE – Get Rating) posted its ...
2,Dáil almost suspended after Sinn Féin TD put p...,2023-06-15T14:32:11.000000000+00:00,The Dáil was almost suspended on Thursday afte...
3,Epic’s latest tool can animate hyperrealistic ...,2023-06-15T14:00:00.000000000+00:00,"Today, Epic is releasing a new tool designed t..."
4,"EU to Ban Huawei, ZTE from Internal Commission...",2023-06-15T13:50:00.000000000+00:00,The European Commission is planning to ban equ...


#### Define Default LLMs

In [6]:
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.llms.openai import OpenAI

llm = OpenAI(model="gpt-4o", temperature=0.0)
embed_model = OpenAIEmbedding(model_name="text-embedding-3-small")

#### Build the Property Graph

In [7]:
from typing import Literal
from llama_index.core.indices.property_graph import SchemaLLMPathExtractor

# best practice to use upper-case
entities = Literal["PERSON", "LOCATION", "ORGANIZATION", "PRODUCT", "EVENT"]
relations = Literal[
    "SUPPLIER_OF",
    "COMPETITOR",
    "PARTNERSHIP",
    "ACQUISITION",
    "WORKS_AT",
    "SUBSIDIARY",
    "BOARD_MEMBER",
    "CEO",
    "PROVIDES",
    "HAS_EVENT",
    "IN_LOCATION",
]

# define which entities can have which relations
validation_schema = {
    "Person": ["WORKS_AT", "BOARD_MEMBER", "CEO", "HAS_EVENT"],
    "Organization": [
        "SUPPLIER_OF",
        "COMPETITOR",
        "PARTNERSHIP",
        "ACQUISITION",
        "WORKS_AT",
        "SUBSIDIARY",
        "BOARD_MEMBER",
        "CEO",
        "PROVIDES",
        "HAS_EVENT",
        "IN_LOCATION",
    ],
    "Product": ["PROVIDES"],
    "Event": ["HAS_EVENT", "IN_LOCATION"],
    "Location": ["HAPPENED_AT", "IN_LOCATION"],
}

kg_extractor = SchemaLLMPathExtractor(
    llm=llm,
    possible_entities=entities,
    possible_relations=relations,
    kg_validation_schema=validation_schema,
    # if false, allows for values outside of the schema
    # useful for using the schema as a suggestion
    strict=True,
)


In [8]:
from llama_index.core import PropertyGraphIndex

NUMBER_OF_ARTICLES = 250

index = PropertyGraphIndex.from_documents(
    documents[:NUMBER_OF_ARTICLES],
    kg_extractors=[kg_extractor],
    llm=llm,
    embed_model=embed_model,
    property_graph_store=graph_store,
    show_progress=True,
)

Parsing nodes:   0%|          | 0/250 [00:00<?, ?it/s]

Extracting paths from text with schema: 100%|██████████| 250/250 [07:10<00:00,  1.72s/it]
Generating embeddings: 100%|██████████| 3/3 [00:02<00:00,  1.11it/s]
Generating embeddings: 100%|██████████| 46/46 [00:02<00:00, 16.68it/s]


## Entity disambiguation

In [9]:
graph_store.structured_query("""
CREATE VECTOR INDEX entity IF NOT EXISTS
FOR (m:`__Entity__`)
ON m.embedding
OPTIONS {indexConfig: {
 `vector.dimensions`: 1536,
 `vector.similarity_function`: 'cosine'
}}
""")

[]

In [16]:
# Just for inspection
similarity_threshold = 0.9
word_edit_distance = 5
data = graph_store.structured_query("""
MATCH (e:__Entity__)
CALL {
  WITH e
  CALL db.index.vector.queryNodes('entity', 10, e.embedding)
  YIELD node, score
  WITH node, score
  WHERE score > toFLoat($cutoff)
      AND (toLower(node.name) CONTAINS toLower(e.name) OR toLower(e.name) CONTAINS toLower(node.name)
           OR apoc.text.distance(toLower(node.name), toLower(e.name)) < $distance)
      AND labels(e) = labels(node)
  WITH node, score
  ORDER BY node.name
  RETURN collect(node) AS nodes
}
WITH distinct nodes
WHERE size(nodes) > 1
WITH collect([n in nodes | n.name]) AS results
UNWIND range(0, size(results)-1, 1) as index
WITH results, index, results[index] as result
WITH apoc.coll.sort(reduce(acc = result, index2 IN range(0, size(results)-1, 1) |
        CASE WHEN index <> index2 AND
            size(apoc.coll.intersection(acc, results[index2])) > 0
            THEN apoc.coll.union(acc, results[index2])
            ELSE acc
        END
)) as combinedResult
WITH distinct(combinedResult) as combinedResult
// extra filtering
WITH collect(combinedResult) as allCombinedResults
UNWIND range(0, size(allCombinedResults)-1, 1) as combinedResultIndex
WITH allCombinedResults[combinedResultIndex] as combinedResult, combinedResultIndex, allCombinedResults
WHERE NOT any(x IN range(0,size(allCombinedResults)-1,1)
    WHERE x <> combinedResultIndex
    AND apoc.coll.containsAll(allCombinedResults[x], combinedResult)
)
RETURN combinedResult
""", param_map={'cutoff': similarity_threshold, 'distance': word_edit_distance})
for row in data:
    print(row)

{'combinedResult': ['American Fork', 'American Fork, Utah']}
{'combinedResult': ['XPeng', 'XPeng Inc', 'Xpeng']}
{'combinedResult': ['Bank of America', 'Bank of America Corp.']}
{'combinedResult': ['bilingual STEM tutoring', 'bilingual STEM tutoring organization to close the educational gap for underserved students']}
{'combinedResult': ['Houston Texans', 'Texans']}
{'combinedResult': ['Hyatt', 'Hyatt Hotels']}
{'combinedResult': ['Star Ocean', 'Star Ocean 2', 'Star Ocean 2 remake', 'Star Ocean The Second Story', 'Star Ocean: Second Evolution', 'Star Ocean: The Second Story', 'Star Ocean: The Second Story R']}
{'combinedResult': ['Star Ocean First Departure R', 'Star Ocean: First Departure R']}
{'combinedResult': ['Citigroup', 'Citigroup Inc.']}
{'combinedResult': ['Wells Fargo', 'Wells Fargo & Co.']}
{'combinedResult': ['JPMorgan', 'JPMorgan Asset Management', 'JPMorgan Chase', 'JPMorgan Chase & Co.']}
{'combinedResult': ['Delhi', 'New Delhi']}
{'combinedResult': ['CarMax Auto Owner T

In [21]:
graph_store.structured_query("""
MATCH (e:__Entity__)
CALL {
  WITH e
  CALL db.index.vector.queryNodes('entity', 10, e.embedding)
  YIELD node, score
  WITH node, score
  WHERE score > toFLoat($cutoff)
      AND (toLower(node.name) CONTAINS toLower(e.name) OR toLower(e.name) CONTAINS toLower(node.name)
           OR apoc.text.distance(toLower(node.name), toLower(e.name)) < $distance)
      AND labels(e) = labels(node)
  WITH node, score
  ORDER BY node.name
  RETURN collect(node) AS nodes
}
WITH distinct nodes
WHERE size(nodes) > 1
WITH collect([n in nodes | n.name]) AS results
UNWIND range(0, size(results)-1, 1) as index
WITH results, index, results[index] as result
WITH apoc.coll.sort(reduce(acc = result, index2 IN range(0, size(results)-1, 1) |
        CASE WHEN index <> index2 AND
            size(apoc.coll.intersection(acc, results[index2])) > 0
            THEN apoc.coll.union(acc, results[index2])
            ELSE acc
        END
)) as combinedResult
WITH distinct(combinedResult) as combinedResult
// extra filtering
WITH collect(combinedResult) as allCombinedResults
UNWIND range(0, size(allCombinedResults)-1, 1) as combinedResultIndex
WITH allCombinedResults[combinedResultIndex] as combinedResult, combinedResultIndex, allCombinedResults
WHERE NOT any(x IN range(0,size(allCombinedResults)-1,1)
    WHERE x <> combinedResultIndex
    AND apoc.coll.containsAll(allCombinedResults[x], combinedResult)
)
CALL {
  WITH combinedResult
	UNWIND combinedResult AS name
	MATCH (e:__Entity__ {name:name})
	WITH e
	ORDER BY size(e.name) DESC // prefer longer names to remain after merging
	RETURN collect(e) AS nodes
}
CALL apoc.refactor.mergeNodes(nodes, {properties: {
    `.*`: 'discard'
}})
YIELD node
RETURN count(*)
""", param_map={'cutoff': similarity_threshold, 'distance': word_edit_distance})

[{'count(*)': 85}]

# Retrieval

In [None]:
from llama_index.core.retrievers import CustomPGRetriever, VectorContextRetriever
from llama_index.core.graph_stores import PropertyGraphStore
from llama_index.core.vector_stores.types import VectorStore
from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.prompts import PromptTemplate
from llama_index.core.llms import LLM
from pydantic import BaseModel
from llama_index.program.openai import OpenAIPydanticProgram


from typing import Optional, Any, Union, List, Optional

class Entities(BaseModel):
    """List of named entities in the text such as names of people, organizations, concepts, and locations"""
    names: Optional[List[str]]


prompt_template_entities = """
Extract all named entities such as names of people, organizations, concepts, and locations
from the following text:
{text}
"""

class MyCustomRetriever(CustomPGRetriever):
    """Custom retriever with cohere reranking."""

    def init(
        self,
        ## vector context retriever params
        embed_model: Optional[BaseEmbedding] = None,
        vector_store: Optional[VectorStore] = None,
        similarity_top_k: int = 4,
        path_depth: int = 1,
        include_text: bool = True,
        **kwargs: Any,
    ) -> None:
        """Uses any kwargs passed in from class constructor."""
        self.entity_extraction = OpenAIPydanticProgram.from_defaults(
    output_cls=Entities, prompt_template_str=prompt_template_entities
)
        self.vector_retriever = VectorContextRetriever(
            self.graph_store,
            include_text=self.include_text,
            embed_model=embed_model,
            #vector_store=vector_store,
            similarity_top_k=similarity_top_k,
            path_depth=path_depth,
        )

    def custom_retrieve(self, query_str: str) -> str:
        """Define custom retriever with reranking.

        Could return `str`, `TextNode`, `NodeWithScore`, or a list of those.
        """
        entities = self.entity_extraction(text=query_str).names
        result_nodes = []
        if entities:
            print(f"Detected entities: {entities}")
            for entity in entities:
                result_nodes.extend(self.vector_retriever.retrieve(entity))
        else:
            result_nodes.extend(self.vector_retriever.retrieve(query_str))
        ## TMP: please change
        final_text = "\n\n".join(
            [n.get_content(metadata_mode="llm") for n in result_nodes]
        )
        return final_text

## Test out the Custom Retriever

Now let's initialize and test out the custom retriever against our data!

To build a full RAG pipeline, we use the `RetrieverQueryEngine` to combine our retriever with the LLM synthesis module - this is also used under the hood for the property graph index.

In [None]:
custom_sub_retriever = MyCustomRetriever(
    index.property_graph_store,
    include_text=True,
    vector_store=index.vector_store,
    embed_model=embed_model
)

In [None]:
from llama_index.core.query_engine import RetrieverQueryEngine

query_engine = RetrieverQueryEngine.from_args(
    index.as_retriever(sub_retrievers=[custom_sub_retriever]), llm=llm
)

### Try out some Queries

In [None]:
response = query_engine.query("What do you know about Maliek Collins or Darragh O’Brien?")
print(str(response))

Detected entities: ['Maliek Collins', "Darragh O'Brien"]
Darragh O’Brien is associated with the Dáil and serves as the Minister for Housing. He was involved in a debate regarding retained firefighters, during which Sinn Féin TD John Brady placed an on-call pager in front of him, leading to a near suspension of the Dáil. O’Brien described Brady's actions as "an act of theatre" and expressed confidence that the dispute with firefighters could be resolved. He also encouraged unions to re-engage with the State’s industrial relations process.
