diff --git a/dspy/predict/chain_of_thought.py b/dspy/predict/chain_of_thought.py index 4a56c0ceb2..d7b56448b8 100644 --- a/dspy/predict/chain_of_thought.py +++ b/dspy/predict/chain_of_thought.py @@ -1,19 +1,36 @@ import dspy from dspy.primitives.program import Module -from dspy.signatures.signature import ensure_signature +from dspy.signatures.field import OutputField +from dspy.signatures.signature import ensure_signature, Signature +from pydantic.fields import FieldInfo +from typing import Optional, Union, Type class ChainOfThought(Module): - def __init__(self, signature, rationale_type=None, **config): + + def __init__( + self, + signature: Type[Signature], + rationale_field: Optional[Union[OutputField, FieldInfo]] = None, + rationale_field_type: Type = str, + **config + ): + """ + A module that reasons step by step in order to predict the output of a task. + + Args: + signature (Type[dspy.Signature]): The signature of the module. + rationale_field (Optional[Union[dspy.OutputField, pydantic.fields.FieldInfo]]): The field that will contain the reasoning. + rationale_field_type (Type): The type of the rationale field. + **config: The configuration for the module. + """ super().__init__() - signature = ensure_signature(signature) - prefix = "Reasoning: Let's think step by step in order to" desc = "${reasoning}" - rationale_type = rationale_type or dspy.OutputField(prefix=prefix, desc=desc) - extended_signature = signature.prepend("reasoning", rationale_type, type_=str) - + rationale_field_type = rationale_field.annotation if rationale_field else rationale_field_type + rationale_field = rationale_field if rationale_field else dspy.OutputField(prefix=prefix, desc=desc) + extended_signature = signature.prepend(name="reasoning", field=rationale_field, type_=rationale_field_type) self.predict = dspy.Predict(extended_signature, **config) def forward(self, **kwargs):