From caade2ddc0f51a9c9c44332b7ecd855e83e4f012 Mon Sep 17 00:00:00 2001 From: Isaac Miller Date: Tue, 5 Mar 2024 14:20:54 -0600 Subject: [PATCH 1/2] Add optional parameter 'k' to forward method in FaissRM class --- dspy/retrieve/faiss_rm.py | 24 +++++++++--------------- 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/dspy/retrieve/faiss_rm.py b/dspy/retrieve/faiss_rm.py index 7a74ec4b91..b7d16a0062 100644 --- a/dspy/retrieve/faiss_rm.py +++ b/dspy/retrieve/faiss_rm.py @@ -3,7 +3,7 @@ """ import logging -from typing import Union +from typing import Union, Optional import numpy as np @@ -107,8 +107,8 @@ def _dump_raw_results(self, queries, index_list, distance_list) -> None: logging.debug(f" Hit {j} = {indices[j]}/{distances[j]}: {self._document_chunks[indices[j]]}") return - def forward(self, query_or_queries: Union[str, list[str]]) -> dspy.Prediction: - """Search the faiss index for self.k top passages for query. + def forward(self, query_or_queries: Union[str, list[str]], k: Optional[int] = None) -> dspy.Prediction: + """Search the faiss index for k or self.k top passages for query. Args: query_or_queries (Union[str, List[str]]): The query or queries to search for. @@ -122,13 +122,12 @@ def forward(self, query_or_queries: Union[str, list[str]]) -> dspy.Prediction: emb_npa = np.array(embeddings) # For single query, just look up the top k passages if len(queries) == 1: - distance_list, index_list = self._faiss_index.search(emb_npa, self.k) + distance_list, index_list = self._faiss_index.search(emb_npa, k or self.k) # self._dump_raw_results(queries, index_list, distance_list) passages = [(self._document_chunks[ind], ind) for ind in index_list[0]] - passages = [dotdict({"long_text": passage[0], "index": passage[1]}) for passage in passages] - return dspy.Prediction(passages=passages) + return [dotdict({"long_text": passage[0], "index": passage[1]}) for passage in passages] - distance_list, index_list = self._faiss_index.search(emb_npa, self.k * 3) + distance_list, index_list = self._faiss_index.search(emb_npa, (k or self.k) * 3) # self._dump_raw_results(queries, index_list, distance_list) passage_scores = {} for emb in range(len(embeddings)): @@ -136,7 +135,7 @@ def forward(self, query_or_queries: Union[str, list[str]]) -> dspy.Prediction: distances = distance_list[ emb ] # distances of neighbors for embeddings[emb] - this is an array of k*3 floating point numbers - for res in range(self.k * 3): + for res in range((k or self.k) * 3): neighbor = indices[res] distance = distances[res] if neighbor in passage_scores: @@ -147,10 +146,5 @@ def forward(self, query_or_queries: Union[str, list[str]]) -> dspy.Prediction: # first degree sort: number of queries that got a hit with any particular document chunk. More # is a better match. This is len(queries)-len(x[1]) # second degree sort: sum of the distances of each hit returned by faiss. Smaller distance is a better match - sorted_passages = sorted(passage_scores.items(), key=lambda x: (len(queries) - len(x[1]), sum(x[1])))[: self.k] - return dspy.Prediction( - passages=[ - dotdict({"long_text": self._document_chunks[passage_index], "index": passage_index}) - for passage_index, _ in sorted_passages - ], - ) + sorted_passages = sorted(passage_scores.items(), key=lambda x: (len(queries) - len(x[1]), sum(x[1])))[: k or self.k] + return [ dotdict({"long_text": self._document_chunks[passage_index], "index": passage_index}) for passage_index, _ in sorted_passages ] \ No newline at end of file From 6f30d0820d9ea7412fe1cd55759e06a033fe66d9 Mon Sep 17 00:00:00 2001 From: isaacbmiller Date: Tue, 5 Mar 2024 20:22:18 +0000 Subject: [PATCH 2/2] Automatic Style fixes --- dspy/retrieve/faiss_rm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dspy/retrieve/faiss_rm.py b/dspy/retrieve/faiss_rm.py index b7d16a0062..3dd34dbeeb 100644 --- a/dspy/retrieve/faiss_rm.py +++ b/dspy/retrieve/faiss_rm.py @@ -3,7 +3,7 @@ """ import logging -from typing import Union, Optional +from typing import Optional, Union import numpy as np