From d139afa6575025804c25400a92e37321148a6e9e Mon Sep 17 00:00:00 2001 From: alpha Date: Sat, 2 Mar 2024 21:18:53 -0500 Subject: [PATCH 1/4] Add Elastic Search Retriever Module --- dspy/retrieve/elasticsearch_rm.py | 68 +++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) create mode 100644 dspy/retrieve/elasticsearch_rm.py diff --git a/dspy/retrieve/elasticsearch_rm.py b/dspy/retrieve/elasticsearch_rm.py new file mode 100644 index 0000000000..7502b33609 --- /dev/null +++ b/dspy/retrieve/elasticsearch_rm.py @@ -0,0 +1,68 @@ +import dspy + +class elastic_rm(dspy.Retrieve): + def __init__(self, es_client, es_index, es_field, k=3): + """" + A retrieval module that uses Elastic simple vector search to return the top passages for a given query. + Assumes that you already have instanciate your ESClient. + + The code has been tested with ElasticSearch 8.12 + For more information on how to instanciate your ESClient, please refer to the official documentation. + Ref: https://www.elastic.co/guide/en/elasticsearch/client/python-api/current/connecting.html + + Args: + es_client (Elasticsearch): An instance of the Elasticsearch client. + es_index (str): The name of the index to search. + es_field (str): The name of the field to search. + k (Optional[int]): The number of context strings to return. Default is 3. + """ + super().__init__() + self.k=k + self.es_index=es_index + self.es_client=es_client + self.field=es_field + + + def forward(self, query) -> dspy.Prediction: + """Search with Elastic Search - local or cloud for top k passages for query or queries + + + Args: + query_or_queries (Union[str, List[str]]): The query or queries to search for. + k (Optional[int]): The number of context strings to return, if not already specified in self.k + + Returns: + dspy.Prediction: An object containing the retrieved passages. + """ + + passages = [] + + # Define the index to search + index_name = self.es_index #the name of the index of your elastic-search-dump + + # Define the search query + search_query = { + "query": { + "match": { + self.field: query #took for granted that your index has : title, text as document format + } + } + } + + # Perform the search + response = self.es_client.search(index=index_name, body=search_query) + + for hit in response['hits']['hits']: + + #Uncomment for debug... + # Retrieve the score + #score = hit["_score"] + # Retrieve other fields from the source + #title = hit["_source"]["title"] + text = hit["_source"]["text"] + #print("Score: %.2f | Tile: %s | Text: %s" % (score,title, text)) + passages.append(text) + if len(passages) == self.k: # Break the loop once k documents are retrieved + break + + return dspy.Prediction(passages=passages) \ No newline at end of file From c88a18feec343a11c18468c6d96b4ad31794b6be Mon Sep 17 00:00:00 2001 From: alpha Date: Sat, 2 Mar 2024 21:44:47 -0500 Subject: [PATCH 2/4] Add Elastic Search Retriever Module - Minor Fix --- dspy/retrieve/elasticsearch_rm.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/dspy/retrieve/elasticsearch_rm.py b/dspy/retrieve/elasticsearch_rm.py index 7502b33609..b1d9c33d2f 100644 --- a/dspy/retrieve/elasticsearch_rm.py +++ b/dspy/retrieve/elasticsearch_rm.py @@ -1,4 +1,5 @@ import dspy +from typing import Optional class elastic_rm(dspy.Retrieve): def __init__(self, es_client, es_index, es_field, k=3): @@ -23,7 +24,7 @@ def __init__(self, es_client, es_index, es_field, k=3): self.field=es_field - def forward(self, query) -> dspy.Prediction: + def forward(self, query,k: Optional[int] = None) -> dspy.Prediction: """Search with Elastic Search - local or cloud for top k passages for query or queries @@ -35,6 +36,8 @@ def forward(self, query) -> dspy.Prediction: dspy.Prediction: An object containing the retrieved passages. """ + k = k if k is not None else self.k + passages = [] # Define the index to search From 6d7e5babaf6b26f1475f51b5acbf8f69b5a1d7d1 Mon Sep 17 00:00:00 2001 From: PM <47463801+pmenkidoo@users.noreply.github.com> Date: Sun, 3 Mar 2024 20:38:50 -0500 Subject: [PATCH 3/4] Update elasticsearch_rm.py --- dspy/retrieve/elasticsearch_rm.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/dspy/retrieve/elasticsearch_rm.py b/dspy/retrieve/elasticsearch_rm.py index b1d9c33d2f..b030ae796b 100644 --- a/dspy/retrieve/elasticsearch_rm.py +++ b/dspy/retrieve/elasticsearch_rm.py @@ -40,32 +40,24 @@ def forward(self, query,k: Optional[int] = None) -> dspy.Prediction: passages = [] - # Define the index to search + index_name = self.es_index #the name of the index of your elastic-search-dump - # Define the search query search_query = { "query": { "match": { - self.field: query #took for granted that your index has : title, text as document format + self.field: query } } } - # Perform the search response = self.es_client.search(index=index_name, body=search_query) for hit in response['hits']['hits']: - #Uncomment for debug... - # Retrieve the score - #score = hit["_score"] - # Retrieve other fields from the source - #title = hit["_source"]["title"] text = hit["_source"]["text"] - #print("Score: %.2f | Tile: %s | Text: %s" % (score,title, text)) passages.append(text) if len(passages) == self.k: # Break the loop once k documents are retrieved break - return dspy.Prediction(passages=passages) \ No newline at end of file + return dspy.Prediction(passages=passages) From cfdfde04d577399b077214a01b545ad8bca0ed1b Mon Sep 17 00:00:00 2001 From: PM <47463801+pmenkidoo@users.noreply.github.com> Date: Mon, 4 Mar 2024 12:06:18 -0500 Subject: [PATCH 4/4] Create ElasticSearch.md --- .../retrieval_model_clients/ElasticSearch.md | 85 +++++++++++++++++++ 1 file changed, 85 insertions(+) create mode 100644 docs/api/retrieval_model_clients/ElasticSearch.md diff --git a/docs/api/retrieval_model_clients/ElasticSearch.md b/docs/api/retrieval_model_clients/ElasticSearch.md new file mode 100644 index 0000000000..e2bcde49e7 --- /dev/null +++ b/docs/api/retrieval_model_clients/ElasticSearch.md @@ -0,0 +1,85 @@ + +# retrieve.elastic_rm + +### Constructor + +Initialize an instance of the `elastic_rm` class, . + +```python +elastic_rm( + es_client: str, + es_index: str, + es_field: str, + k: int = 3, +) +``` + +**Parameters:** +- `es_client` (_str_): The Elastic Search Client previously created and initialized (Ref. 1) +- `es_index` (_str_): Path to the directory where chromadb data is persisted. +- `es_field` (_str): The function used for embedding documents and queries. Defaults to `DefaultEmbeddingFunction()` if not specified. +- `k` (_int_, _optional_): The number of top passages to retrieve. Defaults to 3. + +Ref. 1 - Connecting to Elastic Cloud - +https://www.elastic.co/guide/en/elasticsearch/client/python-api/current/connecting.html + +### Methods + +#### `forward(self, query: [str], k: Optional[int] = None) -> dspy.Prediction` + +Search the chromadb collection for the top `k` passages matching the given query or queries, using embeddings generated via the specified `embedding_function`. + +**Parameters:** +- `query` (str_): The query. +- `k` (_Optional[int]_, _optional_): The number of results to retrieve. If not specified, defaults to the value set during initialization. + +**Returns:** +- `dspy.Prediction`: Contains the retrieved passages as a list of string with the prediction signature. + +ex: +```python +Prediction( + passages=['Passage 1 Lorem Ipsum awesome', 'Passage 2 Lorem Ipsum Youppidoo', 'Passage 3 Lorem Ipsum Yassssss'] +) +``` + +### Quick Example how to use Elastic Search in a local environment. + +Please refer to official doc if your instance is in the cloud. See (Ref. 1) above. + +```python +from dspy.retrieve import elastic_rm +import os +from elasticsearch import Elasticsearch + + +ELASTIC_PASSWORD = os.getenv('ELASTIC_PASSWORD') + +# Create the client instance +es = Elasticsearch( + "https://localhost:9200", + ca_certs="http_ca.crt", #Make sure you specifi the path to the certificate, generate one if you don't have. + basic_auth=("elastic", ELASTIC_PASSWORD) +) + +# Check your connection +if es.ping(): + print("Connected to Elasticsearch cluster") +else: + print("Could not connect to Elasticsearch") + +# Index name you want to search +index_name = "wiki-summary" + +retriever_model = elastic_rm( + 'es_client', + 'es_index', + es_field=embedding_function, + k=3 +) + +results = retriever_model("Explore the significance of quantum computing", k=3) + +for passage in results.passages: + print("Document:", result, "\n") +```