Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 9 additions & 15 deletions dspy/retrieve/faiss_rm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""

import logging
from typing import Union
from typing import Optional, Union

import numpy as np

Expand Down Expand Up @@ -107,8 +107,8 @@ def _dump_raw_results(self, queries, index_list, distance_list) -> None:
logging.debug(f" Hit {j} = {indices[j]}/{distances[j]}: {self._document_chunks[indices[j]]}")
return

def forward(self, query_or_queries: Union[str, list[str]]) -> dspy.Prediction:
"""Search the faiss index for self.k top passages for query.
def forward(self, query_or_queries: Union[str, list[str]], k: Optional[int] = None) -> dspy.Prediction:
"""Search the faiss index for k or self.k top passages for query.

Args:
query_or_queries (Union[str, List[str]]): The query or queries to search for.
Expand All @@ -122,21 +122,20 @@ def forward(self, query_or_queries: Union[str, list[str]]) -> dspy.Prediction:
emb_npa = np.array(embeddings)
# For single query, just look up the top k passages
if len(queries) == 1:
distance_list, index_list = self._faiss_index.search(emb_npa, self.k)
distance_list, index_list = self._faiss_index.search(emb_npa, k or self.k)
# self._dump_raw_results(queries, index_list, distance_list)
passages = [(self._document_chunks[ind], ind) for ind in index_list[0]]
passages = [dotdict({"long_text": passage[0], "index": passage[1]}) for passage in passages]
return dspy.Prediction(passages=passages)
return [dotdict({"long_text": passage[0], "index": passage[1]}) for passage in passages]

distance_list, index_list = self._faiss_index.search(emb_npa, self.k * 3)
distance_list, index_list = self._faiss_index.search(emb_npa, (k or self.k) * 3)
# self._dump_raw_results(queries, index_list, distance_list)
passage_scores = {}
for emb in range(len(embeddings)):
indices = index_list[emb] # indices of neighbors for embeddings[emb] - this is an array of k*3 integers
distances = distance_list[
emb
] # distances of neighbors for embeddings[emb] - this is an array of k*3 floating point numbers
for res in range(self.k * 3):
for res in range((k or self.k) * 3):
neighbor = indices[res]
distance = distances[res]
if neighbor in passage_scores:
Expand All @@ -147,10 +146,5 @@ def forward(self, query_or_queries: Union[str, list[str]]) -> dspy.Prediction:
# first degree sort: number of queries that got a hit with any particular document chunk. More
# is a better match. This is len(queries)-len(x[1])
# second degree sort: sum of the distances of each hit returned by faiss. Smaller distance is a better match
sorted_passages = sorted(passage_scores.items(), key=lambda x: (len(queries) - len(x[1]), sum(x[1])))[: self.k]
return dspy.Prediction(
passages=[
dotdict({"long_text": self._document_chunks[passage_index], "index": passage_index})
for passage_index, _ in sorted_passages
],
)
sorted_passages = sorted(passage_scores.items(), key=lambda x: (len(queries) - len(x[1]), sum(x[1])))[: k or self.k]
return [ dotdict({"long_text": self._document_chunks[passage_index], "index": passage_index}) for passage_index, _ in sorted_passages ]