From 18f3b686cccbe0a009168154240cf710af2f4876 Mon Sep 17 00:00:00 2001 From: Omar Khattab Date: Mon, 16 Dec 2024 11:09:46 -0800 Subject: [PATCH 1/4] Simplify built-in modules (remove extended_signature and new_signature) and remove assertions temporarily --- dspy/predict/__init__.py | 4 +- dspy/predict/chain_of_thought.py | 45 +- dspy/predict/predict.py | 15 +- dspy/predict/react.py | 12 + dspy/predict/retry.py | 122 +++--- dspy/primitives/assertions.py | 674 ++++++++++++++--------------- dspy/primitives/program.py | 34 +- dspy/teleprompt/copro_optimizer.py | 12 +- dspy/teleprompt/utils.py | 13 +- 9 files changed, 440 insertions(+), 491 deletions(-) diff --git a/dspy/predict/__init__.py b/dspy/predict/__init__.py index 0260ead378..e273a3dc91 100644 --- a/dspy/predict/__init__.py +++ b/dspy/predict/__init__.py @@ -6,5 +6,5 @@ from .predict import Predict from .program_of_thought import ProgramOfThought from .react import ReAct, Tool -from .retry import Retry -from .parallel import Parallel \ No newline at end of file +from .parallel import Parallel +# from .retry import Retry \ No newline at end of file diff --git a/dspy/predict/chain_of_thought.py b/dspy/predict/chain_of_thought.py index b6b9cb575e..2a06ca7db9 100644 --- a/dspy/predict/chain_of_thought.py +++ b/dspy/predict/chain_of_thought.py @@ -2,58 +2,23 @@ from dspy.primitives.program import Module from dspy.signatures.signature import ensure_signature -# TODO: This shouldn't inherit from Predict. It should be a module that has one or two predictors. -# Let's focus on the activated case. It's a predictor with the expanded signature. -# Now, when deactivated, it's a predictor with the original signature. -# When activate is None, though, we need the expanded one but during forward we need to pass the right signature. - class ChainOfThought(Module): - def __init__(self, signature, rationale_type=None, activated=True, **config): + def __init__(self, signature, rationale_type=None, **config): super().__init__() - self.activated = activated - - self.signature = signature = ensure_signature(signature) - *_keys, last_key = signature.output_fields.keys() + signature = ensure_signature(signature) prefix = "Reasoning: Let's think step by step in order to" - - if isinstance(dspy.settings.lm, dspy.LM): - desc = "${reasoning}" - elif dspy.settings.experimental: - desc = "${produce the output fields}. We ..." - else: - desc = f"${{produce the {last_key}}}. We ..." - + desc = "${reasoning}" rationale_type = rationale_type or dspy.OutputField(prefix=prefix, desc=desc) - - # Add "rationale" field to the output signature. - if isinstance(dspy.settings.lm, dspy.LM) or dspy.settings.experimental: - extended_signature = signature.prepend("reasoning", rationale_type, type_=str) - else: - extended_signature = signature.prepend("rationale", rationale_type, type_=str) + extended_signature = signature.prepend("reasoning", rationale_type, type_=str) self._predict = dspy.Predict(extended_signature, **config) - self._predict.extended_signature = extended_signature def forward(self, **kwargs): - assert self.activated in [True, False] - - signature = kwargs.pop("new_signature", self._predict.extended_signature if self.activated else self.signature) - return self._predict(signature=signature, **kwargs) + return self._predict(**kwargs) @property def demos(self): return self._predict.demos - - @property - def extended_signature(self): - return self._predict.extended_signature - - -""" -TODO: In principle, we can update the field's prefix during forward too to fill any thing based on the input args. - -IF the user didn't overwrite our default rationale_type. -""" diff --git a/dspy/predict/predict.py b/dspy/predict/predict.py index 3f851e7d61..05d66bb483 100644 --- a/dspy/predict/predict.py +++ b/dspy/predict/predict.py @@ -51,9 +51,6 @@ def dump_state(self): state["demos"].append(demo) state["signature"] = self.signature.dump_state() - # `extended_signature` is a special field for `Predict`s like CoT. - if hasattr(self, "extended_signature"): - state["extended_signature"] = self.extended_signature.dump_state() return state def load_state(self, state): @@ -82,8 +79,8 @@ def load_state(self, state): self.signature = self.signature.load_state(state["signature"]) - if "extended_signature" in state: - self.extended_signature = self.extended_signature.load_state(state["extended_signature"]) + if "extended_signature" in state: # legacy, up to and including 2.5, for CoT. + self.signature = self.signature.load_state(state["extended_signature"]) return self @@ -96,14 +93,14 @@ def forward(self, **kwargs): assert not dspy.settings.compiling, "It's no longer ever the case that .compiling is True" # Extract the three privileged keyword arguments. - new_signature = ensure_signature(kwargs.pop("new_signature", None)) + assert "new_signature" not in kwargs, "new_signature is no longer a valid keyword argument." signature = ensure_signature(kwargs.pop("signature", self.signature)) demos = kwargs.pop("demos", self.demos) config = dict(**self.config, **kwargs.pop("config", {})) # Get the right LM to use. lm = kwargs.pop("lm", self.lm) or dspy.settings.lm - assert lm is not None, "No LM is loaded." + assert isinstance(lm, dspy.LM), "No LM is loaded." # If temperature is 0.0 but its n > 1, set temperature to 0.7. temperature = config.get("temperature") @@ -113,15 +110,11 @@ def forward(self, **kwargs): if (temperature is None or temperature <= 0.15) and num_generations > 1: config["temperature"] = 0.7 - if new_signature is not None: - signature = new_signature - if not all(k in kwargs for k in signature.input_fields): present = [k for k in signature.input_fields if k in kwargs] missing = [k for k in signature.input_fields if k not in kwargs] print(f"WARNING: Not all input fields were provided to module. Present: {present}. Missing: {missing}.") - assert isinstance(lm, dspy.LM) completions = v2_5_generate(lm, config, signature, demos, kwargs, _parse_values=self._parse_values) pred = Prediction.from_completions(completions, signature=signature) diff --git a/dspy/predict/react.py b/dspy/predict/react.py index b626f1dc37..2690d066af 100644 --- a/dspy/predict/react.py +++ b/dspy/predict/react.py @@ -124,6 +124,9 @@ def format(trajectory: dict[str, Any], last_iteration: bool): Another potential fix is to more natively support a "variadic" input field, where the input is a list of dictionaries, or a big dictionary, and have each adatper format it accordingly. +Trajectories also affect meta-programming modules that view the trace later. It's inefficient O(n^2) to view the +trace of every module repeating the prefix. + TOPIC 02: Handling default arguments in the Tool class. @@ -140,4 +143,13 @@ def format(trajectory: dict[str, Any], last_iteration: bool): TOPIC 05: Adding more structure around how the instruction is formatted. * Concretely, it's now a string, so an optimizer can and does rewrite it freely. * An alternative would be to add more structure, such that a certain template is fixed but values are variable? + + +TOPIC 06: Idiomatically allowing tools that maintain state across iterations, but not across different `forward` calls. + * So the tool would be newly initialized at the start of each `forward` call, but maintain state across iterations. + * This is pretty useful for allowing the agent to keep notes or count certain things, etc. + +TOPIC 07: Make max_iters a bit more expressive. + * Allow passing `max_iters` in forward to overwrite the default. + * Get rid of `last_iteration: bool` in the format function. It's not necessary now. """ diff --git a/dspy/predict/retry.py b/dspy/predict/retry.py index b515dfba72..66542ba439 100644 --- a/dspy/predict/retry.py +++ b/dspy/predict/retry.py @@ -1,74 +1,74 @@ -import copy +# import copy -import dspy +# import dspy -from .predict import Predict +# from .predict import Predict -class Retry(Predict): - def __init__(self, module): - super().__init__(module.signature) - self.module = module - self.original_signature = module.extended_signature if isinstance(module, dspy.ChainOfThought) else module.signature - self.original_forward = module.forward - self.new_signature = self._create_new_signature(self.original_signature) +# class Retry(Predict): +# def __init__(self, module): +# super().__init__(module.signature) +# self.module = module +# self.original_signature = module.signature +# self.original_forward = module.forward +# self.new_signature = self._create_new_signature(self.original_signature) - def _create_new_signature(self, signature): - # Add "Past" input fields for each output field - for key, value in signature.output_fields.items(): - actual_prefix = value.json_schema_extra["prefix"].split(":")[0] + ":" - signature = signature.append(f"past_{key}", dspy.InputField( - prefix="Previous " + actual_prefix, - desc=f"past {actual_prefix[:-1]} with errors", - format=value.json_schema_extra.get("format"), - )) +# def _create_new_signature(self, signature): +# # Add "Past" input fields for each output field +# for key, value in signature.output_fields.items(): +# actual_prefix = value.json_schema_extra["prefix"].split(":")[0] + ":" +# signature = signature.append(f"past_{key}", dspy.InputField( +# prefix="Previous " + actual_prefix, +# desc=f"past {actual_prefix[:-1]} with errors", +# format=value.json_schema_extra.get("format"), +# )) - signature = signature.append("feedback", dspy.InputField( - prefix="Instructions:", - desc="Some instructions you must satisfy", - format=str, - )) +# signature = signature.append("feedback", dspy.InputField( +# prefix="Instructions:", +# desc="Some instructions you must satisfy", +# format=str, +# )) - return signature +# return signature - def forward(self, *, past_outputs, **kwargs): - # Take into account the possible new signature, as in TypedPredictor - new_signature = kwargs.pop("new_signature", None) - if new_signature: - self.original_signature = new_signature - self.new_signature = self._create_new_signature(self.original_signature) +# def forward(self, *, past_outputs, **kwargs): +# # Take into account the possible new signature, as in TypedPredictor +# new_signature = kwargs.pop("new_signature", None) +# if new_signature: +# self.original_signature = new_signature +# self.new_signature = self._create_new_signature(self.original_signature) - # Convert the dict past_outputs={"answer": ...} to kwargs - # {past_answer=..., ...} - for key, value in past_outputs.items(): - past_key = f"past_{key}" - if past_key in self.new_signature.input_fields: - kwargs[past_key] = value - # Tell the wrapped module to use the new signature. - # Note: This only works if the wrapped module is a Predict or ChainOfThought. - kwargs["new_signature"] = self.new_signature - return self.original_forward(**kwargs) +# # Convert the dict past_outputs={"answer": ...} to kwargs +# # {past_answer=..., ...} +# for key, value in past_outputs.items(): +# past_key = f"past_{key}" +# if past_key in self.new_signature.input_fields: +# kwargs[past_key] = value +# # Tell the wrapped module to use the new signature. +# # Note: This only works if the wrapped module is a Predict or ChainOfThought. +# kwargs["new_signature"] = self.new_signature +# return self.original_forward(**kwargs) - def __call__(self, **kwargs): - copy.deepcopy(kwargs) - kwargs["_trace"] = False - kwargs.setdefault("demos", self.demos if self.demos is not None else []) +# def __call__(self, **kwargs): +# copy.deepcopy(kwargs) +# kwargs["_trace"] = False +# kwargs.setdefault("demos", self.demos if self.demos is not None else []) - # perform backtracking - if dspy.settings.backtrack_to == self: - for key, value in dspy.settings.backtrack_to_args.items(): - kwargs.setdefault(key, value) - pred = self.forward(**kwargs) - else: - pred = self.module(**kwargs) +# # perform backtracking +# if dspy.settings.backtrack_to == self: +# for key, value in dspy.settings.backtrack_to_args.items(): +# kwargs.setdefault(key, value) +# pred = self.forward(**kwargs) +# else: +# pred = self.module(**kwargs) - # now pop multiple reserved keys - # NOTE(shangyin) past_outputs seems not useful to include in demos, - # therefore dropped - for key in ["_trace", "demos", "signature", "new_signature", "config", "lm", "past_outputs"]: - kwargs.pop(key, None) +# # now pop multiple reserved keys +# # NOTE(shangyin) past_outputs seems not useful to include in demos, +# # therefore dropped +# for key in ["_trace", "demos", "signature", "new_signature", "config", "lm", "past_outputs"]: +# kwargs.pop(key, None) - if dspy.settings.trace is not None: - trace = dspy.settings.trace - trace.append((self, {**kwargs}, pred)) - return pred +# if dspy.settings.trace is not None: +# trace = dspy.settings.trace +# trace.append((self, {**kwargs}, pred)) +# return pred diff --git a/dspy/primitives/assertions.py b/dspy/primitives/assertions.py index a612d02a6c..8e456cf531 100644 --- a/dspy/primitives/assertions.py +++ b/dspy/primitives/assertions.py @@ -1,344 +1,344 @@ -import inspect -import logging -import uuid -from typing import Any - -# import dspy.dsp as dsp -import dspy - -logger = logging.getLogger(__name__) -#################### Assertion Helpers #################### - - -def _build_error_msg(feedback_msgs): - """Build an error message from a list of feedback messages.""" - return "\n".join([msg for msg in feedback_msgs]) - - -#################### Assertion Exceptions #################### - - -class DSPyAssertionError(AssertionError): - """Custom exception raised when a DSPy `Assert` fails.""" - - def __init__( - self, - id: str, - msg: str, - target_module: Any = None, - state: Any = None, - is_metric: bool = False, - ) -> None: - super().__init__(msg) - self.id = id - self.msg = msg - self.target_module = target_module - self.state = state - self.is_metric = is_metric - - -class DSPySuggestionError(AssertionError): - """Custom exception raised when a DSPy `Suggest` fails.""" - - def __init__( - self, - id: str, - msg: str, - target_module: Any = None, - state: Any = None, - is_metric: bool = False, - ) -> None: - super().__init__(msg) - self.id = id - self.msg = msg - self.target_module = target_module - self.state = state - self.is_metric = is_metric - - -#################### Assertion Primitives #################### - - -class Constraint: - def __init__( - self, - result: bool, - msg: str = "", - target_module=None, - is_metric: bool = False, - ): - self.id = str(uuid.uuid4()) - self.result = result - self.msg = msg - self.target_module = target_module - self.is_metric = is_metric +# import inspect +# import logging +# import uuid +# from typing import Any + +# # import dspy.dsp as dsp +# import dspy + +# logger = logging.getLogger(__name__) +# #################### Assertion Helpers #################### + + +# def _build_error_msg(feedback_msgs): +# """Build an error message from a list of feedback messages.""" +# return "\n".join([msg for msg in feedback_msgs]) + + +# #################### Assertion Exceptions #################### + + +# class DSPyAssertionError(AssertionError): +# """Custom exception raised when a DSPy `Assert` fails.""" + +# def __init__( +# self, +# id: str, +# msg: str, +# target_module: Any = None, +# state: Any = None, +# is_metric: bool = False, +# ) -> None: +# super().__init__(msg) +# self.id = id +# self.msg = msg +# self.target_module = target_module +# self.state = state +# self.is_metric = is_metric + + +# class DSPySuggestionError(AssertionError): +# """Custom exception raised when a DSPy `Suggest` fails.""" + +# def __init__( +# self, +# id: str, +# msg: str, +# target_module: Any = None, +# state: Any = None, +# is_metric: bool = False, +# ) -> None: +# super().__init__(msg) +# self.id = id +# self.msg = msg +# self.target_module = target_module +# self.state = state +# self.is_metric = is_metric + + +# #################### Assertion Primitives #################### + + +# class Constraint: +# def __init__( +# self, +# result: bool, +# msg: str = "", +# target_module=None, +# is_metric: bool = False, +# ): +# self.id = str(uuid.uuid4()) +# self.result = result +# self.msg = msg +# self.target_module = target_module +# self.is_metric = is_metric - self.__call__() +# self.__call__() -class Assert(Constraint): - """DSPy Assertion""" +# class Assert(Constraint): +# """DSPy Assertion""" - def __call__(self) -> bool: - if isinstance(self.result, bool): - if self.result: - return True - elif dspy.settings.bypass_assert: - logger.error(f"AssertionError: {self.msg}") - return True - else: - logger.error(f"AssertionError: {self.msg}") - raise DSPyAssertionError( - id=self.id, - msg=self.msg, - target_module=self.target_module, - state=dspy.settings.trace, - is_metric=self.is_metric, - ) - else: - raise ValueError("Assertion function should always return [bool]") +# def __call__(self) -> bool: +# if isinstance(self.result, bool): +# if self.result: +# return True +# elif dspy.settings.bypass_assert: +# logger.error(f"AssertionError: {self.msg}") +# return True +# else: +# logger.error(f"AssertionError: {self.msg}") +# raise DSPyAssertionError( +# id=self.id, +# msg=self.msg, +# target_module=self.target_module, +# state=dspy.settings.trace, +# is_metric=self.is_metric, +# ) +# else: +# raise ValueError("Assertion function should always return [bool]") -class Suggest(Constraint): - """DSPy Suggestion""" +# class Suggest(Constraint): +# """DSPy Suggestion""" - def __call__(self) -> Any: - if isinstance(self.result, bool): - if self.result: - return True - elif dspy.settings.bypass_suggest: - logger.info(f"SuggestionFailed: {self.msg}") - return True - else: - logger.info(f"SuggestionFailed: {self.msg}") - raise DSPySuggestionError( - id=self.id, - msg=self.msg, - target_module=self.target_module, - state=dspy.settings.trace, - is_metric=self.is_metric, - ) - else: - raise ValueError("Suggestion function should always return [bool]") - - -#################### Assertion Handlers #################### - - -def noop_handler(func): - """Handler to bypass assertions and suggestions. - - Now both assertions and suggestions will become noops. - """ - - def wrapper(*args, **kwargs): - with dspy.settings.context(bypass_assert=True, bypass_suggest=True): - return func(*args, **kwargs) - - return wrapper - - -def bypass_suggest_handler(func): - """Handler to bypass suggest only. - - If a suggestion fails, it will be logged but not raised. - And If an assertion fails, it will be raised. - """ - - def wrapper(*args, **kwargs): - with dspy.settings.context(bypass_suggest=True, bypass_assert=False): - return func(*args, **kwargs) - - return wrapper - - -def bypass_assert_handler(func): - """Handler to bypass assertion only. - - If a assertion fails, it will be logged but not raised. - And If an assertion fails, it will be raised. - """ - - def wrapper(*args, **kwargs): - with dspy.settings.context(bypass_assert=True): - return func(*args, **kwargs) - - return wrapper - - -def assert_no_except_handler(func): - """Handler to ignore assertion failure and return None.""" - - def wrapper(*args, **kwargs): - try: - return func(*args, **kwargs) - except DSPyAssertionError: - return None - - return wrapper - - -def backtrack_handler(func, bypass_suggest=True, max_backtracks=2): - """Handler for backtracking suggestion and assertion. - - Re-run the latest predictor up to `max_backtracks` times, - with updated signature if an assertion fails. updated signature adds a new - input field to the signature, which is the feedback. - """ - - def wrapper(*args, **kwargs): - error_msg, result = None, None - with dspy.settings.lock: - dspy.settings.backtrack_to = None - dspy.settings.suggest_failures = 0 - dspy.settings.assert_failures = 0 - - # Predictor -> List[feedback_msg] - dspy.settings.predictor_feedbacks = {} - - current_error = None - for i in range(max_backtracks + 1): - if i > 0 and dspy.settings.backtrack_to is not None: - # generate values for new fields - feedback_msg = _build_error_msg( - dspy.settings.predictor_feedbacks[dspy.settings.backtrack_to], - ) - - dspy.settings.backtrack_to_args = { - "feedback": feedback_msg, - "past_outputs": past_outputs, - } - - # if last backtrack: ignore suggestion errors - if i == max_backtracks: - if isinstance(current_error, DSPyAssertionError): - raise current_error - dspy.settings.trace.clear() - result = bypass_suggest_handler(func)(*args, **kwargs) if bypass_suggest else None - break - else: - try: - dspy.settings.trace.clear() - result = func(*args, **kwargs) - break - except (DSPySuggestionError, DSPyAssertionError) as e: - if not current_error: - current_error = e - _error_id, error_msg, error_target_module, error_state = ( - e.id, - e.msg, - e.target_module, - e.state[-1], - ) - - # increment failure count depending on type of error - if isinstance(e, DSPySuggestionError) and e.is_metric: - dspy.settings.suggest_failures += 1 - elif isinstance(e, DSPyAssertionError) and e.is_metric: - dspy.settings.assert_failures += 1 - - if dspy.settings.trace: - if error_target_module: - for i in range(len(dspy.settings.trace) - 1, -1, -1): - trace_element = dspy.settings.trace[i] - mod = trace_element[0] - if mod == error_target_module: - error_state = e.state[i] - dspy.settings.backtrack_to = mod - break - else: - dspy.settings.backtrack_to = dspy.settings.trace[-1][0] - - if dspy.settings.backtrack_to is None: - logger.error("Module not found in trace. If passing a DSPy Signature, please specify the intended module for the assertion (e.g., use `target_module = self.my_module(my_signature)` instead of `target_module = my_signature`).") - - # save unique feedback message for predictor - if error_msg not in dspy.settings.predictor_feedbacks.setdefault( - dspy.settings.backtrack_to, - [], - ): - dspy.settings.predictor_feedbacks[dspy.settings.backtrack_to].append(error_msg) - - # use `new_signature` if available (CoT) - if hasattr(error_state[0], 'new_signature'): - output_fields = error_state[0].new_signature.output_fields - else: - output_fields = error_state[0].signature.output_fields - past_outputs = {} - for field_name in output_fields.keys(): - past_outputs[field_name] = getattr( - error_state[2], - field_name, - None, - ) - - # save latest failure trace for predictor per suggestion - error_state[1] - error_op = error_state[2].__dict__["_store"] - error_op.pop("_assert_feedback", None) - error_op.pop("_assert_traces", None) - - else: - logger.error( - "UNREACHABLE: No trace available, this should not happen. Is this run time?", - ) - - return result - - return wrapper - - -def handle_assert_forward(assertion_handler, **handler_args): - def forward(self, *args, **kwargs): - args_to_vals = inspect.getcallargs(self._forward, *args, **kwargs) - - # if user has specified a bypass_assert flag, set it - if "bypass_assert" in args_to_vals: - dspy.settings.configure(bypass_assert=args_to_vals["bypass_assert"]) - - wrapped_forward = assertion_handler(self._forward, **handler_args) - return wrapped_forward(*args, **kwargs) - - return forward - - -default_assertion_handler = backtrack_handler - - -def assert_transform_module( - module, - assertion_handler=default_assertion_handler, - **handler_args, -): - """ - Transform a module to handle assertions. - """ - if not getattr(module, "forward", False): - raise ValueError( - "Module must have a forward method to have assertions handled.", - ) - if getattr(module, "_forward", False): - logger.info( - f"Module {module.__class__.__name__} already has a _forward method. Skipping...", - ) - pass # TODO warning: might be overwriting a previous _forward method - - module._forward = module.forward - module.forward = handle_assert_forward(assertion_handler, **handler_args).__get__( - module, - ) - - if all( - map(lambda p: isinstance(p[1], dspy.retry.Retry), module.named_predictors()), - ): - pass # we already applied the Retry mapping outside - elif all( - map(lambda p: not isinstance(p[1], dspy.retry.Retry), module.named_predictors()), - ): - module.map_named_predictors(dspy.retry.Retry) - else: - raise RuntimeError("Module has mixed predictors, can't apply Retry mapping.") - - module._assert_transformed = True - - return module +# def __call__(self) -> Any: +# if isinstance(self.result, bool): +# if self.result: +# return True +# elif dspy.settings.bypass_suggest: +# logger.info(f"SuggestionFailed: {self.msg}") +# return True +# else: +# logger.info(f"SuggestionFailed: {self.msg}") +# raise DSPySuggestionError( +# id=self.id, +# msg=self.msg, +# target_module=self.target_module, +# state=dspy.settings.trace, +# is_metric=self.is_metric, +# ) +# else: +# raise ValueError("Suggestion function should always return [bool]") + + +# #################### Assertion Handlers #################### + + +# def noop_handler(func): +# """Handler to bypass assertions and suggestions. + +# Now both assertions and suggestions will become noops. +# """ + +# def wrapper(*args, **kwargs): +# with dspy.settings.context(bypass_assert=True, bypass_suggest=True): +# return func(*args, **kwargs) + +# return wrapper + + +# def bypass_suggest_handler(func): +# """Handler to bypass suggest only. + +# If a suggestion fails, it will be logged but not raised. +# And If an assertion fails, it will be raised. +# """ + +# def wrapper(*args, **kwargs): +# with dspy.settings.context(bypass_suggest=True, bypass_assert=False): +# return func(*args, **kwargs) + +# return wrapper + + +# def bypass_assert_handler(func): +# """Handler to bypass assertion only. + +# If a assertion fails, it will be logged but not raised. +# And If an assertion fails, it will be raised. +# """ + +# def wrapper(*args, **kwargs): +# with dspy.settings.context(bypass_assert=True): +# return func(*args, **kwargs) + +# return wrapper + + +# def assert_no_except_handler(func): +# """Handler to ignore assertion failure and return None.""" + +# def wrapper(*args, **kwargs): +# try: +# return func(*args, **kwargs) +# except DSPyAssertionError: +# return None + +# return wrapper + + +# def backtrack_handler(func, bypass_suggest=True, max_backtracks=2): +# """Handler for backtracking suggestion and assertion. + +# Re-run the latest predictor up to `max_backtracks` times, +# with updated signature if an assertion fails. updated signature adds a new +# input field to the signature, which is the feedback. +# """ + +# def wrapper(*args, **kwargs): +# error_msg, result = None, None +# with dspy.settings.lock: +# dspy.settings.backtrack_to = None +# dspy.settings.suggest_failures = 0 +# dspy.settings.assert_failures = 0 + +# # Predictor -> List[feedback_msg] +# dspy.settings.predictor_feedbacks = {} + +# current_error = None +# for i in range(max_backtracks + 1): +# if i > 0 and dspy.settings.backtrack_to is not None: +# # generate values for new fields +# feedback_msg = _build_error_msg( +# dspy.settings.predictor_feedbacks[dspy.settings.backtrack_to], +# ) + +# dspy.settings.backtrack_to_args = { +# "feedback": feedback_msg, +# "past_outputs": past_outputs, +# } + +# # if last backtrack: ignore suggestion errors +# if i == max_backtracks: +# if isinstance(current_error, DSPyAssertionError): +# raise current_error +# dspy.settings.trace.clear() +# result = bypass_suggest_handler(func)(*args, **kwargs) if bypass_suggest else None +# break +# else: +# try: +# dspy.settings.trace.clear() +# result = func(*args, **kwargs) +# break +# except (DSPySuggestionError, DSPyAssertionError) as e: +# if not current_error: +# current_error = e +# _error_id, error_msg, error_target_module, error_state = ( +# e.id, +# e.msg, +# e.target_module, +# e.state[-1], +# ) + +# # increment failure count depending on type of error +# if isinstance(e, DSPySuggestionError) and e.is_metric: +# dspy.settings.suggest_failures += 1 +# elif isinstance(e, DSPyAssertionError) and e.is_metric: +# dspy.settings.assert_failures += 1 + +# if dspy.settings.trace: +# if error_target_module: +# for i in range(len(dspy.settings.trace) - 1, -1, -1): +# trace_element = dspy.settings.trace[i] +# mod = trace_element[0] +# if mod == error_target_module: +# error_state = e.state[i] +# dspy.settings.backtrack_to = mod +# break +# else: +# dspy.settings.backtrack_to = dspy.settings.trace[-1][0] + +# if dspy.settings.backtrack_to is None: +# logger.error("Module not found in trace. If passing a DSPy Signature, please specify the intended module for the assertion (e.g., use `target_module = self.my_module(my_signature)` instead of `target_module = my_signature`).") + +# # save unique feedback message for predictor +# if error_msg not in dspy.settings.predictor_feedbacks.setdefault( +# dspy.settings.backtrack_to, +# [], +# ): +# dspy.settings.predictor_feedbacks[dspy.settings.backtrack_to].append(error_msg) + +# # use `new_signature` if available (CoT) +# if hasattr(error_state[0], 'new_signature'): +# output_fields = error_state[0].new_signature.output_fields +# else: +# output_fields = error_state[0].signature.output_fields +# past_outputs = {} +# for field_name in output_fields.keys(): +# past_outputs[field_name] = getattr( +# error_state[2], +# field_name, +# None, +# ) + +# # save latest failure trace for predictor per suggestion +# error_state[1] +# error_op = error_state[2].__dict__["_store"] +# error_op.pop("_assert_feedback", None) +# error_op.pop("_assert_traces", None) + +# else: +# logger.error( +# "UNREACHABLE: No trace available, this should not happen. Is this run time?", +# ) + +# return result + +# return wrapper + + +# def handle_assert_forward(assertion_handler, **handler_args): +# def forward(self, *args, **kwargs): +# args_to_vals = inspect.getcallargs(self._forward, *args, **kwargs) + +# # if user has specified a bypass_assert flag, set it +# if "bypass_assert" in args_to_vals: +# dspy.settings.configure(bypass_assert=args_to_vals["bypass_assert"]) + +# wrapped_forward = assertion_handler(self._forward, **handler_args) +# return wrapped_forward(*args, **kwargs) + +# return forward + + +# default_assertion_handler = backtrack_handler + + +# def assert_transform_module( +# module, +# assertion_handler=default_assertion_handler, +# **handler_args, +# ): +# """ +# Transform a module to handle assertions. +# """ +# if not getattr(module, "forward", False): +# raise ValueError( +# "Module must have a forward method to have assertions handled.", +# ) +# if getattr(module, "_forward", False): +# logger.info( +# f"Module {module.__class__.__name__} already has a _forward method. Skipping...", +# ) +# pass # TODO warning: might be overwriting a previous _forward method + +# module._forward = module.forward +# module.forward = handle_assert_forward(assertion_handler, **handler_args).__get__( +# module, +# ) + +# if all( +# map(lambda p: isinstance(p[1], dspy.retry.Retry), module.named_predictors()), +# ): +# pass # we already applied the Retry mapping outside +# elif all( +# map(lambda p: not isinstance(p[1], dspy.retry.Retry), module.named_predictors()), +# ): +# module.map_named_predictors(dspy.retry.Retry) +# else: +# raise RuntimeError("Module has mixed predictors, can't apply Retry mapping.") + +# module._assert_transformed = True + +# return module diff --git a/dspy/primitives/program.py b/dspy/primitives/program.py index a8b2d1346a..a8b9676daf 100644 --- a/dspy/primitives/program.py +++ b/dspy/primitives/program.py @@ -1,10 +1,10 @@ from dspy.utils.callback import with_callbacks import magicattr -import dspy from dspy.predict.parallel import Parallel -from dspy.primitives.assertions import * from dspy.primitives.module import BaseModule +# import dspy +# from dspy.primitives.assertions import * class ProgramMeta(type): @@ -32,28 +32,16 @@ def predictors(self): return [param for _, param in self.named_predictors()] def set_lm(self, lm): - if not dspy.settings.experimental: - raise ValueError( - "Setting or getting the LM of a program is an experimental feature. Please enable the " - "'dspy.settings.experimental' flag to use these features." - ) - for _, param in self.named_predictors(): param.lm = lm def get_lm(self): - if not dspy.settings.experimental: - raise ValueError( - "Setting or getting the LM of a program is an experimental feature. Please enable the " - "'dspy.settings.experimental' flag to use these features." - ) - all_used_lms = [param.lm for _, param in self.named_predictors()] if len(set(all_used_lms)) == 1: return all_used_lms[0] - raise ValueError("Multiple LMs are being used in the module.") + raise ValueError("Multiple LMs are being used in the module. There's no unique LM to return.") def __repr__(self): s = [] @@ -69,13 +57,13 @@ def map_named_predictors(self, func): set_attribute_by_name(self, name, func(predictor)) return self - def activate_assertions(self, handler=backtrack_handler, **handler_args): - """ - Activates assertions for the module. - The default handler is the backtrack_handler. - """ - assert_transform_module(self, handler, **handler_args) - return self + # def activate_assertions(self, handler=backtrack_handler, **handler_args): + # """ + # Activates assertions for the module. + # The default handler is the backtrack_handler. + # """ + # assert_transform_module(self, handler, **handler_args) + # return self # def __deepcopy__(self, memo): # # memo is a dict of id's to copies already made during the current call @@ -103,7 +91,7 @@ def batch( return_failed_examples: bool = False, provide_traceback: bool = False, disable_progress_bar: bool = False, - ) -> Any: + ): """ Processes a list of dspy.Example instances in parallel using the Parallel module. diff --git a/dspy/teleprompt/copro_optimizer.py b/dspy/teleprompt/copro_optimizer.py index de8cc9f0ec..feeecf9e31 100644 --- a/dspy/teleprompt/copro_optimizer.py +++ b/dspy/teleprompt/copro_optimizer.py @@ -113,16 +113,12 @@ def _print_signature(self, predictor): logger.debug(f"p: {list(signature.fields.values())[-1].json_schema_extra['prefix']}") def _get_signature(self, predictor): - if hasattr(predictor, "extended_signature"): - return predictor.extended_signature - elif hasattr(predictor, "signature"): - return predictor.signature + assert hasattr(predictor, "signature") + return predictor.signature def _set_signature(self, predictor, updated_signature): - if hasattr(predictor, "extended_signature"): - predictor.extended_signature = updated_signature - elif hasattr(predictor, "signature"): - predictor.signature = updated_signature + assert hasattr(predictor, "signature") + predictor.signature = updated_signature def compile(self, student, *, trainset, eval_kwargs): """ diff --git a/dspy/teleprompt/utils.py b/dspy/teleprompt/utils.py index d763d74da2..bdc6a529a1 100644 --- a/dspy/teleprompt/utils.py +++ b/dspy/teleprompt/utils.py @@ -257,18 +257,13 @@ def get_prompt_model(prompt_model): return dspy.settings.lm def get_signature(predictor): - if hasattr(predictor, "extended_signature"): - return predictor.extended_signature - elif hasattr(predictor, "signature"): - return predictor.signature - return None + assert hasattr(predictor, "signature") + return predictor.signature def set_signature(predictor, updated_signature): - if hasattr(predictor, "extended_signature"): - predictor.extended_signature = updated_signature - elif hasattr(predictor, "signature"): - predictor.signature = updated_signature + assert hasattr(predictor, "signature") + predictor.signature = updated_signature def create_n_fewshot_demo_sets( From 2229b4c2d44513a4f1cf56a02fa0ca4389f5865f Mon Sep 17 00:00:00 2001 From: Omar Khattab Date: Mon, 16 Dec 2024 11:14:09 -0800 Subject: [PATCH 2/4] Update test_retry.py --- tests/predict/test_retry.py | 130 ++++++++++++++++++------------------ 1 file changed, 65 insertions(+), 65 deletions(-) diff --git a/tests/predict/test_retry.py b/tests/predict/test_retry.py index ba28e90433..eb1c48b06e 100644 --- a/tests/predict/test_retry.py +++ b/tests/predict/test_retry.py @@ -1,86 +1,48 @@ -import functools +# import functools -import pydantic +# import pydantic -import dspy -from dspy.primitives.assertions import assert_transform_module, backtrack_handler -from dspy.utils import DummyLM +# import dspy +# from dspy.primitives.assertions import assert_transform_module, backtrack_handler +# from dspy.utils import DummyLM -def test_retry_simple(): - predict = dspy.Predict("question -> answer") - retry_module = dspy.Retry(predict) +# def test_retry_simple(): +# predict = dspy.Predict("question -> answer") +# retry_module = dspy.Retry(predict) - # Test Retry has created the correct new signature - for field in predict.signature.output_fields: - assert f"past_{field}" in retry_module.new_signature.input_fields - assert "feedback" in retry_module.new_signature.input_fields +# # Test Retry has created the correct new signature +# for field in predict.signature.output_fields: +# assert f"past_{field}" in retry_module.new_signature.input_fields +# assert "feedback" in retry_module.new_signature.input_fields - lm = DummyLM([{"answer": "blue"}]) - dspy.settings.configure(lm=lm) - result = retry_module.forward( - question="What color is the sky?", - past_outputs={"answer": "red"}, - feedback="Try harder", - ) - assert result.answer == "blue" - - -def test_retry_forward_with_feedback(): - # First we make a mistake, then we fix it - lm = DummyLM([{"answer": "red"}, {"answer": "blue"}]) - dspy.settings.configure(lm=lm, trace=[]) - - class SimpleModule(dspy.Module): - def __init__(self): - super().__init__() - self.predictor = dspy.Predict("question -> answer") - - def forward(self, **kwargs): - result = self.predictor(**kwargs) - print(f"SimpleModule got {result.answer=}") - dspy.Suggest(result.answer == "blue", "Please think harder") - return result - - program = SimpleModule() - program = assert_transform_module( - program.map_named_predictors(dspy.Retry), - functools.partial(backtrack_handler, max_backtracks=1), - ) - - result = program(question="What color is the sky?") - - assert result.answer == "blue" +# lm = DummyLM([{"answer": "blue"}]) +# dspy.settings.configure(lm=lm) +# result = retry_module.forward( +# question="What color is the sky?", +# past_outputs={"answer": "red"}, +# feedback="Try harder", +# ) +# assert result.answer == "blue" -# def test_retry_forward_with_typed_predictor(): +# def test_retry_forward_with_feedback(): # # First we make a mistake, then we fix it -# lm = DummyLM([{"output": '{"answer":"red"}'}, {"output": '{"answer":"blue"}'}]) +# lm = DummyLM([{"answer": "red"}, {"answer": "blue"}]) # dspy.settings.configure(lm=lm, trace=[]) -# class AnswerQuestion(dspy.Signature): -# """Answer questions with succinct responses.""" - -# class Input(pydantic.BaseModel): -# question: str - -# class Output(pydantic.BaseModel): -# answer: str - -# input: Input = dspy.InputField() -# output: Output = dspy.OutputField() - -# class QuestionAnswerer(dspy.Module): +# class SimpleModule(dspy.Module): # def __init__(self): # super().__init__() -# self.answer_question = dspy.TypedPredictor(AnswerQuestion) +# self.predictor = dspy.Predict("question -> answer") # def forward(self, **kwargs): -# result = self.answer_question(input=AnswerQuestion.Input(**kwargs)).output +# result = self.predictor(**kwargs) +# print(f"SimpleModule got {result.answer=}") # dspy.Suggest(result.answer == "blue", "Please think harder") # return result -# program = QuestionAnswerer() +# program = SimpleModule() # program = assert_transform_module( # program.map_named_predictors(dspy.Retry), # functools.partial(backtrack_handler, max_backtracks=1), @@ -89,3 +51,41 @@ def forward(self, **kwargs): # result = program(question="What color is the sky?") # assert result.answer == "blue" + + +# # def test_retry_forward_with_typed_predictor(): +# # # First we make a mistake, then we fix it +# # lm = DummyLM([{"output": '{"answer":"red"}'}, {"output": '{"answer":"blue"}'}]) +# # dspy.settings.configure(lm=lm, trace=[]) + +# # class AnswerQuestion(dspy.Signature): +# # """Answer questions with succinct responses.""" + +# # class Input(pydantic.BaseModel): +# # question: str + +# # class Output(pydantic.BaseModel): +# # answer: str + +# # input: Input = dspy.InputField() +# # output: Output = dspy.OutputField() + +# # class QuestionAnswerer(dspy.Module): +# # def __init__(self): +# # super().__init__() +# # self.answer_question = dspy.TypedPredictor(AnswerQuestion) + +# # def forward(self, **kwargs): +# # result = self.answer_question(input=AnswerQuestion.Input(**kwargs)).output +# # dspy.Suggest(result.answer == "blue", "Please think harder") +# # return result + +# # program = QuestionAnswerer() +# # program = assert_transform_module( +# # program.map_named_predictors(dspy.Retry), +# # functools.partial(backtrack_handler, max_backtracks=1), +# # ) + +# # result = program(question="What color is the sky?") + +# # assert result.answer == "blue" From 80e6ea9f2c4928b00e4f7dd555d1356ece4cec61 Mon Sep 17 00:00:00 2001 From: Omar Khattab Date: Mon, 16 Dec 2024 12:07:42 -0800 Subject: [PATCH 3/4] Fully turn CoT into a typical Module --- dspy/predict/chain_of_thought.py | 8 ++------ dspy/primitives/module.py | 17 +++++++++-------- tests/predict/test_chain_of_thought.py | 2 +- tests/primitives/test_module.py | 14 +++++++------- 4 files changed, 19 insertions(+), 22 deletions(-) diff --git a/dspy/predict/chain_of_thought.py b/dspy/predict/chain_of_thought.py index 2a06ca7db9..4a56c0ceb2 100644 --- a/dspy/predict/chain_of_thought.py +++ b/dspy/predict/chain_of_thought.py @@ -14,11 +14,7 @@ def __init__(self, signature, rationale_type=None, **config): rationale_type = rationale_type or dspy.OutputField(prefix=prefix, desc=desc) extended_signature = signature.prepend("reasoning", rationale_type, type_=str) - self._predict = dspy.Predict(extended_signature, **config) + self.predict = dspy.Predict(extended_signature, **config) def forward(self, **kwargs): - return self._predict(**kwargs) - - @property - def demos(self): - return self._predict.demos + return self.predict(**kwargs) diff --git a/dspy/primitives/module.py b/dspy/primitives/module.py index c2146301c0..8bec91620f 100644 --- a/dspy/primitives/module.py +++ b/dspy/primitives/module.py @@ -256,14 +256,15 @@ def load(self, path): def postprocess_parameter_name(name, value): - # For ChainOfThought backward compatibility, remove ending ._predict if it's there - if name.endswith("._predict"): - name = name[:-9] + return name + # # For ChainOfThought backward compatibility, remove ending ._predict if it's there + # if name.endswith("._predict"): + # name = name[:-9] - if name.endswith(".self"): - name = name[:-5] + # if name.endswith(".self"): + # name = name[:-5] - if name == "_predict": - return "self" + # if name == "_predict": + # return "self" - return name + # return name diff --git a/tests/predict/test_chain_of_thought.py b/tests/predict/test_chain_of_thought.py index 04f73bb513..c050819efe 100644 --- a/tests/predict/test_chain_of_thought.py +++ b/tests/predict/test_chain_of_thought.py @@ -7,7 +7,7 @@ def test_initialization_with_string_signature(): lm = DummyLM([{"reasoning": "find the number after 1", "answer": "2"}]) dspy.settings.configure(lm=lm) predict = ChainOfThought("question -> answer") - assert list(predict.extended_signature.output_fields.keys()) == [ + assert list(predict.predict.signature.output_fields.keys()) == [ "reasoning", "answer", ] diff --git a/tests/primitives/test_module.py b/tests/primitives/test_module.py index 5e6790d8e3..d0532fca10 100644 --- a/tests/primitives/test_module.py +++ b/tests/primitives/test_module.py @@ -51,8 +51,8 @@ def __init__(self): def test_save_and_load_with_json(tmp_path): model = dspy.ChainOfThought(dspy.Signature("q -> a")) - model._predict.signature = model._predict.signature.with_instructions("You are a helpful assistant.") - model._predict.demos = [ + model.predict.signature = model.predict.signature.with_instructions("You are a helpful assistant.") + model.predict.demos = [ dspy.Example(q="What is the capital of France?", a="Paris", reasoning="n/a").with_inputs("q", "a") ] save_path = tmp_path / "model.json" @@ -60,8 +60,8 @@ def test_save_and_load_with_json(tmp_path): new_model = dspy.ChainOfThought(dspy.Signature("q -> a")) new_model.load(save_path) - assert str(new_model.signature) == str(model.signature) - assert new_model.demos[0] == model.demos[0].toDict() + assert str(new_model.predict.signature) == str(model.predict.signature) + assert new_model.predict.demos[0] == model.predict.demos[0].toDict() def test_save_and_load_with_pkl(tmp_path): @@ -96,7 +96,7 @@ def dummy_metric(example, pred, trace=None): optimizer = dspy.BootstrapFewShot(max_bootstrapped_demos=4, max_labeled_demos=4, max_rounds=5, metric=dummy_metric) compiled_cot = optimizer.compile(cot, trainset=trainset) - compiled_cot._predict.signature = compiled_cot._predict.signature.with_instructions("You are a helpful assistant.") + compiled_cot.predict.signature = compiled_cot.predict.signature.with_instructions("You are a helpful assistant.") save_path = tmp_path / "program.pkl" compiled_cot.save(save_path) @@ -104,5 +104,5 @@ def dummy_metric(example, pred, trace=None): new_cot = dspy.ChainOfThought(MySignature) new_cot.load(save_path) - assert str(new_cot.signature) == str(compiled_cot.signature) - assert new_cot.demos == compiled_cot.demos + assert str(new_cot.predict.signature) == str(compiled_cot.predict.signature) + assert new_cot.predict.demos == compiled_cot.predict.demos From 8035218d05812825b2842ecfcf4169c54917cc98 Mon Sep 17 00:00:00 2001 From: Omar Khattab Date: Mon, 16 Dec 2024 13:08:43 -0800 Subject: [PATCH 4/4] Remove dead code paths (and flags) from Predict, Adapters, and Settings --- dspy/adapters/base.py | 10 +++++----- dspy/adapters/chat_adapter.py | 4 ++-- dspy/adapters/json_adapter.py | 6 +++--- dspy/dsp/utils/settings.py | 7 ------- dspy/predict/predict.py | 27 +++++---------------------- dspy/teleprompt/bootstrap.py | 6 +----- 6 files changed, 16 insertions(+), 44 deletions(-) diff --git a/dspy/adapters/base.py b/dspy/adapters/base.py index d472ddda3d..d9b72f2ff1 100644 --- a/dspy/adapters/base.py +++ b/dspy/adapters/base.py @@ -13,7 +13,7 @@ def __init_subclass__(cls, **kwargs) -> None: cls.format = with_callbacks(cls.format) cls.parse = with_callbacks(cls.parse) - def __call__(self, lm, lm_kwargs, signature, demos, inputs, _parse_values=True): + def __call__(self, lm, lm_kwargs, signature, demos, inputs): inputs_ = self.format(signature, demos, inputs) inputs_ = dict(prompt=inputs_) if isinstance(inputs_, str) else dict(messages=inputs_) @@ -27,7 +27,7 @@ def __call__(self, lm, lm_kwargs, signature, demos, inputs, _parse_values=True): if isinstance(output, dict): output, output_logprobs = output["text"], output["logprobs"] - value = self.parse(signature, output, _parse_values=_parse_values) + value = self.parse(signature, output) assert set(value.keys()) == set(signature.output_fields.keys()), \ f"Expected {signature.output_fields.keys()} but got {value.keys()}" @@ -41,8 +41,8 @@ def __call__(self, lm, lm_kwargs, signature, demos, inputs, _parse_values=True): except Exception as e: from .json_adapter import JSONAdapter - if _parse_values and not isinstance(self, JSONAdapter): - return JSONAdapter()(lm, lm_kwargs, signature, demos, inputs, _parse_values=_parse_values) + if not isinstance(self, JSONAdapter): + return JSONAdapter()(lm, lm_kwargs, signature, demos, inputs) raise e @abstractmethod @@ -50,7 +50,7 @@ def format(self, signature, demos, inputs): raise NotImplementedError @abstractmethod - def parse(self, signature, completion, _parse_values): + def parse(self, signature, completion): raise NotImplementedError def format_finetune_data(self, signature, demos, inputs, outputs): diff --git a/dspy/adapters/chat_adapter.py b/dspy/adapters/chat_adapter.py index f92ee6595c..3f58fb0e88 100644 --- a/dspy/adapters/chat_adapter.py +++ b/dspy/adapters/chat_adapter.py @@ -58,7 +58,7 @@ def format(self, signature: Signature, demos: list[dict[str, Any]], inputs: dict messages.append(format_turn(signature, inputs, role="user")) return messages - def parse(self, signature, completion, _parse_values=True): + def parse(self, signature, completion): sections = [(None, [])] for line in completion.splitlines(): @@ -74,7 +74,7 @@ def parse(self, signature, completion, _parse_values=True): for k, v in sections: if (k not in fields) and (k in signature.output_fields): try: - fields[k] = parse_value(v, signature.output_fields[k].annotation) if _parse_values else v + fields[k] = parse_value(v, signature.output_fields[k].annotation) except Exception as e: raise ValueError( f"Error parsing field {k}: {e}.\n\n\t\tOn attempting to parse the value\n```\n{v}\n```" diff --git a/dspy/adapters/json_adapter.py b/dspy/adapters/json_adapter.py index 281df5cb4b..0636497d8f 100644 --- a/dspy/adapters/json_adapter.py +++ b/dspy/adapters/json_adapter.py @@ -32,7 +32,7 @@ class JSONAdapter(Adapter): def __init__(self): pass - def __call__(self, lm, lm_kwargs, signature, demos, inputs, _parse_values=True): + def __call__(self, lm, lm_kwargs, signature, demos, inputs): inputs = self.format(signature, demos, inputs) inputs = dict(prompt=inputs) if isinstance(inputs, str) else dict(messages=inputs) @@ -58,7 +58,7 @@ def __call__(self, lm, lm_kwargs, signature, demos, inputs, _parse_values=True): values = [] for output in outputs: - value = self.parse(signature, output, _parse_values=_parse_values) + value = self.parse(signature, output) assert set(value.keys()) == set( signature.output_fields.keys() ), f"Expected {signature.output_fields.keys()} but got {value.keys()}" @@ -90,7 +90,7 @@ def format(self, signature, demos, inputs): return messages - def parse(self, signature, completion, _parse_values=True): + def parse(self, signature, completion): fields = json_repair.loads(completion) fields = {k: v for k, v in fields.items() if k in signature.output_fields} diff --git a/dspy/dsp/utils/settings.py b/dspy/dsp/utils/settings.py index f5ec1cd517..821f80e03b 100644 --- a/dspy/dsp/utils/settings.py +++ b/dspy/dsp/utils/settings.py @@ -8,18 +8,11 @@ adapter=None, rm=None, branch_idx=0, - reranker=None, - compiled_lm=None, - force_reuse_cached_compilation=False, - compiling=False, - skip_logprobs=False, trace=[], - release=0, bypass_assert=False, bypass_suggest=False, assert_failures=0, suggest_failures=0, - langchain_history=[], experimental=False, backoff_time=10, callbacks=[], diff --git a/dspy/predict/predict.py b/dspy/predict/predict.py index 05d66bb483..0e15137b3d 100644 --- a/dspy/predict/predict.py +++ b/dspy/predict/predict.py @@ -1,6 +1,4 @@ -import logging import random -from functools import lru_cache from pydantic import BaseModel @@ -12,18 +10,12 @@ from dspy.utils.callback import with_callbacks -@lru_cache(maxsize=None) -def warn_once(msg: str): - logging.warning(msg) - - class Predict(Module, Parameter): - def __init__(self, signature, _parse_values=True, callbacks=None, **config): + def __init__(self, signature, callbacks=None, **config): self.stage = random.randbytes(8).hex() self.signature = ensure_signature(signature) self.config = config self.callbacks = callbacks or [] - self._parse_values = _parse_values self.reset() def reset(self): @@ -80,7 +72,7 @@ def load_state(self, state): self.signature = self.signature.load_state(state["signature"]) if "extended_signature" in state: # legacy, up to and including 2.5, for CoT. - self.signature = self.signature.load_state(state["extended_signature"]) + raise NotImplementedError("Loading extended_signature is no longer supported in DSPy 2.6+") return self @@ -90,7 +82,6 @@ def __call__(self, **kwargs): def forward(self, **kwargs): import dspy - assert not dspy.settings.compiling, "It's no longer ever the case that .compiling is True" # Extract the three privileged keyword arguments. assert "new_signature" not in kwargs, "new_signature is no longer a valid keyword argument." @@ -115,7 +106,9 @@ def forward(self, **kwargs): missing = [k for k in signature.input_fields if k not in kwargs] print(f"WARNING: Not all input fields were provided to module. Present: {present}. Missing: {missing}.") - completions = v2_5_generate(lm, config, signature, demos, kwargs, _parse_values=self._parse_values) + import dspy + adapter = dspy.settings.adapter or dspy.ChatAdapter() + completions = adapter(lm, lm_kwargs=config, signature=signature, demos=demos, inputs=kwargs) pred = Prediction.from_completions(completions, signature=signature) @@ -135,16 +128,6 @@ def __repr__(self): return f"{self.__class__.__name__}({self.signature})" -def v2_5_generate(lm, lm_kwargs, signature, demos, inputs, _parse_values=True): - import dspy - - adapter = dspy.settings.adapter or dspy.ChatAdapter() - - return adapter( - lm, lm_kwargs=lm_kwargs, signature=signature, demos=demos, inputs=inputs, _parse_values=_parse_values - ) - - # TODO: get some defaults during init from the context window? # # TODO: FIXME: Hmm, I guess expected behavior is that contexts can # affect execution. Well, we need to determine whether context dominates, __init__ demoninates, or forward dominates. diff --git a/dspy/teleprompt/bootstrap.py b/dspy/teleprompt/bootstrap.py index 840bfe911d..fdc48b0834 100644 --- a/dspy/teleprompt/bootstrap.py +++ b/dspy/teleprompt/bootstrap.py @@ -260,10 +260,6 @@ def _train(self): sample_size = max(0, sample_size) raw_demos = rng.sample(raw_demos, sample_size) - - if dspy.settings.release >= 20230928: - predictor.demos = raw_demos + augmented_demos - else: - predictor.demos = augmented_demos + raw_demos + predictor.demos = augmented_demos + raw_demos return self.student