Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions dspy/evaluate/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import tqdm

import dspy
from dspy.utils.callback import with_callbacks
from dspy.utils.parallelizer import ParallelExecutor

try:
Expand Down Expand Up @@ -83,6 +84,7 @@ def __init__(
self.provide_traceback = provide_traceback
self.failure_score = failure_score

@with_callbacks
def __call__(
self,
program: "dspy.Module",
Expand Down
36 changes: 36 additions & 0 deletions dspy/utils/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,38 @@ def on_tool_end(
"""
pass

def on_evaluate_start(
self,
call_id: str,
instance: Any,
inputs: Dict[str, Any],
):
"""A handler triggered when evaluation is started.

Args:
call_id: A unique identifier for the call. Can be used to connect start/end handlers.
instance: The Evaluate instance.
inputs: The inputs to the Evaluate's __call__ method. Each arguments is stored as
a key-value pair in a dictionary.
"""
pass

def on_evaluate_end(
self,
call_id: str,
outputs: Optional[Dict[str, Any]],
exception: Optional[Exception] = None,
):
"""A handler triggered after evaluation is executed.

Args:
call_id: A unique identifier for the call. Can be used to connect start/end handlers.
outputs: The outputs of the Evaluate's __call__ method. If the method is interrupted by
an exception, this will be None.
exception: If an exception is raised during the execution, it will be stored here.
"""
pass


def with_callbacks(fn):
@functools.wraps(fn)
Expand Down Expand Up @@ -279,6 +311,8 @@ def _get_on_start_handler(callback: BaseCallback, instance: Any, fn: Callable) -
"""Selects the appropriate on_start handler of the callback based on the instance and function name."""
if isinstance(instance, dspy.LM):
return callback.on_lm_start
elif isinstance(instance, dspy.Evaluate):
return callback.on_evaluate_start

if isinstance(instance, dspy.Adapter):
if fn.__name__ == "format":
Expand All @@ -299,6 +333,8 @@ def _get_on_end_handler(callback: BaseCallback, instance: Any, fn: Callable) ->
"""Selects the appropriate on_end handler of the callback based on the instance and function name."""
if isinstance(instance, (dspy.LM)):
return callback.on_lm_end
elif isinstance(instance, dspy.Evaluate):
return callback.on_evaluate_end

if isinstance(instance, (dspy.Adapter)):
if fn.__name__ == "format":
Expand Down
52 changes: 52 additions & 0 deletions tests/evaluate/test_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from dspy.evaluate.evaluate import Evaluate
from dspy.evaluate.metrics import answer_exact_match
from dspy.predict import Predict
from dspy.utils.callback import BaseCallback
from dspy.utils.dummies import DummyLM


Expand Down Expand Up @@ -173,3 +174,54 @@ def test_evaluate_display_table(program_with_example, display_table, is_in_ipyth
# to the console
example_input = next(iter(example.inputs().values()))
assert example_input in out

def test_evaluate_callback():
class TestCallback(BaseCallback):
def __init__(self):
self.start_call_inputs = None
self.start_call_count = 0
self.end_call_outputs = None
self.end_call_count = 0

def on_evaluate_start(
self,
call_id: str,
instance,
inputs,
):
self.start_call_inputs = inputs
self.start_call_count += 1

def on_evaluate_end(
self,
call_id: str,
outputs,
exception = None,
):
self.end_call_outputs = outputs
self.end_call_count += 1

callback = TestCallback()
dspy.settings.configure(
lm=DummyLM(
{
"What is 1+1?": {"answer": "2"},
"What is 2+2?": {"answer": "4"},
}
),
callbacks=[callback]
)
devset = [new_example("What is 1+1?", "2"), new_example("What is 2+2?", "4")]
program = Predict("question -> answer")
assert program(question="What is 1+1?").answer == "2"
ev = Evaluate(
devset=devset,
metric=answer_exact_match,
display_progress=False,
)
score = ev(program)
assert score == 100.0
assert callback.start_call_inputs["program"] == program
assert callback.start_call_count == 1
assert callback.end_call_outputs == 100.0
assert callback.end_call_count == 1
Loading