Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion dspy/signatures/signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,20 @@ def __call__(cls, *args, **kwargs): # noqa: ANN002
return super().__call__(*args, **kwargs)

def __new__(mcs, signature_name, bases, namespace, **kwargs): # noqa: N804
# At this point, the orders have been swapped already.
field_order = [name for name, value in namespace.items() if isinstance(value, FieldInfo)]
# Set `str` as the default type for all fields
raw_annotations = namespace.get("__annotations__", {})
for name, field in namespace.items():
if not isinstance(field, FieldInfo):
continue # Don't add types to non-field attributes
if not name.startswith("__") and name not in raw_annotations:
raw_annotations[name] = str
namespace["__annotations__"] = raw_annotations
# Create ordered annotations dictionary that preserves field order
ordered_annotations = {name: raw_annotations[name] for name in field_order if name in raw_annotations}
# Add any remaining annotations that weren't in field_order
ordered_annotations.update({k: v for k, v in raw_annotations.items() if k not in ordered_annotations})
namespace["__annotations__"] = ordered_annotations

# Let Pydantic do its thing
cls = super().__new__(mcs, signature_name, bases, namespace, **kwargs)
Expand Down
30 changes: 11 additions & 19 deletions tests/predict/test_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,26 +71,18 @@ def forward(self, input):
res2 = self.predictor2.batch([input] * 5)

return (res1, res2)

result, reason_result = MyModule()(dspy.Example(input="test input").with_inputs("input"))

assert result[0].output == "test output 1"
assert result[1].output == "test output 2"
assert result[2].output == "test output 3"
assert result[3].output == "test output 4"
assert result[4].output == "test output 5"
result, reason_result = MyModule()(dspy.Example(input="test input").with_inputs("input"))

assert reason_result[0].output == "test output 1"
assert reason_result[1].output == "test output 2"
assert reason_result[2].output == "test output 3"
assert reason_result[3].output == "test output 4"
assert reason_result[4].output == "test output 5"
# Check that we got all expected outputs without caring about order
expected_outputs = {f"test output {i}" for i in range(1, 6)}
assert {r.output for r in result} == expected_outputs
assert {r.output for r in reason_result} == expected_outputs

assert reason_result[0].reasoning == "test reasoning 1"
assert reason_result[1].reasoning == "test reasoning 2"
assert reason_result[2].reasoning == "test reasoning 3"
assert reason_result[3].reasoning == "test reasoning 4"
assert reason_result[4].reasoning == "test reasoning 5"
# Check that reasoning matches outputs for reason_result
for r in reason_result:
num = r.output.split()[-1] # get the number from "test output X"
assert r.reasoning == f"test reasoning {num}"


def test_nested_parallel_module():
Expand Down Expand Up @@ -120,7 +112,7 @@ def forward(self, input):
(self.predictor, input),
]),
])

output = MyModule()(dspy.Example(input="test input").with_inputs("input"))

assert output[0].output == "test output 1"
Expand Down Expand Up @@ -148,7 +140,7 @@ def forward(self, input):
res = self.predictor.batch([dspy.Example(input=input).with_inputs("input")]*2)

return res

result = MyModule().batch([dspy.Example(input="test input").with_inputs("input")]*2)

assert {result[0][0].output, result[0][1].output, result[1][0].output, result[1][1].output} \
Expand Down
11 changes: 11 additions & 0 deletions tests/signatures/test_signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,17 @@ class InitialSignature(Signature):
assert "new_output_end" == list(S4.output_fields.keys())[-1]


def test_order_preserved_with_mixed_annotations():
class ExampleSignature(dspy.Signature):
text: str = dspy.InputField()
output = dspy.OutputField()
pass_evaluation: bool = dspy.OutputField()

expected_order = ["text", "output", "pass_evaluation"]
actual_order = list(ExampleSignature.fields.keys())
assert actual_order == expected_order


def test_infer_prefix():
assert infer_prefix("someAttributeName42IsCool") == "Some Attribute Name 42 Is Cool"
assert infer_prefix("version2Update") == "Version 2 Update"
Expand Down
Loading