<a href="https://colab.research.google.com/github/tomasonjo/blogs/blob/master/llm/llama_index_neo4j_custom_retriever.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Defining a Custom Property Graph Retriever

This guide shows you how to define a custom retriever against a property graph.

It is more involved than using our out-of-the-box graph retrievers, but allows you to have granular control over the retrieval process so that it's better tailored for your application.

We show you how to define an advanced retrieval flow by directly leveraging the property graph store. We'll execute both vector search and text-to-cypher retrieval, and then combine the results through a reranking module.

In [None]:
!pip install --quiet llama-index llama-index-graph-stores-neo4j llama-index-program-openai llama-index-llms-openai

## Setup and Build the Property Graph

In [None]:
import nest_asyncio

nest_asyncio.apply()

In [None]:
import os

os.environ["OPENAI_API_KEY"] = "sk-"

#### Load News

In [None]:
import pandas as pd

news = pd.read_csv("https://raw.githubusercontent.com/tomasonjo/blog-datasets/main/news_articles.csv")
news.head()

Unnamed: 0,title,date,text
0,Chevron: Best Of Breed,2031-04-06T01:36:32.000000000+00:00,JHVEPhoto Like many companies in the O&G secto...
1,FirstEnergy (NYSE:FE) Posts Earnings Results,2030-04-29T06:55:28.000000000+00:00,FirstEnergy (NYSE:FE – Get Rating) posted its ...
2,Dáil almost suspended after Sinn Féin TD put p...,2023-06-15T14:32:11.000000000+00:00,The Dáil was almost suspended on Thursday afte...
3,Epic’s latest tool can animate hyperrealistic ...,2023-06-15T14:00:00.000000000+00:00,"Today, Epic is releasing a new tool designed t..."
4,"EU to Ban Huawei, ZTE from Internal Commission...",2023-06-15T13:50:00.000000000+00:00,The European Commission is planning to ban equ...


In [None]:
from llama_index.core import Document

documents = [Document(text=f"{row['title']}: {row['text']}") for i, row in news.iterrows()]

#### Define Default LLMs

In [None]:
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.llms.openai import OpenAI

llm = OpenAI(model="gpt-4o", temperature=0.0)
embed_model = OpenAIEmbedding(model_name="text-embedding-3-small")

In [None]:
from llama_index.graph_stores.neo4j import Neo4jPGStore

username="neo4j"
password="stump-inlet-student"
url="bolt://52.201.215.224:7687"


graph_store = Neo4jPGStore(
    username=username,
    password=password,
    url=url,
)

#### Build the Property Graph

In [None]:
from typing import Literal
from llama_index.core.indices.property_graph import SchemaLLMPathExtractor

# best practice to use upper-case
entities = Literal["Person", "Location", "Organization", "Product", "Event"]
relations = Literal["SUPPLIER_OF", "COMPETITOR", "PARTNERSHIP", "ACQUISITION", "WORKS_AT", "SUBSIDIARY", "BOARD_MEMBER", "CEO", "PROVIDES"]

# define which entities can have which relations
validation_schema = {
    "Person": ["WORKS_AT", "BOARD_MEMBER", "CEO", "HAS_EVENT"],
    "Organization": ["SUPPLIER_OF", "COMPETITOR", "PARTNERSHIP", "ACQUISITION", "WORKS_AT", "SUBSIDIARY", "BOARD_MEMBER", "CEO", "PROVIDES", "HAS_EVENT", "IN_LOCATION"],
    "Product": ["PROVIDES"],
    "Event": ["HAS_EVENT", "IN_LOCATION"],
    "Location": ["HAPPENED_AT", "IN_LOCATION"]
}

kg_extractor = SchemaLLMPathExtractor(
    llm=llm,
    possible_entities=entities,
    possible_relations=relations,
    kg_validation_schema=validation_schema,
    # if false, allows for values outside of the schema
    # useful for using the schema as a suggestion
    strict=False,
)

In [9]:
from llama_index.core import PropertyGraphIndex

NUMBER_OF_ARTICLES = 100

