Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 27 additions & 11 deletions dspy/retrieve/chromadb_rm.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,6 @@
from chromadb.config import Settings
from chromadb.utils import embedding_functions
except ImportError:
chromadb = None

if chromadb is None:
raise ImportError(
"The chromadb library is required to use ChromadbRM. Install it with `pip install dspy-ai[chromadb]`",
)
Expand All @@ -46,6 +43,7 @@ class ChromadbRM(dspy.Retrieve):
persist_directory (str): chromadb persist directory
embedding_function (Optional[EmbeddingFunction[Embeddable]]): Optional function to use to embed documents. Defaults to DefaultEmbeddingFunction.
k (int, optional): The number of top passages to retrieve. Defaults to 7.
client(Optional[chromadb.Client]): Optional chromadb client provided by user, default to None

Returns:
dspy.Prediction: An object containing the retrieved passages.
Expand All @@ -54,12 +52,25 @@ class ChromadbRM(dspy.Retrieve):
Below is a code snippet that shows how to use this as the default retriever:
```python
llm = dspy.OpenAI(model="gpt-3.5-turbo")
# using default chromadb client
retriever_model = ChromadbRM('collection_name', 'db_path')
dspy.settings.configure(lm=llm, rm=retriever_model)
# to test the retriever with "my query"
retriever_model("my query")
```

Use provided chromadb client
```python
import chromadb
llm = dspy.OpenAI(model="gpt-3.5-turbo")
# say you have a chromadb running on a different port
client = chromadb.HttpClient(host='localhost', port=8889)
retriever_model = ChromadbRM('collection_name', 'db_path', client=client)
dspy.settings.configure(lm=llm, rm=retriever_model)
# to test the retriever with "my query"
retriever_model("my query")
```

Below is a code snippet that shows how to use this in the forward() function of a module
```python
self.retrieve = ChromadbRM('collection_name', 'db_path', k=num_passages)
Expand All @@ -73,9 +84,10 @@ def __init__(
embedding_function: Optional[
EmbeddingFunction[Embeddable]
] = ef.DefaultEmbeddingFunction(),
client: Optional[chromadb.Client] = None,
k: int = 7,
):
self._init_chromadb(collection_name, persist_directory)
self._init_chromadb(collection_name, persist_directory, client=client)
self.ef = embedding_function

super().__init__(k=k)
Expand All @@ -84,22 +96,26 @@ def _init_chromadb(
self,
collection_name: str,
persist_directory: str,
client: Optional[chromadb.Client] = None,
) -> chromadb.Collection:
"""Initialize chromadb and return the loaded index.

Args:
collection_name (str): chromadb collection name
persist_directory (str): chromadb persist directory
client (chromadb.Client): chromadb client provided by user


Returns:
Returns: collection per collection_name
"""

self._chromadb_client = chromadb.Client(
Settings(
persist_directory=persist_directory,
is_persistent=True,
),
if client:
self._chromadb_client = client
else:
self._chromadb_client = chromadb.Client(
Settings(
persist_directory=persist_directory,
is_persistent=True,
),
)
self._chromadb_collection = self._chromadb_client.get_or_create_collection(
name=collection_name,
Expand Down