Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 52 additions & 8 deletions dspy/clients/lm.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
import functools
from .base_lm import BaseLM
import logging
import os
import uuid
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from typing import Any, Dict, List, Literal, Optional

import litellm
import ujson
from litellm import Router
from litellm.router import RetryPolicy

from dspy.clients.finetune import FinetuneJob, TrainingMethod
from dspy.clients.lm_finetune_utils import execute_finetune_job, get_provider_finetune_job_class
from dspy.utils.callback import BaseCallback, with_callbacks

from .base_lm import BaseLM

logger = logging.getLogger(__name__)

Expand All @@ -32,7 +33,7 @@ def __init__(
cache: bool = True,
launch_kwargs: Optional[Dict[str, Any]] = None,
callbacks: Optional[List[BaseCallback]] = None,
num_retries: int = 3,
Copy link
Collaborator Author

@dbczumar dbczumar Nov 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 retries (equivalent to < 5 seconds) is insufficient to overcome rate limiting in many production environments with high traffic

num_retries: int = 8,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Empirically, 8 retries translates to ~ 1 minute of wall clock time, which should be sufficient to overcome rate limiting in most cases

**kwargs,
):
"""
Expand Down Expand Up @@ -174,13 +175,55 @@ def cached_litellm_completion(request, num_retries: int):

def litellm_completion(request, num_retries: int, cache={"no-cache": True, "no-store": True}):
kwargs = ujson.loads(request)
return litellm.completion(
num_retries=num_retries,
router = _get_litellm_router(model=kwargs["model"], num_retries=num_retries)
return router.completion(
cache=cache,
**kwargs,
)


@functools.lru_cache(maxsize=None)
def _get_litellm_router(model: str, num_retries: int) -> Router:
"""
Get a LiteLLM router for the given model with the specified number of retries
for transient errors.

