Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 25 additions & 43 deletions dspy/predict/program_of_thought.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import re
from typing import Union, Type
import json

import dspy
from dspy.signatures.signature import ensure_signature, Signature
Expand All @@ -10,6 +11,7 @@

logger = logging.getLogger(__name__)


class ProgramOfThought(Module):
"""
A DSPy module that runs Python programs to solve a problem.
Expand Down Expand Up @@ -39,28 +41,6 @@ def __init__(self, signature: Union[str, Type[Signature]], max_iters=3):
self.input_fields = signature.input_fields
self.output_fields = signature.output_fields

assert len(self.output_fields) == 1, "PoT only supports one output field."

self.output_field_name = next(iter(self.output_fields))
inputs_ = ", ".join(
[f"`{field_name}`" for field_name in self.input_fields.keys()],
)
outputs_ = f"`{self.output_field_name}`"

assert len(self.output_fields) == 1, "PoT only supports one output field."

instr = []
instr.append(
f"You will be given {inputs_} and you will respond with {outputs_}.",
)
instr.append(
f"Generating executable Python code that programmatically computes the correct {outputs_}.",
)
instr.append(
f"After you're done with the computation, make sure the last line in your code evaluates to the correct value for {outputs_}.",
)
instr = "\n".join(instr)

self.code_generate = dspy.ChainOfThought(
dspy.Signature(
self._generate_signature("generate").fields,
Expand All @@ -79,8 +59,7 @@ def __init__(self, signature: Union[str, Type[Signature]], max_iters=3):
self._generate_instruction("answer"),
),
)
# Currently, the interpreter class checks the deno availability at execution time.
# We may consider checking it at the initialization time for better instruction.
# It will raises exception when dspy cannot find available deno instance by now.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment will be fulfilled by a new PR later.

self.interpreter = PythonInterpreter()

def _generate_signature(self, mode):
Expand Down Expand Up @@ -119,25 +98,28 @@ def _generate_signature(self, mode):
prefix="Code Output:",
desc="output of previously-generated python code",
),
self.output_field_name: self.signature.fields[self.output_field_name],
},
}
| self.signature.output_fields,
}
signature_dict.update(fields_for_mode[mode])
return dspy.Signature(signature_dict)

def _generate_instruction(self, mode):
mode_inputs = ", ".join(
[
f"`{field_name}`"
for field_name in self._generate_signature(mode).input_fields
],
[f"`{field_name}`" for field_name in self._generate_signature(mode).input_fields],
)
mode_outputs = ", ".join(
[f"`{field_name}`" for field_name in self._generate_signature(mode).output_fields],
)
final_outputs = ", ".join(
[f"`{field_name}`" for field_name in self.output_fields],
)
mode_outputs = f"`{self.output_field_name}`"
if mode == "generate":
instr = [
f"You will be given {mode_inputs} and you will respond with {mode_outputs}.",
f"Generating executable Python code that programmatically computes the correct {mode_outputs}.",
f"After you're done with the computation, make sure the last line in your code evaluates to the correct value for {mode_outputs}.",
"After you're done with the computation and think you have the answer, make sure to provide your answer by calling the preloaded function `final_answer()`.",
f'You should structure your answer in a dict object, like {{"field_a": answer_a, ...}}, evaluates to the correct value mapping for {final_outputs}.',
]
elif mode == "regenerate":
instr = [
Expand All @@ -151,11 +133,8 @@ def _generate_instruction(self, mode):

return "\n".join(instr)


def _parse_code(self, code_data):
code = (
code_data.get("generated_code", "").split("---", 1)[0].split("\n\n\n", 1)[0]
)
code = code_data.get("generated_code", "").split("---", 1)[0].split("\n\n\n", 1)[0]
code_match = re.search(r"```python[ \n](.*?)[ \n]```?", code, re.DOTALL)
code_block = (code_match.group(1) if code_match else code).replace("\\n", "\n")
if not code_block:
Expand All @@ -168,10 +147,14 @@ def _parse_code(self, code_data):
code_block += "\n" + last_line_match.group(1)
else:
code_block = re.sub(
r"([a-zA-Z_]\w* *=.*?)(?=[a-zA-Z_]\w* *=)", r"\1\n", code_block,
r"([a-zA-Z_]\w* *=.*?)(?=[a-zA-Z_]\w* *=)",
r"\1\n",
code_block,
)
code_block = re.sub(
r"([a-zA-Z_]\w* *=.*?)([a-zA-Z_]\w*)$", r"\1\n\2", code_block,
r"([a-zA-Z_]\w* *=.*?)([a-zA-Z_]\w*)$",
r"\1\n\2",
code_block,
)
return code_block, None

Expand All @@ -181,17 +164,16 @@ def _execute_code(self, code):
"""
if not code:
return None, "Error: Empty code before execution."

try:
output = str(self.interpreter.execute(code))
# Since it's more complex structure now, just blindly use json to represents all.
output = json.dumps(self.interpreter.execute(code))
return output, None
except Exception as e:
return None, str(e)

