diff --git a/dspy/retrieve/neo4j_rm.py b/dspy/retrieve/neo4j_rm.py index f71dbb1cbc..bcd576325c 100644 --- a/dspy/retrieve/neo4j_rm.py +++ b/dspy/retrieve/neo4j_rm.py @@ -1,5 +1,5 @@ import os -from typing import Any, List, Optional, Union +from typing import Any, List, Optional, Union, Callable import backoff from openai import ( @@ -108,6 +108,7 @@ def __init__( retrieval_query: str = None, embedding_provider: str = "openai", embedding_model: str = "text-embedding-ada-002", + embedding_function: Optional[Callable] = None, ): super().__init__(k=k) self.index_name = index_name @@ -136,7 +137,7 @@ def __init__( ) as e: raise ConnectionError("Failed to connect to Neo4j database") from e - self.embedder = Embedder(provider=embedding_provider, model=embedding_model) + self.embedder = embedding_function or Embedder(provider=embedding_provider, model=embedding_model) def forward(self, query_or_queries: Union[str, List[str]], k: Optional[int]) -> Prediction: if not isinstance(query_or_queries, list):