diff --git a/dspy/signatures/signature.py b/dspy/signatures/signature.py index 0121de9ea7..455522d836 100644 --- a/dspy/signatures/signature.py +++ b/dspy/signatures/signature.py @@ -45,6 +45,8 @@ def __call__(cls, *args, **kwargs): # noqa: ANN002 return super().__call__(*args, **kwargs) def __new__(mcs, signature_name, bases, namespace, **kwargs): # noqa: N804 + # At this point, the orders have been swapped already. + field_order = [name for name, value in namespace.items() if isinstance(value, FieldInfo)] # Set `str` as the default type for all fields raw_annotations = namespace.get("__annotations__", {}) for name, field in namespace.items(): @@ -52,7 +54,11 @@ def __new__(mcs, signature_name, bases, namespace, **kwargs): # noqa: N804 continue # Don't add types to non-field attributes if not name.startswith("__") and name not in raw_annotations: raw_annotations[name] = str - namespace["__annotations__"] = raw_annotations + # Create ordered annotations dictionary that preserves field order + ordered_annotations = {name: raw_annotations[name] for name in field_order if name in raw_annotations} + # Add any remaining annotations that weren't in field_order + ordered_annotations.update({k: v for k, v in raw_annotations.items() if k not in ordered_annotations}) + namespace["__annotations__"] = ordered_annotations # Let Pydantic do its thing cls = super().__new__(mcs, signature_name, bases, namespace, **kwargs) diff --git a/tests/predict/test_parallel.py b/tests/predict/test_parallel.py index 072aa7eb1a..1c3a932fb0 100644 --- a/tests/predict/test_parallel.py +++ b/tests/predict/test_parallel.py @@ -71,26 +71,18 @@ def forward(self, input): res2 = self.predictor2.batch([input] * 5) return (res1, res2) - - result, reason_result = MyModule()(dspy.Example(input="test input").with_inputs("input")) - assert result[0].output == "test output 1" - assert result[1].output == "test output 2" - assert result[2].output == "test output 3" - assert result[3].output == "test output 4" - assert result[4].output == "test output 5" + result, reason_result = MyModule()(dspy.Example(input="test input").with_inputs("input")) - assert reason_result[0].output == "test output 1" - assert reason_result[1].output == "test output 2" - assert reason_result[2].output == "test output 3" - assert reason_result[3].output == "test output 4" - assert reason_result[4].output == "test output 5" + # Check that we got all expected outputs without caring about order + expected_outputs = {f"test output {i}" for i in range(1, 6)} + assert {r.output for r in result} == expected_outputs + assert {r.output for r in reason_result} == expected_outputs - assert reason_result[0].reasoning == "test reasoning 1" - assert reason_result[1].reasoning == "test reasoning 2" - assert reason_result[2].reasoning == "test reasoning 3" - assert reason_result[3].reasoning == "test reasoning 4" - assert reason_result[4].reasoning == "test reasoning 5" + # Check that reasoning matches outputs for reason_result + for r in reason_result: + num = r.output.split()[-1] # get the number from "test output X" + assert r.reasoning == f"test reasoning {num}" def test_nested_parallel_module(): @@ -120,7 +112,7 @@ def forward(self, input): (self.predictor, input), ]), ]) - + output = MyModule()(dspy.Example(input="test input").with_inputs("input")) assert output[0].output == "test output 1" @@ -148,7 +140,7 @@ def forward(self, input): res = self.predictor.batch([dspy.Example(input=input).with_inputs("input")]*2) return res - + result = MyModule().batch([dspy.Example(input="test input").with_inputs("input")]*2) assert {result[0][0].output, result[0][1].output, result[1][0].output, result[1][1].output} \ diff --git a/tests/signatures/test_signature.py b/tests/signatures/test_signature.py index 864a574a71..136fac213b 100644 --- a/tests/signatures/test_signature.py +++ b/tests/signatures/test_signature.py @@ -151,6 +151,17 @@ class InitialSignature(Signature): assert "new_output_end" == list(S4.output_fields.keys())[-1] +def test_order_preserved_with_mixed_annotations(): + class ExampleSignature(dspy.Signature): + text: str = dspy.InputField() + output = dspy.OutputField() + pass_evaluation: bool = dspy.OutputField() + + expected_order = ["text", "output", "pass_evaluation"] + actual_order = list(ExampleSignature.fields.keys()) + assert actual_order == expected_order + + def test_infer_prefix(): assert infer_prefix("someAttributeName42IsCool") == "Some Attribute Name 42 Is Cool" assert infer_prefix("version2Update") == "Version 2 Update"