Skip to content

Commit

Permalink
fix: Return number of written documents from write_documents as per…
Browse files Browse the repository at this point in the history
… Protocol for document stores

- Using constants for default neo4j client configuration
- Introducing `verify_connectivity` argument to control neo4j connectivity checks if needed
  • Loading branch information
prosto committed Feb 8, 2024
1 parent 88e89bb commit f421ed5
Showing 1 changed file with 31 additions and 12 deletions.
43 changes: 31 additions & 12 deletions src/neo4j_haystack/document_stores/neo4j_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@
from tqdm import tqdm

from neo4j_haystack.client import Neo4jClient, Neo4jClientConfig, Neo4jRecord
from neo4j_haystack.client.neo4j_client import (
DEFAULT_NEO4J_DATABASE,
DEFAULT_NEO4J_PASSWORD,
DEFAULT_NEO4J_URI,
DEFAULT_NEO4J_USERNAME,
)
from neo4j_haystack.document_stores.utils import get_batches_from_generator
from neo4j_haystack.metadata_filter import COMPARISON_OPS, FilterParser, OperatorAST

Expand Down Expand Up @@ -112,10 +118,10 @@ class Neo4jDocumentStore:

def __init__(
self,
url: str,
database: Optional[str] = "neo4j",
username: Optional[str] = None,
password: Optional[str] = None,
url: Optional[str] = DEFAULT_NEO4J_URI,
database: Optional[str] = DEFAULT_NEO4J_DATABASE,
username: Optional[str] = DEFAULT_NEO4J_USERNAME,
password: Optional[str] = DEFAULT_NEO4J_PASSWORD,
client_config: Optional[Neo4jClientConfig] = None,
index: str = "document-embeddings",
node_label: str = "Document",
Expand All @@ -126,6 +132,7 @@ def __init__(
create_index_if_missing: Optional[bool] = True,
recreate_index: Optional[bool] = False,
write_batch_size: int = 100,
verify_connectivity: Optional[bool] = True,
):
"""
Constructor method
Expand Down Expand Up @@ -157,14 +164,15 @@ def __init__(
index. Useful for testing purposes when a new DocumentStore initializes with a clean database state.
write_batch_size: Number of documents to write at once. When working with large number of documents
batching can help reduce memory footprint.
verify_connectivity: If `True` will check connection to the database using provided credentials during
creation of the Document Store.
Raises:
ValueError: In case similarity function specified is not supported
"""

super().__init__()

self.url = url
self.index = index
self.node_label = node_label
self.embedding_dim = embedding_dim
Expand All @@ -177,14 +185,17 @@ def __init__(
self.create_index_if_missing = create_index_if_missing
self.recreate_index = recreate_index
self.write_batch_size = write_batch_size
self.verify_connectivity = verify_connectivity

self.filter_parser = FilterParser()

if client_config and not client_config.url:
client_config.url = url
self.client_config = client_config or Neo4jClientConfig(url, database, username, password)
self.neo4j_client = Neo4jClient(self.client_config)
self.neo4j_client.verify_connectivity()

if verify_connectivity:
self.neo4j_client.verify_connectivity()

if recreate_index:
self.delete_index()
Expand All @@ -209,6 +220,7 @@ def to_dict(self) -> Dict[str, Any]:
create_index_if_missing=self.create_index_if_missing,
recreate_index=self.recreate_index,
write_batch_size=self.write_batch_size,
verify_connectivity=self.verify_connectivity,
)

data["init_parameters"]["client_config"] = self.client_config.to_dict()
Expand All @@ -222,7 +234,6 @@ def from_dict(cls, data: Dict[str, Any]) -> "Neo4jDocumentStore":
"""
client_config = Neo4jClientConfig.from_dict(data["init_parameters"]["client_config"])
data["init_parameters"]["client_config"] = client_config
data["init_parameters"]["url"] = client_config.url

return default_from_dict(cls, data)

Expand All @@ -249,21 +260,25 @@ def write_documents(
self,
documents: List[Document],
policy: DuplicatePolicy = DuplicatePolicy.NONE,
):
) -> int:
"""
Writes documents to the DocumentStore.
Args:
documents: List of `haystack.Document`. If they already contain the embeddings, we'll index
them right away in Neo4j. If not, you can later call `update_embeddings` to create and index them.
policy: Handle duplicates document based on parameter options. Parameter options:
- skip: Ignore the duplicates documents.
- overwrite: Update any existing documents with the same ID when adding documents.
- fail: An error is raised if the document ID of the document being added already exists
- `SKIP`: Ignore the duplicates documents.
- `OVERWRITE`: Update any existing documents with the same ID when adding documents.
- `FAIL`: An error is raised if the document ID of the document being added already exists
Raises:
DuplicateDocumentError: Exception triggers on duplicate document.
ValueError: If `documents` parameter is not a list of of type `haystack.Document`.
Returns:
Number of written documents.
"""

for doc in documents:
Expand All @@ -273,11 +288,12 @@ def write_documents(

if len(documents) == 0:
logger.warning("Calling Neo4jDocumentStore.write_documents() with an empty list")
return
return 0

batch_size = self.write_batch_size
document_objects = self._handle_duplicate_documents(documents, policy)

documents_written = 0
batched_documents = get_batches_from_generator(document_objects, batch_size)
with tqdm(
total=len(document_objects),
Expand All @@ -289,8 +305,11 @@ def write_documents(
records = [self._document_to_neo4j_record(doc) for doc in document_batch]
embedding_field = self.embedding_field
self.neo4j_client.merge_nodes(self.node_label, embedding_field, records)
documents_written += len(records)
progress_bar.update(batch_size)

return documents_written

def delete_documents(self, document_ids: List[str]) -> None:
"""
Deletes all documents with a matching document_ids from the DocumentStore.
Expand Down

0 comments on commit f421ed5

Please sign in to comment.