Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 23 additions & 7 deletions dspy/predict/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,12 @@
from dspy.primitives.program import Module
from dspy.signatures.signature import ensure_signature, signature_to_template
from dspy.utils.callback import with_callbacks


from dspy.adapters.image_utils import Image

@lru_cache(maxsize=None)
def warn_once(msg: str):
logging.warning(msg)


class Predict(Module, Parameter):
def __init__(self, signature, _parse_values=True, callbacks=None, **config):
self.stage = random.randbytes(8).hex()
Expand Down Expand Up @@ -71,10 +68,13 @@ def load_state(self, state, use_legacy_loading=False):
state (dict): The saved state of a `Predict` object.
use_legacy_loading (bool): Whether to use the legacy loading method. Only use it when you are loading a
saved state from a version of DSPy prior to v2.5.3.
Returns:
self: Returns self to allow method chaining
"""
if use_legacy_loading:
self._load_state_legacy(state)
return
return self

if "signature" not in state:
# Check if the state is from a version of DSPy prior to v2.5.3.
raise ValueError(
Expand Down Expand Up @@ -102,10 +102,14 @@ def load_state(self, state, use_legacy_loading=False):
if "extended_signature" in state:
self.extended_signature = self.extended_signature.load_state(state["extended_signature"])

return self

def _load_state_legacy(self, state):
"""Legacy state loading for backwards compatibility.

This method is used to load the saved state of a `Predict` object from a version of DSPy prior to v2.5.3.
Returns:
self: Returns self to allow method chaining
"""
for name, value in state.items():
setattr(self, name, value)
Expand All @@ -130,6 +134,21 @@ def _load_state_legacy(self, state):
*_, last_key = self.extended_signature.fields.keys()
self.extended_signature = self.extended_signature.with_updated_fields(last_key, prefix=prefix)

return self

def load(self, path, return_self=False):
"""Load a saved state from a file.

Args:
path (str): Path to the saved state file
return_self (bool): If True, returns self to allow method chaining. Default is False for backwards compatibility.

Returns:
Union[None, Predict]: Returns None if return_self is False (default), returns self if return_self is True
"""
super().load(path)
return self if return_self else None

@with_callbacks
def __call__(self, **kwargs):
return self.forward(**kwargs)
Expand Down Expand Up @@ -213,8 +232,6 @@ def old_generate(demos, signature, kwargs, config, lm, stage):
with dsp.settings.context(lm=lm, query_only=True):
x, C = dsp.generate(template, **config)(x, stage=stage)

# assert stage in x, "The generated (input, output) example was not stored"

completions = []

for c in C:
Expand Down Expand Up @@ -279,7 +296,6 @@ def v2_5_generate(lm, lm_kwargs, signature, demos, inputs, _parse_values=True):
lm, lm_kwargs=lm_kwargs, signature=signature, demos=demos, inputs=inputs, _parse_values=_parse_values
)


# TODO: get some defaults during init from the context window?
# # TODO: FIXME: Hmm, I guess expected behavior is that contexts can
# affect execution. Well, we need to determine whether context dominates, __init__ demoninates, or forward dominates.
Expand Down
45 changes: 45 additions & 0 deletions tests/predict/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,3 +218,48 @@ class OutputOnlySignature(dspy.Signature):
lm = DummyLM([{"output": "short answer"}])
dspy.settings.configure(lm=lm)
assert predictor().output == "short answer"



def test_chainable_load(tmp_path):
"""Test both traditional and chainable load methods."""

file_path = tmp_path / "test_chainable.json"


original = Predict("question -> answer")
original.demos = [{"question": "test", "answer": "response"}]
original.save(file_path)


traditional = Predict("question -> answer")
traditional.load(file_path)
assert traditional.demos == original.demos


chainable = Predict("question -> answer").load(file_path, return_self=True)
assert chainable is not None
assert chainable.demos == original.demos


assert chainable.signature.dump_state() == original.signature.dump_state()


result = Predict("question -> answer").load(file_path)
assert result is None

def test_load_state_chaining():
"""Test that load_state returns self for chaining."""
original = Predict("question -> answer")
original.demos = [{"question": "test", "answer": "response"}]
state = original.dump_state()


new_instance = Predict("question -> answer").load_state(state)
assert new_instance is not None
assert new_instance.demos == original.demos


legacy_instance = Predict("question -> answer").load_state(state, use_legacy_loading=True)
assert legacy_instance is not None
assert legacy_instance.demos == original.demos
Loading