From 7e05e43a8cdadc26cb3dd13f2dc5b2d3dff843e4 Mon Sep 17 00:00:00 2001 From: lxdlam Date: Mon, 24 Mar 2025 17:53:18 +0800 Subject: [PATCH 1/4] feat: ProgramOfThought supports multiple return and more --- dspy/predict/program_of_thought.py | 64 +++++++-------------- dspy/primitives/python_interpreter.py | 12 ++-- tests/predict/test_program_of_thought.py | 73 ++++++++++++++++++++++-- 3 files changed, 95 insertions(+), 54 deletions(-) diff --git a/dspy/predict/program_of_thought.py b/dspy/predict/program_of_thought.py index 4214861160..8488acbef5 100644 --- a/dspy/predict/program_of_thought.py +++ b/dspy/predict/program_of_thought.py @@ -1,6 +1,7 @@ import logging import re from typing import Union, Type +import json import dspy from dspy.signatures.signature import ensure_signature, Signature @@ -10,6 +11,7 @@ logger = logging.getLogger(__name__) + class ProgramOfThought(Module): """ A DSPy module that runs Python programs to solve a problem. @@ -39,28 +41,6 @@ def __init__(self, signature: Union[str, Type[Signature]], max_iters=3): self.input_fields = signature.input_fields self.output_fields = signature.output_fields - assert len(self.output_fields) == 1, "PoT only supports one output field." - - self.output_field_name = next(iter(self.output_fields)) - inputs_ = ", ".join( - [f"`{field_name}`" for field_name in self.input_fields.keys()], - ) - outputs_ = f"`{self.output_field_name}`" - - assert len(self.output_fields) == 1, "PoT only supports one output field." - - instr = [] - instr.append( - f"You will be given {inputs_} and you will respond with {outputs_}.", - ) - instr.append( - f"Generating executable Python code that programmatically computes the correct {outputs_}.", - ) - instr.append( - f"After you're done with the computation, make sure the last line in your code evaluates to the correct value for {outputs_}.", - ) - instr = "\n".join(instr) - self.code_generate = dspy.ChainOfThought( dspy.Signature( self._generate_signature("generate").fields, @@ -79,8 +59,7 @@ def __init__(self, signature: Union[str, Type[Signature]], max_iters=3): self._generate_instruction("answer"), ), ) - # Currently, the interpreter class checks the deno availability at execution time. - # We may consider checking it at the initialization time for better instruction. + # It will raises exception when dspy cannot find available deno instance by now. self.interpreter = PythonInterpreter() def _generate_signature(self, mode): @@ -119,25 +98,24 @@ def _generate_signature(self, mode): prefix="Code Output:", desc="output of previously-generated python code", ), - self.output_field_name: self.signature.fields[self.output_field_name], - }, + } + | self.signature.output_fields, } signature_dict.update(fields_for_mode[mode]) return dspy.Signature(signature_dict) def _generate_instruction(self, mode): mode_inputs = ", ".join( - [ - f"`{field_name}`" - for field_name in self._generate_signature(mode).input_fields - ], + [f"`{field_name}`" for field_name in self._generate_signature(mode).input_fields], + ) + mode_outputs = ", ".join( + [f"`{field_name}`" for field_name in self._generate_signature(mode).output_fields], ) - mode_outputs = f"`{self.output_field_name}`" if mode == "generate": instr = [ f"You will be given {mode_inputs} and you will respond with {mode_outputs}.", f"Generating executable Python code that programmatically computes the correct {mode_outputs}.", - f"After you're done with the computation, make sure the last line in your code evaluates to the correct value for {mode_outputs}.", + f"After you're done with the computation, make sure format your output in json object that evaluates to the correct value mapping for {mode_outputs}, then response with the preloaded function `final_answer()`.", ] elif mode == "regenerate": instr = [ @@ -151,11 +129,8 @@ def _generate_instruction(self, mode): return "\n".join(instr) - def _parse_code(self, code_data): - code = ( - code_data.get("generated_code", "").split("---", 1)[0].split("\n\n\n", 1)[0] - ) + code = code_data.get("generated_code", "").split("---", 1)[0].split("\n\n\n", 1)[0] code_match = re.search(r"```python[ \n](.*?)[ \n]```?", code, re.DOTALL) code_block = (code_match.group(1) if code_match else code).replace("\\n", "\n") if not code_block: @@ -168,10 +143,14 @@ def _parse_code(self, code_data): code_block += "\n" + last_line_match.group(1) else: code_block = re.sub( - r"([a-zA-Z_]\w* *=.*?)(?=[a-zA-Z_]\w* *=)", r"\1\n", code_block, + r"([a-zA-Z_]\w* *=.*?)(?=[a-zA-Z_]\w* *=)", + r"\1\n", + code_block, ) code_block = re.sub( - r"([a-zA-Z_]\w* *=.*?)([a-zA-Z_]\w*)$", r"\1\n\2", code_block, + r"([a-zA-Z_]\w* *=.*?)([a-zA-Z_]\w*)$", + r"\1\n\2", + code_block, ) return code_block, None @@ -181,17 +160,16 @@ def _execute_code(self, code): """ if not code: return None, "Error: Empty code before execution." - + try: - output = str(self.interpreter.execute(code)) + # Since it's more complex structure now, just blindly use json to represents all. + output = json.dumps(self.interpreter.execute(code)) return output, None except Exception as e: return None, str(e) def forward(self, **kwargs): - input_kwargs = { - field_name: kwargs[field_name] for field_name in self.input_fields - } + input_kwargs = {field_name: kwargs[field_name] for field_name in self.input_fields} code_data = self.code_generate(**input_kwargs) output = None code, error = self._parse_code(code_data) diff --git a/dspy/primitives/python_interpreter.py b/dspy/primitives/python_interpreter.py index 7aca51b074..291477fcf2 100644 --- a/dspy/primitives/python_interpreter.py +++ b/dspy/primitives/python_interpreter.py @@ -117,14 +117,16 @@ def execute( if "error" in result: error_msg = result["error"] error_type = result.get("errorType", "Sandbox Error") - if error_type == "SyntaxError": + if error_type == "FinalAnswer": + # The `FinalAnswer` trick to receive output from the sandboxed address + # Don't rabbit hole here, let replace the output field to create a consistent + result["output"] = result.get("errorArgs", None) + elif error_type == "SyntaxError": raise SyntaxError(f"Invalid Python syntax. message: {error_msg}") - elif error_type == "FinalAnswer": - return result.get("errorArgs") else: raise InterpreterError(f"{error_type}: {result.get('errorArgs') or error_msg}") - # If there's no error, return the "output" field + # If there's no error or got `FinalAnswer`, return the "output" field return result.get("output", None) def __enter__(self): @@ -153,4 +155,4 @@ def shutdown(self) -> None: self.deno_process.stdin.flush() self.deno_process.stdin.close() self.deno_process.wait() - self.deno_process = None + self.deno_process = None \ No newline at end of file diff --git a/tests/predict/test_program_of_thought.py b/tests/predict/test_program_of_thought.py index aeef3aae1e..b457c47dda 100644 --- a/tests/predict/test_program_of_thought.py +++ b/tests/predict/test_program_of_thought.py @@ -9,6 +9,7 @@ # This test suite requires deno to be installed. Please install deno following https://docs.deno.com/runtime/getting_started/installation/ is_deno_available = shutil.which("deno") is not None + class BasicQA(Signature): question = dspy.InputField() answer = dspy.OutputField(desc="often between 1 and 5 words") @@ -16,6 +17,25 @@ class BasicQA(Signature): @pytest.mark.skipif(not is_deno_available, reason="Deno is not installed or not in PATH") def test_pot_code_generation(): + lm = DummyLM( + [ + { + "reasoning": "Reason_A", + "generated_code": "```python\nresult = 1+1\nfinal_answer({'answer': result})\n```", + }, + {"reasoning": "Reason_B", "answer": "2"}, + ] + ) + dspy.settings.configure(lm=lm) + pot = ProgramOfThought(BasicQA) + res = pot(question="What is 1+1?") + assert res.answer == "2" + assert pot.interpreter.deno_process is None + + +# This test ensures the old finetuned saved models still work +@pytest.mark.skipif(not is_deno_available, reason="Deno is not installed or not in PATH") +def test_old_style_pot(): lm = DummyLM( [ {"reasoning": "Reason_A", "generated_code": "```python\nresult = 1+1\n```"}, @@ -29,12 +49,43 @@ def test_pot_code_generation(): assert pot.interpreter.deno_process is None +class ExtremumFinder(Signature): + input_list = dspy.InputField() + maximum = dspy.OutputField() + minimum = dspy.OutputField() + + +@pytest.mark.skipif(not is_deno_available, reason="Deno is not installed or not in PATH") +def test_pot_support_multiple_fields(): + lm = DummyLM( + [ + { + "reasoning": "Reason_A", + "generated_code": "```python\nmaximum = 6\nminimum = 2\nfinal_answer({'maximum': maximum, 'minimum': minimum})\n```", + }, + {"reasoning": "Reason_B", "maximum": "6", "minimum": "2"}, + ] + ) + dspy.settings.configure(lm=lm) + pot = ProgramOfThought(ExtremumFinder) + res = pot(input_list="2, 3, 5, 6") + assert res.maximum == "6" + assert res.minimum == "2" + assert pot.interpreter.deno_process is None + + @pytest.mark.skipif(not is_deno_available, reason="Deno is not installed or not in PATH") def test_pot_code_generation_with_one_error(): lm = DummyLM( [ - {"reasoning": "Reason_A", "generated_code": "```python\nresult = 1+0/0\n```"}, - {"reasoning": "Reason_B", "generated_code": "```python\nresult = 1+1\n```"}, + { + "reasoning": "Reason_A", + "generated_code": "```python\nresult = 1+0/0\nfinal_answer({'answer': result})\n```", + }, + { + "reasoning": "Reason_B", + "generated_code": "```python\nresult = 1+1\nfinal_answer({'answer': result})\n```", + }, {"reasoning": "Reason_C", "answer": "2"}, ] ) @@ -51,8 +102,12 @@ def test_pot_code_generation_persistent_errors(): max_iters = 3 lm = DummyLM( [ - {"reasoning": "Reason_A", "generated_code": "```python\nresult = 1+0/0\n```"}, - ] * max_iters + { + "reasoning": "Reason_A", + "generated_code": "```python\nresult = 1+0/0\nfinal_answer({'answer': result})\n```", + }, + ] + * max_iters ) dspy.settings.configure(lm=lm) @@ -67,11 +122,17 @@ def test_pot_code_parse_error(): lm = DummyLM( [ {"reasoning": "Reason_A", "generated_code": "```python\ninvalid=python=code\n```"}, - ] * max_iters + ] + * max_iters ) dspy.settings.configure(lm=lm) pot = ProgramOfThought(BasicQA, max_iters=max_iters) - with patch("dspy.predict.program_of_thought.ProgramOfThought._execute_code") as mock_execute_code, pytest.raises(RuntimeError, match="Max hops reached. Failed to run ProgramOfThought: Error: Code format is not correct."): + with ( + patch("dspy.predict.program_of_thought.ProgramOfThought._execute_code") as mock_execute_code, + pytest.raises( + RuntimeError, match="Max hops reached. Failed to run ProgramOfThought: Error: Code format is not correct." + ), + ): pot(question="What is 1+1?") mock_execute_code.assert_not_called() From 1f1cd2e07db261a0e11190ca32ebdb0001d583bc Mon Sep 17 00:00:00 2001 From: lxdlam Date: Mon, 24 Mar 2025 18:46:43 +0800 Subject: [PATCH 2/4] chore: prompt fixes --- tests/predict/test_program_of_thought.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/predict/test_program_of_thought.py b/tests/predict/test_program_of_thought.py index b457c47dda..4932657892 100644 --- a/tests/predict/test_program_of_thought.py +++ b/tests/predict/test_program_of_thought.py @@ -51,8 +51,8 @@ def test_old_style_pot(): class ExtremumFinder(Signature): input_list = dspy.InputField() - maximum = dspy.OutputField() - minimum = dspy.OutputField() + maximum = dspy.OutputField(desc="The maximum of the given numbers") + minimum = dspy.OutputField(desc="The minimum of the given numbers") @pytest.mark.skipif(not is_deno_available, reason="Deno is not installed or not in PATH") @@ -90,7 +90,6 @@ def test_pot_code_generation_with_one_error(): ] ) dspy.settings.configure(lm=lm) - pot = ProgramOfThought(BasicQA) res = pot(question="What is 1+1?") assert res.answer == "2" @@ -126,7 +125,6 @@ def test_pot_code_parse_error(): * max_iters ) dspy.settings.configure(lm=lm) - pot = ProgramOfThought(BasicQA, max_iters=max_iters) with ( patch("dspy.predict.program_of_thought.ProgramOfThought._execute_code") as mock_execute_code, From 5e53ba89d221878e50355ff79b08af1bd8495551 Mon Sep 17 00:00:00 2001 From: lxdlam Date: Mon, 24 Mar 2025 18:51:18 +0800 Subject: [PATCH 3/4] chore: add missing codes --- dspy/predict/program_of_thought.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/dspy/predict/program_of_thought.py b/dspy/predict/program_of_thought.py index 8488acbef5..28ade9345c 100644 --- a/dspy/predict/program_of_thought.py +++ b/dspy/predict/program_of_thought.py @@ -111,11 +111,15 @@ def _generate_instruction(self, mode): mode_outputs = ", ".join( [f"`{field_name}`" for field_name in self._generate_signature(mode).output_fields], ) + final_outputs = ", ".join( + [f"`{field_name}`" for field_name in self.output_fields], + ) if mode == "generate": instr = [ f"You will be given {mode_inputs} and you will respond with {mode_outputs}.", f"Generating executable Python code that programmatically computes the correct {mode_outputs}.", - f"After you're done with the computation, make sure format your output in json object that evaluates to the correct value mapping for {mode_outputs}, then response with the preloaded function `final_answer()`.", + "After you're done with the computation and think you have the answer, make sure to provide your answer by calling the preloaded function `final_answer()`.", + f'You should structure your answer in a dict object, like {{"field_a": answer_a, ...}}, evaluates to the correct value mapping for {final_outputs}.', ] elif mode == "regenerate": instr = [ From 28637a8302d36a1e2f7b6dda0a3c868e983f92a7 Mon Sep 17 00:00:00 2001 From: lxdlam Date: Mon, 24 Mar 2025 18:54:34 +0800 Subject: [PATCH 4/4] chore: comment wording fixes --- dspy/primitives/python_interpreter.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dspy/primitives/python_interpreter.py b/dspy/primitives/python_interpreter.py index 291477fcf2..7cb63bf98c 100644 --- a/dspy/primitives/python_interpreter.py +++ b/dspy/primitives/python_interpreter.py @@ -118,8 +118,8 @@ def execute( error_msg = result["error"] error_type = result.get("errorType", "Sandbox Error") if error_type == "FinalAnswer": - # The `FinalAnswer` trick to receive output from the sandboxed address - # Don't rabbit hole here, let replace the output field to create a consistent + # The `FinalAnswer` trick to receive output from the sandbox interpreter, + # just simply replace the output with the arguments. result["output"] = result.get("errorArgs", None) elif error_type == "SyntaxError": raise SyntaxError(f"Invalid Python syntax. message: {error_msg}")