def forward(self, **kwargs):
input_kwargs = {
field_name: kwargs[field_name] for field_name in self.input_fields
}
input_kwargs = {field_name: kwargs[field_name] for field_name in self.input_fields}
code_data = self.code_generate(**input_kwargs)
output = None
code, error = self._parse_code(code_data)
Expand Down
12 changes: 7 additions & 5 deletions dspy/primitives/python_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,14 +117,16 @@ def execute(
if "error" in result:
error_msg = result["error"]
error_type = result.get("errorType", "Sandbox Error")
if error_type == "SyntaxError":
if error_type == "FinalAnswer":
# The `FinalAnswer` trick to receive output from the sandbox interpreter,
# just simply replace the output with the arguments.
result["output"] = result.get("errorArgs", None)
elif error_type == "SyntaxError":
raise SyntaxError(f"Invalid Python syntax. message: {error_msg}")
elif error_type == "FinalAnswer":
return result.get("errorArgs")
else:
raise InterpreterError(f"{error_type}: {result.get('errorArgs') or error_msg}")

# If there's no error, return the "output" field
# If there's no error or got `FinalAnswer`, return the "output" field
return result.get("output", None)

def __enter__(self):
Expand Down Expand Up @@ -153,4 +155,4 @@ def shutdown(self) -> None:
self.deno_process.stdin.flush()
self.deno_process.stdin.close()
self.deno_process.wait()
self.deno_process = None
self.deno_process = None
75 changes: 67 additions & 8 deletions tests/predict/test_program_of_thought.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,33 @@
# This test suite requires deno to be installed. Please install deno following https://docs.deno.com/runtime/getting_started/installation/
is_deno_available = shutil.which("deno") is not None


class BasicQA(Signature):
question = dspy.InputField()
answer = dspy.OutputField(desc="often between 1 and 5 words")


@pytest.mark.skipif(not is_deno_available, reason="Deno is not installed or not in PATH")
def test_pot_code_generation():
lm = DummyLM(
[
{
"reasoning": "Reason_A",
"generated_code": "```python\nresult = 1+1\nfinal_answer({'answer': result})\n```",
},
{"reasoning": "Reason_B", "answer": "2"},
]
)
dspy.settings.configure(lm=lm)
pot = ProgramOfThought(BasicQA)
res = pot(question="What is 1+1?")
assert res.answer == "2"
assert pot.interpreter.deno_process is None


# This test ensures the old finetuned saved models still work
@pytest.mark.skipif(not is_deno_available, reason="Deno is not installed or not in PATH")
def test_old_style_pot():
lm = DummyLM(
[
{"reasoning": "Reason_A", "generated_code": "```python\nresult = 1+1\n```"},
Expand All @@ -29,17 +49,47 @@ def test_pot_code_generation():
assert pot.interpreter.deno_process is None


class ExtremumFinder(Signature):
input_list = dspy.InputField()
maximum = dspy.OutputField(desc="The maximum of the given numbers")
minimum = dspy.OutputField(desc="The minimum of the given numbers")


@pytest.mark.skipif(not is_deno_available, reason="Deno is not installed or not in PATH")
def test_pot_support_multiple_fields():
lm = DummyLM(
[
{
"reasoning": "Reason_A",
"generated_code": "```python\nmaximum = 6\nminimum = 2\nfinal_answer({'maximum': maximum, 'minimum': minimum})\n```",
},
{"reasoning": "Reason_B", "maximum": "6", "minimum": "2"},
]
)
dspy.settings.configure(lm=lm)
pot = ProgramOfThought(ExtremumFinder)
res = pot(input_list="2, 3, 5, 6")
assert res.maximum == "6"
assert res.minimum == "2"
assert pot.interpreter.deno_process is None


@pytest.mark.skipif(not is_deno_available, reason="Deno is not installed or not in PATH")
def test_pot_code_generation_with_one_error():
lm = DummyLM(
[
{"reasoning": "Reason_A", "generated_code": "```python\nresult = 1+0/0\n```"},
{"reasoning": "Reason_B", "generated_code": "```python\nresult = 1+1\n```"},
{
"reasoning": "Reason_A",
"generated_code": "```python\nresult = 1+0/0\nfinal_answer({'answer': result})\n```",
},
{
"reasoning": "Reason_B",
"generated_code": "```python\nresult = 1+1\nfinal_answer({'answer': result})\n```",
},
{"reasoning": "Reason_C", "answer": "2"},
]
)
dspy.settings.configure(lm=lm)

pot = ProgramOfThought(BasicQA)
res = pot(question="What is 1+1?")
assert res.answer == "2"
Expand All @@ -51,8 +101,12 @@ def test_pot_code_generation_persistent_errors():
max_iters = 3
lm = DummyLM(
[
{"reasoning": "Reason_A", "generated_code": "```python\nresult = 1+0/0\n```"},
] * max_iters
{
"reasoning": "Reason_A",
"generated_code": "```python\nresult = 1+0/0\nfinal_answer({'answer': result})\n```",
},
]
* max_iters
)
dspy.settings.configure(lm=lm)

Expand All @@ -67,11 +121,16 @@ def test_pot_code_parse_error():
lm = DummyLM(
[
{"reasoning": "Reason_A", "generated_code": "```python\ninvalid=python=code\n```"},
] * max_iters
]
* max_iters
)
dspy.settings.configure(lm=lm)

pot = ProgramOfThought(BasicQA, max_iters=max_iters)
with patch("dspy.predict.program_of_thought.ProgramOfThought._execute_code") as mock_execute_code, pytest.raises(RuntimeError, match="Max hops reached. Failed to run ProgramOfThought: Error: Code format is not correct."):
with (
patch("dspy.predict.program_of_thought.ProgramOfThought._execute_code") as mock_execute_code,
pytest.raises(
RuntimeError, match="Max hops reached. Failed to run ProgramOfThought: Error: Code format is not correct."
),
):
pot(question="What is 1+1?")
mock_execute_code.assert_not_called()
Loading