Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dspy/predict/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
from .chain_of_thought_with_hint import ChainOfThoughtWithHint
from .react import ReAct
from .aggregation import majority
from .program_of_thought import ProgramOfThought
114 changes: 114 additions & 0 deletions dspy/predict/program_of_thought.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import dsp
import dspy
from ..primitives.program import Module
from ..primitives.python_interpreter import CodePrompt, PythonInterpreter
import re

class ProgramOfThought(Module):
def __init__(self, signature, max_iters=3):
super().__init__()
self.signature = signature = dspy.Predict(signature).signature
self.max_iters = max_iters

self.input_fields = signature.input_fields()
self.output_fields = signature.output_fields()

inputs_ = ', '.join([f"`{field_name}`" for field_name in self.input_fields.keys()])
outputs_ = ', '.join([f"`{field_name}`" for field_name in self.output_fields.keys()])

assert len(self.output_fields) == 1, "PoT only supports one output field."

instr = []
instr.append(f"You will be given {inputs_} and you will respond with {outputs_}.")
instr.append(f"Generating executable Python code that programmatically computes the correct {outputs_}.")
instr.append(f"After you're done with the computation, make sure the last line in your code evaluates to the correct value for {outputs_}.")
instr = '\n'.join(instr)

self.code_generate = dspy.ChainOfThought(dsp.Template(self._generate_instruction('generate'), **self._generate_signature('generate')))
self.code_regenerate = dspy.ChainOfThought(dsp.Template(self._generate_instruction('regenerate'), **self._generate_signature('regenerate')))
self.generate_answer = dspy.ChainOfThought(dsp.Template(self._generate_instruction('answer'), **self._generate_signature('answer')))

def _generate_signature(self, mode):
signature_dict = dict(self.input_fields)
fields_for_mode = {
'generate': {
'generated_code': dspy.OutputField(prefix="Code:", desc="python code that answers the question", format=str)
},
'regenerate': {
'previous_code': dspy.InputField(prefix="Previous Code:", desc="previously-generated python code that errored", format=str),
'error': dspy.InputField(prefix="Error:", desc="error message from previously-generated python code"),
'generated_code': dspy.OutputField(prefix="Code:", desc="python code that answers the question", format=str)
},
'answer': {
'final_generated_code': dspy.InputField(prefix="Code:", desc="python code that answers the question", format=str),
'code_output': dspy.InputField(prefix="Code Output:", desc="output of previously-generated python code"),
'answer': self.signature.kwargs["answer"]
}
}
signature_dict.update(fields_for_mode[mode])
return signature_dict

def _generate_instruction(self, mode):
mode_inputs = ', '.join([f"`{field_name}`" for field_name in self._generate_signature(mode).keys() if isinstance(self._generate_signature(mode)[field_name], dspy.InputField)])
mode_outputs = ', '.join([f"`{field_name}`" for field_name in self._generate_signature(mode).keys() if isinstance(self._generate_signature(mode)[field_name], dspy.OutputField)])
if mode == 'generate':
instr = [
f"You will be given {mode_inputs} and you will respond with {mode_outputs}.",
f"Generating executable Python code that programmatically computes the correct {mode_outputs}.",
f"After you're done with the computation, make sure the last line in your code evaluates to the correct value for {mode_outputs}."
]
elif mode == 'regenerate':
instr = [
f"You are given {mode_inputs} due to an error in previous code.",
f"Your task is to correct the error and provide the new {mode_outputs}."
]
else: # mode == 'answer'
instr = [
f"Given the final code {mode_inputs}, provide the final {mode_outputs}."
]

return '\n'.join(instr)

def parse_code(self, code_data):
code = code_data.get('generated_code', '').split('---', 1)[0].split('\n\n\n', 1)[0]
code_match = re.search(r'```python[ \n](.*?)[ \n]```?', code, re.DOTALL)
code_block = (code_match.group(1) if code_match else code).replace('\\n', '\n')
if not code_block:
return code, "Error: Empty code after parsing."
if "\n" not in code_block and code_block.count('=') > 1:
return code, "Error: Code format is not correct."
lines = code_block.split('\n')
last_line_match = re.match(r'^(\w+)\s*=', lines[-1].strip())
if last_line_match and len(lines) > 1:
code_block += '\n' + last_line_match.group(1)
else:
code_block = re.sub(r'([a-zA-Z_]\w* *=.*?)(?=[a-zA-Z_]\w* *=)', r'\1\n', code_block)
code_block = re.sub(r'([a-zA-Z_]\w* *=.*?)([a-zA-Z_]\w*)$', r'\1\n\2', code_block)
return code_block, None

def execute_code(self, code):
if not code:
return code, None, 'Error: Empty code before execution.'
code_prompt = CodePrompt(code, code_type="python")
interpreter = PythonInterpreter(action_space={"print": print})
try:
output = str(code_prompt.execute(interpreter=interpreter)[0])
return code, output, None
except Exception as e:
return code, None, str(e)

def forward(self, **kwargs):
code_data = self.code_generate(question=kwargs["question"])
parsed_code, error = self.parse_code(code_data)
code, output, error = self.execute_code(parsed_code)
hop = 0
while hop < self.max_iters and error:
print('Error in code execution')
code_data = self.code_regenerate(question=kwargs["question"], previous_code=code, error=error)
parsed_code, error = self.parse_code(code_data)
hop += 1
if hop == self.max_iters:
print('Max hops reached. Error persists.')
return None
answer_gen_result = self.generate_answer(question=kwargs["question"], final_generated_code=code, code_output=output)
return answer_gen_result
3 changes: 2 additions & 1 deletion dspy/primitives/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .example import *
from .program import *
from .prediction import *
from .prediction import *
from .python_interpreter import *
Loading