<a href="https://colab.research.google.com/github/tomasonjo/blogs/blob/master/llm/llama_index_neo4j_custom_retriever.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Defining a Custom Property Graph Retriever

This guide shows you how to define a custom retriever against a property graph.

It is more involved than using our out-of-the-box graph retrievers, but allows you to have granular control over the retrieval process so that it's better tailored for your application.

We show you how to define an advanced retrieval flow by directly leveraging the property graph store. We'll execute both vector search and text-to-cypher retrieval, and then combine the results through a reranking module.

In [1]:
!pip install --quiet llama-index llama-index-graph-stores-neo4j llama-index-program-openai llama-index-llms-openai

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m15.4/15.4 MB[0m [31m50.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m20.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m75.6/75.6 kB[0m [31m3.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m141.9/141.9 kB[0m [31m7.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m324.1/324.1 kB[0m [31m14.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m12.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m203.0/203.0 kB[0m [31m9.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Installing backend dependencies ..

## Setup and Build the Property Graph

In [2]:
import nest_asyncio

nest_asyncio.apply()

In [3]:
import os

os.environ["OPENAI_API_KEY"] = "sk-"

#### Load News

In [4]:
import pandas as pd

news = pd.read_csv("https://raw.githubusercontent.com/tomasonjo/blog-datasets/main/news_articles.csv")
news.head()

Unnamed: 0,title,date,text
0,Chevron: Best Of Breed,2031-04-06T01:36:32.000000000+00:00,JHVEPhoto Like many companies in the O&G secto...
1,FirstEnergy (NYSE:FE) Posts Earnings Results,2030-04-29T06:55:28.000000000+00:00,FirstEnergy (NYSE:FE – Get Rating) posted its ...
2,Dáil almost suspended after Sinn Féin TD put p...,2023-06-15T14:32:11.000000000+00:00,The Dáil was almost suspended on Thursday afte...
3,Epic’s latest tool can animate hyperrealistic ...,2023-06-15T14:00:00.000000000+00:00,"Today, Epic is releasing a new tool designed t..."
4,"EU to Ban Huawei, ZTE from Internal Commission...",2023-06-15T13:50:00.000000000+00:00,The European Commission is planning to ban equ...


In [6]:
from llama_index.core import Document

documents = [Document(text=f"{row['title']}: {row['text']}") for i, row in news.iterrows()]

#### Define Default LLMs

In [7]:
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.llms.openai import OpenAI

llm = OpenAI(model="gpt-4o", temperature=0.3)
embed_model = OpenAIEmbedding(model_name="text-embedding-3-small")

In [8]:
from llama_index.graph_stores.neo4j import Neo4jPGStore

username="neo4j"
password="capture-debit-blanket"
url="bolt://44.202.206.163:7687"


graph_store = Neo4jPGStore(
    username=username,
    password=password,
    url=url,
)

#### Build the Property Graph

In [12]:
from typing import Literal
from llama_index.core.indices.property_graph import SchemaLLMPathExtractor

# best practice to use upper-case
entities = Literal["Person", "Location", "Organization", "Product", "Event"]
relations = Literal["SUPPLIER_OF", "COMPETITOR", "PARTNERSHIP", "ACQUISITION", "WORKS_AT", "SUBSIDIARY", "BOARD_MEMBER", "CEO", "PROVIDES"]

# define which entities can have which relations
validation_schema = {
    "Person": ["WORKS_AT", "BOARD_MEMBER", "CEO", "HAS_EVENT"],
    "Organization": ["SUPPLIER_OF", "COMPETITOR", "PARTNERSHIP", "ACQUISITION", "WORKS_AT", "SUBSIDIARY", "BOARD_MEMBER", "CEO", "PROVIDES", "HAS_EVENT", "IN_LOCATION"],
    "Product": ["PROVIDES"],
    "Event": ["HAS_EVENT", "IN_LOCATION"],
    "Location": ["HAPPENED_AT", "IN_LOCATION"]
}

kg_extractor = SchemaLLMPathExtractor(
    llm=llm,
    possible_entities=entities,
    possible_relations=relations,
    kg_validation_schema=validation_schema,
    # if false, allows for values outside of the schema
    # useful for using the schema as a suggestion
    strict=False,
)

In [13]:
from llama_index.core import PropertyGraphIndex

index = PropertyGraphIndex.from_documents(
    documents[:5],
    kg_extractors=[kg_extractor],
    llm=llm,
    embed_model=embed_model,
    property_graph_store=graph_store,
    show_progress=True,
)

Parsing nodes:   0%|          | 0/5 [00:00<?, ?it/s]

Extracting paths from text with schema: 100%|██████████| 5/5 [00:06<00:00,  1.25s/it]
Generating embeddings: 100%|██████████| 1/1 [00:00<00:00,  4.25it/s]
Generating embeddings: 100%|██████████| 1/1 [00:00<00:00,  2.54it/s]


## Entity disambiguation

In [14]:
similarity_threshold = 0.9
data = graph_store.structured_query("""
MATCH (e:__Entity__)
CALL db.index.vector.queryNodes('vector', 5, e.embedding)
YIELD node, score
WITH e, node, score
WHERE score > toFLoat($cutoff) AND id(e) < id(node)
WITH e, collect(node) AS nodes
RETURN [e.name] + [n in nodes | n.name] AS duplicates LIMIT 5
""", param_map={'cutoff': similarity_threshold})
for row in data:
    print(data)

[{'duplicates': ['MetaHuman Animator', 'MetaHuman']}]


In [None]:
graph_store.structured_query("""
MATCH (e:__Entity__)
CALL db.index.vector.queryNodes('vector', 5, e.embedding)
YIELD node, score
WITH e, node, score
WHERE score > toFLoat($cutoff) AND id(e) < id(node)
WITH e, collect(node) AS nodes
CALL apoc.refactor.mergeNodes([e] + nodes)
YIELD node
RETURN count(*)
""", param_map={'cutoff': similarity_threshold})

[{'count(*)': 0}]

# Retrieval

In [None]:
from llama_index.core.retrievers import CustomPGRetriever
from llama_index.core.graph_stores import PropertyGraphStore
from llama_index.core.vector_stores.types import VectorStore
from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.prompts import PromptTemplate
from llama_index.core.llms import LLM
from pydantic import BaseModel
from llama_index.program.openai import OpenAIPydanticProgram


from typing import Optional, Any, Union, List, Optional

class Entities(BaseModel):
    """List of named entities in the text such as names of people, organizations, concepts, and locations"""
    names: Optional[List[str]]


prompt_template_entities = """
Extract all named entities such as names of people, organizations, concepts, and locations
from the following text:
{text}
"""

class MyCustomRetriever(CustomPGRetriever):
    """Custom retriever with cohere reranking."""

    def init(
        self,
        ## vector context retriever params
        embed_model: Optional[BaseEmbedding] = None,
        vector_store: Optional[VectorStore] = None,
        similarity_top_k: int = 4,
        path_depth: int = 1,
        **kwargs: Any,
    ) -> None:
        """Uses any kwargs passed in from class constructor."""

        # Create fulltext index
        self.graph_store.structured_query(
            """CREATE FULLTEXT INDEX entities IF NOT EXISTS FOR (e:`__Entity__`) ON EACH [e.name];""")
        self.entity_extraction = OpenAIPydanticProgram.from_defaults(
    output_cls=Entities, prompt_template_str=prompt_template_entities
)

    def custom_retrieve(self, query_str: str) -> str:
        """Define custom retriever with reranking.

        Could return `str`, `TextNode`, `NodeWithScore`, or a list of those.
        """
        entities = self.entity_extraction(text=query_str).names
        if entities:
            pass
        else:
            pass
        print(entities)
        ## TMP: please change
        final_text = "\n\n".join(
            [n.get_content(metadata_mode="llm") for n in nodes_1]
        )

        return final_text

## Test out the Custom Retriever

Now let's initialize and test out the custom retriever against our data!

To build a full RAG pipeline, we use the `RetrieverQueryEngine` to combine our retriever with the LLM synthesis module - this is also used under the hood for the property graph index.

In [None]:
custom_sub_retriever = MyCustomRetriever(
    index.property_graph_store,
    include_text=True,
    vector_store=index.vector_store,
)

In [None]:
from llama_index.core.query_engine import RetrieverQueryEngine

query_engine = RetrieverQueryEngine.from_args(
    index.as_retriever(sub_retrievers=[custom_sub_retriever]), llm=llm
)

### Try out some Queries

In [None]:
response = query_engine.query("Did the author like programming?")
print(str(response))

None


NameError: name 'nodes_1' is not defined