# Defining a Custom Property Graph Retriever

This guide shows you how to define a custom retriever against a property graph.

It is more involved than using our out-of-the-box graph retrievers, but allows you to have granular control over the retrieval process so that it's better tailored for your application. 

We show you how to define an advanced retrieval flow by directly leveraging the property graph store. We'll execute both vector search and text-to-cypher retrieval, and then combine the results through a reranking module.

In [None]:
%pip install llama-index

## Setup and Build the Property Graph

In [None]:
import nest_asyncio
nest_asyncio.apply()

In [None]:
import os

os.environ["OPENAI_API_KEY"] = "sk-proj-..."

#### Load Paul Graham Essay

In [None]:
!mkdir -p 'data/paul_graham/'
!wget 'https://raw.githubusercontent.com/run-llama/llama_index/main/docs/docs/examples/data/paul_graham/paul_graham_essay.txt' -O 'data/paul_graham/paul_graham_essay.txt'

In [None]:
from llama_index.core import SimpleDirectoryReader

documents = SimpleDirectoryReader("./data/paul_graham/").load_data()

#### Define Default LLMs 

In [None]:
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.llms.openai import OpenAI

llm = OpenAI(model="gpt-3.5-turbo", temperature=0.3)
embed_model = OpenAIEmbedding(model_name="text-embedding-3-small")

#### Build the Property Graph

In [None]:
from llama_index.core import PropertyGraphIndex

index = PropertyGraphIndex.from_documents(
    documents,
    llm=llm,
    embed_model=embed_model,
    show_progress=True,
)

## Define Custom Retriever

Now we define a custom retriever by subclassing `CustomPGRetriever`. 

#### 1. Initialization 
We initialize two pre-existing property graph retrievers: the `VectorContextRetriever` and the `TextToCypherRetriever`, as well as the cohere reranker.

#### 2. Define `custom_retrieve`

We then define the `custom_retrieve` function. It passes nodes through the two retrievers and gets back a final ranked list.

In [None]:
from llama_index.retrievers import (
    CustomPGRetriever, 
    VectorContextRetriever, 
    TextToCypherRetriever
)
from llama_index.core.graph_stores import PropertyGraphStore
from llama_index.core.vector_stores import VectorStore
from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.prompts import PromptTemplate
from llama_index.core.llms import LLM
from llama_index.postprocessor.cohere_rerank import CohereRerank


from typing import Optional, Any, Union


class MyCustomRetriever(CustomPGRetriever):
    """Custom retriever with cohere reranking."""

    def init(
        graph_store: PropertyGraphStore,
        ## vector context retriever params
        include_text: bool = True,
        embed_model: Optional[BaseEmbedding] = None,
        vector_store: Optional[VectorStore] = None,
        similarity_top_k: int = 4,
        path_depth: int = 1,
        ## text-to-cypher params
        llm: Optional[LLM] = None,
        text_to_cypher_template: Optional[Union[PromptTemplate, str]] = None,
        ## cohere reranker params 
        cohere_api_key: Optional[str] = None,
        cohere_top_n: int = 2
        **kwargs: Any,
    ) -> None:
        """Uses any kwargs passed in from class constructor."""
        
        self.vector_retriever = VectorContextRetriever(
            graph_store,
            include_text=include_text,
            embed_model=embed_model,
            vector_store=vector_store,
            similarity_top_k=similarity_top_k,
            path_depth=path_depth
        )
        
        self.cypher_retriever = TextToCypherRetriever(
            graph_store,
            llm=llm,
            text_to_cypher_template=text_to_cypher_template
            ## NOTE: you can attach other parameters here if you'd like 
        )
        
        self.reranker = CohereRerank(api_key=cohere_api_key, top_n=cohere_top_n)

    def custom_retrieve(self, query_str: str) -> str:
        """Define custom retriever with reranking.""" 
        nodes_1 = self.vector_retriever.retrieve(query_str)
        nodes_2 = self.cypher_retriever.retrieve(query_str)
        reranked_nodes = self.reranker.postprocess_nodes(nodes_1 + nodes_2)
        
        
        ## TMP: please change
        final_text = "\n\n".join([n.get_content(metadata_mode="llm") for n in reranked_nodes])
        
        return final_text
        

    # optional async method
    # async def acustom_retrieve(self, query_str: str) -> str:
    #     ...

## Test out the Custom Retriever

Now let's initialize and test out the custom retriever against our data! 

To build a full RAG pipeline, we use the `RetrieverQueryEngine` to combine our retriever with the LLM synthesis module - this is also used under the hood for the property graph index.

In [None]:
custom_retriever = MyCustomRetriever(
    index.property_graph_store,
    include_text=True,
    vector_store=index.vector_store
)

In [None]:
from llama_index.core.query_engine import RetrieverQueryEngine

query_engine = RetrieverQueryEngine(
    custom_retriever,
    llm=llm
)

#### Try out a 'baseline'

We compare against a baseline retriever that's the vector context only.

In [None]:
base_retriever = VectorContextRetriever(index.graph_store, include_text=True)
base_query_engine = index.as_query_engine(
    sub_retrievers=[base_retriever]
)

### Try out some Queries

In [None]:
query_engine.query("What happened at Interleaf and Viaweb?")

In [None]:
base_query_engine.query("What happened at Interleaf and Viaweb?") 