Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ By default, DSPy depends on `openai==0.28`. However, if you install `openai>=1.0
For the optional Pinecone, Qdrant, [chromadb](https://github.com/chroma-core/chroma), or [marqo](https://github.com/marqo-ai/marqo) retrieval integration(s), include the extra(s) below:

```
pip install dspy-ai[pinecone] # or [qdrant] or [chromadb] or [marqo]
pip install dspy-ai[pinecone] # or [qdrant] or [chromadb] or [marqo] or [mongodb]
```

## 2) Syntax: You're in charge of the workflow—it's free-form Python code!
Expand Down
117 changes: 117 additions & 0 deletions dspy/retrieve/mongodb_atlas_rm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from typing import List, Optional, Union, Any
import dspy
import os
from openai import (
OpenAI,
APITimeoutError,
InternalServerError,
RateLimitError,
UnprocessableEntityError,
)
import backoff

try:
from pymongo import MongoClient
from pymongo.errors import (
ConnectionFailure,
ConfigurationError,
ServerSelectionTimeoutError,
InvalidURI,
OperationFailure,
)
except ImportError:
raise ImportError(
"Please install the pymongo package by running `pip install dspy-ai[mongodb]`"
)


def build_vector_search_pipeline(
index_name: str, query_vector: List[float], num_candidates: int, limit: int
) -> List[dict[str, Any]]:
return [
{
"$vectorSearch": {
"index": index_name,
"path": "embedding",
"queryVector": query_vector,
"numCandidates": num_candidates,
"limit": limit,
}
},
{"$project": {"_id": 0, "text": 1, "score": {"$meta": "vectorSearchScore"}}},
]


class Embedder:
def __init__(self, provider: str, model: str):
if provider == "openai":
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
raise ValueError("Environment variable OPENAI_API_KEY must be set")
self.client = OpenAI()
self.model = model

@backoff.on_exception(
backoff.expo,
(
APITimeoutError,
InternalServerError,
RateLimitError,
UnprocessableEntityError,
),
max_time=15,
)
def __call__(self, queries) -> Any:
embedding = self.client.embeddings.create(input=queries, model=self.model)
return [result.embedding for result in embedding.data]


class MongoDBAtlasRM(dspy.Retrieve):
def __init__(
self,
db_name: str,
collection_name: str,
index_name: str,
k: int = 5,
embedding_provider: str = "openai",
embedding_model: str = "text-embedding-ada-002",
):
super().__init__(k=k)
self.db_name = db_name
self.collection_name = collection_name
self.index_name = index_name
self.username = os.getenv("ATLAS_USERNAME")
self.password = os.getenv("ATLAS_PASSWORD")
self.cluster_url = os.getenv("ATLAS_CLUSTER_URL")
if not self.username:
raise ValueError("Environment variable ATLAS_USERNAME must be set")
if not self.password:
raise ValueError("Environment variable ATLAS_PASSWORD must be set")
if not self.cluster_url:
raise ValueError("Environment variable ATLAS_CLUSTER_URL must be set")
try:
self.client = MongoClient(
f"mongodb+srv://{self.username}:{self.password}@{self.cluster_url}/{self.db_name}"
"?retryWrites=true&w=majority"
)
except (
InvalidURI,
ConfigurationError,
ConnectionFailure,
ServerSelectionTimeoutError,
OperationFailure,
) as e:
raise ConnectionError("Failed to connect to MongoDB Atlas") from e

self.embedder = Embedder(provider=embedding_provider, model=embedding_model)

def forward(self, query_or_queries: str) -> dspy.Prediction:
query_vector = self.embedder([query_or_queries])
pipeline = build_vector_search_pipeline(
index_name=self.index_name,
query_vector=query_vector[0],
num_candidates=self.k * 10,
limit=self.k,
)
contents = self.client[self.db_name][self.collection_name].aggregate(pipeline)
return dspy.Prediction(passages=list(contents))
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
"pinecone": ["pinecone-client~=2.2.4"],
"qdrant": ["qdrant-client~=1.6.2", "fastembed~=0.1.0"],
"chromadb": ["chromadb~=0.4.14"],
"marqo": ["marqo"]
"marqo": ["marqo"],
"mongodb": ["pymongo~=3.12.0"],
},
classifiers=[
"Development Status :: 3 - Alpha",
Expand Down