-
Notifications
You must be signed in to change notification settings - Fork 2.5k
LMs: retry with exponential backoff for a limited set of error codes #1753
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,19 +1,20 @@ | ||
| import functools | ||
| from .base_lm import BaseLM | ||
| import logging | ||
| import os | ||
| import uuid | ||
| from concurrent.futures import ThreadPoolExecutor | ||
| from datetime import datetime | ||
| from typing import Any, Dict, List, Literal, Optional | ||
|
|
||
| import litellm | ||
| import ujson | ||
| from litellm import Router | ||
| from litellm.router import RetryPolicy | ||
|
|
||
| from dspy.clients.finetune import FinetuneJob, TrainingMethod | ||
| from dspy.clients.lm_finetune_utils import execute_finetune_job, get_provider_finetune_job_class | ||
| from dspy.utils.callback import BaseCallback, with_callbacks | ||
|
|
||
| from .base_lm import BaseLM | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
@@ -32,7 +33,7 @@ def __init__( | |
| cache: bool = True, | ||
| launch_kwargs: Optional[Dict[str, Any]] = None, | ||
| callbacks: Optional[List[BaseCallback]] = None, | ||
| num_retries: int = 3, | ||
| num_retries: int = 8, | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Empirically, 8 retries translates to ~ 1 minute of wall clock time, which should be sufficient to overcome rate limiting in most cases |
||
| **kwargs, | ||
| ): | ||
| """ | ||
|
|
@@ -174,13 +175,55 @@ def cached_litellm_completion(request, num_retries: int): | |
|
|
||
| def litellm_completion(request, num_retries: int, cache={"no-cache": True, "no-store": True}): | ||
| kwargs = ujson.loads(request) | ||
| return litellm.completion( | ||
| num_retries=num_retries, | ||
| router = _get_litellm_router(model=kwargs["model"], num_retries=num_retries) | ||
| return router.completion( | ||
| cache=cache, | ||
| **kwargs, | ||
| ) | ||
|
|
||
|
|
||
| @functools.lru_cache(maxsize=None) | ||
| def _get_litellm_router(model: str, num_retries: int) -> Router: | ||
| """ | ||
| Get a LiteLLM router for the given model with the specified number of retries | ||
| for transient errors. | ||
|
|
||
| Args: | ||
| model: The name of the LiteLLM model to query (e.g. 'openai/gpt-4'). | ||
| num_retries: The number of times to retry a request if it fails transiently due to | ||
| network error, rate limiting, etc. Requests are retried with exponential | ||
| backoff. | ||
| Returns: | ||
| A LiteLLM router instance that can be used to query the given model. | ||
| """ | ||
| retry_policy = RetryPolicy( | ||
| TimeoutErrorRetries=num_retries, | ||
| RateLimitErrorRetries=num_retries, | ||
| InternalServerErrorRetries=num_retries, | ||
| # We don't retry on errors that are unlikely to be transient | ||
| # (e.g. bad request, invalid auth credentials) | ||
| BadRequestErrorRetries=0, | ||
| AuthenticationErrorRetries=0, | ||
| ContentPolicyViolationErrorRetries=0, | ||
| ) | ||
|
|
||
| return Router( | ||
| # LiteLLM routers must specify a `model_list`, which maps model names passed | ||
| # to `completions()` into actual LiteLLM model names. For our purposes, the | ||
| # model name is the same as the LiteLLM model name, so we add a single | ||
| # entry to the `model_list` that maps the model name to itself | ||
| model_list=[ | ||
| { | ||
| "model_name": model, | ||
| "litellm_params": { | ||
| "model": model, | ||
| }, | ||
| } | ||
| ], | ||
| retry_policy=retry_policy, | ||
| ) | ||
|
Comment on lines
+199
to
+224
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. LiteLLM Routers appear to be the only mechanism allowing exponential backoff and configurable retry codes. Docs: https://docs.litellm.ai/docs/routing |
||
|
|
||
|
|
||
| @functools.lru_cache(maxsize=None) | ||
| def cached_litellm_text_completion(request, num_retries: int): | ||
| return litellm_text_completion( | ||
|
|
@@ -197,6 +240,7 @@ def litellm_text_completion(request, num_retries: int, cache={"no-cache": True, | |
| # TODO: Not all the models are in the format of "provider/model" | ||
| model = kwargs.pop("model").split("/", 1) | ||
| provider, model = model[0] if len(model) > 1 else "openai", model[-1] | ||
| text_completion_model_name = f"text-completion-openai/{model}" | ||
|
|
||
| # Use the API key and base from the kwargs, or from the environment. | ||
| api_key = kwargs.pop("api_key", None) or os.getenv(f"{provider}_API_KEY") | ||
|
|
@@ -205,12 +249,12 @@ def litellm_text_completion(request, num_retries: int, cache={"no-cache": True, | |
| # Build the prompt from the messages. | ||
| prompt = "\n\n".join([x["content"] for x in kwargs.pop("messages")] + ["BEGIN RESPONSE:"]) | ||
|
|
||
| return litellm.text_completion( | ||
| router = _get_litellm_router(model=text_completion_model_name, num_retries=num_retries) | ||
| return router.text_completion( | ||
| cache=cache, | ||
| model=f"text-completion-openai/{model}", | ||
| model=text_completion_model_name, | ||
| api_key=api_key, | ||
| api_base=api_base, | ||
| prompt=prompt, | ||
| num_retries=num_retries, | ||
| **kwargs, | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,25 +1,106 @@ | ||
| from unittest import mock | ||
|
|
||
| from dspy.clients.lm import LM | ||
| from litellm.router import RetryPolicy | ||
|
|
||
| from dspy.clients.lm import LM, _get_litellm_router | ||
|
|
||
|
|
||
| def test_lm_chat_respects_max_retries(): | ||
| lm = LM(model="openai/gpt4o", model_type="chat", max_retries=17) | ||
| model_name = "openai/gpt4o" | ||
| num_retries = 17 | ||
| temperature = 0.5 | ||
| max_tokens = 100 | ||
| prompt = "Hello, world!" | ||
|
|
||
| lm = LM( | ||
| model=model_name, model_type="chat", num_retries=num_retries, temperature=temperature, max_tokens=max_tokens | ||
| ) | ||
|
|
||
| MockRouter = mock.MagicMock() | ||
| mock_completion = mock.MagicMock() | ||
| MockRouter.completion = mock_completion | ||
|
|
||
| with mock.patch("dspy.clients.lm.litellm.completion") as litellm_completion_api: | ||
| lm(messages=[{"content": "Hello, world!", "role": "user"}]) | ||
| with mock.patch("dspy.clients.lm.Router", return_value=MockRouter) as MockRouterConstructor: | ||
| lm(prompt=prompt) | ||
|
|
||
| assert litellm_completion_api.call_count == 1 | ||
| assert litellm_completion_api.call_args[1]["max_retries"] == 17 | ||
| # assert litellm_completion_api.call_args[1]["retry_strategy"] == "exponential_backoff_retry" | ||
| MockRouterConstructor.assert_called_once_with( | ||
| model_list=[ | ||
| { | ||
| "model_name": model_name, | ||
| "litellm_params": { | ||
| "model": model_name, | ||
| }, | ||
| } | ||
| ], | ||
| retry_policy=RetryPolicy( | ||
| TimeoutErrorRetries=num_retries, | ||
| RateLimitErrorRetries=num_retries, | ||
| InternalServerErrorRetries=num_retries, | ||
| BadRequestErrorRetries=0, | ||
| AuthenticationErrorRetries=0, | ||
| ContentPolicyViolationErrorRetries=0, | ||
| ), | ||
| ) | ||
| mock_completion.assert_called_once_with( | ||
| model=model_name, | ||
| messages=[{"role": "user", "content": prompt}], | ||
| temperature=temperature, | ||
| max_tokens=max_tokens, | ||
| cache=mock.ANY, | ||
| ) | ||
|
|
||
|
|
||
| def test_lm_completions_respects_max_retries(): | ||
| lm = LM(model="openai/gpt-3.5-turbo", model_type="completions", max_retries=17) | ||
| model_name = "openai/gpt-3.5-turbo" | ||
| expected_model = "text-completion-" + model_name | ||
| num_retries = 17 | ||
| temperature = 0.5 | ||
| max_tokens = 100 | ||
| prompt = "Hello, world!" | ||
| api_base = "http://test.com" | ||
| api_key = "apikey" | ||
|
|
||
| lm = LM( | ||
| model=model_name, | ||
| model_type="text", | ||
| num_retries=num_retries, | ||
| temperature=temperature, | ||
| max_tokens=max_tokens, | ||
| api_base=api_base, | ||
| api_key=api_key, | ||
| ) | ||
|
|
||
| MockRouter = mock.MagicMock() | ||
| mock_text_completion = mock.MagicMock() | ||
| MockRouter.text_completion = mock_text_completion | ||
|
|
||
| with mock.patch("dspy.clients.lm.litellm.text_completion") as litellm_completion_api: | ||
| lm(prompt="Hello, world!") | ||
| with mock.patch("dspy.clients.lm.Router", return_value=MockRouter) as MockRouterConstructor: | ||
| lm(prompt=prompt) | ||
|
|
||
| assert litellm_completion_api.call_count == 1 | ||
| assert litellm_completion_api.call_args[1]["max_retries"] == 17 | ||
| # assert litellm_completion_api.call_args[1]["retry_strategy"] == "exponential_backoff_retry" | ||
| MockRouterConstructor.assert_called_once_with( | ||
| model_list=[ | ||
| { | ||
| "model_name": expected_model, | ||
| "litellm_params": { | ||
| "model": expected_model, | ||
| }, | ||
| } | ||
| ], | ||
| retry_policy=RetryPolicy( | ||
| TimeoutErrorRetries=num_retries, | ||
| RateLimitErrorRetries=num_retries, | ||
| InternalServerErrorRetries=num_retries, | ||
| BadRequestErrorRetries=0, | ||
| AuthenticationErrorRetries=0, | ||
| ContentPolicyViolationErrorRetries=0, | ||
| ), | ||
| ) | ||
| mock_text_completion.assert_called_once_with( | ||
| model=expected_model, | ||
| prompt=prompt + "\n\nBEGIN RESPONSE:", | ||
| temperature=temperature, | ||
| max_tokens=max_tokens, | ||
| api_key=api_key, | ||
| api_base=api_base, | ||
| cache=mock.ANY, | ||
| ) |
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 retries (equivalent to < 5 seconds) is insufficient to overcome rate limiting in many production environments with high traffic