index = PropertyGraphIndex.from_documents(
    documents[:NUMBER_OF_ARTICLES],
    kg_extractors=[kg_extractor],
    llm=llm,
    embed_model=embed_model,
    property_graph_store=graph_store,
    show_progress=True,
)

Parsing nodes:   0%|          | 0/100 [00:00<?, ?it/s]

Extracting paths from text with schema: 100%|██████████| 100/100 [02:41<00:00,  1.61s/it]
Generating embeddings: 100%|██████████| 1/1 [00:01<00:00,  1.27s/it]
Generating embeddings: 100%|██████████| 19/19 [00:01<00:00, 10.02it/s]


## Entity disambiguation

In [10]:
graph_store.structured_query("""
CREATE VECTOR INDEX entity IF NOT EXISTS
FOR (m:`__Entity__`)
ON m.embedding
OPTIONS {indexConfig: {
 `vector.dimensions`: 1536,
 `vector.similarity_function`: 'cosine'
}}
""")

[]

In [11]:
similarity_threshold = 0.9
data = graph_store.structured_query("""
MATCH (e:__Entity__)
CALL {
  WITH e
  CALL db.index.vector.queryNodes('entity', 10, e.embedding)
  YIELD node, score
  WITH node, score
  WHERE score > toFLoat($cutoff)
      AND (toLower(node.name) CONTAINS toLower(e.name) OR toLower(e.name) CONTAINS toLower(node.name) OR apoc.text.distance(toLower(node.name), toLower(e.name)) < 5)
  WITH node, score
  ORDER BY node.name
  RETURN collect(node) AS nodes
}
WITH distinct nodes
WHERE size(nodes) > 1
RETURN distinct [n in nodes | n.name] AS duplicates
""", param_map={'cutoff': similarity_threshold})
for row in data:
    print(row)

