Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions dspy/clients/base_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,12 +166,12 @@ def update_history(self, entry):

def _process_completion(self, response, merged_kwargs):
"""Process the response of OpenAI chat completion API and extract outputs.

Args:
response: The OpenAI chat completion response
https://platform.openai.com/docs/api-reference/chat/object
merged_kwargs: Merged kwargs from self.kwargs and method kwargs

Returns:
List of processed outputs
"""
Expand Down Expand Up @@ -200,10 +200,10 @@ def _process_completion(self, response, merged_kwargs):
def _extract_citations_from_response(self, choice):
"""Extract citations from LiteLLM response if available.
Reference: https://docs.litellm.ai/docs/providers/anthropic#beta-citations-api

Args:
choice: The choice object from response.choices

Returns:
A list of citation dictionaries or None if no citations found
"""
Expand Down Expand Up @@ -255,7 +255,6 @@ def _process_response(self, response):
return [result]



def inspect_history(n: int = 1):
"""The global history shared across all LMs."""
return pretty_print_history(GLOBAL_HISTORY, n)
22 changes: 8 additions & 14 deletions dspy/clients/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,7 @@ def __init__(
self._warn_zero_temp_rollout(self.kwargs.get("temperature"), self.kwargs.get("rollout_id"))

def _warn_zero_temp_rollout(self, temperature: float | None, rollout_id):
if (
not self._warned_zero_temp_rollout
and rollout_id is not None
and (temperature is None or temperature == 0)
):
if not self._warned_zero_temp_rollout and rollout_id is not None and (temperature is None or temperature == 0):
warnings.warn(
"rollout_id has no effect when temperature=0; set temperature>0 to bypass the cache.",
stacklevel=3,
Expand All @@ -134,10 +130,7 @@ def forward(self, prompt=None, messages=None, **kwargs):

messages = messages or [{"role": "user", "content": prompt}]
if self.use_developer_role and self.model_type == "responses":
messages = [
{**m, "role": "developer"} if m.get("role") == "system" else m
for m in messages
]
messages = [{**m, "role": "developer"} if m.get("role") == "system" else m for m in messages]
kwargs = {**self.kwargs, **kwargs}
self._warn_zero_temp_rollout(kwargs.get("temperature"), kwargs.get("rollout_id"))
if kwargs.get("rollout_id") is None:
Expand Down Expand Up @@ -170,10 +163,7 @@ async def aforward(self, prompt=None, messages=None, **kwargs):

messages = messages or [{"role": "user", "content": prompt}]
if self.use_developer_role and self.model_type == "responses":
messages = [
{**m, "role": "developer"} if m.get("role") == "system" else m
for m in messages
]
messages = [{**m, "role": "developer"} if m.get("role") == "system" else m for m in messages]
kwargs = {**self.kwargs, **kwargs}
self._warn_zero_temp_rollout(kwargs.get("temperature"), kwargs.get("rollout_id"))
if kwargs.get("rollout_id") is None:
Expand Down Expand Up @@ -237,7 +227,9 @@ def thread_function_wrapper():

return job

def reinforce(self, train_kwargs, gpu_config: MultiGPUConfig = MultiGPUConfig(num_inference_gpus=1, num_training_gpus=1)) -> ReinforceJob:
def reinforce(
self, train_kwargs, gpu_config: MultiGPUConfig = MultiGPUConfig(num_inference_gpus=1, num_training_gpus=1)
) -> ReinforceJob:
# TODO(GRPO Team): Should we return an initialized job here?
from dspy import settings as settings

Expand Down Expand Up @@ -424,6 +416,7 @@ async def alitellm_text_completion(request: dict[str, Any], num_retries: int, ca
**request,
)


def litellm_responses_completion(request: dict[str, Any], num_retries: int, cache: dict[str, Any] | None = None):
cache = cache or {"no-cache": True, "no-store": True}
request = dict(request)
Expand Down Expand Up @@ -451,6 +444,7 @@ async def alitellm_responses_completion(request: dict[str, Any], num_retries: in
**request,
)


def _convert_chat_request_to_responses_request(request: dict[str, Any]):
request = dict(request)
if "messages" in request:
Expand Down