Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 12 additions & 30 deletions dspy/adapters/json_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from pydantic.fields import FieldInfo

from dspy.adapters.base import Adapter
from dspy.adapters.types.image import Image
from dspy.adapters.types.image import try_expand_image_tags
from dspy.adapters.types.history import History
from dspy.adapters.utils import format_field_value, get_annotation_name, parse_value, serialize_for_json
from dspy.clients.lm import LM
Expand Down Expand Up @@ -42,11 +42,11 @@ def __call__(self, lm: LM, lm_kwargs: dict[str, Any], signature: Type[Signature]
try:
response_format = _get_structured_outputs_response_format(signature)
outputs = lm(**inputs, **lm_kwargs, response_format=response_format)
except Exception:
except Exception as e:
logger.debug(
"Failed to obtain response using signature-based structured outputs"
" response format: Falling back to default 'json_object' response format."
" Exception: {e}"
f"Failed to obtain response using signature-based structured outputs"
f" response format: Falling back to default 'json_object' response format."
f" Exception: {e}"
)
outputs = lm(**inputs, **lm_kwargs, response_format={"type": "json_object"})
else:
Expand Down Expand Up @@ -92,6 +92,7 @@ def format(self, signature: Type[Signature], demos: list[dict[str, Any]], inputs
else:
messages.append(self.format_turn(signature, inputs, role="user"))

messages = try_expand_image_tags(messages)
return messages

def parse(self, signature: Type[Signature], completion: str) -> dict[str, Any]:
Expand All @@ -116,7 +117,6 @@ def format_fields(self, signature: Type[Signature], values: dict[str, Any], role
for field_name, field_info in signature.fields.items()
if field_name in values
}

return format_fields(role=role, fields_with_values=fields_with_values)

def format_turn(self, signature: Type[Signature], values, role: str, incomplete: bool = False, is_conversation_history: bool = False) -> dict[str, Any]:
Expand All @@ -127,34 +127,16 @@ def format_finetune_data(self, signature: Type[Signature], demos: list[dict[str,
raise NotImplementedError


def _format_field_value(field_info: FieldInfo, value: Any) -> str:
"""
Formats the value of the specified field according to the field's DSPy type (input or output),
annotation (e.g. str, int, etc.), and the type of the value itself.

Args:
field_info: Information about the field, including its DSPy field type and annotation.
value: The value of the field.

Returns:
The formatted value of the field, represented as a string.
"""
# TODO: Wasnt this easy to fix?
if field_info.annotation is Image:
raise NotImplementedError("Images are not yet supported in JSON mode.")

return format_field_value(field_info=field_info, value=value)


def format_fields(role: str, fields_with_values: Dict[FieldInfoWithName, Any]) -> str:
"""
Formats the values of the specified fields according to the field's DSPy type (input or output),
annotation (e.g. str, int, etc.), and the type of the value itself. Joins the formatted values
into a single string, which is is a multiline string if there are multiple fields.
into a single string, which is a multiline string if there are multiple fields.

Args:
fields_with_values: A dictionary mapping information about a field to its corresponding
value.
role: The role of the message ('user' or 'assistant')
fields_with_values: A dictionary mapping information about a field to its corresponding value.

Returns:
The joined formatted values of the fields, represented as a string.
"""
Expand All @@ -166,7 +148,7 @@ def format_fields(role: str, fields_with_values: Dict[FieldInfoWithName, Any]) -

output = []
for field, field_value in fields_with_values.items():
formatted_field_value = _format_field_value(field_info=field.info, value=field_value)
formatted_field_value = format_field_value(field_info=field.info, value=field_value)
output.append(f"[[ ## {field.name} ## ]]\n{formatted_field_value}")

return "\n\n".join(output).strip()
Expand All @@ -175,7 +157,7 @@ def format_fields(role: str, fields_with_values: Dict[FieldInfoWithName, Any]) -
def format_turn(
signature: SignatureMeta,
values: Dict[str, Any],
role,
role: str,
incomplete=False,
is_conversation_history=False,
) -> Dict[str, str]:
Expand Down
Loading