{'duplicates': ['balance sheet', 'balance sheet report']}
{'duplicates': ['Earnings per Share', 'earnings per share', 'earnings per share (EPS)']}
{'duplicates': ['MetaHuman', 'MetaHuman Animator']}
{'duplicates': ['Vivo X90', 'Vivo X90 Pro', 'Vivo X90s']}
{'duplicates': ['Bank of America', 'Bank of America Corp.']}
{'duplicates': ['banking services', 'corporate banking services']}
{'duplicates': ['investment banking', 'investment banking services']}
{'duplicates': ['Yoto', 'Yoto Player']}
{'duplicates': ['dividend', 'dividend payments', 'dividends']}
{'duplicates': ['shareholder return', 'total shareholder return']}
{'duplicates': ['Star Ocean The Second Story R', 'Star Ocean: The Second Story', 'Star Ocean: The Second Story R', 'Star Ocean: The Second Story Remake']}
{'duplicates': ['Star Ocean First Departure', 'Star Ocean First Departure R', 'Star Ocean: First Departure R']}
{'duplicates': ['JPMorgan', 'JPMorgan Chase & Co.']}
{'duplicates': ['Paytm', 'Paytm wallet']}
{'duplicates'

In [12]:
# Ignore for now
graph_store.structured_query("""
MATCH (e:__Entity__)
CALL {
  WITH e
  CALL db.index.vector.queryNodes('entity', 10, e.embedding)
  YIELD node, score
  WITH node, score
  WHERE score > toFLoat($cutoff)
      AND (toLower(node.name) CONTAINS toLower(e.name) OR toLower(e.name) CONTAINS toLower(node.name) OR apoc.text.distance(toLower(node.name), toLower(e.name)) < 5)
  WITH node, score
  ORDER BY node.name
  RETURN collect(node) AS nodes
}
WITH distinct nodes
WHERE size(nodes) > 1
CALL apoc.refactor.mergeNodes(nodes)
YIELD node
RETURN count(*)
""", param_map={'cutoff': similarity_threshold})

ClientError: {code: Neo.ClientError.Procedure.ProcedureCallFailed} {message: Failed to invoke procedure `apoc.refactor.mergeNodes`: Caused by: org.neo4j.graphdb.NotFoundException: Node 408 not found}

# Retrieval

In [13]:
from llama_index.core.retrievers import CustomPGRetriever, VectorContextRetriever
from llama_index.core.graph_stores import PropertyGraphStore
from llama_index.core.vector_stores.types import VectorStore
from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.prompts import PromptTemplate
from llama_index.core.llms import LLM
from pydantic import BaseModel
from llama_index.program.openai import OpenAIPydanticProgram


from typing import Optional, Any, Union, List, Optional

class Entities(BaseModel):
    """List of named entities in the text such as names of people, organizations, concepts, and locations"""
    names: Optional[List[str]]


prompt_template_entities = """
Extract all named entities such as names of people, organizations, concepts, and locations
from the following text:
{text}
"""

class MyCustomRetriever(CustomPGRetriever):
    """Custom retriever with cohere reranking."""

    def init(
        self,
        ## vector context retriever params
        embed_model: Optional[BaseEmbedding] = None,
        vector_store: Optional[VectorStore] = None,
        similarity_top_k: int = 4,
        path_depth: int = 1,
        include_text: bool = True,
        **kwargs: Any,
    ) -> None:
        """Uses any kwargs passed in from class constructor."""
        self.entity_extraction = OpenAIPydanticProgram.from_defaults(
    output_cls=Entities, prompt_template_str=prompt_template_entities
)
        self.vector_retriever = VectorContextRetriever(
            self.graph_store,
            include_text=self.include_text,
            embed_model=embed_model,
            #vector_store=vector_store,
            similarity_top_k=similarity_top_k,
            path_depth=path_depth,
        )

    def custom_retrieve(self, query_str: str) -> str:
        """Define custom retriever with reranking.

        Could return `str`, `TextNode`, `NodeWithScore`, or a list of those.
        """
        entities = self.entity_extraction(text=query_str).names
        result_nodes = []
        if entities:
            print(f"Detected entities: {entities}")
            for entity in entities:
                result_nodes.extend(self.vector_retriever.retrieve(entity))
        else:
            result_nodes.extend(self.vector_retriever.retrieve(query_str))
        print([t.text for t in result_nodes])
        ## TMP: please change
        final_text = "\n\n".join(
            [n.get_content(metadata_mode="llm") for n in result_nodes]
        )
        return final_text

## Test out the Custom Retriever

Now let's initialize and test out the custom retriever against our data!

To build a full RAG pipeline, we use the `RetrieverQueryEngine` to combine our retriever with the LLM synthesis module - this is also used under the hood for the property graph index.

In [17]:
custom_sub_retriever = MyCustomRetriever(
    index.property_graph_store,
    include_text=False,
    vector_store=index.vector_store,
    embed_model=embed_model
)

In [18]:
from llama_index.core.query_engine import RetrieverQueryEngine

query_engine = RetrieverQueryEngine.from_args(
    index.as_retriever(sub_retrievers=[custom_sub_retriever]), llm=llm
)

### Try out some Queries

In [21]:
response = query_engine.query("What do you know about Maliek Collins or Darragh O’Brien?")
print(str(response))

Detected entities: ['Maliek Collins', "Darragh O'Brien"]
['Maliek Collins -> WORKS_AT -> Houston Texans', 'Maliek Collins -> WORKS_AT -> Las Vegas Raiders', 'Houston Texans -> ACQUISITION -> Maliek Collins', 'Maliek Collins -> WORKS_AT -> Dallas Cowboys', 'Justin Jefferson -> WORKS_AT -> NFL', 'Deepak Maloo -> WORKS_AT -> GE Vernova', 'Patriots -> PARTNERSHIP -> Matthew Judon', 'Darragh O’Brien -> WORKS_AT -> State’s industrial relations process', 'Darragh O’Brien -> WORKS_AT -> Government', 'Darragh O’Brien -> WORKS_AT -> Minister for Housing', 'Seán Ó Fearghaíl -> WORKS_AT -> Ceann Comhairle', 'Pearse Doherty -> WORKS_AT -> Sinn Féin', 'Simon Flannery -> WORKS_AT -> Morgan Stanley']
Maliek Collins has worked for the Houston Texans, Las Vegas Raiders, and Dallas Cowboys. Darragh O’Brien is involved in the State’s industrial relations process, works for the Government, and holds the position of Minister for Housing.
