diff --git a/dspy/predict/program_of_thought.py b/dspy/predict/program_of_thought.py index 4214861160..28ade9345c 100644 --- a/dspy/predict/program_of_thought.py +++ b/dspy/predict/program_of_thought.py @@ -1,6 +1,7 @@ import logging import re from typing import Union, Type +import json import dspy from dspy.signatures.signature import ensure_signature, Signature @@ -10,6 +11,7 @@ logger = logging.getLogger(__name__) + class ProgramOfThought(Module): """ A DSPy module that runs Python programs to solve a problem. @@ -39,28 +41,6 @@ def __init__(self, signature: Union[str, Type[Signature]], max_iters=3): self.input_fields = signature.input_fields self.output_fields = signature.output_fields - assert len(self.output_fields) == 1, "PoT only supports one output field." - - self.output_field_name = next(iter(self.output_fields)) - inputs_ = ", ".join( - [f"`{field_name}`" for field_name in self.input_fields.keys()], - ) - outputs_ = f"`{self.output_field_name}`" - - assert len(self.output_fields) == 1, "PoT only supports one output field." - - instr = [] - instr.append( - f"You will be given {inputs_} and you will respond with {outputs_}.", - ) - instr.append( - f"Generating executable Python code that programmatically computes the correct {outputs_}.", - ) - instr.append( - f"After you're done with the computation, make sure the last line in your code evaluates to the correct value for {outputs_}.", - ) - instr = "\n".join(instr) - self.code_generate = dspy.ChainOfThought( dspy.Signature( self._generate_signature("generate").fields, @@ -79,8 +59,7 @@ def __init__(self, signature: Union[str, Type[Signature]], max_iters=3): self._generate_instruction("answer"), ), ) - # Currently, the interpreter class checks the deno availability at execution time. - # We may consider checking it at the initialization time for better instruction. + # It will raises exception when dspy cannot find available deno instance by now. self.interpreter = PythonInterpreter() def _generate_signature(self, mode): @@ -119,25 +98,28 @@ def _generate_signature(self, mode): prefix="Code Output:", desc="output of previously-generated python code", ), - self.output_field_name: self.signature.fields[self.output_field_name], - }, + } + | self.signature.output_fields, } signature_dict.update(fields_for_mode[mode]) return dspy.Signature(signature_dict) def _generate_instruction(self, mode): mode_inputs = ", ".join( - [ - f"`{field_name}`" - for field_name in self._generate_signature(mode).input_fields - ], + [f"`{field_name}`" for field_name in self._generate_signature(mode).input_fields], + ) + mode_outputs = ", ".join( + [f"`{field_name}`" for field_name in self._generate_signature(mode).output_fields], + ) + final_outputs = ", ".join( + [f"`{field_name}`" for field_name in self.output_fields], ) - mode_outputs = f"`{self.output_field_name}`" if mode == "generate": instr = [ f"You will be given {mode_inputs} and you will respond with {mode_outputs}.", f"Generating executable Python code that programmatically computes the correct {mode_outputs}.", - f"After you're done with the computation, make sure the last line in your code evaluates to the correct value for {mode_outputs}.", + "After you're done with the computation and think you have the answer, make sure to provide your answer by calling the preloaded function `final_answer()`.", + f'You should structure your answer in a dict object, like {{"field_a": answer_a, ...}}, evaluates to the correct value mapping for {final_outputs}.', ] elif mode == "regenerate": instr = [ @@ -151,11 +133,8 @@ def _generate_instruction(self, mode): return "\n".join(instr) - def _parse_code(self, code_data): - code = ( - code_data.get("generated_code", "").split("---", 1)[0].split("\n\n\n", 1)[0] - ) + code = code_data.get("generated_code", "").split("---", 1)[0].split("\n\n\n", 1)[0] code_match = re.search(r"```python[ \n](.*?)[ \n]```?", code, re.DOTALL) code_block = (code_match.group(1) if code_match else code).replace("\\n", "\n") if not code_block: @@ -168,10 +147,14 @@ def _parse_code(self, code_data): code_block += "\n" + last_line_match.group(1) else: code_block = re.sub( - r"([a-zA-Z_]\w* *=.*?)(?=[a-zA-Z_]\w* *=)", r"\1\n", code_block, + r"([a-zA-Z_]\w* *=.*?)(?=[a-zA-Z_]\w* *=)", + r"\1\n", + code_block, ) code_block = re.sub( - r"([a-zA-Z_]\w* *=.*?)([a-zA-Z_]\w*)$", r"\1\n\2", code_block, + r"([a-zA-Z_]\w* *=.*?)([a-zA-Z_]\w*)$", + r"\1\n\2", + code_block, ) return code_block, None @@ -181,17 +164,16 @@ def _execute_code(self, code): """ if not code: return None, "Error: Empty code before execution." - + try: - output = str(self.interpreter.execute(code)) + # Since it's more complex structure now, just blindly use json to represents all. + output = json.dumps(self.interpreter.execute(code)) return output, None except Exception as e: return None, str(e) def forward(self, **kwargs): - input_kwargs = { - field_name: kwargs[field_name] for field_name in self.input_fields - } + input_kwargs = {field_name: kwargs[field_name] for field_name in self.input_fields} code_data = self.code_generate(**input_kwargs) output = None code, error = self._parse_code(code_data) diff --git a/dspy/primitives/python_interpreter.py b/dspy/primitives/python_interpreter.py index 7aca51b074..7cb63bf98c 100644 --- a/dspy/primitives/python_interpreter.py +++ b/dspy/primitives/python_interpreter.py @@ -117,14 +117,16 @@ def execute( if "error" in result: error_msg = result["error"] error_type = result.get("errorType", "Sandbox Error") - if error_type == "SyntaxError": + if error_type == "FinalAnswer": + # The `FinalAnswer` trick to receive output from the sandbox interpreter, + # just simply replace the output with the arguments. + result["output"] = result.get("errorArgs", None) + elif error_type == "SyntaxError": raise SyntaxError(f"Invalid Python syntax. message: {error_msg}") - elif error_type == "FinalAnswer": - return result.get("errorArgs") else: raise InterpreterError(f"{error_type}: {result.get('errorArgs') or error_msg}") - # If there's no error, return the "output" field + # If there's no error or got `FinalAnswer`, return the "output" field return result.get("output", None) def __enter__(self): @@ -153,4 +155,4 @@ def shutdown(self) -> None: self.deno_process.stdin.flush() self.deno_process.stdin.close() self.deno_process.wait() - self.deno_process = None + self.deno_process = None \ No newline at end of file diff --git a/tests/predict/test_program_of_thought.py b/tests/predict/test_program_of_thought.py index aeef3aae1e..4932657892 100644 --- a/tests/predict/test_program_of_thought.py +++ b/tests/predict/test_program_of_thought.py @@ -9,6 +9,7 @@ # This test suite requires deno to be installed. Please install deno following https://docs.deno.com/runtime/getting_started/installation/ is_deno_available = shutil.which("deno") is not None + class BasicQA(Signature): question = dspy.InputField() answer = dspy.OutputField(desc="often between 1 and 5 words") @@ -16,6 +17,25 @@ class BasicQA(Signature): @pytest.mark.skipif(not is_deno_available, reason="Deno is not installed or not in PATH") def test_pot_code_generation(): + lm = DummyLM( + [ + { + "reasoning": "Reason_A", + "generated_code": "```python\nresult = 1+1\nfinal_answer({'answer': result})\n```", + }, + {"reasoning": "Reason_B", "answer": "2"}, + ] + ) + dspy.settings.configure(lm=lm) + pot = ProgramOfThought(BasicQA) + res = pot(question="What is 1+1?") + assert res.answer == "2" + assert pot.interpreter.deno_process is None + + +# This test ensures the old finetuned saved models still work +@pytest.mark.skipif(not is_deno_available, reason="Deno is not installed or not in PATH") +def test_old_style_pot(): lm = DummyLM( [ {"reasoning": "Reason_A", "generated_code": "```python\nresult = 1+1\n```"}, @@ -29,17 +49,47 @@ def test_pot_code_generation(): assert pot.interpreter.deno_process is None +class ExtremumFinder(Signature): + input_list = dspy.InputField() + maximum = dspy.OutputField(desc="The maximum of the given numbers") + minimum = dspy.OutputField(desc="The minimum of the given numbers") + + +@pytest.mark.skipif(not is_deno_available, reason="Deno is not installed or not in PATH") +def test_pot_support_multiple_fields(): + lm = DummyLM( + [ + { + "reasoning": "Reason_A", + "generated_code": "```python\nmaximum = 6\nminimum = 2\nfinal_answer({'maximum': maximum, 'minimum': minimum})\n```", + }, + {"reasoning": "Reason_B", "maximum": "6", "minimum": "2"}, + ] + ) + dspy.settings.configure(lm=lm) + pot = ProgramOfThought(ExtremumFinder) + res = pot(input_list="2, 3, 5, 6") + assert res.maximum == "6" + assert res.minimum == "2" + assert pot.interpreter.deno_process is None + + @pytest.mark.skipif(not is_deno_available, reason="Deno is not installed or not in PATH") def test_pot_code_generation_with_one_error(): lm = DummyLM( [ - {"reasoning": "Reason_A", "generated_code": "```python\nresult = 1+0/0\n```"}, - {"reasoning": "Reason_B", "generated_code": "```python\nresult = 1+1\n```"}, + { + "reasoning": "Reason_A", + "generated_code": "```python\nresult = 1+0/0\nfinal_answer({'answer': result})\n```", + }, + { + "reasoning": "Reason_B", + "generated_code": "```python\nresult = 1+1\nfinal_answer({'answer': result})\n```", + }, {"reasoning": "Reason_C", "answer": "2"}, ] ) dspy.settings.configure(lm=lm) - pot = ProgramOfThought(BasicQA) res = pot(question="What is 1+1?") assert res.answer == "2" @@ -51,8 +101,12 @@ def test_pot_code_generation_persistent_errors(): max_iters = 3 lm = DummyLM( [ - {"reasoning": "Reason_A", "generated_code": "```python\nresult = 1+0/0\n```"}, - ] * max_iters + { + "reasoning": "Reason_A", + "generated_code": "```python\nresult = 1+0/0\nfinal_answer({'answer': result})\n```", + }, + ] + * max_iters ) dspy.settings.configure(lm=lm) @@ -67,11 +121,16 @@ def test_pot_code_parse_error(): lm = DummyLM( [ {"reasoning": "Reason_A", "generated_code": "```python\ninvalid=python=code\n```"}, - ] * max_iters + ] + * max_iters ) dspy.settings.configure(lm=lm) - pot = ProgramOfThought(BasicQA, max_iters=max_iters) - with patch("dspy.predict.program_of_thought.ProgramOfThought._execute_code") as mock_execute_code, pytest.raises(RuntimeError, match="Max hops reached. Failed to run ProgramOfThought: Error: Code format is not correct."): + with ( + patch("dspy.predict.program_of_thought.ProgramOfThought._execute_code") as mock_execute_code, + pytest.raises( + RuntimeError, match="Max hops reached. Failed to run ProgramOfThought: Error: Code format is not correct." + ), + ): pot(question="What is 1+1?") mock_execute_code.assert_not_called()