Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 61 additions & 29 deletions dspy/retrieve/qdrant_rm.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,29 @@
from collections import defaultdict
from typing import List, Optional, Union
from typing import Optional, Union

import dspy
from dsp.modules.sentence_vectorizer import BaseSentenceVectorizer, FastEmbedVectorizer
from dsp.utils import dotdict

try:
import fastembed
from qdrant_client import QdrantClient
except ImportError:
from qdrant_client import QdrantClient, models
except ImportError as e:
raise ImportError(
"The 'qdrant' extra is required to use QdrantRM. Install it with `pip install dspy-ai[qdrant]`",
)
) from e


class QdrantRM(dspy.Retrieve):
"""
A retrieval module that uses Qdrant to return the top passages for a given query.

Assumes that a Qdrant collection has been created and populated with the following payload:
- document: The text of the passage
"""A retrieval module that uses Qdrant to return the top passages for a given query.

Args:
qdrant_collection_name (str): The name of the Qdrant collection.
qdrant_client (QdrantClient): A QdrantClient instance.
k (int, optional): The default number of top passages to retrieve. Defaults to 3.
qdrant_client (QdrantClient): An instance of `qdrant_client.QdrantClient`.
k (int, optional): The default number of top passages to retrieve. Default: 3.
document_field (str, optional): The key in the Qdrant payload with the content. Default: `"document"`.
vectorizer (BaseSentenceVectorizer, optional): An implementation `BaseSentenceVectorizer`.
Default: `FastEmbedVectorizer`.
vector_name (str, optional): Name of the vector in the collection. Default: The first available vector name.

Examples:
Below is a code snippet that shows how to use Qdrant as the default retriver:
Expand All @@ -47,43 +47,75 @@ def __init__(
qdrant_collection_name: str,
qdrant_client: QdrantClient,
k: int = 3,
document_field: str = "document",
vectorizer: Optional[BaseSentenceVectorizer] = None,
vector_name: Optional[str] = None,
):
self._qdrant_collection_name = qdrant_collection_name
self._qdrant_client = qdrant_client
self._collection_name = qdrant_collection_name
self._client = qdrant_client

self._vectorizer = vectorizer or FastEmbedVectorizer(self._client.embedding_model_name)

self._document_field = document_field

self._vector_name = vector_name or self._get_first_vector_name()

super().__init__(k=k)

def forward(self, query_or_queries: Union[str, List[str]], k: Optional[int] = None,**kwargs) -> dspy.Prediction:
"""Search with Qdrant for self.k top passages for query
def forward(self, query_or_queries: Union[str, list[str]], k: Optional[int] = None) -> dspy.Prediction:
"""Search with Qdrant for self.k top passages for query.

Args:
query_or_queries (Union[str, List[str]]): The query or queries to search for.
k (Optional[int]): The number of top passages to retrieve. Defaults to self.k.

Returns:
dspy.Prediction: An object containing the retrieved passages.
"""
queries = (
[query_or_queries]
if isinstance(query_or_queries, str)
else query_or_queries
)
queries = [query_or_queries] if isinstance(query_or_queries, str) else query_or_queries
queries = [q for q in queries if q] # Filter empty queries

k = k if k is not None else self.k
batch_results = self._qdrant_client.query_batch(
self._qdrant_collection_name, query_texts=queries, limit=k,**kwargs)
vectors = self._vectorizer(queries)

# If vector_name is None
# vector = [0.8, 0.2, 0.3...]
# Else
# vector = {"name": vector_name, "vector": [0.8, 0.2, 0.3...]}
vectors = [
vector if self._vector_name is None else {"name": self._vector_name, "vector": vector} for vector in vectors
]

search_requests = [
models.SearchRequest(
vector=vector,
limit=k or self.k,
with_payload=[self._document_field],
)
for vector in vectors
]
batch_results = self._client.search_batch(self._collection_name, requests=search_requests)

passages_scores = defaultdict(float)
for batch in batch_results:
for result in batch:
# If a passage is returned multiple times, the score is accumulated.
passages_scores[result.document] += result.score
document = result.payload.get(self._document_field)
passages_scores[document] += result.score

# Sort passages by their accumulated scores in descending order
sorted_passages = sorted(
passages_scores.items(), key=lambda x: x[1], reverse=True)[:k]
sorted_passages = sorted(passages_scores.items(), key=lambda x: x[1], reverse=True)[:k]

# Wrap each sorted passage in a dotdict with 'long_text'
passages = [dotdict({"long_text": passage}) for passage, _ in sorted_passages]
return [dotdict({"long_text": passage}) for passage, _ in sorted_passages]

def _get_first_vector_name(self) -> str | None:
vectors = self._client.get_collection(self._collection_name).config.params.vectors

if not isinstance(vectors, dict):
# The collection only has the default, unnamed vector
return None

first_vector_name = list(vectors.keys())[0]

return passages
# The collection has multiple vectors. Could also include the falsy unnamed vector - Empty string("")
return first_vector_name or None