From b71b2364b850b359aa999cc524b822366f7ec27b Mon Sep 17 00:00:00 2001 From: koshyviv <32818583+koshyviv@users.noreply.github.com> Date: Wed, 28 Feb 2024 16:44:24 +0530 Subject: [PATCH] handle passages of Prediction objects --- dsp/primitives/search.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/dsp/primitives/search.py b/dsp/primitives/search.py index a2071327d8..42cf17bdc9 100644 --- a/dsp/primitives/search.py +++ b/dsp/primitives/search.py @@ -10,6 +10,11 @@ def retrieve(query: str, k: int, **kwargs) -> list[str]: if not dsp.settings.rm: raise AssertionError("No RM is loaded.") passages = dsp.settings.rm(query, k=k, **kwargs) + + # Check if the returned object has a 'passages' attribute (i.e a Prediction object) + # TODO: use a better approach to determine the Prediction object + if hasattr(passages, 'passages'): + passages = passages.passages if not isinstance(passages, Iterable): # it's not an iterable yet; make it one. # TODO: we should unify the type signatures of dspy.Retriever