Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 26 additions & 20 deletions dspy/adapters/chat_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,17 @@
import textwrap
from collections.abc import Mapping
from itertools import chain

from typing import Any, Dict, Literal, NamedTuple

import pydantic
from pydantic.fields import FieldInfo

from dspy.adapters.base import Adapter
from dspy.adapters.utils import parse_value, format_field_value, get_annotation_name
from dspy.adapters.image_utils import try_expand_image_tags
from dspy.adapters.utils import format_field_value, get_annotation_name, parse_value
from dspy.signatures.field import OutputField
from dspy.signatures.signature import Signature, SignatureMeta
from dspy.signatures.utils import get_dspy_field_type
from dspy.adapters.image_utils import try_expand_image_tags

field_header_pattern = re.compile(r"\[\[ ## (\w+) ## \]\]")

Expand Down Expand Up @@ -99,9 +98,6 @@ def format_finetune_data(self, signature, demos, inputs, outputs):
# Wrap the messages in a dictionary with a "messages" key
return dict(messages=messages)

def format_turn(self, signature, values, role, incomplete=False):
return format_turn(signature, values, role, incomplete)

def format_fields(self, signature, values, role):
fields_with_values = {
FieldInfoWithName(name=field_name, info=field_info): values.get(
Expand Down Expand Up @@ -152,7 +148,9 @@ def format_turn(signature, values, role, incomplete=False):
"""
if role == "user":
fields = signature.input_fields
message_prefix = "This is an example of the task, though some input or output fields are not supplied." if incomplete else ""
message_prefix = (
"This is an example of the task, though some input or output fields are not supplied." if incomplete else ""
)
else:
# Add the completed field for the assistant turn
fields = {**signature.output_fields, BuiltInCompletedOutputFieldInfo.name: BuiltInCompletedOutputFieldInfo.info}
Expand All @@ -167,31 +165,40 @@ def format_turn(signature, values, role, incomplete=False):
messages.append(message_prefix)

field_messages = format_fields(
{FieldInfoWithName(name=k, info=v): values.get(k, "Not supplied for this particular example.")
for k, v in fields.items()},
{
FieldInfoWithName(name=k, info=v): values.get(k, "Not supplied for this particular example.")
for k, v in fields.items()
},
)
messages.append(field_messages)

def type_info(v):
if v.annotation is not str:
return f" (must be formatted as a valid Python {get_annotation_name(v.annotation)})"
else:
return ""

# Add output field instructions for user messages
if role == "user" and signature.output_fields:
type_info = lambda v: f" (must be formatted as a valid Python {get_annotation_name(v.annotation)})" if v.annotation is not str else ""
field_instructions = "Respond with the corresponding output fields, starting with the field " + \
", then ".join(f"`[[ ## {f} ## ]]`{type_info(v)}" for f, v in signature.output_fields.items()) + \
", and then ending with the marker for `[[ ## completed ## ]]`."
field_instructions = (
"Respond with the corresponding output fields, starting with the field "
+ ", then ".join(f"`[[ ## {f} ## ]]`{type_info(v)}" for f, v in signature.output_fields.items())
+ ", and then ending with the marker for `[[ ## completed ## ]]`."
)
messages.append(field_instructions)
joined_messages = "\n\n".join(msg for msg in messages)
return {"role": role, "content": joined_messages}


def flatten_messages(messages):
"""Flatten nested message lists."""
return list(chain.from_iterable(
item if isinstance(item, list) else [item] for item in messages
))
return list(chain.from_iterable(item if isinstance(item, list) else [item] for item in messages))


def enumerate_fields(fields: dict) -> str:
parts = []
for idx, (k, v) in enumerate(fields.items()):
parts.append(f"{idx+1}. `{k}`")
parts.append(f"{idx + 1}. `{k}`")
parts[-1] += f" ({get_annotation_name(v.annotation)})"
parts[-1] += f": {v.json_schema_extra['desc']}" if v.json_schema_extra["desc"] != f"${{{k}}}" else ""

Expand All @@ -207,8 +214,8 @@ def move_type_to_front(d):
return d


def prepare_schema(type_):
schema = pydantic.TypeAdapter(type_).json_schema()
def prepare_schema(field_type):
schema = pydantic.TypeAdapter(field_type).json_schema()
schema = move_type_to_front(schema)
return schema

Expand Down Expand Up @@ -237,7 +244,6 @@ def field_metadata(field_name, field_info):
f"must exactly match (no extra characters) one of: {'; '.join([str(x) for x in field_type.__args__])}"
)
else:
# desc = "must be pareseable according to the following JSON schema: "
desc = "must adhere to the JSON schema: "
desc += json.dumps(prepare_schema(field_type), ensure_ascii=False)

Expand Down
Loading