Skip to content

Commit

Permalink
feat: Retriever component for documents stored in Neo4j
Browse files Browse the repository at this point in the history
Includes both a standard embeddings retriever as well as more advanced retriever based on plain Cypher queries
  • Loading branch information
prosto committed Jan 15, 2024
1 parent f95f5c8 commit b411ebc
Show file tree
Hide file tree
Showing 3 changed files with 388 additions and 1 deletion.
12 changes: 11 additions & 1 deletion src/neo4j_haystack/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,17 @@
from neo4j_haystack.components.neo4j_retriever import (
Neo4jDocumentRetriever,
Neo4jDynamicDocumentRetriever,
)
from neo4j_haystack.document_stores import (
Neo4jClient,
Neo4jClientConfig,
Neo4jDocumentStore,
)

__all__ = ("Neo4jDocumentStore", "Neo4jClient", "Neo4jClientConfig")
__all__ = (
"Neo4jDocumentStore",
"Neo4jClient",
"Neo4jClientConfig",
"Neo4jDocumentRetriever",
"Neo4jDynamicDocumentRetriever",
)
Empty file.
377 changes: 377 additions & 0 deletions src/neo4j_haystack/components/neo4j_retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,377 @@
from typing import Any, Dict, List, Optional, cast

from haystack import (
ComponentError,
Document,
component,
default_from_dict,
default_to_dict,
)

from neo4j_haystack.document_stores import Neo4jDocumentStore
from neo4j_haystack.document_stores.neo4j_client import Neo4jClient, Neo4jClientConfig


@component
class Neo4jDocumentRetriever:
"""
A component for retrieving documents from Neo4jDocumentStore.
```py title="Retrieving documents assuming documents have been previously indexed"
from haystack import Document, Pipeline
from haystack.components.embedders import SentenceTransformersTextEmbedder
from neo4j_haystack import Neo4jDocumentStore, Neo4jDocumentRetriever
model_name = "sentence-transformers/all-MiniLM-L6-v2"
# Document store with default credentials
document_store = Neo4jDocumentStore(
url="bolt://localhost:7687",
embedding_dim=384, # same as the embedding model
)
pipeline = Pipeline()
pipeline.add_component("text_embedder", SentenceTransformersTextEmbedder(model_name_or_path=model_name))
pipeline.add_component("retriever", Neo4jDocumentRetriever(document_store=document_store))
pipeline.connect("text_embedder.embedding", "retriever.query_embedding")
result = pipeline.run(
data={
"text_embedder": {"text": "Query to be embedded"},
"retriever": {
"top_k": 5,
"filters": {"field": "release_date", "operator": "==", "value": "2018-12-09"},
},
}
)
# Obtain retrieved documents from pipeline execution
documents: List[Document] = result["retriever"]["documents"]
```
"""

def __init__(
self,
document_store: Neo4jDocumentStore,
filters: Optional[Dict[str, Any]] = None,
top_k: int = 10,
scale_score: bool = True,
return_embedding: bool = False,
):
"""
Create a Neo4jDocumentRetriever component.
Args:
document_store: An instance of `Neo4jDocumentStore`.
filters: A dictionary with filters to narrow down the search space.
top_k: The maximum number of documents to retrieve.
scale_score: Whether to scale the scores of the retrieved documents or not.
return_embedding: Whether to return the embedding of the retrieved Documents.
Raises:
ValueError: If `document_store` is not an instance of `Neo4jDocumentStore`.
"""

if not isinstance(document_store, Neo4jDocumentStore):
msg = "document_store must be an instance of Neo4jDocumentStore"
raise ValueError(msg)

self._document_store = document_store

self._filters = filters
self._top_k = top_k
self._scale_score = scale_score
self._return_embedding = return_embedding

