Skip to content

Commit

Permalink
feat: Introducing execute_write method in Neo4jClient to run arbitr…
Browse files Browse the repository at this point in the history
…ary Cypher queries which modify data

- documenting `use_env` client configuration property to trigger logic which read values from environment variables if any
- creating constants for default configuration values
  • Loading branch information
prosto committed Feb 8, 2024
1 parent ee6f0b6 commit 88e89bb
Showing 1 changed file with 59 additions and 8 deletions.
67 changes: 59 additions & 8 deletions src/neo4j_haystack/client/neo4j_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
ManagedTransaction,
Record,
Result,
ResultSummary,
Session,
unit_of_work,
)
Expand All @@ -34,6 +35,18 @@
"""Default variable name used in Cypher queries to match and return Documents, e.g.
`:::cypher match(doc:Document) where doc.id = $id return doc` where `doc` is a variable name."""

DEFAULT_NEO4J_URI = "bolt://localhost:7687"
"""Default URI to connect to neo4j instance, e.g. a local DB running in Docker container."""

DEFAULT_NEO4J_DATABASE = "neo4j"
"""Default Neo4j database name to connect to if not provided."""

DEFAULT_NEO4J_USERNAME = "neo4j"
"""Default Neo4j username to be used for authentication with Neo4j. Used to simplify local development."""

DEFAULT_NEO4J_PASSWORD = "neo4j"
"""Default Neo4j password to be used for authentication with Neo4j. Used to simplify local development."""

Neo4jRecord = Dict[str, Any]
"""Type alias for data items returned from Neo4j queries"""

Expand Down Expand Up @@ -97,16 +110,23 @@ class Neo4jClientConfig:
driver_config: Additional driver configuration.
session_config: Additional session configuration.
transaction_config: Additional transaction configuration (e.g. ``timeout``)
use_env: If `True` the following Driver attributes will be assigned from respective environment variables:
```py
>>> url = os.getenv("NEO4J_URI")
>>> database = os.getenv("NEO4J_DATABASE")
>>> username = os.getenv("NEO4J_USERNAME")
>>> password = os.getenv("NEO4J_PASSWORD")
```
Raises:
ValueError: In case conflicting auth credentials are provided - choose either username/password combination
or `driver_config.auth`.
"""

url: Optional[str] = field(default="bolt://localhost:7687")
database: Optional[str] = field(default="neo4j")
username: Optional[str] = field(default="neo4j")
password: Optional[str] = field(default="neo4j")
url: Optional[str] = field(default=DEFAULT_NEO4J_URI)
database: Optional[str] = field(default=DEFAULT_NEO4J_DATABASE)
username: Optional[str] = field(default=DEFAULT_NEO4J_USERNAME)
password: Optional[str] = field(default=DEFAULT_NEO4J_PASSWORD)

driver_config: Neo4jDriverConfig = field(default_factory=dict)
session_config: Neo4jSessionConfig = field(default_factory=dict)
Expand All @@ -117,7 +137,7 @@ class Neo4jClientConfig:

def __post_init__(self):
if self.use_env:
self.url = os.getenv("NEO4J_URL", self.url)
self.url = os.getenv("NEO4J_URI", self.url)
self.database = os.getenv("NEO4J_DATABASE", self.database)
self.username = os.getenv("NEO4J_USERNAME", self.username)
self.password = os.getenv("NEO4J_PASSWORD", self.password)
Expand Down Expand Up @@ -376,7 +396,7 @@ def _mgt_tx(tx: ManagedTransaction) -> None:
with self._begin_session() as session:
session.execute_write(_mgt_tx)

def merge_nodes(self, node_label: str, embedding_field: str, records: List[Neo4jRecord]) -> None:
def merge_nodes(self, node_label: str, embedding_field: str, records: List[Neo4jRecord]) -> ResultSummary:
"""
Creates or updates a node in neo4j representing a Document with all properties. Nodes are matched by "id",
if not found a new node will be created. See the following manuals:
Expand All @@ -396,7 +416,7 @@ def merge_nodes(self, node_label: str, embedding_field: str, records: List[Neo4j

@self._unit_of_work()
def _mgt_tx(tx: ManagedTransaction):
tx.run(
result = tx.run(
f"""
WITH $records AS batch
UNWIND batch as row
Expand All @@ -410,9 +430,11 @@ def _mgt_tx(tx: ManagedTransaction):
""",
records=records,
)
summary = result.consume()
return summary

with self._begin_session() as session:
session.execute_write(_mgt_tx)
return session.execute_write(_mgt_tx)

def count_nodes(self, node_label: str, filter_ast: Optional[AST] = None) -> int:
"""
Expand Down Expand Up @@ -586,6 +608,35 @@ def _mgt_tx(tx: ManagedTransaction) -> List[Record]:

return [{**record.value(NODE_VAR), score_property: record.value("score")} for record in records]

def execute_write(
self,
query: str,
parameters: Optional[Dict[str, Any]] = None,
) -> tuple[ResultSummary, List[Dict[str, Any]]]:
"""
Runs an arbitrary write Cypher query with parameters.
Args:
query: Cypher query to run in Neo4j.
parameters: Query parameters which can be used as placeholders in the `query`.
Returns:
A tuple consisting of execution result summary (`neo4j.ResultSummary`) and data records (`dict`) if any.
"""

@self._unit_of_work()
def _mgt_tx(tx: ManagedTransaction):
result = tx.run(
query,
parameters=parameters,
)
records = result.data()
summary = result.consume()
return summary, records

with self._begin_session() as session:
return session.execute_write(_mgt_tx)

def update_node(self, node_label: str, doc_id: str, data: Dict[str, Any]) -> Optional[Neo4jRecord]:
"""
Updates a given node matched by the given id (`doc_id`). Properties are mutated by `+=` operator,
Expand Down

0 comments on commit 88e89bb

Please sign in to comment.