-
Notifications
You must be signed in to change notification settings - Fork 2.4k
Colbert local mode support both as retriever and reranker. #797
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
9632e5e
e415f39
a4b3844
321a768
6cd1d56
eeafacb
1639bd2
ec062b6
987d923
9ff5b28
825a272
c25e9c4
ab5b12e
63dd534
f6a9293
197a2c2
4698b00
81d142f
b73753c
0ec1ded
567d5c4
685df2a
fa2bc20
509b36c
34328fd
146ec7b
f0437e3
9cb522b
ec4b9b3
326ce01
b5913fc
c60fadc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
| @@ -1,9 +1,11 @@ | ||||
| import logging | ||||
| from collections.abc import Iterable | ||||
|
|
||||
| import numpy as np | ||||
|
|
||||
| import dsp | ||||
|
|
||||
| logger = logging.getLogger(__name__) | ||||
|
|
||||
| def retrieve(query: str, k: int, **kwargs) -> list[str]: | ||||
| """Retrieves passages from the RM for the query and returns the top k passages.""" | ||||
|
|
@@ -15,12 +17,25 @@ def retrieve(query: str, k: int, **kwargs) -> list[str]: | |||
| # TODO: we should unify the type signatures of dspy.Retriever | ||||
| passages = [passages] | ||||
| passages = [psg.long_text for psg in passages] | ||||
|
|
||||
| if dsp.settings.reranker: | ||||
| passages_cs_scores = dsp.settings.reranker(query, passages) | ||||
| passages_cs_scores_sorted = np.argsort(passages_cs_scores)[::-1] | ||||
| passages = [passages[idx] for idx in passages_cs_scores_sorted] | ||||
|
|
||||
|
|
||||
| return passages | ||||
| def retrievewithMetadata(query: str, k: int, **kwargs) -> list[str]: | ||||
| """Retrieves passages from the RM for the query and returns the top k passages.""" | ||||
|
|
||||
| if not dsp.settings.rm: | ||||
| raise AssertionError("No RM is loaded.") | ||||
| passages = dsp.settings.rm(query, k=k, **kwargs) | ||||
| if not isinstance(passages, Iterable): | ||||
| # it's not an iterable yet; make it one. | ||||
| # TODO: we should unify the type signatures of dspy.Retriever | ||||
| passages = [passages] | ||||
|
|
||||
| return passages | ||||
|
|
||||
|
|
||||
|
|
@@ -38,9 +53,31 @@ def retrieveRerankEnsemble(queries: list[str], k: int,**kwargs) -> list[str]: | |||
| passages_cs_scores[idx], | ||||
| ] | ||||
|
|
||||
|
|
||||
| passages = [(np.average(score), text) for text, score in passages.items()] | ||||
| return [text for _, text in sorted(passages, reverse=True)[:k]] | ||||
|
|
||||
| def retrieveRerankEnsemblewithMetadata(queries: list[str], k: int, **kwargs) -> list[str]: | ||||
| if not (dsp.settings.rm and dsp.settings.reranker): | ||||
| raise AssertionError("Both RM and Reranker are needed to retrieve & re-rank.") | ||||
| queries = [q for q in queries if q] | ||||
| all_queries_passages = [] | ||||
| for query in queries: | ||||
| passages = [] | ||||
| retrieved_passages = dsp.settings.rm(query, k=k * 3, **kwargs) | ||||
| passages_cs_scores = dsp.settings.reranker( | ||||
| query, passages=[psg["long_text"] for psg in retrieved_passages], | ||||
| ) | ||||
| for idx in np.argsort(passages_cs_scores)[::-1][:k]: | ||||
| curr_passage = retrieved_passages[idx] | ||||
| curr_passage["rerank_score"] = passages_cs_scores[idx] | ||||
| passages.append(curr_passage) | ||||
| all_queries_passages.append(passages) | ||||
| if len(queries) == 1: | ||||
| return all_queries_passages[0] | ||||
| else: | ||||
| return all_queries_passages | ||||
|
|
||||
|
|
||||
| def retrieveEnsemble(queries: list[str], k: int, by_prob: bool = True,**kwargs) -> list[str]: | ||||
| """Retrieves passages from the RM for each query in queries and returns the top k passages | ||||
|
|
@@ -50,7 +87,6 @@ def retrieveEnsemble(queries: list[str], k: int, by_prob: bool = True,**kwargs) | |||
| raise AssertionError("No RM is loaded.") | ||||
| if dsp.settings.reranker: | ||||
| return retrieveRerankEnsemble(queries, k, **kwargs) | ||||
|
|
||||
| queries = [q for q in queries if q] | ||||
|
|
||||
| if len(queries) == 1: | ||||
|
|
@@ -68,4 +104,43 @@ def retrieveEnsemble(queries: list[str], k: int, by_prob: bool = True,**kwargs) | |||
| passages = sorted(passages, reverse=True)[:k] | ||||
| passages = [text for _, text in passages] | ||||
|
|
||||
|
|
||||
| return passages | ||||
|
|
||||
| def retrieveEnsemblewithMetadata( | ||||
| queries: list[str], k: int, by_prob: bool = True, **kwargs, | ||||
| ) -> list[str]: | ||||
| """Retrieves passages from the RM for each query in queries and returns the top k passages | ||||
| based on the probability or score. | ||||
| """ | ||||
|
|
||||
| if not dsp.settings.rm: | ||||
| raise AssertionError("No RM is loaded.") | ||||
| if not dsp.settings.reranker: | ||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. logging here too
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. as with above - dspy/dspy/evaluate/evaluate.py Line 56 in d09d984
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. as with above- |
||||
| return retrieveRerankEnsemblewithMetadata(queries=queries,k=k) | ||||
|
|
||||
| queries = [q for q in queries if q] | ||||
|
|
||||
| if len(queries) == 1: | ||||
| return retrieve(queries[0], k) | ||||
| all_queries_passages = [] | ||||
| for q in queries: | ||||
| passages = {} | ||||
| retrieved_passages = dsp.settings.rm(q, k=k * 3, **kwargs) | ||||
| for idx, psg in enumerate(retrieved_passages): | ||||
| if by_prob: | ||||
| passages[(idx, psg.long_text)] = ( | ||||
| passages.get(psg.long_text, 0.0) + psg.prob | ||||
| ) | ||||
| else: | ||||
| passages[(idx, psg.long_text)] = ( | ||||
| passages.get(psg.long_text, 0.0) + psg.score | ||||
| ) | ||||
| retrieved_passages[idx]["tracking_idx"] = idx | ||||
| passages = sorted(passages.items(), key=lambda item: item[1])[:k] | ||||
| req_indices = [psg[0][0] for psg in passages] | ||||
| passages = [ | ||||
| rp for rp in retrieved_passages if rp.get("tracking_idx") in req_indices | ||||
| ] | ||||
| all_queries_passages.append(passages) | ||||
| return all_queries_passages | ||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1 +1 @@ | ||
| from .retrieve import Retrieve | ||
| from .retrieve import Retrieve, RetrieveThenRerank |
Uh oh!
There was an error while loading. Please reload this page.