From 858561aa82eedf5f582fb4aa84c5a24c77822ae5 Mon Sep 17 00:00:00 2001 From: Zach Bamberger <94684184+zbambergerNLP@users.noreply.github.com> Date: Mon, 24 Mar 2025 13:54:09 +0200 Subject: [PATCH 1/4] Update chain_of_thought.py Allow use of custom CoT representations (e.g., List of strings) instead of forcing CoT be a string output. Previously, even if the user passed in a custom definition for the reasoning field, the type of the reasoning field output would be a string. This change ensures that the type of the field is consistent with its annotation, and allows for users to specify a custom type for reasoning without creating a custom field for it. Also, this change introduces a docstring and type hints into the ChainOfThought module. --- dspy/predict/chain_of_thought.py | 29 +++++++++++++++++++++++------ 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/dspy/predict/chain_of_thought.py b/dspy/predict/chain_of_thought.py index 4a56c0ceb2..d81e136d13 100644 --- a/dspy/predict/chain_of_thought.py +++ b/dspy/predict/chain_of_thought.py @@ -4,17 +4,34 @@ class ChainOfThought(Module): - def __init__(self, signature, rationale_type=None, **config): + + def __init__( + self, + signature: Type[dspy.Signature], + rationale_field: Optional[Union[dspy.OutputField, pydantic.fields.FieldInfo]] = None, + rationale_field_type: Type = str, + **config + ): + """ + A module that reasons step by step in order to predict the output of a task. + + Args: + signature (Type[dspy.Signature]): The signature of the module. + rationale_field (Optional[Union[dspy.OutputField, pydantic.fields.FieldInfo]]): The field that will contain the reasoning. + rationale_field_type (Type): The type of the rationale field. + **config: The configuration for the module. + """ super().__init__() - signature = ensure_signature(signature) - prefix = "Reasoning: Let's think step by step in order to" desc = "${reasoning}" - rationale_type = rationale_type or dspy.OutputField(prefix=prefix, desc=desc) - extended_signature = signature.prepend("reasoning", rationale_type, type_=str) - + rationale_field_type = rationale_field.annotation if rationale_field else rationale_field_type + rationale_field = rationale_field if rationale_field else dspy.OutputField(prefix=prefix, desc=desc) + extended_signature = signature.prepend(name="reasoning", field=rationale_field, type_=rationale_field_type) self.predict = dspy.Predict(extended_signature, **config) def forward(self, **kwargs): return self.predict(**kwargs) + + def forward(self, **kwargs): + return self.predict(**kwargs) From a650fdd6804a194ccfac4c4ffcc6bf4bfc53c55c Mon Sep 17 00:00:00 2001 From: Zach Bamberger <94684184+zbambergerNLP@users.noreply.github.com> Date: Mon, 24 Mar 2025 13:57:52 +0200 Subject: [PATCH 2/4] Update chain_of_thought.py Fixed error due to missing import. --- dspy/predict/chain_of_thought.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/dspy/predict/chain_of_thought.py b/dspy/predict/chain_of_thought.py index d81e136d13..33654dab18 100644 --- a/dspy/predict/chain_of_thought.py +++ b/dspy/predict/chain_of_thought.py @@ -1,14 +1,15 @@ import dspy from dspy.primitives.program import Module from dspy.signatures.signature import ensure_signature - +from pydantic.fields import FieldInfo +from typing import Optional, Union, Type class ChainOfThought(Module): def __init__( self, signature: Type[dspy.Signature], - rationale_field: Optional[Union[dspy.OutputField, pydantic.fields.FieldInfo]] = None, + rationale_field: Optional[Union[dspy.OutputField, FieldInfo]] = None, rationale_field_type: Type = str, **config ): @@ -32,6 +33,3 @@ def __init__( def forward(self, **kwargs): return self.predict(**kwargs) - - def forward(self, **kwargs): - return self.predict(**kwargs) From cbb129029ce806e64b143f3a1af990a932833fda Mon Sep 17 00:00:00 2001 From: Zach Bamberger <94684184+zbambergerNLP@users.noreply.github.com> Date: Mon, 24 Mar 2025 14:01:05 +0200 Subject: [PATCH 3/4] Update chain_of_thought.py Avoid circular import and ensure there are two newlines between imports and initial code. --- dspy/predict/chain_of_thought.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/dspy/predict/chain_of_thought.py b/dspy/predict/chain_of_thought.py index 33654dab18..1a58da5fba 100644 --- a/dspy/predict/chain_of_thought.py +++ b/dspy/predict/chain_of_thought.py @@ -1,14 +1,15 @@ import dspy from dspy.primitives.program import Module -from dspy.signatures.signature import ensure_signature +from dspy.signatures.signature import ensure_signature, Signature from pydantic.fields import FieldInfo from typing import Optional, Union, Type + class ChainOfThought(Module): def __init__( self, - signature: Type[dspy.Signature], + signature: Type[Signature], rationale_field: Optional[Union[dspy.OutputField, FieldInfo]] = None, rationale_field_type: Type = str, **config From ec7c6e3e8c2662ee0113c536f6af9a4300b72bf6 Mon Sep 17 00:00:00 2001 From: Zach Bamberger <94684184+zbambergerNLP@users.noreply.github.com> Date: Mon, 24 Mar 2025 14:06:16 +0200 Subject: [PATCH 4/4] Update chain_of_thought.py Avoided circular import caused by referencing `dspy.OutputField` --- dspy/predict/chain_of_thought.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dspy/predict/chain_of_thought.py b/dspy/predict/chain_of_thought.py index 1a58da5fba..d7b56448b8 100644 --- a/dspy/predict/chain_of_thought.py +++ b/dspy/predict/chain_of_thought.py @@ -1,5 +1,6 @@ import dspy from dspy.primitives.program import Module +from dspy.signatures.field import OutputField from dspy.signatures.signature import ensure_signature, Signature from pydantic.fields import FieldInfo from typing import Optional, Union, Type @@ -10,7 +11,7 @@ class ChainOfThought(Module): def __init__( self, signature: Type[Signature], - rationale_field: Optional[Union[dspy.OutputField, FieldInfo]] = None, + rationale_field: Optional[Union[OutputField, FieldInfo]] = None, rationale_field_type: Type = str, **config ):