forked from langchain-ai/langchain
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Neo4j Advanced RAG template (langchain-ai#12794)
Todo: - [x] Docs
- Loading branch information
Showing
10 changed files
with
1,887 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
# neo4j-advanced-rag | ||
|
||
This template allows you to balance precise embeddings and context retention by implementing advanced retrieval strategies. | ||
|
||
## Strategies | ||
|
||
1. **Typical RAG**: | ||
- Traditional method where the exact data indexed is the data retrieved. | ||
2. **Parent retriever**: | ||
- Instead of indexing entire documents, data is divided into smaller chunks, referred to as Parent and Child documents. | ||
- Child documents are indexed for better representation of specific concepts, while parent documents is retrieved to ensure context retention. | ||
3. **Hypothetical Questions**: | ||
- Documents are processed to determine potential questions they might answer. | ||
- These questions are then indexed for better representation of specific concepts, while parent documents are retrieved to ensure context retention. | ||
4. **Summaries**: | ||
- Instead of indexing the entire document, a summary of the document is created and indexed. | ||
- Similarly, the parent document is retrieved in a RAG application. | ||
|
||
## Environment Setup | ||
|
||
You need to define the following environment variables | ||
|
||
``` | ||
OPENAI_API_KEY=<YOUR_OPENAI_API_KEY> | ||
NEO4J_URI=<YOUR_NEO4J_URI> | ||
NEO4J_USERNAME=<YOUR_NEO4J_USERNAME> | ||
NEO4J_PASSWORD=<YOUR_NEO4J_PASSWORD> | ||
``` | ||
|
||
## Populating with data | ||
|
||
If you want to populate the DB with some example data, you can run `python ingest.py`. | ||
The script process and stores sections of the text from the file `dune.txt` into a Neo4j graph database. | ||
First, the text is divided into larger chunks ("parents") and then further subdivided into smaller chunks ("children"), where both parent and child chunks overlap slightly to maintain context. | ||
After storing these chunks in the database, embeddings for the child nodes are computed using OpenAI's embeddings and stored back in the graph for future retrieval or analysis. | ||
For every parent node, hypothetical questions and summaries are generated, embedded, and added to the database. | ||
Additionally, a vector index for each retrieval strategy is created for efficient querying of these embeddings. | ||
|
||
*Note that ingestion can take a minute or two due to LLMs velocity of generating hypothetical questions and summaries.* | ||
|
||
## Usage | ||
|
||
To use this package, you should first have the LangChain CLI installed: | ||
|
||
```shell | ||
pip install -U "langchain-cli[serve]" | ||
``` | ||
|
||
To create a new LangChain project and install this as the only package, you can do: | ||
|
||
```shell | ||
langchain app new my-app --package neo4j-advanced-rag | ||
``` | ||
|
||
If you want to add this to an existing project, you can just run: | ||
|
||
```shell | ||
langchain app add neo4j-advanced-rag | ||
``` | ||
|
||
And add the following code to your `server.py` file: | ||
```python | ||
from neo4j_advanced_rag import chain as neo4j_advanced_chain | ||
|
||
add_routes(app, neo4j_advanced_chain, path="/neo4j-advanced-rag") | ||
``` | ||
|
||
(Optional) Let's now configure LangSmith. | ||
LangSmith will help us trace, monitor and debug LangChain applications. | ||
LangSmith is currently in private beta, you can sign up [here](https://smith.langchain.com/). | ||
If you don't have access, you can skip this section | ||
|
||
```shell | ||
export LANGCHAIN_TRACING_V2=true | ||
export LANGCHAIN_API_KEY=<your-api-key> | ||
export LANGCHAIN_PROJECT=<your-project> # if not specified, defaults to "default" | ||
``` | ||
|
||
If you are inside this directory, then you can spin up a LangServe instance directly by: | ||
|
||
```shell | ||
langchain serve | ||
``` | ||
|
||
This will start the FastAPI app with a server is running locally at | ||
[http://localhost:8000](http://localhost:8000) | ||
|
||
We can see all templates at [http://127.0.0.1:8000/docs](http://127.0.0.1:8000/docs) | ||
We can access the playground at [http://127.0.0.1:8000/neo4j-advanced-rag/playground](http://127.0.0.1:8000/neo4j-advanced-rag/playground) | ||
|
||
We can access the template from code with: | ||
|
||
```python | ||
from langserve.client import RemoteRunnable | ||
|
||
runnable = RemoteRunnable("http://localhost:8000/neo4j-advanced-rag") | ||
``` |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,203 @@ | ||
from pathlib import Path | ||
from typing import List | ||
|
||
from langchain.chains.openai_functions import create_structured_output_chain | ||
from langchain.chat_models import ChatOpenAI | ||
from langchain.document_loaders import TextLoader | ||
from langchain.embeddings.openai import OpenAIEmbeddings | ||
from langchain.graphs import Neo4jGraph | ||
from langchain.prompts import ChatPromptTemplate | ||
from langchain.pydantic_v1 import BaseModel, Field | ||
from langchain.text_splitter import TokenTextSplitter | ||
from neo4j.exceptions import ClientError | ||
|
||
txt_path = Path(__file__).parent / "dune.txt" | ||
|
||
graph = Neo4jGraph() | ||
|
||
# Embeddings & LLM models | ||
embeddings = OpenAIEmbeddings() | ||
embedding_dimension = 1536 | ||
llm = ChatOpenAI(temperature=0) | ||
|
||
# Load the text file | ||
loader = TextLoader(str(txt_path)) | ||
documents = loader.load() | ||
|
||
# Ingest Parent-Child node pairs | ||
parent_splitter = TokenTextSplitter(chunk_size=512, chunk_overlap=24) | ||
child_splitter = TokenTextSplitter(chunk_size=100, chunk_overlap=24) | ||
parent_documents = parent_splitter.split_documents(documents) | ||
|
||
for i, parent in enumerate(parent_documents): | ||
child_documents = child_splitter.split_documents([parent]) | ||
params = { | ||
"parent_text": parent.page_content, | ||
"parent_id": i, | ||
"parent_embedding": embeddings.embed_query(parent.page_content), | ||
"children": [ | ||
{ | ||
"text": c.page_content, | ||
"id": f"{i}-{ic}", | ||
"embedding": embeddings.embed_query(c.page_content), | ||
} | ||
for ic, c in enumerate(child_documents) | ||
], | ||
} | ||
# Ingest data | ||
graph.query( | ||
""" | ||
MERGE (p:Parent {id: $parent_id}) | ||
SET p.text = $parent_text | ||
WITH p | ||
CALL db.create.setVectorProperty(p, 'embedding', $parent_embedding) | ||
YIELD node | ||
WITH p | ||
UNWIND $children AS child | ||
MERGE (c:Child {id: child.id}) | ||
SET c.text = child.text | ||
MERGE (c)<-[:HAS_CHILD]-(p) | ||
WITH c, child | ||
CALL db.create.setVectorProperty(c, 'embedding', child.embedding) | ||
YIELD node | ||
RETURN count(*) | ||
""", | ||
params, | ||
) | ||
# Create vector index for child | ||
try: | ||
graph.query( | ||
"CALL db.index.vector.createNodeIndex('parent_document', " | ||
"'Child', 'embedding', $dimension, 'cosine')", | ||
{"dimension": embedding_dimension}, | ||
) | ||
except ClientError: # already exists | ||
pass | ||
# Create vector index for parents | ||
try: | ||
graph.query( | ||
"CALL db.index.vector.createNodeIndex('typical_rag', " | ||
"'Parent', 'embedding', $dimension, 'cosine')", | ||
{"dimension": embedding_dimension}, | ||
) | ||
except ClientError: # already exists | ||
pass | ||
# Ingest hypothethical questions | ||
|
||
|
||
class Questions(BaseModel): | ||
"""Generating hypothetical questions about text.""" | ||
|
||
questions: List[str] = Field( | ||
..., | ||
description=( | ||
"Generated hypothetical questions based on " "the information from the text" | ||
), | ||
) | ||
|
||
|
||
questions_prompt = ChatPromptTemplate.from_messages( | ||
[ | ||
( | ||
"system", | ||
( | ||
"You are generating hypothetical questions based on the information " | ||
"found in the text. Make sure to provide full context in the generated " | ||
"questions." | ||
), | ||
), | ||
( | ||
"human", | ||
( | ||
"Use the given format to generate hypothetical questions from the " | ||
"following input: {input}" | ||
), | ||
), | ||
] | ||
) | ||
|
||
question_chain = create_structured_output_chain(Questions, llm, questions_prompt) | ||
|
||
for i, parent in enumerate(parent_documents): | ||
questions = question_chain.run(parent.page_content).questions | ||
params = { | ||
"parent_id": i, | ||
"questions": [ | ||
{"text": q, "id": f"{i}-{iq}", "embedding": embeddings.embed_query(q)} | ||
for iq, q in enumerate(questions) | ||
if q | ||
], | ||
} | ||
graph.query( | ||
""" | ||
MERGE (p:Parent {id: $parent_id}) | ||
WITH p | ||
UNWIND $questions AS question | ||
CREATE (q:Question {id: question.id}) | ||
SET q.text = question.text | ||
MERGE (q)<-[:HAS_QUESTION]-(p) | ||
WITH q, question | ||
CALL db.create.setVectorProperty(q, 'embedding', question.embedding) | ||
YIELD node | ||
RETURN count(*) | ||
""", | ||
params, | ||
) | ||
# Create vector index | ||
try: | ||
graph.query( | ||
"CALL db.index.vector.createNodeIndex('hypothetical_questions', " | ||
"'Question', 'embedding', $dimension, 'cosine')", | ||
{"dimension": embedding_dimension}, | ||
) | ||
except ClientError: # already exists | ||
pass | ||
|
||
# Ingest summaries | ||
|
||
summary_prompt = ChatPromptTemplate.from_messages( | ||
[ | ||
( | ||
"system", | ||
( | ||
"You are generating concise and accurate summaries based on the " | ||
"information found in the text." | ||
), | ||
), | ||
( | ||
"human", | ||
("Generate a summary of the following input: {question}\n" "Summary:"), | ||
), | ||
] | ||
) | ||
|
||
summary_chain = summary_prompt | llm | ||
|
||
for i, parent in enumerate(parent_documents): | ||
summary = summary_chain.invoke({"question": parent.page_content}).content | ||
params = { | ||
"parent_id": i, | ||
"summary": summary, | ||
"embedding": embeddings.embed_query(summary), | ||
} | ||
graph.query( | ||
""" | ||
MERGE (p:Parent {id: $parent_id}) | ||
MERGE (p)-[:HAS_SUMMARY]->(s:Summary) | ||
SET s.text = $summary | ||
WITH s | ||
CALL db.create.setVectorProperty(s, 'embedding', $embedding) | ||
YIELD node | ||
RETURN count(*) | ||
""", | ||
params, | ||
) | ||
# Create vector index | ||
try: | ||
graph.query( | ||
"CALL db.index.vector.createNodeIndex('summary', " | ||
"'Summary', 'embedding', $dimension, 'cosine')", | ||
{"dimension": embedding_dimension}, | ||
) | ||
except ClientError: # already exists | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
from neo4j_advanced_rag.chain import chain | ||
|
||
if __name__ == "__main__": | ||
original_query = "What is the plot of the Dune?" | ||
print( | ||
chain.invoke( | ||
{"question": original_query}, | ||
{"configurable": {"strategy": "parent_document"}}, | ||
) | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from neo4j_advanced_rag.chain import chain | ||
|
||
__all__ = ["chain"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
from operator import itemgetter | ||
|
||
from langchain.chat_models import ChatOpenAI | ||
from langchain.prompts import ChatPromptTemplate | ||
from langchain.pydantic_v1 import BaseModel | ||
from langchain.schema.output_parser import StrOutputParser | ||
from langchain.schema.runnable import ConfigurableField, RunnableParallel | ||
|
||
from neo4j_advanced_rag.retrievers import ( | ||
hypothetic_question_vectorstore, | ||
parent_vectorstore, | ||
summary_vectorstore, | ||
typical_rag, | ||
) | ||
|
||
template = """Answer the question based only on the following context: | ||
{context} | ||
Question: {question} | ||
""" | ||
prompt = ChatPromptTemplate.from_template(template) | ||
|
||
model = ChatOpenAI() | ||
|
||
retriever = typical_rag.as_retriever().configurable_alternatives( | ||
ConfigurableField(id="strategy"), | ||
default_key="typical_rag", | ||
parent_strategy=parent_vectorstore.as_retriever(), | ||
hypothetical_questions=hypothetic_question_vectorstore.as_retriever(), | ||
summary_strategy=summary_vectorstore.as_retriever(), | ||
) | ||
|
||
chain = ( | ||
RunnableParallel( | ||
{ | ||
"context": itemgetter("question") | retriever, | ||
"question": itemgetter("question"), | ||
} | ||
) | ||
| prompt | ||
| model | ||
| StrOutputParser() | ||
) | ||
|
||
|
||
# Add typing for input | ||
class Question(BaseModel): | ||
question: str | ||
|
||
|
||
chain = chain.with_types(input_type=Question) |
Oops, something went wrong.