From 355a26f219171d0fb17a1da800f50307c277033f Mon Sep 17 00:00:00 2001 From: dbczumar Date: Wed, 30 Oct 2024 16:26:25 -0700 Subject: [PATCH 1/5] fix Signed-off-by: dbczumar --- dspy/clients/lm.py | 19 +++++++++++++++++++ requirements.txt | 10 +++++----- tests/clients/test_lm.py | 4 ++-- 3 files changed, 26 insertions(+), 7 deletions(-) diff --git a/dspy/clients/lm.py b/dspy/clients/lm.py index 02fe174021..a2b184e43b 100644 --- a/dspy/clients/lm.py +++ b/dspy/clients/lm.py @@ -6,9 +6,11 @@ from pathlib import Path from typing import Any, Dict, List, Literal, Optional +import importlib_metadata import litellm import ujson from litellm.caching import Cache +from packaging.version import Version from dspy.clients.finetune import FinetuneJob, TrainingMethod from dspy.clients.lm_finetune_utils import execute_finetune_job, get_provider_finetune_job_class @@ -187,6 +189,7 @@ def litellm_completion(request, num_retries: int, cache={"no-cache": True, "no-s return litellm.completion( num_retries=num_retries, cache=cache, + **_get_litellm_retry_strategy_kwargs(), **kwargs, ) @@ -222,10 +225,26 @@ def litellm_text_completion(request, num_retries: int, cache={"no-cache": True, api_base=api_base, prompt=prompt, num_retries=num_retries, + **_get_litellm_retry_strategy_kwargs(), **kwargs, ) +def _get_litellm_retry_strategy_kwargs() -> Dict[str, str]: + """ + Returns retry strategy strategy kwargs for LiteLLM. + """ + litellm_version = importlib_metadata.version("litellm") + if Version(litellm_version) >= Version("1.51.3"): + # Enable retries with exponential backoff via the `retry_strategy` flag, which + # requires LiteLLM version >= 1.51.3 + return {"retry_strategy": "exponential_backoff_retry"} + else: + # Otherwise, use the LiteLLM default retry strategy (fixed interval retries) + # in older LiteLLM versions where the `retry_strategy` flag is not available + return {} + + def _green(text: str, end: str = "\n"): return "\x1b[32m" + str(text).lstrip() + "\x1b[0m" + end diff --git a/requirements.txt b/requirements.txt index 3e0f4501df..60e712d7a2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,11 @@ backoff datasets +diskcache +httpx joblib~=1.3 -litellm<=1.49.1 +json-repair +litellm<=1.51.3 +magicattr~=0.1.6 openai optuna pandas @@ -11,7 +15,3 @@ requests structlog tqdm ujson -httpx -magicattr~=0.1.6 -diskcache -json-repair \ No newline at end of file diff --git a/tests/clients/test_lm.py b/tests/clients/test_lm.py index ef5d85a9ca..8825dab95d 100644 --- a/tests/clients/test_lm.py +++ b/tests/clients/test_lm.py @@ -11,7 +11,7 @@ def test_lm_chat_respects_max_retries(): assert litellm_completion_api.call_count == 1 assert litellm_completion_api.call_args[1]["max_retries"] == 17 - # assert litellm_completion_api.call_args[1]["retry_strategy"] == "exponential_backoff_retry" + assert litellm_completion_api.call_args[1]["retry_strategy"] == "exponential_backoff_retry" def test_lm_completions_respects_max_retries(): @@ -22,4 +22,4 @@ def test_lm_completions_respects_max_retries(): assert litellm_completion_api.call_count == 1 assert litellm_completion_api.call_args[1]["max_retries"] == 17 - # assert litellm_completion_api.call_args[1]["retry_strategy"] == "exponential_backoff_retry" + assert litellm_completion_api.call_args[1]["retry_strategy"] == "exponential_backoff_retry" From ed185402055f6f6b6fed79ec8475f9c11a5cdeb5 Mon Sep 17 00:00:00 2001 From: dbczumar Date: Wed, 30 Oct 2024 16:29:20 -0700 Subject: [PATCH 2/5] fix Signed-off-by: dbczumar --- dspy/clients/lm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dspy/clients/lm.py b/dspy/clients/lm.py index a2b184e43b..d159229ee9 100644 --- a/dspy/clients/lm.py +++ b/dspy/clients/lm.py @@ -1,4 +1,5 @@ import functools +import importlib.metadata import os import uuid from concurrent.futures import ThreadPoolExecutor @@ -6,7 +7,6 @@ from pathlib import Path from typing import Any, Dict, List, Literal, Optional -import importlib_metadata import litellm import ujson from litellm.caching import Cache @@ -234,7 +234,7 @@ def _get_litellm_retry_strategy_kwargs() -> Dict[str, str]: """ Returns retry strategy strategy kwargs for LiteLLM. """ - litellm_version = importlib_metadata.version("litellm") + litellm_version = importlib.metadata.version("litellm") if Version(litellm_version) >= Version("1.51.3"): # Enable retries with exponential backoff via the `retry_strategy` flag, which # requires LiteLLM version >= 1.51.3 From ce046d8c2ade3f741db93e5b9da358740ab725f9 Mon Sep 17 00:00:00 2001 From: dbczumar Date: Thu, 31 Oct 2024 09:59:11 -0700 Subject: [PATCH 3/5] progress Signed-off-by: dbczumar --- dspy/clients/lm.py | 21 ++------------------- pyproject.toml | 2 +- requirements.txt | 2 +- 3 files changed, 4 insertions(+), 21 deletions(-) diff --git a/dspy/clients/lm.py b/dspy/clients/lm.py index d159229ee9..25dc84e2d3 100644 --- a/dspy/clients/lm.py +++ b/dspy/clients/lm.py @@ -1,5 +1,4 @@ import functools -import importlib.metadata import os import uuid from concurrent.futures import ThreadPoolExecutor @@ -10,7 +9,6 @@ import litellm import ujson from litellm.caching import Cache -from packaging.version import Version from dspy.clients.finetune import FinetuneJob, TrainingMethod from dspy.clients.lm_finetune_utils import execute_finetune_job, get_provider_finetune_job_class @@ -189,7 +187,7 @@ def litellm_completion(request, num_retries: int, cache={"no-cache": True, "no-s return litellm.completion( num_retries=num_retries, cache=cache, - **_get_litellm_retry_strategy_kwargs(), + retry_strategy="exponential_backoff_retry", **kwargs, ) @@ -225,26 +223,11 @@ def litellm_text_completion(request, num_retries: int, cache={"no-cache": True, api_base=api_base, prompt=prompt, num_retries=num_retries, - **_get_litellm_retry_strategy_kwargs(), + retry_strategy="exponential_backoff_retry", **kwargs, ) -def _get_litellm_retry_strategy_kwargs() -> Dict[str, str]: - """ - Returns retry strategy strategy kwargs for LiteLLM. - """ - litellm_version = importlib.metadata.version("litellm") - if Version(litellm_version) >= Version("1.51.3"): - # Enable retries with exponential backoff via the `retry_strategy` flag, which - # requires LiteLLM version >= 1.51.3 - return {"retry_strategy": "exponential_backoff_retry"} - else: - # Otherwise, use the LiteLLM default retry strategy (fixed interval retries) - # in older LiteLLM versions where the `retry_strategy` flag is not available - return {} - - def _green(text: str, end: str = "\n"): return "\x1b[32m" + str(text).lstrip() + "\x1b[0m" + end diff --git a/pyproject.toml b/pyproject.toml index 0170fce95c..8a990917ab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -134,7 +134,7 @@ llama-index = {version = "^0.10.30", optional = true} snowflake-snowpark-python = { version = "*",optional=true, python = ">=3.9,<3.12" } jinja2 = "^3.1.3" magicattr = "^0.1.6" -litellm = "1.49.1" +litellm = "^1.53.1" diskcache = "^5.6.0" json-repair = "^0.30.0" diff --git a/requirements.txt b/requirements.txt index 60e712d7a2..722cc72644 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,7 @@ diskcache httpx joblib~=1.3 json-repair -litellm<=1.51.3 +litellm~=1.51.3 magicattr~=0.1.6 openai optuna From f42303abab2420edb779bbfb9fb80b715d84a7c3 Mon Sep 17 00:00:00 2001 From: dbczumar Date: Mon, 4 Nov 2024 20:58:01 -0800 Subject: [PATCH 4/5] fix Signed-off-by: dbczumar --- dspy/clients/lm.py | 57 ++++++++++++++++++--- tests/clients/test_lm.py | 107 ++++++++++++++++++++++++++++++++++----- 2 files changed, 143 insertions(+), 21 deletions(-) diff --git a/dspy/clients/lm.py b/dspy/clients/lm.py index 25dc84e2d3..03b56a7769 100644 --- a/dspy/clients/lm.py +++ b/dspy/clients/lm.py @@ -8,7 +8,9 @@ import litellm import ujson +from litellm import Router from litellm.caching import Cache +from litellm.router import RetryPolicy from dspy.clients.finetune import FinetuneJob, TrainingMethod from dspy.clients.lm_finetune_utils import execute_finetune_job, get_provider_finetune_job_class @@ -39,7 +41,7 @@ def __init__( cache: bool = True, launch_kwargs: Optional[Dict[str, Any]] = None, callbacks: Optional[List[BaseCallback]] = None, - num_retries: int = 3, + num_retries: int = 8, **kwargs, ): """ @@ -184,14 +186,53 @@ def cached_litellm_completion(request, num_retries: int): def litellm_completion(request, num_retries: int, cache={"no-cache": True, "no-store": True}): kwargs = ujson.loads(request) - return litellm.completion( - num_retries=num_retries, + router = _get_litellm_router(model=kwargs["model"], num_retries=num_retries) + return router.completion( cache=cache, - retry_strategy="exponential_backoff_retry", **kwargs, ) +@functools.lru_cache(maxsize=None) +def _get_litellm_router(model: str, num_retries: int) -> Router: + """ + Get a LiteLLM router for the given model with the specified number of retries + for transient errors. + + Args: + model: The name of the LiteLLM model to query (e.g. 'openai/gpt-4'). + num_retries: The number of times to retry a request if it fails transiently due to + network error, rate limiting, etc. Requests are retried with exponential + backoff. + Returns: + A LiteLLM router instance that can be used to query the given model. + """ + retry_policy = RetryPolicy( + TimeoutErrorRetries=num_retries, + RateLimitErrorRetries=num_retries, + InternalServerErrorRetries=num_retries, + BadRequestErrorRetries=0, + AuthenticationErrorRetries=0, + ContentPolicyViolationErrorRetries=0, + ) + + return Router( + # LiteLLM routers must specify a `model_list`, which maps model names passed + # to `completions()` into actual LiteLLM model names. For our purposes, the + # model name is the same as the LiteLLM model name, so we add a single + # entry to the `model_list` that maps the model name to itself + model_list=[ + { + "model_name": model, + "litellm_params": { + "model": model, + }, + } + ], + retry_policy=retry_policy, + ) + + @functools.lru_cache(maxsize=None) def cached_litellm_text_completion(request, num_retries: int): return litellm_text_completion( @@ -208,6 +249,7 @@ def litellm_text_completion(request, num_retries: int, cache={"no-cache": True, # TODO: Not all the models are in the format of "provider/model" model = kwargs.pop("model").split("/", 1) provider, model = model[0] if len(model) > 1 else "openai", model[-1] + text_completion_model_name = f"text-completion-openai/{model}" # Use the API key and base from the kwargs, or from the environment. api_key = kwargs.pop("api_key", None) or os.getenv(f"{provider}_API_KEY") @@ -216,14 +258,13 @@ def litellm_text_completion(request, num_retries: int, cache={"no-cache": True, # Build the prompt from the messages. prompt = "\n\n".join([x["content"] for x in kwargs.pop("messages")] + ["BEGIN RESPONSE:"]) - return litellm.text_completion( + router = _get_litellm_router(model=text_completion_model_name, num_retries=num_retries) + return router.text_completion( cache=cache, - model=f"text-completion-openai/{model}", + model=text_completion_model_name, api_key=api_key, api_base=api_base, prompt=prompt, - num_retries=num_retries, - retry_strategy="exponential_backoff_retry", **kwargs, ) diff --git a/tests/clients/test_lm.py b/tests/clients/test_lm.py index 8825dab95d..61e5828aee 100644 --- a/tests/clients/test_lm.py +++ b/tests/clients/test_lm.py @@ -1,25 +1,106 @@ from unittest import mock -from dspy.clients.lm import LM +from litellm.router import RetryPolicy + +from dspy.clients.lm import LM, _get_litellm_router def test_lm_chat_respects_max_retries(): - lm = LM(model="openai/gpt4o", model_type="chat", max_retries=17) + model_name = "openai/gpt4o" + num_retries = 17 + temperature = 0.5 + max_tokens = 100 + prompt = "Hello, world!" + + lm = LM( + model=model_name, model_type="chat", num_retries=num_retries, temperature=temperature, max_tokens=max_tokens + ) + + MockRouter = mock.MagicMock() + mock_completion = mock.MagicMock() + MockRouter.completion = mock_completion - with mock.patch("dspy.clients.lm.litellm.completion") as litellm_completion_api: - lm(messages=[{"content": "Hello, world!", "role": "user"}]) + with mock.patch("dspy.clients.lm.Router", return_value=MockRouter) as MockRouterConstructor: + lm(prompt=prompt) - assert litellm_completion_api.call_count == 1 - assert litellm_completion_api.call_args[1]["max_retries"] == 17 - assert litellm_completion_api.call_args[1]["retry_strategy"] == "exponential_backoff_retry" + MockRouterConstructor.assert_called_once_with( + model_list=[ + { + "model_name": model_name, + "litellm_params": { + "model": model_name, + }, + } + ], + retry_policy=RetryPolicy( + TimeoutErrorRetries=num_retries, + RateLimitErrorRetries=num_retries, + InternalServerErrorRetries=num_retries, + BadRequestErrorRetries=0, + AuthenticationErrorRetries=0, + ContentPolicyViolationErrorRetries=0, + ), + ) + mock_completion.assert_called_once_with( + model=model_name, + messages=[{"role": "user", "content": prompt}], + temperature=temperature, + max_tokens=max_tokens, + cache=mock.ANY, + ) def test_lm_completions_respects_max_retries(): - lm = LM(model="openai/gpt-3.5-turbo", model_type="completions", max_retries=17) + model_name = "openai/gpt-3.5-turbo" + expected_model = "text-completion-" + model_name + num_retries = 17 + temperature = 0.5 + max_tokens = 100 + prompt = "Hello, world!" + api_base = "http://test.com" + api_key = "apikey" + + lm = LM( + model=model_name, + model_type="text", + num_retries=num_retries, + temperature=temperature, + max_tokens=max_tokens, + api_base=api_base, + api_key=api_key, + ) + + MockRouter = mock.MagicMock() + mock_text_completion = mock.MagicMock() + MockRouter.text_completion = mock_text_completion - with mock.patch("dspy.clients.lm.litellm.text_completion") as litellm_completion_api: - lm(prompt="Hello, world!") + with mock.patch("dspy.clients.lm.Router", return_value=MockRouter) as MockRouterConstructor: + lm(prompt=prompt) - assert litellm_completion_api.call_count == 1 - assert litellm_completion_api.call_args[1]["max_retries"] == 17 - assert litellm_completion_api.call_args[1]["retry_strategy"] == "exponential_backoff_retry" + MockRouterConstructor.assert_called_once_with( + model_list=[ + { + "model_name": expected_model, + "litellm_params": { + "model": expected_model, + }, + } + ], + retry_policy=RetryPolicy( + TimeoutErrorRetries=num_retries, + RateLimitErrorRetries=num_retries, + InternalServerErrorRetries=num_retries, + BadRequestErrorRetries=0, + AuthenticationErrorRetries=0, + ContentPolicyViolationErrorRetries=0, + ), + ) + mock_text_completion.assert_called_once_with( + model=expected_model, + prompt=prompt + "\n\nBEGIN RESPONSE:", + temperature=temperature, + max_tokens=max_tokens, + api_key=api_key, + api_base=api_base, + cache=mock.ANY, + ) From 23fc5e5f74119532c3c21f43af4ef30256d9e81b Mon Sep 17 00:00:00 2001 From: dbczumar Date: Mon, 4 Nov 2024 21:02:04 -0800 Subject: [PATCH 5/5] fix Signed-off-by: dbczumar --- dspy/clients/lm.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/dspy/clients/lm.py b/dspy/clients/lm.py index fd2a2bac24..6a55e4cfd4 100644 --- a/dspy/clients/lm.py +++ b/dspy/clients/lm.py @@ -200,6 +200,8 @@ def _get_litellm_router(model: str, num_retries: int) -> Router: TimeoutErrorRetries=num_retries, RateLimitErrorRetries=num_retries, InternalServerErrorRetries=num_retries, + # We don't retry on errors that are unlikely to be transient + # (e.g. bad request, invalid auth credentials) BadRequestErrorRetries=0, AuthenticationErrorRetries=0, ContentPolicyViolationErrorRetries=0,