From 9defca71a211b2c084fa43f201ce1ee7e5d39cb5 Mon Sep 17 00:00:00 2001 From: chenmoneygithub Date: Fri, 25 Oct 2024 14:18:02 -0700 Subject: [PATCH] Some fixes to callback --- dspy/clients/lm.py | 17 +++-- dspy/utils/__init__.py | 6 +- dspy/utils/callback.py | 112 +++++++++++++++----------------- tests/callback/test_callback.py | 34 +++++----- 4 files changed, 82 insertions(+), 87 deletions(-) diff --git a/dspy/clients/lm.py b/dspy/clients/lm.py index 89f259f5a5..523f8ffc11 100644 --- a/dspy/clients/lm.py +++ b/dspy/clients/lm.py @@ -1,23 +1,22 @@ -from concurrent.futures import ThreadPoolExecutor -from datetime import datetime import functools import os +import uuid +from concurrent.futures import ThreadPoolExecutor +from datetime import datetime from pathlib import Path from typing import Any, Dict, List, Optional + +import litellm import ujson -import uuid +from litellm.caching import Cache -from dspy.utils.logging import logger from dspy.clients.finetune import FinetuneJob, TrainingMethod from dspy.clients.lm_finetune_utils import ( - get_provider_finetune_job_class, execute_finetune_job, + get_provider_finetune_job_class, ) - from dspy.utils.callback import with_callbacks -import litellm -from litellm.caching import Cache - +from dspy.utils.logging import logger DISK_CACHE_DIR = os.environ.get("DSPY_CACHEDIR") or os.path.join(Path.home(), ".dspy_cache") litellm.cache = Cache(disk_cache_dir=DISK_CACHE_DIR, type="disk") diff --git a/dspy/utils/__init__.py b/dspy/utils/__init__.py index 1f8dfb4ce6..bff180328f 100644 --- a/dspy/utils/__init__.py +++ b/dspy/utils/__init__.py @@ -1,3 +1,3 @@ -from .callback import * -from .dummies import * -from .logging import * +from dspy.utils.callback import BaseCallback, with_callbacks +from dspy.utils.dummies import * +from dspy.utils.logging import * diff --git a/dspy/utils/callback.py b/dspy/utils/callback.py index b50bfd39aa..5f9bcdc33e 100644 --- a/dspy/utils/callback.py +++ b/dspy/utils/callback.py @@ -13,14 +13,16 @@ class BaseCallback: - """ - A base class for defining callback handlers for DSPy components. + """A base class for defining callback handlers for DSPy components. To use a callback, subclass this class and implement the desired handlers. Each handler - will be called at the appropriate time before/after the execution of the corresponding component. + will be called at the appropriate time before/after the execution of the corresponding component. For example, if + you want to print a message before and after an LM is called, implement `the on_llm_start` and `on_lm_end` handler. + Users can set the callback globally using `dspy.settings.configure` or locally by passing it to the component + constructor. + - For example, if you want to print a message before and after an LM is called, implement - the on_llm_start and on_lm_end handler and set the callback to the global settings using `dspy.settings.configure`. + Example 1: Set a global callback using `dspy.settings.configure`. ``` import dspy @@ -45,19 +47,18 @@ def on_lm_end(self, call_id, outputs, exception): # > LM is finished with outputs: {'answer': '42'} ``` - Another way to set the callback is to pass it directly to the component constructor. - In this case, the callback will only be triggered for that specific instance. + Example 2: Set a local callback by passing it to the component constructor. ``` - lm = dspy.LM("gpt-3.5-turbo", callbacks=[LoggingCallback()]) - lm(question="What is the meaning of life?") + lm_1 = dspy.LM("gpt-3.5-turbo", callbacks=[LoggingCallback()]) + lm_1(question="What is the meaning of life?") # > LM is called with inputs: {'question': 'What is the meaning of life?'} # > LM is finished with outputs: {'answer': '42'} lm_2 = dspy.LM("gpt-3.5-turbo") lm_2(question="What is the meaning of life?") - # No logging here + # No logging here because only `lm_1` has the callback set. ``` """ @@ -67,8 +68,7 @@ def on_module_start( instance: Any, inputs: Dict[str, Any], ): - """ - A handler triggered when forward() method of a module (subclass of dspy.Module) is called. + """A handler triggered when forward() method of a module (subclass of dspy.Module) is called. Args: call_id: A unique identifier for the call. Can be used to connect start/end handlers. @@ -84,8 +84,7 @@ def on_module_end( outputs: Optional[Any], exception: Optional[Exception] = None, ): - """ - A handler triggered after forward() method of a module (subclass of dspy.Module) is executed. + """A handler triggered after forward() method of a module (subclass of dspy.Module) is executed. Args: call_id: A unique identifier for the call. Can be used to connect start/end handlers. @@ -101,8 +100,7 @@ def on_lm_start( instance: Any, inputs: Dict[str, Any], ): - """ - A handler triggered when __call__ method of dspy.LM instance is called. + """A handler triggered when __call__ method of dspy.LM instance is called. Args: call_id: A unique identifier for the call. Can be used to connect start/end handlers. @@ -118,8 +116,7 @@ def on_lm_end( outputs: Optional[Dict[str, Any]], exception: Optional[Exception] = None, ): - """ - A handler triggered after __call__ method of dspy.LM instance is executed. + """A handler triggered after __call__ method of dspy.LM instance is executed. Args: call_id: A unique identifier for the call. Can be used to connect start/end handlers. @@ -129,14 +126,13 @@ def on_lm_end( """ pass - def on_format_start( + def on_adapter_format_start( self, call_id: str, instance: Any, inputs: Dict[str, Any], ): - """ - A handler triggered when format() method of an adapter (subclass of dspy.Adapter) is called. + """A handler triggered when format() method of an adapter (subclass of dspy.Adapter) is called. Args: call_id: A unique identifier for the call. Can be used to connect start/end handlers. @@ -146,14 +142,13 @@ def on_format_start( """ pass - def on_format_end( + def on_adapter_format_end( self, call_id: str, outputs: Optional[Dict[str, Any]], exception: Optional[Exception] = None, ): - """ - A handler triggered after format() method of dspy.LM instance is executed. + """A handler triggered after format() method of an adapter (subclass of dspy.Adapter) is called.. Args: call_id: A unique identifier for the call. Can be used to connect start/end handlers. @@ -163,14 +158,13 @@ def on_format_end( """ pass - def on_parse_start( + def on_adapter_parse_start( self, call_id: str, instance: Any, inputs: Dict[str, Any], ): - """ - A handler triggered when parse() method of an adapter (subclass of dspy.Adapter) is called. + """A handler triggered when parse() method of an adapter (subclass of dspy.Adapter) is called. Args: call_id: A unique identifier for the call. Can be used to connect start/end handlers. @@ -180,14 +174,13 @@ def on_parse_start( """ pass - def on_parse_end( + def on_adapter_parse_end( self, call_id: str, outputs: Optional[Dict[str, Any]], exception: Optional[Exception] = None, ): - """ - A handler triggered after parse() method of dspy.LM instance is executed. + """A handler triggered after parse() method of an adapter (subclass of dspy.Adapter) is called. Args: call_id: A unique identifier for the call. Can be used to connect start/end handlers. @@ -200,23 +193,23 @@ def on_parse_end( def with_callbacks(fn): @functools.wraps(fn) - def wrapper(self, *args, **kwargs): - # Combine global and local (per-instance) callbacks - callbacks = dspy.settings.get("callbacks", []) + getattr(self, "callbacks", []) + def wrapper(instance, *args, **kwargs): + # Combine global and local (per-instance) callbacks. + callbacks = dspy.settings.get("callbacks", []) + getattr(instance, "callbacks", []) - # if no callbacks are provided, just call the function + # If no callbacks are provided, just call the function if not callbacks: - return fn(self, *args, **kwargs) + return fn(instance, *args, **kwargs) - # Generate call ID to connect start/end handlers if needed + # Generate call ID as the unique identifier for the call, this is useful for instrumentation. call_id = uuid.uuid4().hex - inputs = inspect.getcallargs(fn, self, *args, **kwargs) + inputs = inspect.getcallargs(fn, instance, *args, **kwargs) inputs.pop("self") # Not logging self as input for callback in callbacks: try: - _get_on_start_handler(callback, self, fn)(call_id=call_id, instance=self, inputs=inputs) + _get_on_start_handler(callback, instance, fn)(call_id=call_id, instance=instance, inputs=inputs) except Exception as e: logger.warning(f"Error when calling callback {callback}: {e}") @@ -225,58 +218,59 @@ def wrapper(self, *args, **kwargs): exception = None try: parent_call_id = ACTIVE_CALL_ID.get() - # Active ID must be set right before the function is called, - # not before calling the callbacks. + # Active ID must be set right before the function is called, not before calling the callbacks. ACTIVE_CALL_ID.set(call_id) - results = fn(self, *args, **kwargs) + results = fn(instance, *args, **kwargs) return results except Exception as e: exception = e raise exception finally: + # Execute the end handlers even if the function call raises an exception. ACTIVE_CALL_ID.set(parent_call_id) for callback in callbacks: try: - _get_on_end_handler(callback, self, fn)( + _get_on_end_handler(callback, instance, fn)( call_id=call_id, outputs=results, exception=exception, ) except Exception as e: - logger.warning(f"Error when calling callback {callback}: {e}") + logger.warning( + f"Error when applying callback {callback}'s end handler on function {fn.__name__}: {e}." + ) return wrapper def _get_on_start_handler(callback: BaseCallback, instance: Any, fn: Callable) -> Callable: - """ - Selects the appropriate on_start handler of the callback - based on the instance and function name. - """ - if isinstance(instance, (dspy.LM)): + """Selects the appropriate on_start handler of the callback based on the instance and function name.""" + if isinstance(instance, dspy.LM): return callback.on_lm_start - elif isinstance(instance, (dspy.Adapter)): + + if isinstance(instance, dspy.Adapter): if fn.__name__ == "format": - return callback.on_format_start + return callback.on_adapter_format_start elif fn.__name__ == "parse": - return callback.on_parse_start + return callback.on_adapter_parse_start + else: + raise ValueError(f"Unsupported adapter method for using callback: {fn.__name__}.") # We treat everything else as a module. return callback.on_module_start def _get_on_end_handler(callback: BaseCallback, instance: Any, fn: Callable) -> Callable: - """ - Selects the appropriate on_end handler of the callback - based on the instance and function name. - """ + """Selects the appropriate on_end handler of the callback based on the instance and function name.""" if isinstance(instance, (dspy.LM)): return callback.on_lm_end - elif isinstance(instance, (dspy.Adapter)): + + if isinstance(instance, (dspy.Adapter)): if fn.__name__ == "format": - return callback.on_format_end + return callback.on_adapter_format_end elif fn.__name__ == "parse": - return callback.on_parse_end - + return callback.on_adapter_parse_end + else: + raise ValueError(f"Unsupported adapter method for using callback: {fn.__name__}.") # We treat everything else as a module. return callback.on_module_end diff --git a/tests/callback/test_callback.py b/tests/callback/test_callback.py index a80e4060fd..50b04d0e34 100644 --- a/tests/callback/test_callback.py +++ b/tests/callback/test_callback.py @@ -18,6 +18,8 @@ def reset_settings(): class MyCallback(BaseCallback): + """A simple callback that records the calls.""" + def __init__(self): self.calls = [] @@ -33,17 +35,17 @@ def on_lm_start(self, call_id, instance, inputs): def on_lm_end(self, call_id, outputs, exception): self.calls.append({"handler": "on_lm_end", "outputs": outputs, "exception": exception}) - def on_format_start(self, call_id, instance, inputs): - self.calls.append({"handler": "on_format_start", "instance": instance, "inputs": inputs}) + def on_adapter_format_start(self, call_id, instance, inputs): + self.calls.append({"handler": "on_adapter_format_start", "instance": instance, "inputs": inputs}) - def on_format_end(self, call_id, outputs, exception): - self.calls.append({"handler": "on_format_end", "outputs": outputs, "exception": exception}) + def on_adapter_format_end(self, call_id, outputs, exception): + self.calls.append({"handler": "on_adapter_format_end", "outputs": outputs, "exception": exception}) - def on_parse_start(self, call_id, instance, inputs): - self.calls.append({"handler": "on_parse_start", "instance": instance, "inputs": inputs}) + def on_adapter_parse_start(self, call_id, instance, inputs): + self.calls.append({"handler": "on_adapter_parse_start", "instance": instance, "inputs": inputs}) - def on_parse_end(self, call_id, outputs, exception): - self.calls.append({"handler": "on_parse_end", "outputs": outputs, "exception": exception}) + def on_adapter_parse_end(self, call_id, outputs, exception): + self.calls.append({"handler": "on_adapter_parse_end", "outputs": outputs, "exception": exception}) @pytest.mark.parametrize( @@ -163,17 +165,17 @@ def test_callback_complex_module(): assert [call["handler"] for call in callback.calls] == [ "on_module_start", "on_module_start", - "on_format_start", - "on_format_end", + "on_adapter_format_start", + "on_adapter_format_end", "on_lm_start", "on_lm_end", # Parsing will run per output (n=3) - "on_parse_start", - "on_parse_end", - "on_parse_start", - "on_parse_end", - "on_parse_start", - "on_parse_end", + "on_adapter_parse_start", + "on_adapter_parse_end", + "on_adapter_parse_start", + "on_adapter_parse_end", + "on_adapter_parse_start", + "on_adapter_parse_end", "on_module_end", "on_module_end", ]