Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 61 additions & 26 deletions dspy/adapters/chat_adapter.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,29 @@
import re
import ast
import json
import re
import enum
import inspect
import pydantic
import textwrap
from typing import Any, Dict, KeysView, List, Literal, NamedTuple, get_args, get_origin

import pydantic
from pydantic import TypeAdapter
from pydantic.fields import FieldInfo
from typing import Any, Dict, KeysView, List, Literal, NamedTuple, get_args, get_origin

from dspy.adapters.base import Adapter
from ..signatures.field import OutputField
from ..signatures.signature import SignatureMeta
from ..signatures.utils import get_dspy_field_type
from .base import Adapter

field_header_pattern = re.compile(r"\[\[ ## (\w+) ## \]\]")


class FieldInfoWithName(NamedTuple):
"""
A tuple containing a field name and its corresponding FieldInfo object.
"""

name: str
info: FieldInfo


# Built-in field indicating that a chat turn (i.e. a user or assistant reply to a chat
# thread) has been completed.
# Built-in field indicating that a chat turn has been completed.
BuiltInCompletedOutputFieldInfo = FieldInfoWithName(name="completed", info=OutputField())


Expand Down Expand Up @@ -114,6 +111,16 @@ def format_input_list_field_value(value: List[Any]) -> str:
return "\n".join([f"[{idx+1}] {format_blob(txt)}" for idx, txt in enumerate(value)])


def _serialize_for_json(value):
if isinstance(value, pydantic.BaseModel):
return value.model_dump()
elif isinstance(value, list):
return [_serialize_for_json(item) for item in value]
elif isinstance(value, dict):
return {key: _serialize_for_json(val) for key, val in value.items()}
else:
return value

def _format_field_value(field_info: FieldInfo, value: Any) -> str:
"""
Formats the value of the specified field according to the field's DSPy type (input or output),
Expand All @@ -125,24 +132,17 @@ def _format_field_value(field_info: FieldInfo, value: Any) -> str:
Returns:
The formatted value of the field, represented as a string.
"""
dspy_field_type: Literal["input", "output"] = get_dspy_field_type(field_info)
if isinstance(value, list):
if dspy_field_type == "input" or field_info.annotation is str:
# If the field is an input field or has no special type requirements, format it as
# numbered list so that it's organized in a way suitable for presenting long context
# to an LLM (i.e. not JSON)
return format_input_list_field_value(value)
else:
# If the field is an output field that has strict parsing requirements, format the
# value as a stringified JSON Array. This ensures that downstream routines can parse
# the field value correctly using methods from the `ujson` or `json` packages.
return json.dumps(value)
elif isinstance(value, pydantic.BaseModel):
return value.model_dump_json()

if isinstance(value, list) and field_info.annotation is str:
# If the field has no special type requirements, format it as a nice numbere list for the LM.
return format_input_list_field_value(value)
elif isinstance(value, pydantic.BaseModel) or isinstance(value, dict) or isinstance(value, list):
return json.dumps(_serialize_for_json(value))
else:
return str(value)



def format_fields(fields_with_values: Dict[FieldInfoWithName, Any]) -> str:
"""
Formats the values of the specified fields according to the field's DSPy type (input or output),
Expand All @@ -166,15 +166,20 @@ def format_fields(fields_with_values: Dict[FieldInfoWithName, Any]) -> str:
def parse_value(value, annotation):
if annotation is str:
return str(value)

parsed_value = value
if isinstance(value, str):

if isinstance(annotation, enum.EnumMeta):
parsed_value = annotation[value]
elif isinstance(value, str):
try:
parsed_value = json.loads(value)
except json.JSONDecodeError:
try:
parsed_value = ast.literal_eval(value)
except (ValueError, SyntaxError):
parsed_value = value

return TypeAdapter(annotation).validate_python(parsed_value)


Expand Down Expand Up @@ -222,6 +227,16 @@ def format_turn(signature: SignatureMeta, values: Dict[str, Any], role, incomple
content.append(formatted_fields)

if role == "user":
# def type_info(v):
# return f" (must be formatted as a valid Python {get_annotation_name(v.annotation)})" \
# if v.annotation is not str else ""
#
# content.append(
# "Respond with the corresponding output fields, starting with the field "
# + ", then ".join(f"`[[ ## {f} ## ]]`{type_info(v)}" for f, v in signature.output_fields.items())
# + ", and then ending with the marker for `[[ ## completed ## ]]`."
# )

content.append(
"Respond with the corresponding output fields, starting with the field "
+ ", then ".join(f"`{f}`" for f in signature.output_fields)
Expand Down Expand Up @@ -260,10 +275,30 @@ def prepare_instructions(signature: SignatureMeta):
parts.append("Your output fields are:\n" + enumerate_fields(signature.output_fields))
parts.append("All interactions will be structured in the following way, with the appropriate values filled in.")

def field_metadata(field_name, field_info):
type_ = field_info.annotation

if get_dspy_field_type(field_info) == 'input' or type_ is str:
desc = ""
elif type_ is bool:
desc = "must be True or False"
elif type_ in (int, float):
desc = f"must be a single {type_.__name__} value"
elif inspect.isclass(type_) and issubclass(type_, enum.Enum):
desc= f"must be one of: {'; '.join(type_.__members__)}"
elif hasattr(type_, '__origin__') and type_.__origin__ is Literal:
desc = f"must be one of: {'; '.join([str(x) for x in type_.__args__])}"
else:
desc = "must be pareseable according to the following JSON schema: "
desc += json.dumps(pydantic.TypeAdapter(type_).json_schema())

desc = (" " * 8) + f"# note: the value you produce {desc}" if desc else ""
return f"{{{field_name}}}{desc}"

def format_signature_fields_for_instructions(fields: Dict[str, FieldInfo]):
return format_fields(
fields_with_values={
FieldInfoWithName(name=field_name, info=field_info): f"{{{field_name}}}"
FieldInfoWithName(name=field_name, info=field_info): field_metadata(field_name, field_info)
for field_name, field_info in fields.items()
}
)
Expand Down
Loading