Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 27 additions & 7 deletions dspy/adapters/chat_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,14 @@ class ChatAdapter(Adapter):
def __init__(self, callbacks: Optional[list[BaseCallback]] = None):
super().__init__(callbacks)

def __call__(self, lm: LM, lm_kwargs: dict[str, Any], signature: Type[Signature], demos: list[dict[str, Any]], inputs: dict[str, Any]) -> list[dict[str, Any]]:
def __call__(
self,
lm: LM,
lm_kwargs: dict[str, Any],
signature: Type[Signature],
demos: list[dict[str, Any]],
inputs: dict[str, Any],
) -> list[dict[str, Any]]:
try:
return super().__call__(lm, lm_kwargs, signature, demos, inputs)
except Exception as e:
Expand All @@ -46,8 +53,10 @@ def __call__(self, lm: LM, lm_kwargs: dict[str, Any], signature: Type[Signature]
raise e
# fallback to JSONAdapter
return JSONAdapter()(lm, lm_kwargs, signature, demos, inputs)

def format(self, signature: Type[Signature], demos: list[dict[str, Any]], inputs: dict[str, Any]) -> list[dict[str, Any]]:

def format(
self, signature: Type[Signature], demos: list[dict[str, Any]], inputs: dict[str, Any]
) -> list[dict[str, Any]]:
messages: list[dict[str, Any]] = []

# Extract demos where some of the output_fields are not filled in.
Expand Down Expand Up @@ -88,7 +97,7 @@ def parse(self, signature: Type[Signature], completion: str) -> dict[str, Any]:
if match:
# If the header pattern is found, split the rest of the line as content
header = match.group(1)
remaining_content = line[match.end():].strip()
remaining_content = line[match.end() :].strip()
sections.append((header, [remaining_content] if remaining_content else []))
else:
sections[-1][1].append(line)
Expand All @@ -111,7 +120,9 @@ def parse(self, signature: Type[Signature], completion: str) -> dict[str, Any]:
return fields

# TODO(PR): Looks ok?
def format_finetune_data(self, signature: Type[Signature], demos: list[dict[str, Any]], inputs: dict[str, Any], outputs: dict[str, Any]) -> dict[str, list[Any]]:
def format_finetune_data(
self, signature: Type[Signature], demos: list[dict[str, Any]], inputs: dict[str, Any], outputs: dict[str, Any]
) -> dict[str, list[Any]]:
# Get system + user messages
messages = self.format(signature, demos, inputs)

Expand All @@ -134,7 +145,14 @@ def format_fields(self, signature: Type[Signature], values: dict[str, Any], role
}
return format_fields(fields_with_values)

def format_turn(self, signature: Type[Signature], values: dict[str, Any], role: str, incomplete: bool = False, is_conversation_history: bool = False) -> dict[str, Any]:
def format_turn(
self,
signature: Type[Signature],
values: dict[str, Any],
role: str,
incomplete: bool = False,
is_conversation_history: bool = False,
) -> dict[str, Any]:
return format_turn(signature, values, role, incomplete, is_conversation_history)


Expand All @@ -158,7 +176,9 @@ def format_fields(fields_with_values: Dict[FieldInfoWithName, Any]) -> str:
return "\n\n".join(output).strip()


def format_turn(signature: Type[Signature], values: dict[str, Any], role: str, incomplete=False, is_conversation_history=False):
def format_turn(
signature: Type[Signature], values: dict[str, Any], role: str, incomplete=False, is_conversation_history=False
):
"""
Constructs a new message ("turn") to append to a chat thread. The message is carefully formatted
so that it can instruct an LLM to generate responses conforming to the specified DSPy signature.
Expand Down
37 changes: 28 additions & 9 deletions dspy/adapters/json_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
from pydantic.fields import FieldInfo

from dspy.adapters.base import Adapter
from dspy.adapters.types.image import try_expand_image_tags
from dspy.adapters.types.history import History
from dspy.adapters.types.image import try_expand_image_tags
from dspy.adapters.utils import format_field_value, get_annotation_name, parse_value, serialize_for_json
from dspy.clients.lm import LM
from dspy.signatures.signature import SignatureMeta, Signature
from dspy.signatures.signature import Signature, SignatureMeta
from dspy.signatures.utils import get_dspy_field_type

logger = logging.getLogger(__name__)
Expand All @@ -27,11 +27,19 @@ class FieldInfoWithName(NamedTuple):
name: str
info: FieldInfo


class JSONAdapter(Adapter):
def __init__(self):
pass

def __call__(self, lm: LM, lm_kwargs: dict[str, Any], signature: Type[Signature], demos: list[dict[str, Any]], inputs: dict[str, Any]) -> list[dict[str, Any]]:
def __call__(
self,
lm: LM,
lm_kwargs: dict[str, Any],
signature: Type[Signature],
demos: list[dict[str, Any]],
inputs: dict[str, Any],
) -> list[dict[str, Any]]:
inputs = self.format(signature, demos, inputs)
inputs = dict(prompt=inputs) if isinstance(inputs, str) else dict(messages=inputs)

Expand Down Expand Up @@ -66,7 +74,9 @@ def __call__(self, lm: LM, lm_kwargs: dict[str, Any], signature: Type[Signature]

return values

def format(self, signature: Type[Signature], demos: list[dict[str, Any]], inputs: dict[str, Any]) -> list[dict[str, Any]]:
def format(
self, signature: Type[Signature], demos: list[dict[str, Any]], inputs: dict[str, Any]
) -> list[dict[str, Any]]:
messages = []

# Extract demos where some of the output_fields are not filled in.
Expand Down Expand Up @@ -118,11 +128,20 @@ def format_fields(self, signature: Type[Signature], values: dict[str, Any], role
if field_name in values
}
return format_fields(role=role, fields_with_values=fields_with_values)

def format_turn(self, signature: Type[Signature], values, role: str, incomplete: bool = False, is_conversation_history: bool = False) -> dict[str, Any]:

def format_turn(
self,
signature: Type[Signature],
values,
role: str,
incomplete: bool = False,
is_conversation_history: bool = False,
) -> dict[str, Any]:
return format_turn(signature, values, role, incomplete, is_conversation_history)

def format_finetune_data(self, signature: Type[Signature], demos: list[dict[str, Any]], inputs: dict[str, Any], outputs: dict[str, Any]) -> dict[str, list[Any]]:

def format_finetune_data(
self, signature: Type[Signature], demos: list[dict[str, Any]], inputs: dict[str, Any], outputs: dict[str, Any]
) -> dict[str, list[Any]]:
# TODO: implement format_finetune_data method in JSONAdapter
raise NotImplementedError

Expand All @@ -136,7 +155,7 @@ def format_fields(role: str, fields_with_values: Dict[FieldInfoWithName, Any]) -
Args:
role: The role of the message ('user' or 'assistant')
fields_with_values: A dictionary mapping information about a field to its corresponding value.

Returns:
The joined formatted values of the fields, represented as a string.
"""
Expand Down
Loading