From 0e96712fedf27e083344e0590b5cb3c5d6aa5c71 Mon Sep 17 00:00:00 2001 From: John Yearsley Date: Mon, 25 Dec 2023 18:25:04 -0700 Subject: [PATCH 1/3] (250) - Add MongoDB Atlas Retrieval Model --- README.md | 2 +- dspy/retrieve/mongodb_atlas_rm.py | 110 ++++++++++++++++++++++++++++++ setup.py | 3 +- 3 files changed, 113 insertions(+), 2 deletions(-) create mode 100644 dspy/retrieve/mongodb_atlas_rm.py diff --git a/README.md b/README.md index 44834ef363..571a94c1ae 100644 --- a/README.md +++ b/README.md @@ -62,7 +62,7 @@ Or open our intro notebook in Google Colab: [ List[dict[str, Any]]: + return [ + { + "$vectorSearch": { + "index": index_name, + "path": "embedding", + "queryVector": query_vector, + "numCandidates": num_candidates, + "limit": limit, + } + }, + {"$project": {"_id": 0, "text": 1, "score": {"$meta": "vectorSearchScore"}}}, + ] + + +class Embedder: + def __init__(self, provider: str, model: str): + if provider == "openai": + openai.api_key = os.getenv("OPENAI_API_KEY") + if not openai.api_key: + raise ValueError("Environment variable OPENAI_API_KEY must be set") + self.client = openai + self.model = model + + @backoff.on_exception( + backoff.expo, + ( + openai.error.RateLimitError, + openai.error.ServiceUnavailableError, + openai.error.APIError, + ), + max_time=15, + ) + def __call__(self, queries) -> Any: + embedding = self.client.Embedding.create(input=queries, model=self.model) + return [embedding["embedding"] for embedding in embedding["data"]] + + +class MongoDBAtlasRM(dspy.Retrieve): + def __init__( + self, + db_name: str, + collection_name: str, + index_name: str, + k: int = 5, + embedding_provider: str = "openai", + embedding_model: str = "text-embedding-ada-002", + ): + super().__init__(k=k) + self.db_name = db_name + self.collection_name = collection_name + self.index_name = index_name + self.username = os.getenv("ATLAS_USERNAME") + self.password = os.getenv("ATLAS_PASSWORD") + self.cluster_url = os.getenv("ATLAS_CLUSTER_URL") + if not self.username: + raise ValueError("Environment variable ATLAS_USERNAME must be set") + if not self.password: + raise ValueError("Environment variable ATLAS_PASSWORD must be set") + if not self.cluster_url: + raise ValueError("Environment variable ATLAS_CLUSTER_URL must be set") + try: + self.client = MongoClient( + f"mongodb+srv://{self.username}:{self.password}@{self.cluster_url}/{self.db_name}" + "?retryWrites=true&w=majority" + ) + except ( + InvalidURI, + ConfigurationError, + ConnectionFailure, + ServerSelectionTimeoutError, + OperationFailure, + ) as e: + raise ConnectionError("Failed to connect to MongoDB Atlas") from e + + self.embedder = Embedder(provider=embedding_provider, model=embedding_model) + + def forward(self, query_or_queries: str) -> dspy.Prediction: + query_vector = self.embedder([query_or_queries]) + pipeline = build_vector_search_pipeline( + index_name=self.index_name, + query_vector=query_vector[0], + num_candidates=self.k * 10, + limit=self.k, + ) + contents = self.client[self.db_name][self.collection_name].aggregate(pipeline) + return dspy.Prediction(passages=list(contents)) diff --git a/setup.py b/setup.py index f5a867bc4c..4629d89d5e 100644 --- a/setup.py +++ b/setup.py @@ -25,7 +25,8 @@ "pinecone": ["pinecone-client~=2.2.4"], "qdrant": ["qdrant-client~=1.6.2", "fastembed~=0.1.0"], "chromadb": ["chromadb~=0.4.14"], - "marqo": ["marqo"] + "marqo": ["marqo"], + "mongodb": ["pymongo~=3.12.0"], }, classifiers=[ "Development Status :: 3 - Alpha", From 8b2f7676568ec492ed6367b93ad8ae129fdb3ae1 Mon Sep 17 00:00:00 2001 From: John Yearsley Date: Mon, 25 Dec 2023 18:25:04 -0700 Subject: [PATCH 2/3] (250) - Add MongoDB Atlas Retrieval Model --- README.md | 2 +- dspy/retrieve/mongodb_atlas_rm.py | 115 ++++++++++++++++++++++++++++++ setup.py | 3 +- 3 files changed, 118 insertions(+), 2 deletions(-) create mode 100644 dspy/retrieve/mongodb_atlas_rm.py diff --git a/README.md b/README.md index 44834ef363..571a94c1ae 100644 --- a/README.md +++ b/README.md @@ -62,7 +62,7 @@ Or open our intro notebook in Google Colab: [ List[dict[str, Any]]: + return [ + { + "$vectorSearch": { + "index": index_name, + "path": "embedding", + "queryVector": query_vector, + "numCandidates": num_candidates, + "limit": limit, + } + }, + {"$project": {"_id": 0, "text": 1, "score": {"$meta": "vectorSearchScore"}}}, + ] + + +class Embedder: + def __init__(self, provider: str, model: str): + if provider == "openai": + openai.api_key = os.getenv("OPENAI_API_KEY") + if not openai.api_key: + raise ValueError("Environment variable OPENAI_API_KEY must be set") + self.client = openai + self.model = model + + @backoff.on_exception( + backoff.expo, + ( + openai.error.RateLimitError, + openai.error.ServiceUnavailableError, + openai.error.APIError, + ), + max_time=15, + ) + def __call__(self, queries) -> Any: + embedding = self.client.Embedding.create(input=queries, model=self.model) + return [embedding["embedding"] for embedding in embedding["data"]] + + +class MongoDBAtlasRM(dspy.Retrieve): + def __init__( + self, + db_name: str, + collection_name: str, + index_name: str, + k: int = 5, + embedding_provider: str = "openai", + embedding_model: str = "text-embedding-ada-002", + ): + super().__init__(k=k) + self.db_name = db_name + self.collection_name = collection_name + self.index_name = index_name + self.username = os.getenv("ATLAS_USERNAME") + self.password = os.getenv("ATLAS_PASSWORD") + self.cluster_url = os.getenv("ATLAS_CLUSTER_URL") + if not self.username: + raise ValueError("Environment variable ATLAS_USERNAME must be set") + if not self.password: + raise ValueError("Environment variable ATLAS_PASSWORD must be set") + if not self.cluster_url: + raise ValueError("Environment variable ATLAS_CLUSTER_URL must be set") + try: + self.client = MongoClient( + f"mongodb+srv://{self.username}:{self.password}@{self.cluster_url}/{self.db_name}" + "?retryWrites=true&w=majority" + ) + except ( + InvalidURI, + ConfigurationError, + ConnectionFailure, + ServerSelectionTimeoutError, + OperationFailure, + ) as e: + raise ConnectionError("Failed to connect to MongoDB Atlas") from e + + self.embedder = Embedder(provider=embedding_provider, model=embedding_model) + + def forward(self, query_or_queries: Union[str, List[str]]) -> dspy.Prediction: + queries = ( + [query_or_queries] + if isinstance(query_or_queries, str) + else query_or_queries + ) + query_vector = self.embedder(queries) + pipeline = build_vector_search_pipeline( + index_name=self.index_name, + query_vector=query_vector[0], + num_candidates=self.k * 10, + limit=self.k, + ) + contents = self.client[self.db_name][self.collection_name].aggregate(pipeline) + return dspy.Prediction(passages=list(contents)) diff --git a/setup.py b/setup.py index f5a867bc4c..4629d89d5e 100644 --- a/setup.py +++ b/setup.py @@ -25,7 +25,8 @@ "pinecone": ["pinecone-client~=2.2.4"], "qdrant": ["qdrant-client~=1.6.2", "fastembed~=0.1.0"], "chromadb": ["chromadb~=0.4.14"], - "marqo": ["marqo"] + "marqo": ["marqo"], + "mongodb": ["pymongo~=3.12.0"], }, classifiers=[ "Development Status :: 3 - Alpha", From 1864168e541ce5d3dd4c857417ef923a3e1da916 Mon Sep 17 00:00:00 2001 From: John Yearsley Date: Wed, 27 Dec 2023 22:30:27 -0700 Subject: [PATCH 3/3] Update openai imports and client methods to match upgraded python client --- dspy/retrieve/mongodb_atlas_rm.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/dspy/retrieve/mongodb_atlas_rm.py b/dspy/retrieve/mongodb_atlas_rm.py index 5bf9d6f4d5..6577a6c237 100644 --- a/dspy/retrieve/mongodb_atlas_rm.py +++ b/dspy/retrieve/mongodb_atlas_rm.py @@ -1,7 +1,13 @@ from typing import List, Optional, Union, Any import dspy import os -import openai +from openai import ( + OpenAI, + APITimeoutError, + InternalServerError, + RateLimitError, + UnprocessableEntityError, +) import backoff try: @@ -39,24 +45,25 @@ def build_vector_search_pipeline( class Embedder: def __init__(self, provider: str, model: str): if provider == "openai": - openai.api_key = os.getenv("OPENAI_API_KEY") - if not openai.api_key: + api_key = os.getenv("OPENAI_API_KEY") + if not api_key: raise ValueError("Environment variable OPENAI_API_KEY must be set") - self.client = openai + self.client = OpenAI() self.model = model @backoff.on_exception( backoff.expo, ( - openai.error.RateLimitError, - openai.error.ServiceUnavailableError, - openai.error.APIError, + APITimeoutError, + InternalServerError, + RateLimitError, + UnprocessableEntityError, ), max_time=15, ) def __call__(self, queries) -> Any: - embedding = self.client.Embedding.create(input=queries, model=self.model) - return [embedding["embedding"] for embedding in embedding["data"]] + embedding = self.client.embeddings.create(input=queries, model=self.model) + return [result.embedding for result in embedding.data] class MongoDBAtlasRM(dspy.Retrieve):