Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion dspy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,4 +73,5 @@

# TODO: Consider if this should access settings.lm *or* a list that's shared across all LMs in the program.
def inspect_history(*args, **kwargs):
return settings.lm.inspect_history(*args, **kwargs)
from dspy.clients.lm import GLOBAL_HISTORY, _inspect_history
return _inspect_history(GLOBAL_HISTORY, *args, **kwargs)
3 changes: 2 additions & 1 deletion dspy/adapters/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .base import Adapter
from .chat_adapter import ChatAdapter
from .chat_adapter import ChatAdapter
from .json_adapter import JsonAdapter
25 changes: 16 additions & 9 deletions dspy/adapters/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,22 @@ def __init_subclass__(cls, **kwargs) -> None:
cls.parse = with_callbacks(cls.parse)

def __call__(self, lm, lm_kwargs, signature, demos, inputs, _parse_values=True):
inputs = self.format(signature, demos, inputs)
inputs = dict(prompt=inputs) if isinstance(inputs, str) else dict(messages=inputs)
inputs_ = self.format(signature, demos, inputs)
inputs_ = dict(prompt=inputs_) if isinstance(inputs_, str) else dict(messages=inputs_)

outputs = lm(**inputs, **lm_kwargs)
outputs = lm(**inputs_, **lm_kwargs)
values = []

for output in outputs:
value = self.parse(signature, output, _parse_values=_parse_values)
assert set(value.keys()) == set(signature.output_fields.keys()), f"Expected {signature.output_fields.keys()} but got {value.keys()}"
values.append(value)

return values
try:
for output in outputs:
value = self.parse(signature, output, _parse_values=_parse_values)
assert set(value.keys()) == set(signature.output_fields.keys()), f"Expected {signature.output_fields.keys()} but got {value.keys()}"
values.append(value)
return values

except Exception as e:
from .json_adapter import JsonAdapter
if _parse_values and not isinstance(self, JsonAdapter):
return JsonAdapter()(lm, lm_kwargs, signature, demos, inputs, _parse_values=_parse_values)
raise e

16 changes: 15 additions & 1 deletion dspy/adapters/chat_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import textwrap

from pydantic import TypeAdapter
from collections.abc import Mapping
from pydantic.fields import FieldInfo
from typing import Any, Dict, KeysView, List, Literal, NamedTuple, get_args, get_origin

Expand Down Expand Up @@ -269,6 +270,19 @@ def enumerate_fields(fields):
return "\n".join(parts).strip()


def move_type_to_front(d):
# Move the 'type' key to the front of the dictionary, recursively, for LLM readability/adherence.
if isinstance(d, Mapping):
return {k: move_type_to_front(v) for k, v in sorted(d.items(), key=lambda item: (item[0] != 'type', item[0]))}
elif isinstance(d, list):
return [move_type_to_front(item) for item in d]
return d

def prepare_schema(type_):
schema = pydantic.TypeAdapter(type_).json_schema()
schema = move_type_to_front(schema)
return schema

def prepare_instructions(signature: SignatureMeta):
parts = []
parts.append("Your input fields are:\n" + enumerate_fields(signature.input_fields))
Expand All @@ -290,7 +304,7 @@ def field_metadata(field_name, field_info):
desc = f"must be one of: {'; '.join([str(x) for x in type_.__args__])}"
else:
desc = "must be pareseable according to the following JSON schema: "
desc += json.dumps(pydantic.TypeAdapter(type_).json_schema())
desc += json.dumps(prepare_schema(type_))

desc = (" " * 8) + f"# note: the value you produce {desc}" if desc else ""
return f"{{{field_name}}}{desc}"
Expand Down
Loading
Loading