From 7fa8ffe3432f67837743c007ca1aff9080ad2a67 Mon Sep 17 00:00:00 2001 From: chenmoneygithub Date: Thu, 20 Feb 2025 18:27:31 -0800 Subject: [PATCH] fix lint for chat adapter --- dspy/adapters/chat_adapter.py | 46 ++++++++++++++++++++--------------- 1 file changed, 26 insertions(+), 20 deletions(-) diff --git a/dspy/adapters/chat_adapter.py b/dspy/adapters/chat_adapter.py index 1f4543fcdc..ef38073ce2 100644 --- a/dspy/adapters/chat_adapter.py +++ b/dspy/adapters/chat_adapter.py @@ -5,18 +5,17 @@ import textwrap from collections.abc import Mapping from itertools import chain - from typing import Any, Dict, Literal, NamedTuple import pydantic from pydantic.fields import FieldInfo from dspy.adapters.base import Adapter -from dspy.adapters.utils import parse_value, format_field_value, get_annotation_name +from dspy.adapters.image_utils import try_expand_image_tags +from dspy.adapters.utils import format_field_value, get_annotation_name, parse_value from dspy.signatures.field import OutputField from dspy.signatures.signature import Signature, SignatureMeta from dspy.signatures.utils import get_dspy_field_type -from dspy.adapters.image_utils import try_expand_image_tags field_header_pattern = re.compile(r"\[\[ ## (\w+) ## \]\]") @@ -99,9 +98,6 @@ def format_finetune_data(self, signature, demos, inputs, outputs): # Wrap the messages in a dictionary with a "messages" key return dict(messages=messages) - def format_turn(self, signature, values, role, incomplete=False): - return format_turn(signature, values, role, incomplete) - def format_fields(self, signature, values, role): fields_with_values = { FieldInfoWithName(name=field_name, info=field_info): values.get( @@ -152,7 +148,9 @@ def format_turn(signature, values, role, incomplete=False): """ if role == "user": fields = signature.input_fields - message_prefix = "This is an example of the task, though some input or output fields are not supplied." if incomplete else "" + message_prefix = ( + "This is an example of the task, though some input or output fields are not supplied." if incomplete else "" + ) else: # Add the completed field for the assistant turn fields = {**signature.output_fields, BuiltInCompletedOutputFieldInfo.name: BuiltInCompletedOutputFieldInfo.info} @@ -167,31 +165,40 @@ def format_turn(signature, values, role, incomplete=False): messages.append(message_prefix) field_messages = format_fields( - {FieldInfoWithName(name=k, info=v): values.get(k, "Not supplied for this particular example.") - for k, v in fields.items()}, + { + FieldInfoWithName(name=k, info=v): values.get(k, "Not supplied for this particular example.") + for k, v in fields.items() + }, ) messages.append(field_messages) + def type_info(v): + if v.annotation is not str: + return f" (must be formatted as a valid Python {get_annotation_name(v.annotation)})" + else: + return "" + # Add output field instructions for user messages if role == "user" and signature.output_fields: - type_info = lambda v: f" (must be formatted as a valid Python {get_annotation_name(v.annotation)})" if v.annotation is not str else "" - field_instructions = "Respond with the corresponding output fields, starting with the field " + \ - ", then ".join(f"`[[ ## {f} ## ]]`{type_info(v)}" for f, v in signature.output_fields.items()) + \ - ", and then ending with the marker for `[[ ## completed ## ]]`." + field_instructions = ( + "Respond with the corresponding output fields, starting with the field " + + ", then ".join(f"`[[ ## {f} ## ]]`{type_info(v)}" for f, v in signature.output_fields.items()) + + ", and then ending with the marker for `[[ ## completed ## ]]`." + ) messages.append(field_instructions) joined_messages = "\n\n".join(msg for msg in messages) return {"role": role, "content": joined_messages} + def flatten_messages(messages): """Flatten nested message lists.""" - return list(chain.from_iterable( - item if isinstance(item, list) else [item] for item in messages - )) + return list(chain.from_iterable(item if isinstance(item, list) else [item] for item in messages)) + def enumerate_fields(fields: dict) -> str: parts = [] for idx, (k, v) in enumerate(fields.items()): - parts.append(f"{idx+1}. `{k}`") + parts.append(f"{idx + 1}. `{k}`") parts[-1] += f" ({get_annotation_name(v.annotation)})" parts[-1] += f": {v.json_schema_extra['desc']}" if v.json_schema_extra["desc"] != f"${{{k}}}" else "" @@ -207,8 +214,8 @@ def move_type_to_front(d): return d -def prepare_schema(type_): - schema = pydantic.TypeAdapter(type_).json_schema() +def prepare_schema(field_type): + schema = pydantic.TypeAdapter(field_type).json_schema() schema = move_type_to_front(schema) return schema @@ -237,7 +244,6 @@ def field_metadata(field_name, field_info): f"must exactly match (no extra characters) one of: {'; '.join([str(x) for x in field_type.__args__])}" ) else: - # desc = "must be pareseable according to the following JSON schema: " desc = "must adhere to the JSON schema: " desc += json.dumps(prepare_schema(field_type), ensure_ascii=False)