From f8f72035c5e4cce94210999e09359e797040e3be Mon Sep 17 00:00:00 2001 From: Anush008 Date: Mon, 6 May 2024 13:00:07 +0530 Subject: [PATCH 1/2] refactor: QdrantRM --- dspy/retrieve/qdrant_rm.py | 90 ++++++++++++++++++++++++++------------ 1 file changed, 61 insertions(+), 29 deletions(-) diff --git a/dspy/retrieve/qdrant_rm.py b/dspy/retrieve/qdrant_rm.py index 0ffd42dd80..1172cd23c0 100644 --- a/dspy/retrieve/qdrant_rm.py +++ b/dspy/retrieve/qdrant_rm.py @@ -1,29 +1,29 @@ from collections import defaultdict -from typing import List, Optional, Union +from typing import Optional, Union import dspy +from dsp.modules.sentence_vectorizer import BaseSentenceVectorizer, FastEmbedVectorizer from dsp.utils import dotdict try: - import fastembed - from qdrant_client import QdrantClient -except ImportError: + from qdrant_client import QdrantClient, models +except ImportError as e: raise ImportError( "The 'qdrant' extra is required to use QdrantRM. Install it with `pip install dspy-ai[qdrant]`", - ) + ) from e class QdrantRM(dspy.Retrieve): - """ - A retrieval module that uses Qdrant to return the top passages for a given query. - - Assumes that a Qdrant collection has been created and populated with the following payload: - - document: The text of the passage + """A retrieval module that uses Qdrant to return the top passages for a given query. Args: qdrant_collection_name (str): The name of the Qdrant collection. - qdrant_client (QdrantClient): A QdrantClient instance. - k (int, optional): The default number of top passages to retrieve. Defaults to 3. + qdrant_client (QdrantClient): An instance of `qdrant_client.QdrantClient`. + k (int, optional): The default number of top passages to retrieve. Default: 3. + document_field (str, optional): The key in the Qdrant payload with the content. Default: `"document"`. + vectorizer (BaseSentenceVectorizer, optional): An implementation `BaseSentenceVectorizer`. + Default: `FastEmbedVectorizer`. + vector_name (str, optional): Name of the vector in the collection. Default: The first available vector name. Examples: Below is a code snippet that shows how to use Qdrant as the default retriver: @@ -47,43 +47,75 @@ def __init__( qdrant_collection_name: str, qdrant_client: QdrantClient, k: int = 3, + document_field: str = "document", + vectorizer: Optional[BaseSentenceVectorizer] = None, + vector_name: Optional[str] = None, ): - self._qdrant_collection_name = qdrant_collection_name - self._qdrant_client = qdrant_client + self._collection_name = qdrant_collection_name + self._client = qdrant_client + + self._vectorizer = vectorizer or FastEmbedVectorizer(self._client.embedding_model_name) + + self._document_field = document_field + + self._vector_name = vector_name or self._get_first_vector_name() super().__init__(k=k) - def forward(self, query_or_queries: Union[str, List[str]], k: Optional[int] = None,**kwargs) -> dspy.Prediction: - """Search with Qdrant for self.k top passages for query + def forward(self, query_or_queries: Union[str, list[str]], k: Optional[int] = None) -> dspy.Prediction: + """Search with Qdrant for self.k top passages for query. Args: query_or_queries (Union[str, List[str]]): The query or queries to search for. k (Optional[int]): The number of top passages to retrieve. Defaults to self.k. + Returns: dspy.Prediction: An object containing the retrieved passages. """ - queries = ( - [query_or_queries] - if isinstance(query_or_queries, str) - else query_or_queries - ) + queries = [query_or_queries] if isinstance(query_or_queries, str) else query_or_queries queries = [q for q in queries if q] # Filter empty queries - k = k if k is not None else self.k - batch_results = self._qdrant_client.query_batch( - self._qdrant_collection_name, query_texts=queries, limit=k,**kwargs) + vectors = self._vectorizer(queries) + + # If vector_name is None + # vector = [0.8, 0.2, 0.3...] + # Else + # vector = {"name": vector_name, "vector": [0.8, 0.2, 0.3...]} + vectors = [ + vector if self._vector_name is None else {"name": self._vector_name, "vector": vector} for vector in vectors + ] + + search_requests = [ + models.SearchRequest( + vector=vector, + limit=k or self.k, + with_payload=[self._document_field], + ) + for vector in vectors + ] + batch_results = self._client.search_batch(self._collection_name, requests=search_requests) passages_scores = defaultdict(float) for batch in batch_results: for result in batch: # If a passage is returned multiple times, the score is accumulated. - passages_scores[result.document] += result.score + document = result.payload.get(self._document_field) + passages_scores[document] += result.score # Sort passages by their accumulated scores in descending order - sorted_passages = sorted( - passages_scores.items(), key=lambda x: x[1], reverse=True)[:k] + sorted_passages = sorted(passages_scores.items(), key=lambda x: x[1], reverse=True)[:k] # Wrap each sorted passage in a dotdict with 'long_text' - passages = [dotdict({"long_text": passage}) for passage, _ in sorted_passages] + return [dotdict({"long_text": passage}) for passage, _ in sorted_passages] + + def _get_first_vector_name(self) -> str | None: + vectors = self._client.get_collection(self._collection_name).config.params.vectors + + if not isinstance(vectors, dict): + # The collection only has the default, unnamed vector + return None + + first_vector_name = list(vectors.keys())[0] - return passages + # The collection has multiple vectors. Could also include the falsy unnamed vector - Empty string("") + return first_vector_name if first_vector_name else None From 33949d015e129b6c128108863abc5f07306969af Mon Sep 17 00:00:00 2001 From: Anush Date: Wed, 8 May 2024 08:42:55 +0530 Subject: [PATCH 2/2] chore: Simplified return qdrant_rm.py --- dspy/retrieve/qdrant_rm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dspy/retrieve/qdrant_rm.py b/dspy/retrieve/qdrant_rm.py index 1172cd23c0..89a5849288 100644 --- a/dspy/retrieve/qdrant_rm.py +++ b/dspy/retrieve/qdrant_rm.py @@ -118,4 +118,4 @@ def _get_first_vector_name(self) -> str | None: first_vector_name = list(vectors.keys())[0] # The collection has multiple vectors. Could also include the falsy unnamed vector - Empty string("") - return first_vector_name if first_vector_name else None + return first_vector_name or None