diff --git a/dspy/predict/predict.py b/dspy/predict/predict.py index 7188044ed7..0f93fd2781 100644 --- a/dspy/predict/predict.py +++ b/dspy/predict/predict.py @@ -10,15 +10,12 @@ from dspy.primitives.program import Module from dspy.signatures.signature import ensure_signature, signature_to_template from dspy.utils.callback import with_callbacks - - from dspy.adapters.image_utils import Image @lru_cache(maxsize=None) def warn_once(msg: str): logging.warning(msg) - class Predict(Module, Parameter): def __init__(self, signature, _parse_values=True, callbacks=None, **config): self.stage = random.randbytes(8).hex() @@ -71,10 +68,13 @@ def load_state(self, state, use_legacy_loading=False): state (dict): The saved state of a `Predict` object. use_legacy_loading (bool): Whether to use the legacy loading method. Only use it when you are loading a saved state from a version of DSPy prior to v2.5.3. + Returns: + self: Returns self to allow method chaining """ if use_legacy_loading: self._load_state_legacy(state) - return + return self + if "signature" not in state: # Check if the state is from a version of DSPy prior to v2.5.3. raise ValueError( @@ -102,10 +102,14 @@ def load_state(self, state, use_legacy_loading=False): if "extended_signature" in state: self.extended_signature = self.extended_signature.load_state(state["extended_signature"]) + return self + def _load_state_legacy(self, state): """Legacy state loading for backwards compatibility. This method is used to load the saved state of a `Predict` object from a version of DSPy prior to v2.5.3. + Returns: + self: Returns self to allow method chaining """ for name, value in state.items(): setattr(self, name, value) @@ -130,6 +134,21 @@ def _load_state_legacy(self, state): *_, last_key = self.extended_signature.fields.keys() self.extended_signature = self.extended_signature.with_updated_fields(last_key, prefix=prefix) + return self + + def load(self, path, return_self=False): + """Load a saved state from a file. + + Args: + path (str): Path to the saved state file + return_self (bool): If True, returns self to allow method chaining. Default is False for backwards compatibility. + + Returns: + Union[None, Predict]: Returns None if return_self is False (default), returns self if return_self is True + """ + super().load(path) + return self if return_self else None + @with_callbacks def __call__(self, **kwargs): return self.forward(**kwargs) @@ -213,8 +232,6 @@ def old_generate(demos, signature, kwargs, config, lm, stage): with dsp.settings.context(lm=lm, query_only=True): x, C = dsp.generate(template, **config)(x, stage=stage) - # assert stage in x, "The generated (input, output) example was not stored" - completions = [] for c in C: @@ -279,7 +296,6 @@ def v2_5_generate(lm, lm_kwargs, signature, demos, inputs, _parse_values=True): lm, lm_kwargs=lm_kwargs, signature=signature, demos=demos, inputs=inputs, _parse_values=_parse_values ) - # TODO: get some defaults during init from the context window? # # TODO: FIXME: Hmm, I guess expected behavior is that contexts can # affect execution. Well, we need to determine whether context dominates, __init__ demoninates, or forward dominates. diff --git a/tests/predict/test_predict.py b/tests/predict/test_predict.py index 8d0121eedd..3fb4fce4e1 100644 --- a/tests/predict/test_predict.py +++ b/tests/predict/test_predict.py @@ -218,3 +218,48 @@ class OutputOnlySignature(dspy.Signature): lm = DummyLM([{"output": "short answer"}]) dspy.settings.configure(lm=lm) assert predictor().output == "short answer" + + + +def test_chainable_load(tmp_path): + """Test both traditional and chainable load methods.""" + + file_path = tmp_path / "test_chainable.json" + + + original = Predict("question -> answer") + original.demos = [{"question": "test", "answer": "response"}] + original.save(file_path) + + + traditional = Predict("question -> answer") + traditional.load(file_path) + assert traditional.demos == original.demos + + + chainable = Predict("question -> answer").load(file_path, return_self=True) + assert chainable is not None + assert chainable.demos == original.demos + + + assert chainable.signature.dump_state() == original.signature.dump_state() + + + result = Predict("question -> answer").load(file_path) + assert result is None + +def test_load_state_chaining(): + """Test that load_state returns self for chaining.""" + original = Predict("question -> answer") + original.demos = [{"question": "test", "answer": "response"}] + state = original.dump_state() + + + new_instance = Predict("question -> answer").load_state(state) + assert new_instance is not None + assert new_instance.demos == original.demos + + + legacy_instance = Predict("question -> answer").load_state(state, use_legacy_loading=True) + assert legacy_instance is not None + assert legacy_instance.demos == original.demos \ No newline at end of file