diff --git a/dsp/utils/settings.py b/dsp/utils/settings.py index 61fc37ae7e..ea302f35aa 100644 --- a/dsp/utils/settings.py +++ b/dsp/utils/settings.py @@ -23,6 +23,7 @@ langchain_history=[], experimental=False, backoff_time=10, + callbacks=[], ) diff --git a/dspy/adapters/base.py b/dspy/adapters/base.py index ddb0ee3e80..e9010d6efc 100644 --- a/dspy/adapters/base.py +++ b/dspy/adapters/base.py @@ -1,4 +1,18 @@ +from dspy.utils.callback import with_callbacks + + class Adapter: + def __init__(self, callbacks=None): + self.callbacks = callbacks or [] + + def __init_subclass__(cls, **kwargs) -> None: + super().__init_subclass__(**kwargs) + from dspy.utils.callback import with_callbacks + + # Decorate format() and parse() method with with_callbacks + cls.format = with_callbacks(cls.format) + cls.parse = with_callbacks(cls.parse) + def __call__(self, lm, lm_kwargs, signature, demos, inputs, _parse_values=True): inputs = self.format(signature, demos, inputs) inputs = dict(prompt=inputs) if isinstance(inputs, str) else dict(messages=inputs) diff --git a/dspy/clients/lm.py b/dspy/clients/lm.py index f2949372ac..89f259f5a5 100644 --- a/dspy/clients/lm.py +++ b/dspy/clients/lm.py @@ -14,6 +14,7 @@ execute_finetune_job, ) +from dspy.utils.callback import with_callbacks import litellm from litellm.caching import Cache @@ -35,6 +36,7 @@ def __init__( max_tokens=1000, cache=True, launch_kwargs=None, + callbacks=None, **kwargs ): # Remember to update LM.copy() if you modify the constructor! @@ -44,6 +46,7 @@ def __init__( self.launch_kwargs = launch_kwargs or {} self.kwargs = dict(temperature=temperature, max_tokens=max_tokens, **kwargs) self.history = [] + self.callbacks = callbacks or [] # TODO: Arbitrary model strings could include the substring "o1-". We # should find a more robust way to check for the "o1-" family models. @@ -52,6 +55,7 @@ def __init__( max_tokens >= 5000 and temperature == 1.0 ), "OpenAI's o1-* models require passing temperature=1.0 and max_tokens >= 5000 to `dspy.LM(...)`" + @with_callbacks def __call__(self, prompt=None, messages=None, **kwargs): # Build the request. cache = kwargs.pop("cache", self.cache) diff --git a/dspy/predict/predict.py b/dspy/predict/predict.py index 35260fabad..164036f559 100644 --- a/dspy/predict/predict.py +++ b/dspy/predict/predict.py @@ -9,6 +9,7 @@ from dspy.primitives.prediction import Prediction from dspy.primitives.program import Module from dspy.signatures.signature import ensure_signature, signature_to_template +from dspy.utils.callback import with_callbacks @lru_cache(maxsize=None) @@ -17,10 +18,11 @@ def warn_once(msg: str): class Predict(Module, Parameter): - def __init__(self, signature, _parse_values=True, **config): + def __init__(self, signature, _parse_values=True, callbacks=None, **config): self.stage = random.randbytes(8).hex() self.signature = ensure_signature(signature) self.config = config + self.callbacks = callbacks or [] self._parse_values = _parse_values self.reset() @@ -114,6 +116,7 @@ def _load_state_legacy(self, state): *_, last_key = self.extended_signature.fields.keys() self.extended_signature = self.extended_signature.with_updated_fields(last_key, prefix=prefix) + @with_callbacks def __call__(self, **kwargs): return self.forward(**kwargs) diff --git a/dspy/primitives/program.py b/dspy/primitives/program.py index 3cf6220b6c..eb43af38b9 100644 --- a/dspy/primitives/program.py +++ b/dspy/primitives/program.py @@ -1,3 +1,4 @@ +from dspy.utils.callback import with_callbacks import magicattr import dspy @@ -13,9 +14,11 @@ class Module(BaseModule, metaclass=ProgramMeta): def _base_init(self): self._compiled = False - def __init__(self): + def __init__(self, callbacks=None): + self.callbacks = callbacks or [] self._compiled = False + @with_callbacks def __call__(self, *args, **kwargs): return self.forward(*args, **kwargs) diff --git a/dspy/retrieve/retrieve.py b/dspy/retrieve/retrieve.py index dc659a3fed..5fa55f2d93 100644 --- a/dspy/retrieve/retrieve.py +++ b/dspy/retrieve/retrieve.py @@ -4,6 +4,7 @@ import dsp from dspy.predict.parameter import Parameter from dspy.primitives.prediction import Prediction +from dspy.utils.callback import with_callbacks def single_query_passage(passages): @@ -21,9 +22,10 @@ class Retrieve(Parameter): input_variable = "query" desc = "takes a search query and returns one or more potentially relevant passages from a corpus" - def __init__(self, k=3): + def __init__(self, k=3, callbacks=None): self.stage = random.randbytes(8).hex() self.k = k + self.callbacks = callbacks or [] def reset(self): pass @@ -37,6 +39,7 @@ def load_state(self, state): for name, value in state.items(): setattr(self, name, value) + @with_callbacks def __call__(self, *args, **kwargs): return self.forward(*args, **kwargs) diff --git a/dspy/utils/__init__.py b/dspy/utils/__init__.py index 48544eae4a..1f8dfb4ce6 100644 --- a/dspy/utils/__init__.py +++ b/dspy/utils/__init__.py @@ -1,2 +1,3 @@ +from .callback import * from .dummies import * from .logging import * diff --git a/dspy/utils/callback.py b/dspy/utils/callback.py new file mode 100644 index 0000000000..b50bfd39aa --- /dev/null +++ b/dspy/utils/callback.py @@ -0,0 +1,282 @@ +import functools +import inspect +import logging +import uuid +from contextvars import ContextVar +from typing import Any, Callable, Dict, Optional + +import dspy + +ACTIVE_CALL_ID = ContextVar("active_call_id", default=None) + +logger = logging.getLogger(__name__) + + +class BaseCallback: + """ + A base class for defining callback handlers for DSPy components. + + To use a callback, subclass this class and implement the desired handlers. Each handler + will be called at the appropriate time before/after the execution of the corresponding component. + + For example, if you want to print a message before and after an LM is called, implement + the on_llm_start and on_lm_end handler and set the callback to the global settings using `dspy.settings.configure`. + + ``` + import dspy + from dspy.utils.callback import BaseCallback + + class LoggingCallback(BaseCallback): + + def on_lm_start(self, call_id, instance, inputs): + print(f"LM is called with inputs: {inputs}") + + def on_lm_end(self, call_id, outputs, exception): + print(f"LM is finished with outputs: {outputs}") + + dspy.settings.configure( + callbacks=[LoggingCallback()] + ) + + cot = dspy.ChainOfThought("question -> answer") + cot(question="What is the meaning of life?") + + # > LM is called with inputs: {'question': 'What is the meaning of life?'} + # > LM is finished with outputs: {'answer': '42'} + ``` + + Another way to set the callback is to pass it directly to the component constructor. + In this case, the callback will only be triggered for that specific instance. + + ``` + lm = dspy.LM("gpt-3.5-turbo", callbacks=[LoggingCallback()]) + lm(question="What is the meaning of life?") + + # > LM is called with inputs: {'question': 'What is the meaning of life?'} + # > LM is finished with outputs: {'answer': '42'} + + lm_2 = dspy.LM("gpt-3.5-turbo") + lm_2(question="What is the meaning of life?") + # No logging here + ``` + """ + + def on_module_start( + self, + call_id: str, + instance: Any, + inputs: Dict[str, Any], + ): + """ + A handler triggered when forward() method of a module (subclass of dspy.Module) is called. + + Args: + call_id: A unique identifier for the call. Can be used to connect start/end handlers. + instance: The Module instance. + inputs: The inputs to the module's forward() method. Each arguments is stored as + a key-value pair in a dictionary. + """ + pass + + def on_module_end( + self, + call_id: str, + outputs: Optional[Any], + exception: Optional[Exception] = None, + ): + """ + A handler triggered after forward() method of a module (subclass of dspy.Module) is executed. + + Args: + call_id: A unique identifier for the call. Can be used to connect start/end handlers. + outputs: The outputs of the module's forward() method. If the method is interrupted by + an exception, this will be None. + exception: If an exception is raised during the execution, it will be stored here. + """ + pass + + def on_lm_start( + self, + call_id: str, + instance: Any, + inputs: Dict[str, Any], + ): + """ + A handler triggered when __call__ method of dspy.LM instance is called. + + Args: + call_id: A unique identifier for the call. Can be used to connect start/end handlers. + instance: The LM instance. + inputs: The inputs to the LM's __call__ method. Each arguments is stored as + a key-value pair in a dictionary. + """ + pass + + def on_lm_end( + self, + call_id: str, + outputs: Optional[Dict[str, Any]], + exception: Optional[Exception] = None, + ): + """ + A handler triggered after __call__ method of dspy.LM instance is executed. + + Args: + call_id: A unique identifier for the call. Can be used to connect start/end handlers. + outputs: The outputs of the LM's __call__ method. If the method is interrupted by + an exception, this will be None. + exception: If an exception is raised during the execution, it will be stored here. + """ + pass + + def on_format_start( + self, + call_id: str, + instance: Any, + inputs: Dict[str, Any], + ): + """ + A handler triggered when format() method of an adapter (subclass of dspy.Adapter) is called. + + Args: + call_id: A unique identifier for the call. Can be used to connect start/end handlers. + instance: The Adapter instance. + inputs: The inputs to the Adapter's format() method. Each arguments is stored as + a key-value pair in a dictionary. + """ + pass + + def on_format_end( + self, + call_id: str, + outputs: Optional[Dict[str, Any]], + exception: Optional[Exception] = None, + ): + """ + A handler triggered after format() method of dspy.LM instance is executed. + + Args: + call_id: A unique identifier for the call. Can be used to connect start/end handlers. + outputs: The outputs of the Adapter's format() method. If the method is interrupted + by an exception, this will be None. + exception: If an exception is raised during the execution, it will be stored here. + """ + pass + + def on_parse_start( + self, + call_id: str, + instance: Any, + inputs: Dict[str, Any], + ): + """ + A handler triggered when parse() method of an adapter (subclass of dspy.Adapter) is called. + + Args: + call_id: A unique identifier for the call. Can be used to connect start/end handlers. + instance: The Adapter instance. + inputs: The inputs to the Adapter's parse() method. Each arguments is stored as + a key-value pair in a dictionary. + """ + pass + + def on_parse_end( + self, + call_id: str, + outputs: Optional[Dict[str, Any]], + exception: Optional[Exception] = None, + ): + """ + A handler triggered after parse() method of dspy.LM instance is executed. + + Args: + call_id: A unique identifier for the call. Can be used to connect start/end handlers. + outputs: The outputs of the Adapter's parse() method. If the method is interrupted + by an exception, this will be None. + exception: If an exception is raised during the execution, it will be stored here. + """ + pass + + +def with_callbacks(fn): + @functools.wraps(fn) + def wrapper(self, *args, **kwargs): + # Combine global and local (per-instance) callbacks + callbacks = dspy.settings.get("callbacks", []) + getattr(self, "callbacks", []) + + # if no callbacks are provided, just call the function + if not callbacks: + return fn(self, *args, **kwargs) + + # Generate call ID to connect start/end handlers if needed + call_id = uuid.uuid4().hex + + inputs = inspect.getcallargs(fn, self, *args, **kwargs) + inputs.pop("self") # Not logging self as input + + for callback in callbacks: + try: + _get_on_start_handler(callback, self, fn)(call_id=call_id, instance=self, inputs=inputs) + + except Exception as e: + logger.warning(f"Error when calling callback {callback}: {e}") + + results = None + exception = None + try: + parent_call_id = ACTIVE_CALL_ID.get() + # Active ID must be set right before the function is called, + # not before calling the callbacks. + ACTIVE_CALL_ID.set(call_id) + results = fn(self, *args, **kwargs) + return results + except Exception as e: + exception = e + raise exception + finally: + ACTIVE_CALL_ID.set(parent_call_id) + for callback in callbacks: + try: + _get_on_end_handler(callback, self, fn)( + call_id=call_id, + outputs=results, + exception=exception, + ) + except Exception as e: + logger.warning(f"Error when calling callback {callback}: {e}") + + return wrapper + + +def _get_on_start_handler(callback: BaseCallback, instance: Any, fn: Callable) -> Callable: + """ + Selects the appropriate on_start handler of the callback + based on the instance and function name. + """ + if isinstance(instance, (dspy.LM)): + return callback.on_lm_start + elif isinstance(instance, (dspy.Adapter)): + if fn.__name__ == "format": + return callback.on_format_start + elif fn.__name__ == "parse": + return callback.on_parse_start + + # We treat everything else as a module. + return callback.on_module_start + + +def _get_on_end_handler(callback: BaseCallback, instance: Any, fn: Callable) -> Callable: + """ + Selects the appropriate on_end handler of the callback + based on the instance and function name. + """ + if isinstance(instance, (dspy.LM)): + return callback.on_lm_end + elif isinstance(instance, (dspy.Adapter)): + if fn.__name__ == "format": + return callback.on_format_end + elif fn.__name__ == "parse": + return callback.on_parse_end + + # We treat everything else as a module. + return callback.on_module_end diff --git a/dspy/utils/dummies.py b/dspy/utils/dummies.py index ccf758aa14..90f5def66a 100644 --- a/dspy/utils/dummies.py +++ b/dspy/utils/dummies.py @@ -10,6 +10,7 @@ from dspy.adapters.chat_adapter import FieldInfoWithName, field_header_pattern, format_fields from dspy.clients.lm import LM from dspy.signatures.field import OutputField +from dspy.utils.callback import with_callbacks class DSPDummyLM(DSPLM): @@ -170,6 +171,7 @@ def _use_example(self, messages): if any(field in output["content"] for field in output_fields) and final_input in input["content"]: return output["content"] + @with_callbacks def __call__(self, prompt=None, messages=None, **kwargs): def format_answer_fields(field_names_and_values: Dict[str, Any]): return format_fields( diff --git a/tests/callback/test_callback.py b/tests/callback/test_callback.py new file mode 100644 index 0000000000..a80e4060fd --- /dev/null +++ b/tests/callback/test_callback.py @@ -0,0 +1,217 @@ +import time + +import pytest + +import dspy +from dspy.utils.callback import ACTIVE_CALL_ID, BaseCallback, with_callbacks +from dspy.utils.dummies import DummyLM + + +@pytest.fixture(autouse=True) +def reset_settings(): + # Make sure the settings are reset after each test + original_settings = dspy.settings.copy() + + yield + + dspy.settings.configure(**original_settings) + + +class MyCallback(BaseCallback): + def __init__(self): + self.calls = [] + + def on_module_start(self, call_id, instance, inputs): + self.calls.append({"handler": "on_module_start", "instance": instance, "inputs": inputs}) + + def on_module_end(self, call_id, outputs, exception): + self.calls.append({"handler": "on_module_end", "outputs": outputs, "exception": exception}) + + def on_lm_start(self, call_id, instance, inputs): + self.calls.append({"handler": "on_lm_start", "instance": instance, "inputs": inputs}) + + def on_lm_end(self, call_id, outputs, exception): + self.calls.append({"handler": "on_lm_end", "outputs": outputs, "exception": exception}) + + def on_format_start(self, call_id, instance, inputs): + self.calls.append({"handler": "on_format_start", "instance": instance, "inputs": inputs}) + + def on_format_end(self, call_id, outputs, exception): + self.calls.append({"handler": "on_format_end", "outputs": outputs, "exception": exception}) + + def on_parse_start(self, call_id, instance, inputs): + self.calls.append({"handler": "on_parse_start", "instance": instance, "inputs": inputs}) + + def on_parse_end(self, call_id, outputs, exception): + self.calls.append({"handler": "on_parse_end", "outputs": outputs, "exception": exception}) + + +@pytest.mark.parametrize( + ("args", "kwargs"), + [ + ([1, "2", 3.0], {}), + ([1, "2"], {"z": 3.0}), + ([1], {"y": "2", "z": 3.0}), + ([], {"x": 1, "y": "2", "z": 3.0}), + ], +) +def test_callback_injection(args, kwargs): + class Target(dspy.Module): + @with_callbacks + def forward(self, x: int, y: str, z: float) -> int: + time.sleep(0.1) + return x + int(y) + int(z) + + callback = MyCallback() + dspy.settings.configure(callbacks=[callback]) + + target = Target() + result = target.forward(*args, **kwargs) + + assert result == 6 + + assert len(callback.calls) == 2 + assert callback.calls[0]["handler"] == "on_module_start" + assert callback.calls[0]["inputs"] == {"x": 1, "y": "2", "z": 3.0} + assert callback.calls[1]["handler"] == "on_module_end" + assert callback.calls[1]["outputs"] == 6 + + +def test_callback_injection_local(): + class Target(dspy.Module): + @with_callbacks + def forward(self, x: int, y: str, z: float) -> int: + time.sleep(0.1) + return x + int(y) + int(z) + + callback = MyCallback() + + target_1 = Target(callbacks=[callback]) + result = target_1.forward(1, "2", 3.0) + + assert result == 6 + + assert len(callback.calls) == 2 + assert callback.calls[0]["handler"] == "on_module_start" + assert callback.calls[0]["inputs"] == {"x": 1, "y": "2", "z": 3.0} + assert callback.calls[1]["handler"] == "on_module_end" + assert callback.calls[1]["outputs"] == 6 + + callback.calls = [] + + target_2 = Target() + result = target_2.forward(1, "2", 3.0) + + # Other instance should not trigger the callback + assert not callback.calls + + +def test_callback_error_handling(): + class Target(dspy.Module): + @with_callbacks + def forward(self, x: int, y: str, z: float) -> int: + time.sleep(0.1) + raise ValueError("Error") + + callback = MyCallback() + dspy.settings.configure(callbacks=[callback]) + + target = Target() + + with pytest.raises(ValueError, match="Error"): + target.forward(1, "2", 3.0) + + assert len(callback.calls) == 2 + assert callback.calls[0]["handler"] == "on_module_start" + assert callback.calls[1]["handler"] == "on_module_end" + assert isinstance(callback.calls[1]["exception"], ValueError) + + +def test_multiple_callbacks(): + class Target(dspy.Module): + @with_callbacks + def forward(self, x: int, y: str, z: float) -> int: + time.sleep(0.1) + return x + int(y) + int(z) + + callback_1 = MyCallback() + callback_2 = MyCallback() + dspy.settings.configure(callbacks=[callback_1, callback_2]) + + target = Target() + result = target.forward(1, "2", 3.0) + + assert result == 6 + + assert len(callback_1.calls) == 2 + assert len(callback_2.calls) == 2 + + +def test_callback_complex_module(): + callback = MyCallback() + dspy.settings.configure( + lm=DummyLM({"How are you?": {"answer": "test output", "reasoning": "No more responses"}}), + callbacks=[callback], + ) + + cot = dspy.ChainOfThought("question -> answer", n=3) + result = cot(question="How are you?") + assert result["answer"] == "test output" + assert result["reasoning"] == "No more responses" + + assert len(callback.calls) == 14 + assert [call["handler"] for call in callback.calls] == [ + "on_module_start", + "on_module_start", + "on_format_start", + "on_format_end", + "on_lm_start", + "on_lm_end", + # Parsing will run per output (n=3) + "on_parse_start", + "on_parse_end", + "on_parse_start", + "on_parse_end", + "on_parse_start", + "on_parse_end", + "on_module_end", + "on_module_end", + ] + + +def test_active_id(): + # Test the call ID is generated and handled properly + class CustomCallback(BaseCallback): + def __init__(self): + self.parent_call_ids = [] + self.call_ids = [] + + def on_module_start(self, call_id, instance, inputs): + parent_call_id = ACTIVE_CALL_ID.get() + self.parent_call_ids.append(parent_call_id) + self.call_ids.append(call_id) + + class Parent(dspy.Module): + def __init__(self): + self.child_1 = Child() + self.child_2 = Child() + + def forward(self): + self.child_1() + self.child_2() + + class Child(dspy.Module): + def forward(self): + pass + + callback = CustomCallback() + dspy.settings.configure(callbacks=[callback]) + + parent = Parent() + parent() + + assert len(callback.call_ids) == 3 + # All three calls should have different call ids + assert len(set(callback.call_ids)) == 3 + parent_call_id = callback.call_ids[0] + assert callback.parent_call_ids == [None, parent_call_id, parent_call_id] diff --git a/tests/predict/test_react.py b/tests/predict/test_react.py index 2573abcb8d..1a85a1267e 100644 --- a/tests/predict/test_react.py +++ b/tests/predict/test_react.py @@ -1,12 +1,12 @@ from dataclasses import dataclass import dspy -from dspy.utils.dummies import dummy_rm +from dspy.utils.dummies import DummyLM, dummy_rm def test_example_no_tools(): # Createa a simple dataset which the model will use with the Retrieve tool. - lm = dspy.utils.DummyLM( + lm = DummyLM( [ {"Thought_1": "Initial thoughts", "Action_1": "Finish[blue]"}, ] @@ -26,7 +26,7 @@ def test_example_no_tools(): def test_example_search(): # Createa a simple dataset which the model will use with the Retrieve tool. - lm = dspy.utils.DummyLM( + lm = DummyLM( [ {"Thought_1": "Initial thoughts", "Action_1": "Search[the color of the sky]"}, {"Thought_2": "More thoughts", "Action_2": "Finish[blue]\n\n"}, @@ -89,7 +89,7 @@ def __call__(self, *args, **kwargs): def test_custom_tools(): - lm = dspy.utils.DummyLM( + lm = DummyLM( [ {"Thought_1": "Initial thoughts", "Action_1": "Tool1[foo]"}, {"Thought_2": "More thoughts", "Action_2": "Tool2[bar]"},