diff --git a/dspy/retrieve/chromadb_rm.py b/dspy/retrieve/chromadb_rm.py index 4dc93661d1..aa8589d751 100644 --- a/dspy/retrieve/chromadb_rm.py +++ b/dspy/retrieve/chromadb_rm.py @@ -26,9 +26,6 @@ from chromadb.config import Settings from chromadb.utils import embedding_functions except ImportError: - chromadb = None - -if chromadb is None: raise ImportError( "The chromadb library is required to use ChromadbRM. Install it with `pip install dspy-ai[chromadb]`", ) @@ -46,6 +43,7 @@ class ChromadbRM(dspy.Retrieve): persist_directory (str): chromadb persist directory embedding_function (Optional[EmbeddingFunction[Embeddable]]): Optional function to use to embed documents. Defaults to DefaultEmbeddingFunction. k (int, optional): The number of top passages to retrieve. Defaults to 7. + client(Optional[chromadb.Client]): Optional chromadb client provided by user, default to None Returns: dspy.Prediction: An object containing the retrieved passages. @@ -54,12 +52,25 @@ class ChromadbRM(dspy.Retrieve): Below is a code snippet that shows how to use this as the default retriever: ```python llm = dspy.OpenAI(model="gpt-3.5-turbo") + # using default chromadb client retriever_model = ChromadbRM('collection_name', 'db_path') dspy.settings.configure(lm=llm, rm=retriever_model) # to test the retriever with "my query" retriever_model("my query") ``` + Use provided chromadb client + ```python + import chromadb + llm = dspy.OpenAI(model="gpt-3.5-turbo") + # say you have a chromadb running on a different port + client = chromadb.HttpClient(host='localhost', port=8889) + retriever_model = ChromadbRM('collection_name', 'db_path', client=client) + dspy.settings.configure(lm=llm, rm=retriever_model) + # to test the retriever with "my query" + retriever_model("my query") + ``` + Below is a code snippet that shows how to use this in the forward() function of a module ```python self.retrieve = ChromadbRM('collection_name', 'db_path', k=num_passages) @@ -73,9 +84,10 @@ def __init__( embedding_function: Optional[ EmbeddingFunction[Embeddable] ] = ef.DefaultEmbeddingFunction(), + client: Optional[chromadb.Client] = None, k: int = 7, ): - self._init_chromadb(collection_name, persist_directory) + self._init_chromadb(collection_name, persist_directory, client=client) self.ef = embedding_function super().__init__(k=k) @@ -84,22 +96,26 @@ def _init_chromadb( self, collection_name: str, persist_directory: str, + client: Optional[chromadb.Client] = None, ) -> chromadb.Collection: """Initialize chromadb and return the loaded index. Args: collection_name (str): chromadb collection name persist_directory (str): chromadb persist directory + client (chromadb.Client): chromadb client provided by user - - Returns: + Returns: collection per collection_name """ - self._chromadb_client = chromadb.Client( - Settings( - persist_directory=persist_directory, - is_persistent=True, - ), + if client: + self._chromadb_client = client + else: + self._chromadb_client = chromadb.Client( + Settings( + persist_directory=persist_directory, + is_persistent=True, + ), ) self._chromadb_collection = self._chromadb_client.get_or_create_collection( name=collection_name,