-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Allow DSPy to use the native reasoning from models #8822
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
6af5b75
c699a1f
beb85de
4c5b633
5228863
3210914
d5b0dfb
b2daf8f
3cff43a
3258da5
8de0a65
ec2fbe4
56973f0
c65b774
8c1630c
93991f5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
from typing import Any, Optional | ||
|
||
import litellm | ||
import pydantic | ||
|
||
from dspy.adapters.types.base_type import Type | ||
|
||
|
||
class Reasoning(Type): | ||
"""Reasoning type in DSPy. | ||
This type is useful when you want the DSPy output to include the reasoning of the LM. We build this type so that | ||
DSPy can support the reasoning model and non-reasoning model with the same code. | ||
This is a str-like type, you can convert a string directly to a Reasoning object, and from DSPy adapters' | ||
perspective, `Reasoning` is treated as a string. | ||
""" | ||
|
||
content: str | ||
|
||
def format(self): | ||
return f"{self.content}" | ||
|
||
@pydantic.model_validator(mode="before") | ||
@classmethod | ||
def validate_input(cls, data: Any): | ||
if isinstance(data, cls): | ||
return data | ||
|
||
if isinstance(data, str): | ||
return {"content": data} | ||
|
||
if isinstance(data, dict): | ||
if "content" not in data: | ||
raise ValueError("`content` field is required for `dspy.Reasoning`") | ||
if not isinstance(data["content"], str): | ||
raise ValueError(f"`content` field must be a string, but received type: {type(data['content'])}") | ||
return {"content": data["content"]} | ||
|
||
raise ValueError(f"Received invalid value for `dspy.Reasoning`: {data}") | ||
|
||
@classmethod | ||
def adapt_to_native_lm_feature(cls, lm, lm_kwargs) -> bool: | ||
if not litellm.supports_reasoning(lm.model): | ||
return False | ||
|
||
reasoning_effort = "unspecified" | ||
if "reasoning_effort" in lm_kwargs: | ||
# `lm_kwargs` overrides `lm.kwargs` | ||
reasoning_effort = lm_kwargs["reasoning_effort"] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: we can simplify by There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ohh good catch! The structure is right but code is actually buggy - if users explicitly turn off native reasoning, we should respect the setting, so I am not using |
||
elif "reasoning_effort" in lm.kwargs: | ||
reasoning_effort = lm.kwargs["reasoning_effort"] | ||
|
||
if reasoning_effort is None: | ||
# If users explicitly set `reasoning_effort` to None, we don't enable native reasoning | ||
return False | ||
|
||
# Turn on the native reasoning | ||
lm_kwargs["reasoning_effort"] = "low" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. According to @arnavsinghvi11, the OpenAI default is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes we talked offline a bit about that, basically I won't rely on that default. From Responses API OpenAI is no longer the industry API standard. I am setting it to |
||
return True | ||
|
||
@classmethod | ||
def parse_lm_response(cls, response: str | dict[str, Any]) -> Optional["Reasoning"]: | ||
"""Parse the LM response into a Reasoning object.""" | ||
if "reasoning_content" in response: | ||
return Reasoning(content=response["reasoning_content"]) | ||
return None | ||
|
||
def __repr__(self) -> str: | ||
return f"{self.content!r}" | ||
|
||
def __str__(self) -> str: | ||
return self.content | ||
|
||
def __eq__(self, other: object) -> bool: | ||
if isinstance(other, Reasoning): | ||
return self.content == other.content | ||
if isinstance(other, str): | ||
return self.content == other |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,36 +1,36 @@ | ||
import logging | ||
from typing import Any | ||
|
||
from pydantic.fields import FieldInfo | ||
|
||
import dspy | ||
from dspy.primitives.module import Module | ||
from dspy.signatures.field import OutputField | ||
from dspy.signatures.signature import Signature, ensure_signature | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class ChainOfThought(Module): | ||
def __init__( | ||
self, | ||
signature: str | type[Signature], | ||
rationale_field: FieldInfo | None = None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we show deprecation warnings if these arguments are used? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sg, let me add that |
||
rationale_field_type: type = str, | ||
**config: dict[str, Any], | ||
): | ||
""" | ||
A module that reasons step by step in order to predict the output of a task. | ||
Args: | ||
signature (Type[dspy.Signature]): The signature of the module. | ||
rationale_field (Optional[Union[dspy.OutputField, pydantic.fields.FieldInfo]]): The field that will contain the reasoning. | ||
rationale_field_type (Type): The type of the rationale field. | ||
**config: The configuration for the module. | ||
""" | ||
super().__init__() | ||
signature = ensure_signature(signature) | ||
prefix = "Reasoning: Let's think step by step in order to" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems we will lose this CoT prompting, is that fine? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it's actually not used today There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, it's not used today because it's the prefix of the output field? Does it mean There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I know this feels odd, but yes exactly - -" There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. btw guys lets hold off on merging this until we fully hash out the right design choices for this PR, given that its a decently significant change! |
||
desc = "${reasoning}" | ||
rationale_field_type = rationale_field.annotation if rationale_field else rationale_field_type | ||
rationale_field = rationale_field if rationale_field else dspy.OutputField(prefix=prefix, desc=desc) | ||
extended_signature = signature.prepend(name="reasoning", field=rationale_field, type_=rationale_field_type) | ||
|
||
if "rationale_field" in config or "rationale_field_type" in config: | ||
logger.warning("`rationale_field` and `rationale_field_type` are deprecated, they are no-op now.") | ||
|
||
from dspy.adapters.types.reasoning import Reasoning | ||
|
||
extended_signature = signature.prepend(name="reasoning", field=OutputField(), type_=Reasoning) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we continue to return string in the reasoning field for backward compatibility? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
We need this specific type to modify the signature inside Adapter. I don't feel this is a perfectly clean solution, but it's the most robust way in my mind. would like to hear your thoughts! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see, but what's gonna happen when users directly call the string funcitons on the reasoning field like this? Does it continue to work? cot = dspy.ChainOfThough("question -> answer")
cot(question="where is the capital of Japan?").strip() |
||
self.predict = dspy.Predict(extended_signature, **config) | ||
|
||
def forward(self, **kwargs): | ||
|
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't need to check
lm_kwargs
, instead we should probably overwritereasoning_effort
if native reasoning is onThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so
reasoning_effort
has multiple valid values - low, medium, and high. I am setting the default value aslow
, but if users specify the other values we want to keep those unchanged.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the default
reasoning_effort
on OpenAI reasoning models ismedium
, so I'd be careful setting it differently here