diff --git a/dspy/adapters/chat_adapter.py b/dspy/adapters/chat_adapter.py index beb74a28ac..c011819b6c 100644 --- a/dspy/adapters/chat_adapter.py +++ b/dspy/adapters/chat_adapter.py @@ -37,7 +37,14 @@ class ChatAdapter(Adapter): def __init__(self, callbacks: Optional[list[BaseCallback]] = None): super().__init__(callbacks) - def __call__(self, lm: LM, lm_kwargs: dict[str, Any], signature: Type[Signature], demos: list[dict[str, Any]], inputs: dict[str, Any]) -> list[dict[str, Any]]: + def __call__( + self, + lm: LM, + lm_kwargs: dict[str, Any], + signature: Type[Signature], + demos: list[dict[str, Any]], + inputs: dict[str, Any], + ) -> list[dict[str, Any]]: try: return super().__call__(lm, lm_kwargs, signature, demos, inputs) except Exception as e: @@ -46,8 +53,10 @@ def __call__(self, lm: LM, lm_kwargs: dict[str, Any], signature: Type[Signature] raise e # fallback to JSONAdapter return JSONAdapter()(lm, lm_kwargs, signature, demos, inputs) - - def format(self, signature: Type[Signature], demos: list[dict[str, Any]], inputs: dict[str, Any]) -> list[dict[str, Any]]: + + def format( + self, signature: Type[Signature], demos: list[dict[str, Any]], inputs: dict[str, Any] + ) -> list[dict[str, Any]]: messages: list[dict[str, Any]] = [] # Extract demos where some of the output_fields are not filled in. @@ -88,7 +97,7 @@ def parse(self, signature: Type[Signature], completion: str) -> dict[str, Any]: if match: # If the header pattern is found, split the rest of the line as content header = match.group(1) - remaining_content = line[match.end():].strip() + remaining_content = line[match.end() :].strip() sections.append((header, [remaining_content] if remaining_content else [])) else: sections[-1][1].append(line) @@ -111,7 +120,9 @@ def parse(self, signature: Type[Signature], completion: str) -> dict[str, Any]: return fields # TODO(PR): Looks ok? - def format_finetune_data(self, signature: Type[Signature], demos: list[dict[str, Any]], inputs: dict[str, Any], outputs: dict[str, Any]) -> dict[str, list[Any]]: + def format_finetune_data( + self, signature: Type[Signature], demos: list[dict[str, Any]], inputs: dict[str, Any], outputs: dict[str, Any] + ) -> dict[str, list[Any]]: # Get system + user messages messages = self.format(signature, demos, inputs) @@ -134,7 +145,14 @@ def format_fields(self, signature: Type[Signature], values: dict[str, Any], role } return format_fields(fields_with_values) - def format_turn(self, signature: Type[Signature], values: dict[str, Any], role: str, incomplete: bool = False, is_conversation_history: bool = False) -> dict[str, Any]: + def format_turn( + self, + signature: Type[Signature], + values: dict[str, Any], + role: str, + incomplete: bool = False, + is_conversation_history: bool = False, + ) -> dict[str, Any]: return format_turn(signature, values, role, incomplete, is_conversation_history) @@ -158,7 +176,9 @@ def format_fields(fields_with_values: Dict[FieldInfoWithName, Any]) -> str: return "\n\n".join(output).strip() -def format_turn(signature: Type[Signature], values: dict[str, Any], role: str, incomplete=False, is_conversation_history=False): +def format_turn( + signature: Type[Signature], values: dict[str, Any], role: str, incomplete=False, is_conversation_history=False +): """ Constructs a new message ("turn") to append to a chat thread. The message is carefully formatted so that it can instruct an LLM to generate responses conforming to the specified DSPy signature. diff --git a/dspy/adapters/json_adapter.py b/dspy/adapters/json_adapter.py index 2d1288f4f7..f5ea79a63b 100644 --- a/dspy/adapters/json_adapter.py +++ b/dspy/adapters/json_adapter.py @@ -13,11 +13,11 @@ from pydantic.fields import FieldInfo from dspy.adapters.base import Adapter -from dspy.adapters.types.image import try_expand_image_tags from dspy.adapters.types.history import History +from dspy.adapters.types.image import try_expand_image_tags from dspy.adapters.utils import format_field_value, get_annotation_name, parse_value, serialize_for_json from dspy.clients.lm import LM -from dspy.signatures.signature import SignatureMeta, Signature +from dspy.signatures.signature import Signature, SignatureMeta from dspy.signatures.utils import get_dspy_field_type logger = logging.getLogger(__name__) @@ -27,11 +27,19 @@ class FieldInfoWithName(NamedTuple): name: str info: FieldInfo + class JSONAdapter(Adapter): def __init__(self): pass - def __call__(self, lm: LM, lm_kwargs: dict[str, Any], signature: Type[Signature], demos: list[dict[str, Any]], inputs: dict[str, Any]) -> list[dict[str, Any]]: + def __call__( + self, + lm: LM, + lm_kwargs: dict[str, Any], + signature: Type[Signature], + demos: list[dict[str, Any]], + inputs: dict[str, Any], + ) -> list[dict[str, Any]]: inputs = self.format(signature, demos, inputs) inputs = dict(prompt=inputs) if isinstance(inputs, str) else dict(messages=inputs) @@ -66,7 +74,9 @@ def __call__(self, lm: LM, lm_kwargs: dict[str, Any], signature: Type[Signature] return values - def format(self, signature: Type[Signature], demos: list[dict[str, Any]], inputs: dict[str, Any]) -> list[dict[str, Any]]: + def format( + self, signature: Type[Signature], demos: list[dict[str, Any]], inputs: dict[str, Any] + ) -> list[dict[str, Any]]: messages = [] # Extract demos where some of the output_fields are not filled in. @@ -118,11 +128,20 @@ def format_fields(self, signature: Type[Signature], values: dict[str, Any], role if field_name in values } return format_fields(role=role, fields_with_values=fields_with_values) - - def format_turn(self, signature: Type[Signature], values, role: str, incomplete: bool = False, is_conversation_history: bool = False) -> dict[str, Any]: + + def format_turn( + self, + signature: Type[Signature], + values, + role: str, + incomplete: bool = False, + is_conversation_history: bool = False, + ) -> dict[str, Any]: return format_turn(signature, values, role, incomplete, is_conversation_history) - - def format_finetune_data(self, signature: Type[Signature], demos: list[dict[str, Any]], inputs: dict[str, Any], outputs: dict[str, Any]) -> dict[str, list[Any]]: + + def format_finetune_data( + self, signature: Type[Signature], demos: list[dict[str, Any]], inputs: dict[str, Any], outputs: dict[str, Any] + ) -> dict[str, list[Any]]: # TODO: implement format_finetune_data method in JSONAdapter raise NotImplementedError @@ -136,7 +155,7 @@ def format_fields(role: str, fields_with_values: Dict[FieldInfoWithName, Any]) - Args: role: The role of the message ('user' or 'assistant') fields_with_values: A dictionary mapping information about a field to its corresponding value. - + Returns: The joined formatted values of the fields, represented as a string. """