From a61c571879a027e6ccc7eb42257a2e366871c737 Mon Sep 17 00:00:00 2001 From: Xiao Cui Date: Fri, 19 Apr 2024 13:06:57 -0400 Subject: [PATCH 1/3] ChromadbRM: Provide optional client param --- dspy/retrieve/chromadb_rm.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/dspy/retrieve/chromadb_rm.py b/dspy/retrieve/chromadb_rm.py index 4dc93661d1..399bcdcc40 100644 --- a/dspy/retrieve/chromadb_rm.py +++ b/dspy/retrieve/chromadb_rm.py @@ -26,9 +26,6 @@ from chromadb.config import Settings from chromadb.utils import embedding_functions except ImportError: - chromadb = None - -if chromadb is None: raise ImportError( "The chromadb library is required to use ChromadbRM. Install it with `pip install dspy-ai[chromadb]`", ) @@ -73,9 +70,10 @@ def __init__( embedding_function: Optional[ EmbeddingFunction[Embeddable] ] = ef.DefaultEmbeddingFunction(), + client: Optional[chromadb.Client] = None, k: int = 7, ): - self._init_chromadb(collection_name, persist_directory) + self._init_chromadb(collection_name, persist_directory, client=client) self.ef = embedding_function super().__init__(k=k) @@ -84,22 +82,26 @@ def _init_chromadb( self, collection_name: str, persist_directory: str, + client: Optional[chromadb.Client] = None ) -> chromadb.Collection: """Initialize chromadb and return the loaded index. Args: collection_name (str): chromadb collection name persist_directory (str): chromadb persist directory + client (chromadb.Client): A chromadb client provided by user - - Returns: + Returns: collection per collection_name """ - self._chromadb_client = chromadb.Client( - Settings( - persist_directory=persist_directory, - is_persistent=True, - ), + if client: + self._chromadb_client = client + else: + self._chromadb_client = chromadb.Client( + Settings( + persist_directory=persist_directory, + is_persistent=True, + ), ) self._chromadb_collection = self._chromadb_client.get_or_create_collection( name=collection_name, From d48e8eba2d6be75f9bb06ec2be8f404b36b4e33d Mon Sep 17 00:00:00 2001 From: Xiao Cui Date: Fri, 19 Apr 2024 13:14:50 -0400 Subject: [PATCH 2/3] Update doc --- dspy/retrieve/chromadb_rm.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/dspy/retrieve/chromadb_rm.py b/dspy/retrieve/chromadb_rm.py index 399bcdcc40..fe26318d6d 100644 --- a/dspy/retrieve/chromadb_rm.py +++ b/dspy/retrieve/chromadb_rm.py @@ -43,6 +43,7 @@ class ChromadbRM(dspy.Retrieve): persist_directory (str): chromadb persist directory embedding_function (Optional[EmbeddingFunction[Embeddable]]): Optional function to use to embed documents. Defaults to DefaultEmbeddingFunction. k (int, optional): The number of top passages to retrieve. Defaults to 7. + client(Optional[chromadb.Client]): Optional chromadb client provided by user, default to None Returns: dspy.Prediction: An object containing the retrieved passages. @@ -51,12 +52,25 @@ class ChromadbRM(dspy.Retrieve): Below is a code snippet that shows how to use this as the default retriever: ```python llm = dspy.OpenAI(model="gpt-3.5-turbo") + # using default chromadb client retriever_model = ChromadbRM('collection_name', 'db_path') dspy.settings.configure(lm=llm, rm=retriever_model) # to test the retriever with "my query" retriever_model("my query") ``` + Use provided chromadb client + ```python + import chromadb + llm = dspy.OpenAI(model="gpt-3.5-turbo") + # say you have a chromadb running on a different port + client = chromadb.HttpClient(host='localhost', port=8889) + retriever_model = ChromadbRM('collection_name', 'db_path', client=client) + dspy.settings.configure(lm=llm, rm=retriever_model) + # to test the retriever with "my query" + retriever_model("my query") + ``` + Below is a code snippet that shows how to use this in the forward() function of a module ```python self.retrieve = ChromadbRM('collection_name', 'db_path', k=num_passages) @@ -89,7 +103,7 @@ def _init_chromadb( Args: collection_name (str): chromadb collection name persist_directory (str): chromadb persist directory - client (chromadb.Client): A chromadb client provided by user + client (chromadb.Client): chromadb client provided by user Returns: collection per collection_name """ From 76dd1246c6a3e73f9303316fcbba84eb38c7e16f Mon Sep 17 00:00:00 2001 From: Xiao Cui Date: Sat, 27 Apr 2024 20:33:10 -0400 Subject: [PATCH 3/3] Fix lint error --- dspy/retrieve/chromadb_rm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dspy/retrieve/chromadb_rm.py b/dspy/retrieve/chromadb_rm.py index fe26318d6d..aa8589d751 100644 --- a/dspy/retrieve/chromadb_rm.py +++ b/dspy/retrieve/chromadb_rm.py @@ -96,7 +96,7 @@ def _init_chromadb( self, collection_name: str, persist_directory: str, - client: Optional[chromadb.Client] = None + client: Optional[chromadb.Client] = None, ) -> chromadb.Collection: """Initialize chromadb and return the loaded index.