Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 24 additions & 7 deletions dspy/predict/chain_of_thought.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,36 @@
import dspy
from dspy.primitives.program import Module
from dspy.signatures.signature import ensure_signature
from dspy.signatures.field import OutputField
from dspy.signatures.signature import ensure_signature, Signature
from pydantic.fields import FieldInfo
from typing import Optional, Union, Type


class ChainOfThought(Module):
def __init__(self, signature, rationale_type=None, **config):

def __init__(
self,
signature: Type[Signature],
rationale_field: Optional[Union[OutputField, FieldInfo]] = None,
rationale_field_type: Type = str,
**config
):
"""
A module that reasons step by step in order to predict the output of a task.

Args:
signature (Type[dspy.Signature]): The signature of the module.
rationale_field (Optional[Union[dspy.OutputField, pydantic.fields.FieldInfo]]): The field that will contain the reasoning.
rationale_field_type (Type): The type of the rationale field.
**config: The configuration for the module.
"""
super().__init__()

signature = ensure_signature(signature)

prefix = "Reasoning: Let's think step by step in order to"
desc = "${reasoning}"
rationale_type = rationale_type or dspy.OutputField(prefix=prefix, desc=desc)
extended_signature = signature.prepend("reasoning", rationale_type, type_=str)

rationale_field_type = rationale_field.annotation if rationale_field else rationale_field_type
rationale_field = rationale_field if rationale_field else dspy.OutputField(prefix=prefix, desc=desc)
extended_signature = signature.prepend(name="reasoning", field=rationale_field, type_=rationale_field_type)
self.predict = dspy.Predict(extended_signature, **config)

def forward(self, **kwargs):
Expand Down
Loading