diff --git a/dspy/evaluate/evaluate.py b/dspy/evaluate/evaluate.py index 89be568937..407789cc03 100644 --- a/dspy/evaluate/evaluate.py +++ b/dspy/evaluate/evaluate.py @@ -6,6 +6,7 @@ import tqdm import dspy +from dspy.utils.callback import with_callbacks from dspy.utils.parallelizer import ParallelExecutor try: @@ -83,6 +84,7 @@ def __init__( self.provide_traceback = provide_traceback self.failure_score = failure_score + @with_callbacks def __call__( self, program: "dspy.Module", diff --git a/dspy/utils/callback.py b/dspy/utils/callback.py index a1292ce520..ea88a8ac51 100644 --- a/dspy/utils/callback.py +++ b/dspy/utils/callback.py @@ -222,6 +222,38 @@ def on_tool_end( """ pass + def on_evaluate_start( + self, + call_id: str, + instance: Any, + inputs: Dict[str, Any], + ): + """A handler triggered when evaluation is started. + + Args: + call_id: A unique identifier for the call. Can be used to connect start/end handlers. + instance: The Evaluate instance. + inputs: The inputs to the Evaluate's __call__ method. Each arguments is stored as + a key-value pair in a dictionary. + """ + pass + + def on_evaluate_end( + self, + call_id: str, + outputs: Optional[Dict[str, Any]], + exception: Optional[Exception] = None, + ): + """A handler triggered after evaluation is executed. + + Args: + call_id: A unique identifier for the call. Can be used to connect start/end handlers. + outputs: The outputs of the Evaluate's __call__ method. If the method is interrupted by + an exception, this will be None. + exception: If an exception is raised during the execution, it will be stored here. + """ + pass + def with_callbacks(fn): @functools.wraps(fn) @@ -279,6 +311,8 @@ def _get_on_start_handler(callback: BaseCallback, instance: Any, fn: Callable) - """Selects the appropriate on_start handler of the callback based on the instance and function name.""" if isinstance(instance, dspy.LM): return callback.on_lm_start + elif isinstance(instance, dspy.Evaluate): + return callback.on_evaluate_start if isinstance(instance, dspy.Adapter): if fn.__name__ == "format": @@ -299,6 +333,8 @@ def _get_on_end_handler(callback: BaseCallback, instance: Any, fn: Callable) -> """Selects the appropriate on_end handler of the callback based on the instance and function name.""" if isinstance(instance, (dspy.LM)): return callback.on_lm_end + elif isinstance(instance, dspy.Evaluate): + return callback.on_evaluate_end if isinstance(instance, (dspy.Adapter)): if fn.__name__ == "format": diff --git a/tests/evaluate/test_evaluate.py b/tests/evaluate/test_evaluate.py index 78935e3b86..8ac6f6564a 100644 --- a/tests/evaluate/test_evaluate.py +++ b/tests/evaluate/test_evaluate.py @@ -8,6 +8,7 @@ from dspy.evaluate.evaluate import Evaluate from dspy.evaluate.metrics import answer_exact_match from dspy.predict import Predict +from dspy.utils.callback import BaseCallback from dspy.utils.dummies import DummyLM @@ -173,3 +174,54 @@ def test_evaluate_display_table(program_with_example, display_table, is_in_ipyth # to the console example_input = next(iter(example.inputs().values())) assert example_input in out + +def test_evaluate_callback(): + class TestCallback(BaseCallback): + def __init__(self): + self.start_call_inputs = None + self.start_call_count = 0 + self.end_call_outputs = None + self.end_call_count = 0 + + def on_evaluate_start( + self, + call_id: str, + instance, + inputs, + ): + self.start_call_inputs = inputs + self.start_call_count += 1 + + def on_evaluate_end( + self, + call_id: str, + outputs, + exception = None, + ): + self.end_call_outputs = outputs + self.end_call_count += 1 + + callback = TestCallback() + dspy.settings.configure( + lm=DummyLM( + { + "What is 1+1?": {"answer": "2"}, + "What is 2+2?": {"answer": "4"}, + } + ), + callbacks=[callback] + ) + devset = [new_example("What is 1+1?", "2"), new_example("What is 2+2?", "4")] + program = Predict("question -> answer") + assert program(question="What is 1+1?").answer == "2" + ev = Evaluate( + devset=devset, + metric=answer_exact_match, + display_progress=False, + ) + score = ev(program) + assert score == 100.0 + assert callback.start_call_inputs["program"] == program + assert callback.start_call_count == 1 + assert callback.end_call_outputs == 100.0 + assert callback.end_call_count == 1