From 2ce3d462981efabe507074efe3e12384e4f96805 Mon Sep 17 00:00:00 2001 From: B-Step62 Date: Sun, 1 Dec 2024 15:49:11 +0900 Subject: [PATCH] Add on_tool_start/end callbacks Signed-off-by: B-Step62 --- dspy/predict/__init__.py | 2 +- dspy/predict/react.py | 2 ++ dspy/utils/callback.py | 39 +++++++++++++++++++++++++++++++ tests/callback/test_callback.py | 41 +++++++++++++++++++++++++++++++++ 4 files changed, 83 insertions(+), 1 deletion(-) diff --git a/dspy/predict/__init__.py b/dspy/predict/__init__.py index d889d85493..0260ead378 100644 --- a/dspy/predict/__init__.py +++ b/dspy/predict/__init__.py @@ -5,6 +5,6 @@ from .multi_chain_comparison import MultiChainComparison from .predict import Predict from .program_of_thought import ProgramOfThought -from .react import ReAct +from .react import ReAct, Tool from .retry import Retry from .parallel import Parallel \ No newline at end of file diff --git a/dspy/predict/react.py b/dspy/predict/react.py index f114460181..c6fcb61300 100644 --- a/dspy/predict/react.py +++ b/dspy/predict/react.py @@ -5,6 +5,7 @@ from dspy.primitives.program import Module from dspy.signatures.signature import ensure_signature from dspy.adapters.json_adapter import get_annotation_name +from dspy.utils.callback import with_callbacks from typing import Callable, Any, get_type_hints, get_origin, Literal class Tool: @@ -19,6 +20,7 @@ def __init__(self, func: Callable, name: str = None, desc: str = None, args: dic for k, v in (args or get_type_hints(annotations_func)).items() if k != 'return' } + @with_callbacks def __call__(self, *args, **kwargs): return self.func(*args, **kwargs) diff --git a/dspy/utils/callback.py b/dspy/utils/callback.py index 5f9bcdc33e..a1292ce520 100644 --- a/dspy/utils/callback.py +++ b/dspy/utils/callback.py @@ -190,6 +190,38 @@ def on_adapter_parse_end( """ pass + def on_tool_start( + self, + call_id: str, + instance: Any, + inputs: Dict[str, Any], + ): + """A handler triggered when a tool is called. + + Args: + call_id: A unique identifier for the call. Can be used to connect start/end handlers. + instance: The Tool instance. + inputs: The inputs to the Tool's __call__ method. Each arguments is stored as + a key-value pair in a dictionary. + """ + pass + + def on_tool_end( + self, + call_id: str, + outputs: Optional[Dict[str, Any]], + exception: Optional[Exception] = None, + ): + """A handler triggered after a tool is executed. + + Args: + call_id: A unique identifier for the call. Can be used to connect start/end handlers. + outputs: The outputs of the Tool's __call__ method. If the method is interrupted by + an exception, this will be None. + exception: If an exception is raised during the execution, it will be stored here. + """ + pass + def with_callbacks(fn): @functools.wraps(fn) @@ -256,6 +288,9 @@ def _get_on_start_handler(callback: BaseCallback, instance: Any, fn: Callable) - else: raise ValueError(f"Unsupported adapter method for using callback: {fn.__name__}.") + if isinstance(instance, dspy.Tool): + return callback.on_tool_start + # We treat everything else as a module. return callback.on_module_start @@ -272,5 +307,9 @@ def _get_on_end_handler(callback: BaseCallback, instance: Any, fn: Callable) -> return callback.on_adapter_parse_end else: raise ValueError(f"Unsupported adapter method for using callback: {fn.__name__}.") + + if isinstance(instance, dspy.Tool): + return callback.on_tool_end + # We treat everything else as a module. return callback.on_module_end diff --git a/tests/callback/test_callback.py b/tests/callback/test_callback.py index 50b04d0e34..ae87538897 100644 --- a/tests/callback/test_callback.py +++ b/tests/callback/test_callback.py @@ -47,6 +47,12 @@ def on_adapter_parse_start(self, call_id, instance, inputs): def on_adapter_parse_end(self, call_id, outputs, exception): self.calls.append({"handler": "on_adapter_parse_end", "outputs": outputs, "exception": exception}) + def on_tool_start(self, call_id, instance, inputs): + self.calls.append({"handler": "on_tool_start", "instance": instance, "inputs": inputs}) + + def on_tool_end(self, call_id, outputs, exception): + self.calls.append({"handler": "on_tool_end", "outputs": outputs, "exception": exception}) + @pytest.mark.parametrize( ("args", "kwargs"), @@ -181,6 +187,41 @@ def test_callback_complex_module(): ] +def test_tool_calls(): + callback = MyCallback() + dspy.settings.configure(callbacks=[callback]) + + def tool_1(query: str) -> str: + """A dummy tool function.""" + return "result 1" + + def tool_2(query: str) -> str: + """Another dummy tool function.""" + return "result 2" + + class MyModule(dspy.Module): + def __init__(self): + self.tools = [dspy.Tool(tool_1), dspy.Tool(tool_2)] + + def forward(self, query: str) -> str: + query = self.tools[0](query) + return self.tools[1](query) + + module = MyModule() + result = module("query") + + assert result == "result 2" + assert len(callback.calls) == 6 + assert [call["handler"] for call in callback.calls] == [ + "on_module_start", + "on_tool_start", + "on_tool_end", + "on_tool_start", + "on_tool_end", + "on_module_end", + ] + + def test_active_id(): # Test the call ID is generated and handled properly class CustomCallback(BaseCallback):