Skip to content

Commit

Permalink
refactor: Change name of retriever component as per documented naming…
Browse files Browse the repository at this point in the history
… convention
  • Loading branch information
prosto committed Jan 15, 2024
1 parent a7e3bf1 commit f79a952
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 17 deletions.
4 changes: 2 additions & 2 deletions examples/rag_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from haystack.components.embedders import SentenceTransformersTextEmbedder
from haystack.components.generators import HuggingFaceTGIGenerator

from neo4j_haystack import Neo4jDocumentRetriever, Neo4jDocumentStore
from neo4j_haystack import Neo4jDocumentStore, Neo4jEmbeddingRetriever

# Load HF Token from environment variables.
HF_TOKEN = os.environ.get("HF_TOKEN")
Expand Down Expand Up @@ -46,7 +46,7 @@
"query_embedder",
SentenceTransformersTextEmbedder(model_name_or_path="sentence-transformers/all-MiniLM-L6-v2", progress_bar=False),
)
rag_pipeline.add_component("retriever", Neo4jDocumentRetriever(document_store=document_store))
rag_pipeline.add_component("retriever", Neo4jEmbeddingRetriever(document_store=document_store))
rag_pipeline.add_component("prompt_builder", PromptBuilder(template=prompt_template))
rag_pipeline.add_component(
"llm",
Expand Down
6 changes: 3 additions & 3 deletions src/neo4j_haystack/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from neo4j_haystack.components.neo4j_retriever import (
Neo4jDocumentRetriever,
from neo4j_haystack.components import (
Neo4jDynamicDocumentRetriever,
Neo4jEmbeddingRetriever,
)
from neo4j_haystack.document_stores import (
Neo4jClient,
Expand All @@ -12,6 +12,6 @@
"Neo4jDocumentStore",
"Neo4jClient",
"Neo4jClientConfig",
"Neo4jDocumentRetriever",
"Neo4jEmbeddingRetriever",
"Neo4jDynamicDocumentRetriever",
)
9 changes: 9 additions & 0 deletions src/neo4j_haystack/components/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from neo4j_haystack.components.neo4j_retriever import (
Neo4jDynamicDocumentRetriever,
Neo4jEmbeddingRetriever,
)

__all__ = (
"Neo4jEmbeddingRetriever",
"Neo4jDynamicDocumentRetriever",
)
10 changes: 5 additions & 5 deletions src/neo4j_haystack/components/neo4j_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@


@component
class Neo4jDocumentRetriever:
class Neo4jEmbeddingRetriever:
"""
A component for retrieving documents from Neo4jDocumentStore.
```py title="Retrieving documents assuming documents have been previously indexed"
from haystack import Document, Pipeline
from haystack.components.embedders import SentenceTransformersTextEmbedder
from neo4j_haystack import Neo4jDocumentStore, Neo4jDocumentRetriever
from neo4j_haystack import Neo4jDocumentStore, Neo4jEmbeddingRetriever
model_name = "sentence-transformers/all-MiniLM-L6-v2"
Expand All @@ -33,7 +33,7 @@ class Neo4jDocumentRetriever:
pipeline = Pipeline()
pipeline.add_component("text_embedder", SentenceTransformersTextEmbedder(model_name_or_path=model_name))
pipeline.add_component("retriever", Neo4jDocumentRetriever(document_store=document_store))
pipeline.add_component("retriever", Neo4jEmbeddingRetriever(document_store=document_store))
pipeline.connect("text_embedder.embedding", "retriever.query_embedding")
result = pipeline.run(
Expand All @@ -60,7 +60,7 @@ def __init__(
return_embedding: bool = False,
):
"""
Create a Neo4jDocumentRetriever component.
Create a Neo4jEmbeddingRetriever component.
Args:
document_store: An instance of `Neo4jDocumentStore`.
Expand Down Expand Up @@ -101,7 +101,7 @@ def to_dict(self) -> Dict[str, Any]:
return data

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "Neo4jDocumentRetriever":
def from_dict(cls, data: Dict[str, Any]) -> "Neo4jEmbeddingRetriever":
"""
Deserialize this component from a dictionary.
"""
Expand Down
14 changes: 7 additions & 7 deletions tests/neo4j_haystack/components/test_neo4j_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest
from haystack import Document

from neo4j_haystack.components.neo4j_retriever import Neo4jDocumentRetriever
from neo4j_haystack.components.neo4j_retriever import Neo4jEmbeddingRetriever
from neo4j_haystack.document_stores.neo4j_store import Neo4jDocumentStore


Expand All @@ -21,7 +21,7 @@ def movie_document_store(


def test_retrieve_documents(movie_document_store: Neo4jDocumentStore, text_embedder: Callable[[str], List[float]]):
retriever = Neo4jDocumentRetriever(document_store=movie_document_store)
retriever = Neo4jEmbeddingRetriever(document_store=movie_document_store)

query_embedding = text_embedder(
"A young fella pretending to be a good citizen but actually planning to commit a crime"
Expand All @@ -42,7 +42,7 @@ def test_retrieve_documents(movie_document_store: Neo4jDocumentStore, text_embed
def test_retrieve_documents_with_filters(
movie_document_store: Neo4jDocumentStore, text_embedder: Callable[[str], List[float]]
):
retriever = Neo4jDocumentRetriever(document_store=movie_document_store)
retriever = Neo4jEmbeddingRetriever(document_store=movie_document_store)

query_embedding = text_embedder(
"A young fella pretending to be a good citizen but actually planning to commit a crime"
Expand All @@ -61,7 +61,7 @@ def test_retriever_to_dict():
doc_store = mock.create_autospec(Neo4jDocumentStore)
doc_store.to_dict.return_value = {"ds": "yes"}

retriever = Neo4jDocumentRetriever(
retriever = Neo4jEmbeddingRetriever(
document_store=doc_store,
filters={"field": "num", "operator": ">", "value": 10},
top_k=11,
Expand All @@ -71,7 +71,7 @@ def test_retriever_to_dict():
data = retriever.to_dict()

assert data == {
"type": "neo4j_haystack.components.neo4j_retriever.Neo4jDocumentRetriever",
"type": "neo4j_haystack.components.neo4j_retriever.Neo4jEmbeddingRetriever",
"init_parameters": {
"document_store": {"ds": "yes"},
"filters": {"field": "num", "operator": ">", "value": 10},
Expand All @@ -86,7 +86,7 @@ def test_retriever_to_dict():
@mock.patch.object(Neo4jDocumentStore, "from_dict")
def test_retriever_from_dict(from_dict_mock):
data = {
"type": "neo4j_haystack.components.neo4j_retriever.Neo4jDocumentRetriever",
"type": "neo4j_haystack.components.neo4j_retriever.Neo4jEmbeddingRetriever",
"init_parameters": {
"document_store": {"ds": "yes"},
"filters": {"field": "num", "operator": ">", "value": 10},
Expand All @@ -98,7 +98,7 @@ def test_retriever_from_dict(from_dict_mock):
doc_store = mock.create_autospec(Neo4jDocumentStore)
from_dict_mock.return_value = doc_store

retriever = Neo4jDocumentRetriever.from_dict(data)
retriever = Neo4jEmbeddingRetriever.from_dict(data)

assert retriever._document_store == doc_store
assert retriever._filters == {"field": "num", "operator": ">", "value": 10}
Expand Down

0 comments on commit f79a952

Please sign in to comment.