def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
"""
data = default_to_dict(
self,
document_store=self._document_store,
filters=self._filters,
top_k=self._top_k,
scale_score=self._scale_score,
return_embedding=self._return_embedding,
)
data["init_parameters"]["document_store"] = self._document_store.to_dict()

return data

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "Neo4jDocumentRetriever":
"""
Deserialize this component from a dictionary.
"""
document_store = Neo4jDocumentStore.from_dict(data["init_parameters"]["document_store"])
data["init_parameters"]["document_store"] = document_store
return default_from_dict(cls, data)

@component.output_types(documents=List[Document])
def run(
self,
query_embedding: List[float],
filters: Optional[Dict[str, Any]] = None,
top_k: Optional[int] = None,
scale_score: Optional[bool] = None,
return_embedding: Optional[bool] = None,
):
"""
Run the Embedding Retriever on the given input data.
Args:
query_embedding: Embedding of the query.
filters: A dictionary with filters to narrow down the search space.
top_k: The maximum number of documents to return.
scale_score: Whether to scale the scores of the retrieved documents or not.
return_embedding: Whether to return the embedding of the retrieved Documents.
Returns:
The retrieved documents.
"""
docs = self._document_store.query_by_embedding(
query_embedding=query_embedding,
filters=filters or self._filters,
top_k=top_k or self._top_k,
scale_score=scale_score or self._scale_score,
return_embedding=return_embedding or self._return_embedding,
)

return {"documents": docs}


@component
class Neo4jDynamicDocumentRetriever:
"""
A component for retrieving Documents from Neo4j database using plain Cypher query.
This component gives flexible way to retrieve data from Neo4j by running arbitrary Cypher query along with query
parameters. Query parameters can be supplied in a pipeline from other components (or pipeline data).
See the following documentation on how to compose Cypher queries with parameters:
- [Overview of Cypher query syntax](https://neo4j.com/docs/cypher-manual/current/queries/)
- [Cypher Query Parameters](https://neo4j.com/docs/cypher-manual/current/syntax/parameters/)
Above are resources which will help understand better Cypher query syntax and parameterization. Under the hood
[Neo4j Python Driver](https://neo4j.com/docs/python-manual/current/) is used to query database and fetch results.
You might be interested in the following documentation:
- [Query the database](https://neo4j.com/docs/python-manual/current/query-simple/)
- [Query parameters](https://neo4j.com/docs/python-manual/current/query-simple/#query-parameters)
- [Data types and mapping to Cypher types](https://neo4j.com/docs/python-manual/current/data-types/)
Note:
Please consider data types mappings in Cypher query when working with parameters. Neo4j Python Driver handles
type conversions/mappings. Specifically you can figure out in the documentation of the driver how to work with
temporal types (e.g. `DateTime`).
Query execution results will be mapped/converted to `haystack.Document` type. See more details in the
[RETURN clause](https://neo4j.com/docs/cypher-manual/current/clauses/return/) documentation. There are two
ways how Documents are being composed from query results.
(1) Converting documents from [nodes](https://neo4j.com/docs/cypher-manual/current/clauses/return/#return-nodes)
```py title="Convert Neo4j `node` to `haystack.Document`"
client_config = Neo4jClientConfig(
"bolt://localhost:7687", database="neo4j", username="neo4j", password="passw0rd"
)
retriever = Neo4jDynamicDocumentRetriever(
client_config=client_config, doc_node_name="doc", verify_connectivity=True
)
result = retriever.run(
query="MATCH (doc:Document) WHERE doc.year > $year OR doc.year is NULL RETURN doc",
parameters={"year": 2020}
)
documents: List[Document] = result["documents"]
```
Please notice how `doc_node_name` attribute assumes `"doc"` node is going to be returned from the query.
`Neo4jDynamicDocumentRetriever` will convert properties of the node (e.g. `id`, `content` etc) to
`haystack.Document` type.
(2) Converting documents from query output keys (e.g. column aliases)
You might want to run a complex query which aggregates information from multiple sources (nodes) in Neo4j. In such
case you might want to compose final Document from
```py title="Convert Neo4j `node` to `haystack.Document`"
# Configuration with default settings
client_config=Neo4jClientConfig()
retriever = Neo4jDynamicDocumentRetriever(client_config=client_config, compose_doc_from_result=True)
result = retriever.run(
query=(
"MATCH (doc:Document) "
"WHERE doc.year > $year OR doc.year is NULL "
"RETURN doc.id as id, doc.content as content, doc.year as year"
),
parameters={"year": 2020},
)
documents: List[Document] = result["documents"]
```
The above will produce Documents with `id`, `content` and `year`(meta) fields. Please notice
`compose_doc_from_result` is set to `True` to enable such Document construction behavior.
Below is an example of a pipeline which explores all ways how parameters could be supplied to the
`Neo4jDynamicDocumentRetriever` component in the pipeline.
```py
@component
class YearProvider:
@component.output_types(year_start=int, year_end=int)
def run(self, year_start: int, year_end: int):
return {"year_start": year_start, "year_end": year_end}
# Configuration with default settings
client_config=Neo4jClientConfig()
retriever = Neo4jDynamicDocumentRetriever(
client_config=client_config,
runtime_parameters=["year_start", "year_end"],
)
query = (
"MATCH (doc:Document) "
"WHERE (doc.year >= $year_start and doc.year <= $year_end) AND doc.month = $month"
"RETURN doc LIMIT $num_return"
)
pipeline = Pipeline()
pipeline.add_component("year_provider", YearProvider())
pipeline.add_component("retriever", retriever)
pipeline.connect("year_provider.year_start", "retriever.year_start")
pipeline.connect("year_provider.year_end", "retriever.year_end")
result = pipeline.run(
data={
"year_provider": {"year_start": 2020, "year_end": 2021},
"retriever": {
"query": query,
"parameters": {
"month": "02",
"num_return": 2,
},
},
}
)
documents = result["retriever"]["documents"]
```
Please notice the following from the example above:
- `runtime_parameters` is a list of parameter names which are going to be input slots when connecting components
in a pipeline. In our case `year_start` and `year_end` parameters flow from the `year_provider` component into
`retriever`. The `query` uses those parameters in the `WHERE` clause.
- `pipeline.run` specifies additional parameters to the `retriever` component which can be referenced in the
`query`. If parameter names clash those provided in the pipeline's data take precedence.
"""

def __init__(
self,
client_config: Neo4jClientConfig,
runtime_parameters: Optional[List[str]] = None,
doc_node_name: Optional[str] = "doc",
compose_doc_from_result: Optional[bool] = False,
verify_connectivity: Optional[bool] = False,
):
"""
Create a Neo4jDynamicDocumentRetriever component.
Args:
client_config: Neo4j client configuration to connect to database (e.g. credentials and connection settings).
runtime_parameters: list of input parameters/slots for connecting components in a pipeline.
doc_node_name: the name of the variable which is returned from Cypher query which contains Document
attributes (e.g. `id`, `content`, `meta` fields).
compose_doc_from_result: If `True` Document attributes will be constructed from Cypher query outputs (keys).
`doc_node_name` setting will be ignored in this case.
verify_connectivity: If `True` will verify connectivity with Neo4j database configured by `client_config`.
Raises:
ComponentError: In case neither `compose_doc_from_result` nor `doc_node_name` are defined.
"""
if not compose_doc_from_result and not doc_node_name:
raise ComponentError(
"Please specify how Document is being composed out of Neo4j query response. "
"With `compose_doc_from_result` set to `True` documents will be created out of properties/keys "
"returned by the query."
)

self._client_config = client_config
self._runtime_parameters = runtime_parameters or []
self._doc_node_name = doc_node_name
self._compose_doc_from_result = compose_doc_from_result
self._verify_connectivity = verify_connectivity

self._neo4j_client = Neo4jClient(client_config)

# setup inputs
run_input_slots = {"query": str, "parameters": Optional[Dict[str, Any]]}
kwargs_input_slots = {param: Optional[Any] for param in self._runtime_parameters}
component.set_input_types(self, **run_input_slots, **kwargs_input_slots)

# setup outputs
component.set_output_types(self, documents=List[Document])

if verify_connectivity:
self._neo4j_client.verify_connectivity()

def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
"""
data = default_to_dict(
self,
runtime_parameters=self._runtime_parameters,
doc_node_name=self._doc_node_name,
compose_doc_from_result=self._compose_doc_from_result,
verify_connectivity=self._verify_connectivity,
)

data["init_parameters"]["client_config"] = self._client_config.to_dict()

return data

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "Neo4jDynamicDocumentRetriever":
"""
Deserialize this component from a dictionary.
"""
client_config = Neo4jClientConfig.from_dict(data["init_parameters"]["client_config"])
data["init_parameters"]["client_config"] = client_config
return default_from_dict(cls, data)

def run(self, query: str, parameters: Optional[Dict[str, Any]] = None, **kwargs):
"""
Runs the arbitrary Cypher `query` with `parameters` and returns Documents.
Args:
query: Cypher query to run.
parameters: Cypher query parameters which can be used as placeholders in the `query`.
kwargs: Arbitrary parameters supplied in a pipeline execution from other component's output slots, e.g.
`pipeline.connect("year_provider.year_start", "retriever.year_start")`, where `year_start` will be part
of `kwargs`.
Returns:
Retrieved documents.
"""
kwargs = kwargs or {}
parameters = parameters or {}
parameters_combined = {**kwargs, **parameters}

documents: List[Document] = []
neo4j_query_result = self._neo4j_client.query_nodes(query, parameters_combined)

for record in neo4j_query_result:
data = record.data()
document_dict = data if self._compose_doc_from_result else data.get(cast(str, self._doc_node_name))
documents.append(Document.from_dict(document_dict))

return {"documents": documents}

0 comments on commit b411ebc

Please sign in to comment.