From 2818fce29e4ce306345a6deed6ba888ec66c9ebf Mon Sep 17 00:00:00 2001 From: HARISHKUMAR1112001 Date: Sun, 14 Apr 2024 18:27:18 +0530 Subject: [PATCH 1/8] feat(dspy): add support for vector, hybrid and fulltext search in azure ai search --- dspy/retrieve/azureaisearch_rm.py | 293 ++++++++++++++++++++++++------ 1 file changed, 241 insertions(+), 52 deletions(-) diff --git a/dspy/retrieve/azureaisearch_rm.py b/dspy/retrieve/azureaisearch_rm.py index 6847ada5ea..90e00624cf 100644 --- a/dspy/retrieve/azureaisearch_rm.py +++ b/dspy/retrieve/azureaisearch_rm.py @@ -1,4 +1,5 @@ -from typing import List, Optional, Union +import warnings +from typing import Any, Callable, List, Optional, Union import dspy from dsp.utils.utils import dotdict @@ -7,13 +8,22 @@ from azure.core.credentials import AzureKeyCredential from azure.search.documents import SearchClient from azure.search.documents._paging import SearchItemPaged - from azure.search.documents.models import QueryType + from azure.search.documents.models import QueryType, VectorFilterMode, VectorizedQuery except ImportError: raise ImportError( "You need to install azure-search-documents library" - "Please use the command: pip install azure-search-documents", + "Please use the command: pip install azure-search-documents==11.6.0b1", ) +try: + import openai +except ImportError: + warnings.warn( + "`openai` is not installed. Install it with `pip install openai` to use AzureOpenAI embedding models.", + category=ImportWarning, + ) + + class AzureAISearchRM(dspy.Retrieve): """ @@ -24,7 +34,11 @@ class AzureAISearchRM(dspy.Retrieve): search_api_key (str): The API key for accessing the Azure AI Search service. search_index_name (str): The name of the search index in the Azure AI Search service. field_text (str): The name of the field containing text content in the search index. This field will be mapped to the "content" field in the dsp framework. + field_vector (Optional[str]): The name of the field containing vector content in the search index. Defaults to None. k (int, optional): The default number of top passages to retrieve. Defaults to 3. + azure_openai_client (Optional[openai.AzureOpenAI]): An instance of the AzureOpenAI client. Defaults to None. + openai_embed_model (Optional[str]): The name of the OpenAI embedding model. Defaults to "text-embedding-ada-002". + embedding_func (Optional[Callable]): A function for generating embeddings. Defaults to None. semantic_ranker (bool, optional): Whether to use semantic ranking. Defaults to False. filter (str, optional): Additional filter query. Defaults to None. query_language (str, optional): The language of the query. Defaults to "en-Us". @@ -32,6 +46,10 @@ class AzureAISearchRM(dspy.Retrieve): use_semantic_captions (bool, optional): Whether to use semantic captions. Defaults to False. query_type (Optional[QueryType], optional): The type of query. Defaults to QueryType.FULL. semantic_configuration_name (str, optional): The name of the semantic configuration. Defaults to None. + is_vector_search (Optional[bool]): Whether to enable vector search. Defaults to False. + is_hybrid_search (Optional[bool]): Whether to enable hybrid search. Defaults to False. + is_fulltext_search (Optional[bool]): Whether to enable fulltext search. Defaults to True. + vector_filter_mode (Optional[VectorFilterMode]): The vector filter mode. Defaults to None. Examples: Below is a code snippet that demonstrates how to instantiate and use the AzureAISearchRM class: @@ -48,16 +66,32 @@ class AzureAISearchRM(dspy.Retrieve): search_service_name (str): The name of the Azure AI Search service. search_api_key (str): The API key for accessing the Azure AI Search service. search_index_name (str): The name of the search index in the Azure AI Search service. - field_text (str): The name of the field containing text content in the search index. endpoint (str): The endpoint URL for the Azure AI Search service. + field_text (str): The name of the field containing text content in the search index. + field_vector (Optional[str]): The name of the field containing vector content in the search index. + azure_openai_client (Optional[openai.AzureOpenAI]): An instance of the AzureOpenAI client. + openai_embed_model (Optional[str]): The name of the OpenAI embedding model. + embedding_func (Optional[Callable]): A function for generating embeddings. credential (AzureKeyCredential): The Azure key credential for accessing the service. client (SearchClient): The Azure AI Search client instance. + semantic_ranker (bool): Whether to use semantic ranking. + filter (str): Additional filter query. + query_language (str): The language of the query. + query_speller (str): The speller mode. + use_semantic_captions (bool): Whether to use semantic captions. + query_type (Optional[QueryType]): The type of query. + semantic_configuration_name (str): The name of the semantic configuration. + is_vector_search (Optional[bool]): Whether to enable vector search. + is_hybrid_search (Optional[bool]): Whether to enable hybrid search. + is_fulltext_search (Optional[bool]): Whether to enable fulltext search. + vector_filter_mode (Optional[VectorFilterMode]): The vector filter mode. Methods: forward(query_or_queries: Union[str, List[str]], k: Optional[int]) -> dspy.Prediction: Search for the top passages corresponding to the given query or queries. azure_search_request( + self, key_content: str, client: SearchClient, query: str, @@ -68,13 +102,38 @@ class AzureAISearchRM(dspy.Retrieve): query_speller: str, use_semantic_captions: bool, query_type: QueryType, - semantic_configuration_name: str + semantic_configuration_name: str, + is_vector_search: bool, + is_hybrid_search: bool, + is_fulltext_search: bool, + field_vector: str, + vector_filter_mode: VectorFilterMode ) -> List[dict]: Perform a search request to the Azure AI Search service. - process_azure_result(results: SearchItemPaged, content_key: str, content_score: str) -> List[dict]: + process_azure_result( + self, + results:SearchItemPaged, + content_key:str, + content_score: str + ) -> List[dict]: Process the results received from the Azure AI Search service and map them to the correct format. + get_embeddings( + self, + query: str, + k_nearest_neighbors: int, + field_vector: str + ) -> List | Any: + Returns embeddings for the given query. + + check_sementic_configuration( + self, + semantic_configuration_name, + query_type + ): + Checks semantic configuration. + Raises: ImportError: If the required Azure AI Search libraries are not installed. @@ -89,7 +148,11 @@ def __init__( search_api_key: str, search_index_name: str, field_text: str, + field_vector: Optional[str] = None, k: int = 3, + azure_openai_client: Optional[openai.AzureOpenAI] = None, + openai_embed_model: Optional[str] = "text-embedding-ada-002", + embedding_func: Optional[Callable] = None, semantic_ranker: bool = False, filter: str = None, query_language: str = "en-Us", @@ -97,18 +160,25 @@ def __init__( use_semantic_captions: bool = False, query_type: Optional[QueryType] = QueryType.FULL, semantic_configuration_name: str = None, - + is_vector_search: Optional[bool] = False, + is_hybride_search: Optional[bool] = False, + is_fulltext_search: Optional[bool] = True, + vector_filter_mode: Optional[VectorFilterMode.PRE_FILTER] = None, ): self.search_service_name = search_service_name self.search_api_key = search_api_key self.search_index_name = search_index_name - self.endpoint=f"https://{self.search_service_name}.search.windows.net" - self.field_text = field_text # field name of the text content + self.endpoint = f"https://{self.search_service_name}.search.windows.net" + self.field_text = field_text # field name of the text content + self.field_vector = field_vector # field name of the vector content + self.azure_openai_client = azure_openai_client + self.openai_embed_model = openai_embed_model + self.embedding_func = embedding_func # Create a client self.credential = AzureKeyCredential(self.search_api_key) - self.client = SearchClient(endpoint=self.endpoint, - index_name=self.search_index_name, - credential=self.credential) + self.client = SearchClient( + endpoint=self.endpoint, index_name=self.search_index_name, credential=self.credential, + ) self.semantic_ranker = semantic_ranker self.filter = filter self.query_language = query_language @@ -116,39 +186,106 @@ def __init__( self.use_semantic_captions = use_semantic_captions self.query_type = query_type self.semantic_configuration_name = semantic_configuration_name + self.is_vector_search = is_vector_search + self.is_hybride_search = is_hybride_search + self.is_fulltext_search = is_fulltext_search + self.vector_filter_mode = vector_filter_mode super().__init__(k=k) - def azure_search_request(self,key_content: str, client: SearchClient, query: str, top: int, semantic_ranker: bool, filter: str, query_language: str, query_speller: str, use_semantic_captions: bool, query_type: QueryType, semantic_configuration_name: str): + def azure_search_request( + self, + key_content: str, + client: SearchClient, + query: str, + top: int, + semantic_ranker: bool, + filter: str, + query_language: str, + query_speller: str, + use_semantic_captions: bool, + query_type: QueryType, + semantic_configuration_name: str, + is_vector_search: bool, + is_hybrid_search: bool, + is_fulltext_search: bool, + field_vector: str, + vector_filter_mode: VectorFilterMode, + ): """ Search in Azure AI Search Index """ - # TODO: Add Support for Vector Search And Hybride Search - if semantic_ranker: - results = client.search(search_text=query, - filter=filter, - query_type=query_type, - query_language = query_language, - query_speller=query_speller, - semantic_configuration_name=semantic_configuration_name, - top=top, - query_caption = ( - 'extractive|highlight-false' - if use_semantic_captions - else None - ), - ) - else: - results = client.search(search_text=query,top=top,filter=filter) + if is_vector_search: + vector_query = self.get_embeddings(query, top, field_vector) + if semantic_ranker: + self.check_sementic_configuration(semantic_configuration_name, query_type) + results = client.search( + search_text=None, + filter=filter, + query_type=query_type, + vector_queries=[vector_query], + vector_filter_mode=vector_filter_mode, + semantic_configuration_name=semantic_configuration_name, + top=top, + ) + else: + results = client.search( + search_text=None, + filter=filter, + vector_queries=[vector_query], + vector_filter_mode=vector_filter_mode, + top=top, + ) + if is_hybrid_search: + vector_query = self.get_embeddings(query, top, field_vector) + if semantic_ranker: + self.check_sementic_configuration(semantic_configuration_name, query_type) + results = client.search( + search_text=query, + filter=filter, + query_type=query_type, + query_language=query_language, + query_speller=query_speller, + semantic_configuration_name=semantic_configuration_name, + top=top, + vector_queries=[vector_query], + vector_filter_mode=vector_filter_mode, + query_caption=("extractive|highlight-false" if use_semantic_captions else None), + ) + else: + results = client.search( + search_text=query, + filter=filter, + query_language=query_language, + query_speller=query_speller, + top=top, + vector_queries=[vector_query], + vector_filter_mode=vector_filter_mode, + ) + if is_fulltext_search: + if semantic_ranker: + self.check_sementic_configuration(semantic_configuration_name, query_type) + results = client.search( + search_text=query, + filter=filter, + query_type=query_type, + query_language=query_language, + query_speller=query_speller, + semantic_configuration_name=semantic_configuration_name, + top=top, + query_caption=("extractive|highlight-false" if use_semantic_captions else None), + ) + else: + results = client.search(search_text=query, top=top, filter=filter) - sorted_results = sorted(results, key=lambda x: x['@search.score'], reverse=True) + sorted_results = sorted(results, key=lambda x: x["@search.score"], reverse=True) sorted_results = self.process_azure_result(sorted_results, key_content, key_content) return sorted_results - def process_azure_result(self,results:SearchItemPaged, content_key:str, content_score: str): + def process_azure_result(self, results: SearchItemPaged, content_key: str, content_score: str): """ process received result from Azure AI Search as dictionary array and map content and score to correct format """ @@ -156,16 +293,16 @@ def process_azure_result(self,results:SearchItemPaged, content_key:str, content_ for result in results: tmp = {} for key, value in result.items(): - if(key == content_key): - tmp["text"] = value # assign content - elif(key == content_score): + if key == content_key: + tmp["text"] = value # assign content + elif key == content_score: tmp["score"] = value else: tmp[key] = value res.append(tmp) return res - def forward(self, query_or_queries: Union[str, List[str]], k: Optional[int]) -> dspy.Prediction: + def forward(self, query_or_queries: Union[str, List[str]], k: Optional[int]) -> dspy.Prediction: """ Search with pinecone for self.k top passages for query @@ -177,25 +314,77 @@ def forward(self, query_or_queries: Union[str, List[str]], k: Optional[int]) -> """ k = k if k is not None else self.k - queries = ( - [query_or_queries] - if isinstance(query_or_queries, str) - else query_or_queries - ) + queries = [query_or_queries] if isinstance(query_or_queries, str) else query_or_queries queries = [q for q in queries if q] # Filter empty queries passages = [] for query in queries: - results = self.azure_search_request(self.field_text, - self.client, query, - k, - self.semantic_ranker, - self.filter, - self.query_language, - self.query_speller, - self.use_semantic_captions, - self.query_type, - self.semantic_configuration_name) - passages.extend(dotdict({"long_text": d['text']}) for d in results) + results = self.azure_search_request( + self.field_text, + self.client, + query, + k, + self.semantic_ranker, + self.filter, + self.query_language, + self.query_speller, + self.use_semantic_captions, + self.query_type, + self.semantic_configuration_name, + self.is_vector_search, + self.is_hybride_search, + self.is_fulltext_search, + self.field_vector, + self.vector_filter_mode, + ) + passages.extend(dotdict({"long_text": d["text"]}) for d in results) return passages + + def get_embeddings(self, query: str, k_nearest_neighbors: int, field_vector: str) -> List | Any: + """ + Returns embeddings for the given query. + + Args: + query (str): The query for which embeddings are to be retrieved. + k_nearest_neighbors (int): The number of nearest neighbors to consider. + field_vector (str): The field vector to use for embeddings. + + Returns: + list: A list containing the vectorized query. + Any: The result of embedding_func if azure_openai_client is not provided. + + Raises: + AssertionError: If neither azure_openai_client nor embedding_func is provided, + or if field_vector is not provided. + """ + assert ( + self.azure_openai_client or self.embedding_func + ), "Either azure_openai_client or embedding_func must be provided." + assert field_vector, "field_vector must be provided." + + if self.azure_openai_client is not None: + embedding = ( + self.azure_openai_client.embeddings.create(input=query, model=self.openai_embed_model).data[0].embedding + ) + vector_query = VectorizedQuery( + vector=embedding, k_nearest_neighbors=k_nearest_neighbors, fields=field_vector, + ) + return [vector_query] + else: + return self.embedding_func(query) + + def check_sementic_configuration(self, semantic_configuration_name, query_type): + """ + Checks semantic configuration. + + Args: + semantic_configuration_name: The name of the semantic configuration. + query_type: The type of the query. + + Raises: + AssertionError: If semantic_configuration_name is not provided + or if query_type is not QueryType.SEMANTIC. + """ + assert semantic_configuration_name, "Semantic configuration name must be provided." + assert query_type == QueryType.SEMANTIC, "Query type must be QueryType.SEMANTIC." From 807882c390fa8167dbad91f03f83c6961bc285d8 Mon Sep 17 00:00:00 2001 From: HARISHKUMAR1112001 Date: Sun, 14 Apr 2024 18:29:23 +0530 Subject: [PATCH 2/8] fix(dspy): update pyproject.toml --- pyproject.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index f99df941a0..75c549c007 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,7 @@ dependencies = [ [project.optional-dependencies] anthropic = ["anthropic~=0.18.0"] +azure-ai-search = ["azure-search-documents~=11.6.0b1"] chromadb = ["chromadb~=0.4.14"] qdrant = ["qdrant-client>=1.6.2", "fastembed>=0.1.0"] marqo = ["marqo"] @@ -118,6 +119,7 @@ rich = "^13.7.1" psycopg2 = {version = "^2.9.9", optional = true} pgvector = {version = "^0.2.5", optional = true} structlog = "^24.1.0" +azure-search-documents = {version = "^11.6.0b1", optional = true} [tool.poetry.group.dev.dependencies] @@ -138,6 +140,7 @@ weaviate = ["weaviate-client"] milvus = ["pymilvus"] aws = ["boto3"] postgres = ["psycopg2", "pgvector"] +azure-ai-search = ["azure-search-documents"] docs = [ "sphinx", "furo", From fdf76cc2cf14366b2adbceb3dab8e22a08f42b61 Mon Sep 17 00:00:00 2001 From: HARISHKUMAR1112001 Date: Sun, 14 Apr 2024 19:50:08 +0530 Subject: [PATCH 3/8] fix(dspy): update args description for azure ai search retrieval module --- dspy/retrieve/azureaisearch_rm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dspy/retrieve/azureaisearch_rm.py b/dspy/retrieve/azureaisearch_rm.py index 90e00624cf..cb61add571 100644 --- a/dspy/retrieve/azureaisearch_rm.py +++ b/dspy/retrieve/azureaisearch_rm.py @@ -36,9 +36,9 @@ class AzureAISearchRM(dspy.Retrieve): field_text (str): The name of the field containing text content in the search index. This field will be mapped to the "content" field in the dsp framework. field_vector (Optional[str]): The name of the field containing vector content in the search index. Defaults to None. k (int, optional): The default number of top passages to retrieve. Defaults to 3. - azure_openai_client (Optional[openai.AzureOpenAI]): An instance of the AzureOpenAI client. Defaults to None. + azure_openai_client (Optional[openai.AzureOpenAI]): An instance of the AzureOpenAI client. Either openai_client or embedding_func must be provided. Defaults to None. openai_embed_model (Optional[str]): The name of the OpenAI embedding model. Defaults to "text-embedding-ada-002". - embedding_func (Optional[Callable]): A function for generating embeddings. Defaults to None. + embedding_func (Optional[Callable]): A function for generating embeddings. Either openai_client or embedding_func must be provided. Defaults to None. semantic_ranker (bool, optional): Whether to use semantic ranking. Defaults to False. filter (str, optional): Additional filter query. Defaults to None. query_language (str, optional): The language of the query. Defaults to "en-Us". From 94b7fcdb06c527d34da2eb2287539cf5cb266073 Mon Sep 17 00:00:00 2001 From: HARISHKUMAR1112001 Date: Sun, 14 Apr 2024 19:55:46 +0530 Subject: [PATCH 4/8] fix(dspy): update documentation for azure ai search retrieval class --- .../retrieval_models_clients/Azure.mdx | 70 +++++++++++++++++-- 1 file changed, 65 insertions(+), 5 deletions(-) diff --git a/docs/docs/deep-dive/retrieval_models_clients/Azure.mdx b/docs/docs/deep-dive/retrieval_models_clients/Azure.mdx index 3eed7c4a50..d19162bfab 100644 --- a/docs/docs/deep-dive/retrieval_models_clients/Azure.mdx +++ b/docs/docs/deep-dive/retrieval_models_clients/Azure.mdx @@ -22,7 +22,11 @@ The constructor initializes an instance of the `AzureAISearchRM` class and sets - `search_api_key` (str): The API key for accessing the Azure AI Search service. - `search_index_name` (str): The name of the search index in the Azure AI Search service. - `field_text` (str): The name of the field containing text content in the search index. This field will be mapped to the "content" field in the dsp framework. +- `field_vector` (Optional[str]): The name of the field containing vector content in the search index. - `k` (int, optional): The default number of top passages to retrieve. Defaults to 3. +- `azure_openai_client` (Optional[openai.AzureOpenAI]): An instance of the AzureOpenAI client. Either openai_client or embedding_func must be provided. Defaults to None. +- `openai_embed_model` (Optional[str]): The name of the OpenAI embedding model. Defaults to "text-embedding-ada-002". +- `embedding_func` (Optional[Callable]): A function for generating embeddings. Either openai_client or embedding_func must be provided. Defaults to None. - `semantic_ranker` (bool, optional): Whether to use semantic ranking. Defaults to False. - `filter` (str, optional): Additional filter query. Defaults to None. - `query_language` (str, optional): The language of the query. Defaults to "en-Us". @@ -30,24 +34,47 @@ The constructor initializes an instance of the `AzureAISearchRM` class and sets - `use_semantic_captions` (bool, optional): Whether to use semantic captions. Defaults to False. - `query_type` (Optional[QueryType], optional): The type of query. Defaults to QueryType.FULL. - `semantic_configuration_name` (str, optional): The name of the semantic configuration. Defaults to None. +- `is_vector_search` (Optional[bool]): Whether to enable vector search. Defaults to False. +- `is_hybrid_search` (Optional[bool]): Whether to enable hybrid search. Defaults to False. +- `is_fulltext_search` (Optional[bool]): Whether to enable fulltext search. Defaults to True. +- `vector_filter_mode` (Optional[VectorFilterMode]): The vector filter mode. Defaults to None. -Available Query Types: - SIMPLE +**Available Query Types:** + +- SIMPLE """Uses the simple query syntax for searches. Search text is interpreted using a simple query #: language that allows for symbols such as +, * and "". Queries are evaluated across all #: searchable fields by default, unless the searchFields parameter is specified.""" - FULL +- FULL """Uses the full Lucene query syntax for searches. Search text is interpreted using the Lucene #: query language which allows field-specific and weighted searches, as well as other advanced #: features.""" - SEMANTIC +- SEMANTIC """Best suited for queries expressed in natural language as opposed to keywords. Improves #: precision of search results by re-ranking the top search results using a ranking model trained #: on the Web corpus."" More Details: https://learn.microsoft.com/en-us/azure/search/search-query-overview +**Available Vector Filter Mode:** + +- POST_FILTER = "postFilter" + """The filter will be applied after the candidate set of vector results is returned. Depending on + #: the filter selectivity, this can result in fewer results than requested by the parameter 'k'.""" + +- PRE_FILTER = "preFilter" + """The filter will be applied before the search query.""" + + More Details: https://learn.microsoft.com/en-us/azure/search/vector-search-filters + +**Note** + +- The `AzureAISearchRM` client allows you to perform Vector search, Hybrid search, or Full text search. +- By default, the `AzureAISearchRM` client uses the Azure OpenAI Client for generating embeddings. If you want to use something else, you can provide your custom embedding_func, but either the openai_client or embedding_func must be provided. +- If you need to enable semantic search, either with vector, hybrid, or full text search, then set the `semantic_ranker` flag to True. +- If `semantic_ranker` is True, always set the `query_type` to QueryType.SEMANTIC and always provide the `semantic_configuration_name`. + Example of the AzureAISearchRM constructor: ```python @@ -56,14 +83,22 @@ AzureAISearchRM( search_api_key: str, search_index_name: str, field_text: str, + field_vector: Optional[str] = None, k: int = 3, + azure_openai_client: Optional[openai.AzureOpenAI] = None, + openai_embed_model: Optional[str] = "text-embedding-ada-002", + embedding_func: Optional[Callable] = None, semantic_ranker: bool = False, filter: str = None, query_language: str = "en-Us", query_speller: str = "lexicon", use_semantic_captions: bool = False, query_type: Optional[QueryType] = QueryType.FULL, - semantic_configuration_name: str = None + semantic_configuration_name: str = None, + is_vector_search: Optional[bool] = False, + is_hybride_search: Optional[bool] = False, + is_fulltext_search: Optional[bool] = True, + vector_filter_mode: Optional[VectorFilterMode.PRE_FILTER] = None ) ``` @@ -128,6 +163,31 @@ for result in retrieval_response: print("Text:", result.long_text, "\n") ``` +3. Example of Semantic Hybrid Search. + +```python +from dspy.retrieve.azureaisearch_rm import AzureAISearchRM + +azure_search = AzureAISearchRM( + search_service_name="search_service_name", + search_api_key="search_api_key", + search_index_name="search_index_name", + field_text="field_text", + field_vector="field_vector", + k=3, + azure_openai_client="azure_openai_client", + openai_embed_model="text-embedding-ada-002" + semantic_ranker=True, + query_type=QueryType.SEMANTIC, + semantic_configuration_name="semantic_configuration_name", + is_hybrid_search=True, +) + +retrieval_response = azure_search("What is Thermodynamics", k=3) +for result in retrieval_response: + print("Text:", result.long_text, "\n") +``` + *** \ No newline at end of file From 385a35d0daf273050dd1e6531af37110e7e9d4be Mon Sep 17 00:00:00 2001 From: HARISHKUMAR1112001 Date: Sun, 14 Apr 2024 20:01:42 +0530 Subject: [PATCH 5/8] fix(dspy): add author doc string for azure ai search retrieval module --- dspy/retrieve/azureaisearch_rm.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/dspy/retrieve/azureaisearch_rm.py b/dspy/retrieve/azureaisearch_rm.py index cb61add571..42810555b5 100644 --- a/dspy/retrieve/azureaisearch_rm.py +++ b/dspy/retrieve/azureaisearch_rm.py @@ -1,3 +1,8 @@ +""" +Retriever module for Azure AI Search +Author: Prajapati Harishkumar Kishorkumar (@HARISHKUMAR1112001) +""" + import warnings from typing import Any, Callable, List, Optional, Union From 05bba1c88d1e556bbb13f1dd6914a0cba6060498 Mon Sep 17 00:00:00 2001 From: HARISHKUMAR1112001 Date: Sun, 14 Apr 2024 20:27:02 +0530 Subject: [PATCH 6/8] fix(dspy): update the pyproject.toml --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 75c549c007..b0802d3242 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -119,7 +119,7 @@ rich = "^13.7.1" psycopg2 = {version = "^2.9.9", optional = true} pgvector = {version = "^0.2.5", optional = true} structlog = "^24.1.0" -azure-search-documents = {version = "^11.6.0b1", optional = true} +azure-search-documents = {version = "11.6.0b1", optional = true} [tool.poetry.group.dev.dependencies] From d94294c29aeb909706b99454401a80af2aad9ec3 Mon Sep 17 00:00:00 2001 From: HARISHKUMAR1112001 Date: Sun, 14 Apr 2024 20:29:39 +0530 Subject: [PATCH 7/8] fix(dspy): update the pyproject.toml --- pyproject.toml | 3 --- 1 file changed, 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b0802d3242..f99df941a0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,6 @@ dependencies = [ [project.optional-dependencies] anthropic = ["anthropic~=0.18.0"] -azure-ai-search = ["azure-search-documents~=11.6.0b1"] chromadb = ["chromadb~=0.4.14"] qdrant = ["qdrant-client>=1.6.2", "fastembed>=0.1.0"] marqo = ["marqo"] @@ -119,7 +118,6 @@ rich = "^13.7.1" psycopg2 = {version = "^2.9.9", optional = true} pgvector = {version = "^0.2.5", optional = true} structlog = "^24.1.0" -azure-search-documents = {version = "11.6.0b1", optional = true} [tool.poetry.group.dev.dependencies] @@ -140,7 +138,6 @@ weaviate = ["weaviate-client"] milvus = ["pymilvus"] aws = ["boto3"] postgres = ["psycopg2", "pgvector"] -azure-ai-search = ["azure-search-documents"] docs = [ "sphinx", "furo", From 4067d9a3ac28b1e264c8de50d315a8e73e6827ce Mon Sep 17 00:00:00 2001 From: HARISHKUMAR1112001 Date: Tue, 16 Apr 2024 08:23:17 +0530 Subject: [PATCH 8/8] fix(dspy): fix spelling mistake in azure ai search --- .../retrieval_models_clients/Azure.mdx | 2 +- dspy/retrieve/azureaisearch_rm.py | 22 ++++++++++--------- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/docs/docs/deep-dive/retrieval_models_clients/Azure.mdx b/docs/docs/deep-dive/retrieval_models_clients/Azure.mdx index d19162bfab..1f87dde237 100644 --- a/docs/docs/deep-dive/retrieval_models_clients/Azure.mdx +++ b/docs/docs/deep-dive/retrieval_models_clients/Azure.mdx @@ -96,7 +96,7 @@ AzureAISearchRM( query_type: Optional[QueryType] = QueryType.FULL, semantic_configuration_name: str = None, is_vector_search: Optional[bool] = False, - is_hybride_search: Optional[bool] = False, + is_hybrid_search: Optional[bool] = False, is_fulltext_search: Optional[bool] = True, vector_filter_mode: Optional[VectorFilterMode.PRE_FILTER] = None ) diff --git a/dspy/retrieve/azureaisearch_rm.py b/dspy/retrieve/azureaisearch_rm.py index 42810555b5..a115d3344f 100644 --- a/dspy/retrieve/azureaisearch_rm.py +++ b/dspy/retrieve/azureaisearch_rm.py @@ -132,7 +132,7 @@ class AzureAISearchRM(dspy.Retrieve): ) -> List | Any: Returns embeddings for the given query. - check_sementic_configuration( + check_semantic_configuration( self, semantic_configuration_name, query_type @@ -166,7 +166,7 @@ def __init__( query_type: Optional[QueryType] = QueryType.FULL, semantic_configuration_name: str = None, is_vector_search: Optional[bool] = False, - is_hybride_search: Optional[bool] = False, + is_hybrid_search: Optional[bool] = False, is_fulltext_search: Optional[bool] = True, vector_filter_mode: Optional[VectorFilterMode.PRE_FILTER] = None, ): @@ -192,7 +192,7 @@ def __init__( self.query_type = query_type self.semantic_configuration_name = semantic_configuration_name self.is_vector_search = is_vector_search - self.is_hybride_search = is_hybride_search + self.is_hybrid_search = is_hybrid_search self.is_fulltext_search = is_fulltext_search self.vector_filter_mode = vector_filter_mode @@ -224,7 +224,7 @@ def azure_search_request( if is_vector_search: vector_query = self.get_embeddings(query, top, field_vector) if semantic_ranker: - self.check_sementic_configuration(semantic_configuration_name, query_type) + self.check_semantic_configuration(semantic_configuration_name, query_type) results = client.search( search_text=None, filter=filter, @@ -233,6 +233,7 @@ def azure_search_request( vector_filter_mode=vector_filter_mode, semantic_configuration_name=semantic_configuration_name, top=top, + query_caption=("extractive|highlight-false" if use_semantic_captions else None), ) else: results = client.search( @@ -245,7 +246,7 @@ def azure_search_request( if is_hybrid_search: vector_query = self.get_embeddings(query, top, field_vector) if semantic_ranker: - self.check_sementic_configuration(semantic_configuration_name, query_type) + self.check_semantic_configuration(semantic_configuration_name, query_type) results = client.search( search_text=query, filter=filter, @@ -270,7 +271,7 @@ def azure_search_request( ) if is_fulltext_search: if semantic_ranker: - self.check_sementic_configuration(semantic_configuration_name, query_type) + self.check_semantic_configuration(semantic_configuration_name, query_type) results = client.search( search_text=query, filter=filter, @@ -337,7 +338,7 @@ def forward(self, query_or_queries: Union[str, List[str]], k: Optional[int]) -> self.query_type, self.semantic_configuration_name, self.is_vector_search, - self.is_hybride_search, + self.is_hybrid_search, self.is_fulltext_search, self.field_vector, self.vector_filter_mode, @@ -366,9 +367,10 @@ def get_embeddings(self, query: str, k_nearest_neighbors: int, field_vector: str assert ( self.azure_openai_client or self.embedding_func ), "Either azure_openai_client or embedding_func must be provided." - assert field_vector, "field_vector must be provided." - + if self.azure_openai_client is not None: + assert field_vector, "field_vector must be provided." + embedding = ( self.azure_openai_client.embeddings.create(input=query, model=self.openai_embed_model).data[0].embedding ) @@ -379,7 +381,7 @@ def get_embeddings(self, query: str, k_nearest_neighbors: int, field_vector: str else: return self.embedding_func(query) - def check_sementic_configuration(self, semantic_configuration_name, query_type): + def check_semantic_configuration(self, semantic_configuration_name, query_type): """ Checks semantic configuration.