diff --git a/dspy/predict/react.py b/dspy/predict/react.py index ce8edaa46d..28640d5d10 100644 --- a/dspy/predict/react.py +++ b/dspy/predict/react.py @@ -9,7 +9,7 @@ class Tool: def __init__(self, func: Callable, name: str = None, desc: str = None, args: dict[str, Any] = None): - annotations_func = func if inspect.isfunction(func) else func.__call__ + annotations_func = func if inspect.isfunction(func) or inspect.ismethod(func) else func.__call__ self.func = func self.name = name or getattr(func, '__name__', type(func).__name__) self.desc = desc or getattr(func, '__doc__', None) or getattr(annotations_func, '__doc__', "") diff --git a/tests/predict/test_react.py b/tests/predict/test_react.py index 8435f86a9e..4c6a150db4 100644 --- a/tests/predict/test_react.py +++ b/tests/predict/test_react.py @@ -2,6 +2,7 @@ import dspy from dspy.utils.dummies import DummyLM, dummy_rm +from dspy.predict import react # def test_example_no_tools(): @@ -121,4 +122,28 @@ # react = dspy.ReAct(ExampleSignature) # assert react.react[0].signature.instructions is not None -# assert react.react[0].signature.instructions.startswith("You are going to generate output based on input.") \ No newline at end of file +# assert react.react[0].signature.instructions.startswith("You are going to generate output based on input.") + +def test_tool_from_function(): + def foo(a: int, b: int) -> int: + """Add two numbers.""" + return a + b + + tool = react.Tool(foo) + assert tool.name == "foo" + assert tool.desc == "Add two numbers." + assert tool.args == {"a": "int", "b": "int"} + +def test_tool_from_class(): + class Foo: + def __init__(self, user_id: str): + self.user_id = user_id + + def foo(self, a: int, b: int) -> int: + """Add two numbers.""" + return a + b + + tool = react.Tool(Foo("123").foo) + assert tool.name == "foo" + assert tool.desc == "Add two numbers." + assert tool.args == {"a": "int", "b": "int"}