Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions dspy/adapters/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .base import Adapter
from .chat_adapter import ChatAdapter
from .json_adapter import JsonAdapter
from dspy.adapters.base import Adapter
from dspy.adapters.chat_adapter import ChatAdapter
from dspy.adapters.json_adapter import JsonAdapter
43 changes: 28 additions & 15 deletions dspy/adapters/chat_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,18 @@ def parse(self, signature, completion, _parse_values=True):

def format_turn(self, signature, values, role, incomplete=False):
return format_turn(signature, values, role, incomplete)

def format_fields(self, signature, values):
fields_with_values = {
FieldInfoWithName(name=field_name, info=field_info): values.get(
field_name, "Not supplied for this particular example."
)
for field_name, field_info in signature.fields.items()
if field_name in values
}

return format_fields(fields_with_values)



def format_blob(blob):
Expand Down Expand Up @@ -228,21 +240,22 @@ def format_turn(signature: SignatureMeta, values: Dict[str, Any], role, incomple
content.append(formatted_fields)

if role == "user":
# def type_info(v):
# return f" (must be formatted as a valid Python {get_annotation_name(v.annotation)})" \
# if v.annotation is not str else ""
#
# content.append(
# "Respond with the corresponding output fields, starting with the field "
# + ", then ".join(f"`[[ ## {f} ## ]]`{type_info(v)}" for f, v in signature.output_fields.items())
# + ", and then ending with the marker for `[[ ## completed ## ]]`."
# )

content.append(
"Respond with the corresponding output fields, starting with the field "
+ ", then ".join(f"`{f}`" for f in signature.output_fields)
+ ", and then ending with the marker for `completed`."
)
def type_info(v):
return f" (must be formatted as a valid Python {get_annotation_name(v.annotation)})" \
if v.annotation is not str else ""

if not incomplete:
content.append(
"Respond with the corresponding output fields, starting with the field "
+ ", then ".join(f"`[[ ## {f} ## ]]`{type_info(v)}" for f, v in signature.output_fields.items())
+ ", and then ending with the marker for `[[ ## completed ## ]]`."
)

# content.append(
# "Respond with the corresponding output fields, starting with the field "
# + ", then ".join(f"`{f}`" for f in signature.output_fields)
# + ", and then ending with the marker for `completed`."
# )

return {"role": role, "content": "\n\n".join(content).strip()}

Expand Down
36 changes: 30 additions & 6 deletions dspy/adapters/json_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,18 @@ def parse(self, signature, completion, _parse_values=True):

def format_turn(self, signature, values, role, incomplete=False):
return format_turn(signature, values, role, incomplete)

def format_fields(self, signature, values):
fields_with_values = {
FieldInfoWithName(name=field_name, info=field_info): values.get(
field_name, "Not supplied for this particular example."
)
for field_name, field_info in signature.fields.items()
if field_name in values
}

return format_fields(role='user', fields_with_values=fields_with_values)



def parse_value(value, annotation):
Expand Down Expand Up @@ -168,7 +180,7 @@ def _format_field_value(field_info: FieldInfo, value: Any) -> str:



def format_fields(fields_with_values: Dict[FieldInfoWithName, Any]) -> str:
def format_fields(role: str, fields_with_values: Dict[FieldInfoWithName, Any]) -> str:
"""
Formats the values of the specified fields according to the field's DSPy type (input or output),
annotation (e.g. str, int, etc.), and the type of the value itself. Joins the formatted values
Expand All @@ -180,6 +192,13 @@ def format_fields(fields_with_values: Dict[FieldInfoWithName, Any]) -> str:
Returns:
The joined formatted values of the fields, represented as a string.
"""

if role == "assistant":
d = fields_with_values.items()
d = {k.name: _serialize_for_json(v) for k, v in d}

return json.dumps(_serialize_for_json(d), indent=2)

output = []
for field, field_value in fields_with_values.items():
formatted_field_value = _format_field_value(field_info=field.info, value=field_value)
Expand Down Expand Up @@ -219,6 +238,7 @@ def format_turn(signature: SignatureMeta, values: Dict[str, Any], role, incomple
raise ValueError(f"Expected {field_names} but got {values.keys()}")

formatted_fields = format_fields(
role=role,
fields_with_values={
FieldInfoWithName(name=field_name, info=field_info): values.get(
field_name, "Not supplied for this particular example."
Expand All @@ -233,6 +253,7 @@ def type_info(v):
return f" (must be formatted as a valid Python {get_annotation_name(v.annotation)})" \
if v.annotation is not str else ""

# TODO: Consider if not incomplete:
content.append(
"Respond with a JSON object in the following order of fields: "
+ ", then ".join(f"`{f}`{type_info(v)}" for f, v in signature.output_fields.items())
Expand Down Expand Up @@ -269,7 +290,7 @@ def prepare_instructions(signature: SignatureMeta):
parts = []
parts.append("Your input fields are:\n" + enumerate_fields(signature.input_fields))
parts.append("Your output fields are:\n" + enumerate_fields(signature.output_fields))
# parts.append("All interactions will be structured in the following way, with the appropriate values filled in.")
parts.append("All interactions will be structured in the following way, with the appropriate values filled in.")

def field_metadata(field_name, field_info):
type_ = field_info.annotation
Expand All @@ -291,16 +312,19 @@ def field_metadata(field_name, field_info):
desc = (" " * 8) + f"# note: the value you produce {desc}" if desc else ""
return f"{{{field_name}}}{desc}"

def format_signature_fields_for_instructions(fields: Dict[str, FieldInfo]):
def format_signature_fields_for_instructions(role, fields: Dict[str, FieldInfo]):
return format_fields(
role=role,
fields_with_values={
FieldInfoWithName(name=field_name, info=field_info): field_metadata(field_name, field_info)
for field_name, field_info in fields.items()
}
)

parts.append(format_signature_fields_for_instructions(signature.input_fields))
parts.append(format_signature_fields_for_instructions(signature.output_fields))

parts.append("Inputs will have the following structure:")
parts.append(format_signature_fields_for_instructions('user', signature.input_fields))
parts.append("Outputs will be a JSON object with the following fields.")
parts.append(format_signature_fields_for_instructions('assistant', signature.output_fields))
# parts.append(format_fields({BuiltInCompletedOutputFieldInfo: ""}))

instructions = textwrap.dedent(signature.instructions)
Expand Down
193 changes: 82 additions & 111 deletions dspy/predict/react.py
Original file line number Diff line number Diff line change
@@ -1,130 +1,101 @@
import dsp
import dspy
from dspy.signatures.signature import ensure_signature

from ..primitives.program import Module
from .predict import Predict
import inspect

# TODO: Simplify a lot.
# TODO: Divide Action and Action Input like langchain does for ReAct.
from pydantic import BaseModel
from dspy.primitives.program import Module
from dspy.signatures.signature import ensure_signature
from dspy.adapters.json_adapter import get_annotation_name
from typing import Callable, Any, get_type_hints, get_origin, Literal

# TODO: There's a lot of value in having a stopping condition in the LM calls at `\n\nObservation:`
class Tool:
def __init__(self, func: Callable, name: str = None, desc: str = None, args: dict[str, Any] = None):
annotations_func = func if inspect.isfunction(func) else func.__call__
self.func = func
self.name = name or getattr(func, '__name__', type(func).__name__)
self.desc = desc or getattr(func, '__doc__', None) or getattr(annotations_func, '__doc__', "No description")
self.args = {
k: v.schema() if isinstance((origin := get_origin(v) or v), type) and issubclass(origin, BaseModel)
else get_annotation_name(v)
for k, v in (args or get_type_hints(annotations_func)).items() if k != 'return'
}

# TODO [NEW]: When max_iters is about to be reached, reduce the set of available actions to only the Finish action.
def __call__(self, *args, **kwargs):
return self.func(*args, **kwargs)


class ReAct(Module):
def __init__(self, signature, max_iters=5, num_results=3, tools=None):
super().__init__()
def __init__(self, signature, tools: list[Callable], max_iters=5):
"""
Tools is either a list of functions, callable classes, or dspy.Tool instances.
"""

self.signature = signature = ensure_signature(signature)
self.max_iters = max_iters

self.tools = tools or [dspy.Retrieve(k=num_results)]
self.tools = {tool.name: tool for tool in self.tools}

self.input_fields = self.signature.input_fields
self.output_fields = self.signature.output_fields

assert len(self.output_fields) == 1, "ReAct only supports one output field."
tools = [t if isinstance(t, Tool) or hasattr(t, 'input_variable') else Tool(t) for t in tools]
tools = {tool.name: tool for tool in tools}

inputs_ = ", ".join([f"`{k}`" for k in self.input_fields.keys()])
outputs_ = ", ".join([f"`{k}`" for k in self.output_fields.keys()])
inputs_ = ", ".join([f"`{k}`" for k in signature.input_fields.keys()])
outputs_ = ", ".join([f"`{k}`" for k in signature.output_fields.keys()])
instr = [f"{signature.instructions}\n"] if signature.instructions else []

instr = []

if self.signature.instructions is not None:
instr.append(f"{self.signature.instructions}\n")

instr.extend([
f"You will be given {inputs_} and you will respond with {outputs_}.\n",
"To do this, you will interleave Thought, Action, and Observation steps.\n",
"Thought can reason about the current situation, and Action can be the following types:\n",
f"You will be given {inputs_} and your goal is to finish with {outputs_}.\n",
"To do this, you will interleave Thought, Tool Name, and Tool Args, and receive a resulting Observation.\n",
"Thought can reason about the current situation, and Tool Name can be the following types:\n",
])

self.tools["Finish"] = dspy.Example(
name="Finish",
input_variable=outputs_.strip("`"),
desc=f"returns the final {outputs_} and finishes the task",
finish_desc = f"Signals that the final outputs, i.e. {outputs_}, are now available and marks the task as complete."
finish_args = {} #k: v.annotation for k, v in signature.output_fields.items()}
tools["finish"] = Tool(func=lambda **kwargs: kwargs, name="finish", desc=finish_desc, args=finish_args)

for idx, tool in enumerate(tools.values()):
desc = tool.desc.replace("\n", " ")
args = tool.args if hasattr(tool, 'args') else str({tool.input_variable: str})
desc = f"whose description is <desc>{desc}</desc>. It takes arguments {args} in JSON format."
instr.append(f"({idx+1}) {tool.name}, {desc}")

signature_ = (
dspy.Signature({**signature.input_fields}, "\n".join(instr))
.append("trajectory", dspy.InputField(), type_=str)
.append("next_thought", dspy.OutputField(), type_=str)
.append("next_tool_name", dspy.OutputField(), type_=Literal[tuple(tools.keys())])
.append("next_tool_args", dspy.OutputField(), type_=dict[str, Any])
)

for idx, tool in enumerate(self.tools):
tool = self.tools[tool]
instr.append(
f"({idx+1}) {tool.name}[{tool.input_variable}], which {tool.desc}",
)

instr = "\n".join(instr)
self.react = [
Predict(dspy.Signature(self._generate_signature(i), instr))
for i in range(1, max_iters + 1)
]

def _generate_signature(self, iters):
signature_dict = {}
for key, val in self.input_fields.items():
signature_dict[key] = val

for j in range(1, iters + 1):
IOField = dspy.OutputField if j == iters else dspy.InputField

signature_dict[f"Thought_{j}"] = IOField(
prefix=f"Thought {j}:",
desc="next steps to take based on last observation",
)

tool_list = " or ".join(
[
f"{tool.name}[{tool.input_variable}]"
for tool in self.tools.values()
if tool.name != "Finish"
],
)
signature_dict[f"Action_{j}"] = IOField(
prefix=f"Action {j}:",
desc=f"always either {tool_list} or, when done, Finish[<answer>], where <answer> is the answer to the question itself.",
)

if j < iters:
signature_dict[f"Observation_{j}"] = IOField(
prefix=f"Observation {j}:",
desc="observations based on action",
format=dsp.passages2text,
)

return signature_dict

def act(self, output, hop):
try:
action = output[f"Action_{hop+1}"]
action_name, action_val = action.strip().split("\n")[0].split("[", 1)
action_val = action_val.rsplit("]", 1)[0]

if action_name == "Finish":
return action_val

result = self.tools[action_name](action_val) #result must be a str, list, or tuple
# Handle the case where 'passages' attribute is missing
output[f"Observation_{hop+1}"] = getattr(result, "passages", result)

except Exception:
output[f"Observation_{hop+1}"] = (
"Failed to parse action. Bad formatting or incorrect action name."
)
# raise e

def forward(self, **kwargs):
args = {key: kwargs[key] for key in self.input_fields.keys() if key in kwargs}

for hop in range(self.max_iters):
# with dspy.settings.context(show_guidelines=(i <= 2)):
output = self.react[hop](**args)
output[f'Action_{hop + 1}'] = output[f'Action_{hop + 1}'].split('\n')[0]

if action_val := self.act(output, hop):
break
args.update(output)
fallback_signature = (
dspy.Signature({**signature.input_fields, **signature.output_fields})
.append("trajectory", dspy.InputField(), type_=str)
)

observations = [args[key] for key in args if key.startswith("Observation")]
self.tools = tools
self.react = dspy.Predict(signature_)
self.extract = dspy.ChainOfThought(fallback_signature)

def forward(self, **input_args):
trajectory = {}

def format(trajectory_: dict[str, Any], last_iteration: bool):
adapter = dspy.settings.adapter or dspy.ChatAdapter()
blob = adapter.format_fields(dspy.Signature(f"{', '.join(trajectory_.keys())} -> x"), trajectory_)
warning = f"\n\nWarning: The maximum number of iterations ({self.max_iters}) has been reached."
warning += " You must now produce the finish action."
return blob + (warning if last_iteration else "")

for idx in range(self.max_iters):
pred = self.react(**input_args, trajectory=format(trajectory, last_iteration=(idx == self.max_iters-1)))

trajectory[f"thought_{idx}"] = pred.next_thought
trajectory[f"tool_name_{idx}"] = pred.next_tool_name
trajectory[f"tool_args_{idx}"] = pred.next_tool_args

try:
trajectory[f"observation_{idx}"] = self.tools[pred.next_tool_name](**pred.next_tool_args)
except Exception as e:
trajectory[f"observation_{idx}"] = f"Failed to execute: {e}"

if pred.next_tool_name == "finish":
break

# assumes only 1 output field for now - TODO: handling for multiple output fields
return dspy.Prediction(observations=observations, **{list(self.output_fields.keys())[0]: action_val or ""})
extract = self.extract(**input_args, trajectory=format(trajectory, last_iteration=False))
return dspy.Prediction(trajectory=trajectory, **extract)
5 changes: 4 additions & 1 deletion dspy/retrieve/retrieve.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,15 @@ def __call__(self, *args, **kwargs):

def forward(
self,
query_or_queries: Union[str, List[str]],
query_or_queries: Union[str, List[str]] = None,
query: Optional[str] = None,
k: Optional[int] = None,
by_prob: bool = True,
with_metadata: bool = False,
**kwargs,
) -> Union[List[str], Prediction, List[Prediction]]:
query_or_queries = query_or_queries or query

# queries = [query_or_queries] if isinstance(query_or_queries, str) else query_or_queries
# queries = [query.strip().split('\n')[0].strip() for query in queries]

Expand Down
Loading
Loading