diff --git a/dspy/adapters/chat_adapter.py b/dspy/adapters/chat_adapter.py index fa80fe556f..f2e270d825 100644 --- a/dspy/adapters/chat_adapter.py +++ b/dspy/adapters/chat_adapter.py @@ -1,32 +1,29 @@ +import re import ast import json -import re +import enum +import inspect +import pydantic import textwrap -from typing import Any, Dict, KeysView, List, Literal, NamedTuple, get_args, get_origin -import pydantic from pydantic import TypeAdapter from pydantic.fields import FieldInfo +from typing import Any, Dict, KeysView, List, Literal, NamedTuple, get_args, get_origin +from dspy.adapters.base import Adapter from ..signatures.field import OutputField from ..signatures.signature import SignatureMeta from ..signatures.utils import get_dspy_field_type -from .base import Adapter field_header_pattern = re.compile(r"\[\[ ## (\w+) ## \]\]") class FieldInfoWithName(NamedTuple): - """ - A tuple containing a field name and its corresponding FieldInfo object. - """ - name: str info: FieldInfo -# Built-in field indicating that a chat turn (i.e. a user or assistant reply to a chat -# thread) has been completed. +# Built-in field indicating that a chat turn has been completed. BuiltInCompletedOutputFieldInfo = FieldInfoWithName(name="completed", info=OutputField()) @@ -114,6 +111,16 @@ def format_input_list_field_value(value: List[Any]) -> str: return "\n".join([f"[{idx+1}] {format_blob(txt)}" for idx, txt in enumerate(value)]) +def _serialize_for_json(value): + if isinstance(value, pydantic.BaseModel): + return value.model_dump() + elif isinstance(value, list): + return [_serialize_for_json(item) for item in value] + elif isinstance(value, dict): + return {key: _serialize_for_json(val) for key, val in value.items()} + else: + return value + def _format_field_value(field_info: FieldInfo, value: Any) -> str: """ Formats the value of the specified field according to the field's DSPy type (input or output), @@ -125,24 +132,17 @@ def _format_field_value(field_info: FieldInfo, value: Any) -> str: Returns: The formatted value of the field, represented as a string. """ - dspy_field_type: Literal["input", "output"] = get_dspy_field_type(field_info) - if isinstance(value, list): - if dspy_field_type == "input" or field_info.annotation is str: - # If the field is an input field or has no special type requirements, format it as - # numbered list so that it's organized in a way suitable for presenting long context - # to an LLM (i.e. not JSON) - return format_input_list_field_value(value) - else: - # If the field is an output field that has strict parsing requirements, format the - # value as a stringified JSON Array. This ensures that downstream routines can parse - # the field value correctly using methods from the `ujson` or `json` packages. - return json.dumps(value) - elif isinstance(value, pydantic.BaseModel): - return value.model_dump_json() + + if isinstance(value, list) and field_info.annotation is str: + # If the field has no special type requirements, format it as a nice numbere list for the LM. + return format_input_list_field_value(value) + elif isinstance(value, pydantic.BaseModel) or isinstance(value, dict) or isinstance(value, list): + return json.dumps(_serialize_for_json(value)) else: return str(value) + def format_fields(fields_with_values: Dict[FieldInfoWithName, Any]) -> str: """ Formats the values of the specified fields according to the field's DSPy type (input or output), @@ -166,8 +166,12 @@ def format_fields(fields_with_values: Dict[FieldInfoWithName, Any]) -> str: def parse_value(value, annotation): if annotation is str: return str(value) + parsed_value = value - if isinstance(value, str): + + if isinstance(annotation, enum.EnumMeta): + parsed_value = annotation[value] + elif isinstance(value, str): try: parsed_value = json.loads(value) except json.JSONDecodeError: @@ -175,6 +179,7 @@ def parse_value(value, annotation): parsed_value = ast.literal_eval(value) except (ValueError, SyntaxError): parsed_value = value + return TypeAdapter(annotation).validate_python(parsed_value) @@ -222,6 +227,16 @@ def format_turn(signature: SignatureMeta, values: Dict[str, Any], role, incomple content.append(formatted_fields) if role == "user": + # def type_info(v): + # return f" (must be formatted as a valid Python {get_annotation_name(v.annotation)})" \ + # if v.annotation is not str else "" + # + # content.append( + # "Respond with the corresponding output fields, starting with the field " + # + ", then ".join(f"`[[ ## {f} ## ]]`{type_info(v)}" for f, v in signature.output_fields.items()) + # + ", and then ending with the marker for `[[ ## completed ## ]]`." + # ) + content.append( "Respond with the corresponding output fields, starting with the field " + ", then ".join(f"`{f}`" for f in signature.output_fields) @@ -260,10 +275,30 @@ def prepare_instructions(signature: SignatureMeta): parts.append("Your output fields are:\n" + enumerate_fields(signature.output_fields)) parts.append("All interactions will be structured in the following way, with the appropriate values filled in.") + def field_metadata(field_name, field_info): + type_ = field_info.annotation + + if get_dspy_field_type(field_info) == 'input' or type_ is str: + desc = "" + elif type_ is bool: + desc = "must be True or False" + elif type_ in (int, float): + desc = f"must be a single {type_.__name__} value" + elif inspect.isclass(type_) and issubclass(type_, enum.Enum): + desc= f"must be one of: {'; '.join(type_.__members__)}" + elif hasattr(type_, '__origin__') and type_.__origin__ is Literal: + desc = f"must be one of: {'; '.join([str(x) for x in type_.__args__])}" + else: + desc = "must be pareseable according to the following JSON schema: " + desc += json.dumps(pydantic.TypeAdapter(type_).json_schema()) + + desc = (" " * 8) + f"# note: the value you produce {desc}" if desc else "" + return f"{{{field_name}}}{desc}" + def format_signature_fields_for_instructions(fields: Dict[str, FieldInfo]): return format_fields( fields_with_values={ - FieldInfoWithName(name=field_name, info=field_info): f"{{{field_name}}}" + FieldInfoWithName(name=field_name, info=field_info): field_metadata(field_name, field_info) for field_name, field_info in fields.items() } )