Args:
model: The name of the LiteLLM model to query (e.g. 'openai/gpt-4').
num_retries: The number of times to retry a request if it fails transiently due to
network error, rate limiting, etc. Requests are retried with exponential
backoff.
Returns:
A LiteLLM router instance that can be used to query the given model.
"""
retry_policy = RetryPolicy(
TimeoutErrorRetries=num_retries,
RateLimitErrorRetries=num_retries,
InternalServerErrorRetries=num_retries,
# We don't retry on errors that are unlikely to be transient
# (e.g. bad request, invalid auth credentials)
BadRequestErrorRetries=0,
AuthenticationErrorRetries=0,
ContentPolicyViolationErrorRetries=0,
)

return Router(
# LiteLLM routers must specify a `model_list`, which maps model names passed
# to `completions()` into actual LiteLLM model names. For our purposes, the
# model name is the same as the LiteLLM model name, so we add a single
# entry to the `model_list` that maps the model name to itself
model_list=[
{
"model_name": model,
"litellm_params": {
"model": model,
},
}
],
retry_policy=retry_policy,
)
Comment on lines +199 to +224
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LiteLLM Routers appear to be the only mechanism allowing exponential backoff and configurable retry codes. Docs: https://docs.litellm.ai/docs/routing



@functools.lru_cache(maxsize=None)
def cached_litellm_text_completion(request, num_retries: int):
return litellm_text_completion(
Expand All @@ -197,6 +240,7 @@ def litellm_text_completion(request, num_retries: int, cache={"no-cache": True,
# TODO: Not all the models are in the format of "provider/model"
model = kwargs.pop("model").split("/", 1)
provider, model = model[0] if len(model) > 1 else "openai", model[-1]
text_completion_model_name = f"text-completion-openai/{model}"

# Use the API key and base from the kwargs, or from the environment.
api_key = kwargs.pop("api_key", None) or os.getenv(f"{provider}_API_KEY")
Expand All @@ -205,12 +249,12 @@ def litellm_text_completion(request, num_retries: int, cache={"no-cache": True,
# Build the prompt from the messages.
prompt = "\n\n".join([x["content"] for x in kwargs.pop("messages")] + ["BEGIN RESPONSE:"])

return litellm.text_completion(
router = _get_litellm_router(model=text_completion_model_name, num_retries=num_retries)
return router.text_completion(
cache=cache,
model=f"text-completion-openai/{model}",
model=text_completion_model_name,
api_key=api_key,
api_base=api_base,
prompt=prompt,
num_retries=num_retries,
**kwargs,
)
107 changes: 94 additions & 13 deletions tests/clients/test_lm.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,106 @@
from unittest import mock

from dspy.clients.lm import LM
from litellm.router import RetryPolicy

from dspy.clients.lm import LM, _get_litellm_router


def test_lm_chat_respects_max_retries():
lm = LM(model="openai/gpt4o", model_type="chat", max_retries=17)
model_name = "openai/gpt4o"
num_retries = 17
temperature = 0.5
max_tokens = 100
prompt = "Hello, world!"

lm = LM(
model=model_name, model_type="chat", num_retries=num_retries, temperature=temperature, max_tokens=max_tokens
)

MockRouter = mock.MagicMock()
mock_completion = mock.MagicMock()
MockRouter.completion = mock_completion

with mock.patch("dspy.clients.lm.litellm.completion") as litellm_completion_api:
lm(messages=[{"content": "Hello, world!", "role": "user"}])
with mock.patch("dspy.clients.lm.Router", return_value=MockRouter) as MockRouterConstructor:
lm(prompt=prompt)

assert litellm_completion_api.call_count == 1
assert litellm_completion_api.call_args[1]["max_retries"] == 17
# assert litellm_completion_api.call_args[1]["retry_strategy"] == "exponential_backoff_retry"
MockRouterConstructor.assert_called_once_with(
model_list=[
{
"model_name": model_name,
"litellm_params": {
"model": model_name,
},
}
],
retry_policy=RetryPolicy(
TimeoutErrorRetries=num_retries,
RateLimitErrorRetries=num_retries,
InternalServerErrorRetries=num_retries,
BadRequestErrorRetries=0,
AuthenticationErrorRetries=0,
ContentPolicyViolationErrorRetries=0,
),
)
mock_completion.assert_called_once_with(
model=model_name,
messages=[{"role": "user", "content": prompt}],
temperature=temperature,
max_tokens=max_tokens,
cache=mock.ANY,
)


def test_lm_completions_respects_max_retries():
lm = LM(model="openai/gpt-3.5-turbo", model_type="completions", max_retries=17)
model_name = "openai/gpt-3.5-turbo"
expected_model = "text-completion-" + model_name
num_retries = 17
temperature = 0.5
max_tokens = 100
prompt = "Hello, world!"
api_base = "http://test.com"
api_key = "apikey"

lm = LM(
model=model_name,
model_type="text",
num_retries=num_retries,
temperature=temperature,
max_tokens=max_tokens,
api_base=api_base,
api_key=api_key,
)

MockRouter = mock.MagicMock()
mock_text_completion = mock.MagicMock()
MockRouter.text_completion = mock_text_completion

with mock.patch("dspy.clients.lm.litellm.text_completion") as litellm_completion_api:
lm(prompt="Hello, world!")
with mock.patch("dspy.clients.lm.Router", return_value=MockRouter) as MockRouterConstructor:
lm(prompt=prompt)

assert litellm_completion_api.call_count == 1
assert litellm_completion_api.call_args[1]["max_retries"] == 17
# assert litellm_completion_api.call_args[1]["retry_strategy"] == "exponential_backoff_retry"
MockRouterConstructor.assert_called_once_with(
model_list=[
{
"model_name": expected_model,
"litellm_params": {
"model": expected_model,
},
}
],
retry_policy=RetryPolicy(
TimeoutErrorRetries=num_retries,
RateLimitErrorRetries=num_retries,
InternalServerErrorRetries=num_retries,
BadRequestErrorRetries=0,
AuthenticationErrorRetries=0,
ContentPolicyViolationErrorRetries=0,
),
)
mock_text_completion.assert_called_once_with(
model=expected_model,
prompt=prompt + "\n\nBEGIN RESPONSE:",
temperature=temperature,
max_tokens=max_tokens,
api_key=api_key,
api_base=api_base,
cache=mock.ANY,
)
Loading