-
Notifications
You must be signed in to change notification settings - Fork 2.4k
Add retries to LM calls with LiteLLM #1718
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,18 +4,15 @@ | |
| from concurrent.futures import ThreadPoolExecutor | ||
| from datetime import datetime | ||
| from pathlib import Path | ||
| from typing import Any, Dict, List, Optional | ||
| from typing import Any, Dict, List, Literal, Optional | ||
|
|
||
| import litellm | ||
| import ujson | ||
| from litellm.caching import Cache | ||
|
|
||
| from dspy.clients.finetune import FinetuneJob, TrainingMethod | ||
| from dspy.clients.lm_finetune_utils import ( | ||
| execute_finetune_job, | ||
| get_provider_finetune_job_class, | ||
| ) | ||
| from dspy.utils.callback import with_callbacks | ||
| from dspy.clients.lm_finetune_utils import execute_finetune_job, get_provider_finetune_job_class | ||
| from dspy.utils.callback import BaseCallback, with_callbacks | ||
| from dspy.utils.logging import logger | ||
|
|
||
| DISK_CACHE_DIR = os.environ.get("DSPY_CACHEDIR") or os.path.join(Path.home(), ".dspy_cache") | ||
|
|
@@ -27,18 +24,40 @@ | |
|
|
||
| GLOBAL_HISTORY = [] | ||
|
|
||
|
|
||
| class LM: | ||
| """ | ||
| A language model supporting chat or text completion requests for use with DSPy modules. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| model, | ||
| model_type='chat', | ||
| temperature=0.0, | ||
| max_tokens=1000, | ||
| cache=True, | ||
| launch_kwargs=None, | ||
| callbacks=None, | ||
| **kwargs | ||
| ): | ||
| self, | ||
| model: str, | ||
| model_type: Literal["chat", "text"] = "chat", | ||
| temperature: float = 0.0, | ||
| max_tokens: int = 1000, | ||
| cache: bool = True, | ||
| launch_kwargs: Optional[Dict[str, Any]] = None, | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I didn't document this parameter because it's not actually used |
||
| callbacks: Optional[List[BaseCallback]] = None, | ||
| num_retries: int = 8, | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Empirically, 8 retries with exponential backoff in LiteLLM takes just over 1 minute with LiteLLM's builtin / default retry strategy. I defer to their exponential backoff retry strategy defaults, rather than defining our own (presumably, they know best) |
||
| **kwargs, | ||
| ): | ||
| """ | ||
| Create a new language model instance for use with DSPy modules and programs. | ||
|
|
||
| Args: | ||
| model: The model to use. This should be a string of the form ``"llm_provider/llm_name"`` | ||
| supported by LiteLLM. For example, ``"openai/gpt-4o"``. | ||
| model_type: The type of the model, either ``"chat"`` or ``"text"``. | ||
| temperature: The sampling temperature to use when generating responses. | ||
| max_tokens: The maximum number of tokens to generate per response. | ||
| cache: Whether to cache the model responses for reuse to improve performance | ||
| and reduce costs. | ||
| callbacks: A list of callback functions to run before and after each request. | ||
| num_retries: The number of times to retry a request if it fails transiently due to | ||
| network error, rate limiting, etc. Requests are retried with exponential | ||
| backoff. | ||
| """ | ||
| # Remember to update LM.copy() if you modify the constructor! | ||
| self.model = model | ||
| self.model_type = model_type | ||
|
|
@@ -47,6 +66,7 @@ def __init__( | |
| self.kwargs = dict(temperature=temperature, max_tokens=max_tokens, **kwargs) | ||
| self.history = [] | ||
| self.callbacks = callbacks or [] | ||
| self.num_retries = num_retries | ||
|
|
||
| # TODO: Arbitrary model strings could include the substring "o1-". We | ||
| # should find a more robust way to check for the "o1-" family models. | ||
|
|
@@ -68,7 +88,10 @@ def __call__(self, prompt=None, messages=None, **kwargs): | |
| else: | ||
| completion = cached_litellm_text_completion if cache else litellm_text_completion | ||
|
|
||
| response = completion(ujson.dumps(dict(model=self.model, messages=messages, **kwargs))) | ||
| response = completion( | ||
| request=ujson.dumps(dict(model=self.model, messages=messages, **kwargs)), | ||
| num_retries=self.num_retries, | ||
| ) | ||
| outputs = [c.message.content if hasattr(c, "message") else c["text"] for c in response["choices"]] | ||
|
|
||
| # Logging, with removed api key & where `cost` is None on cache hit. | ||
|
|
@@ -85,9 +108,9 @@ def __call__(self, prompt=None, messages=None, **kwargs): | |
| ) | ||
| self.history.append(entry) | ||
| GLOBAL_HISTORY.append(entry) | ||
|
|
||
| return outputs | ||
|
|
||
| def inspect_history(self, n: int = 1): | ||
| _inspect_history(self.history, n) | ||
|
|
||
|
|
@@ -104,15 +127,16 @@ def kill(self): | |
| logger.info(msg) | ||
|
|
||
| def finetune( | ||
| self, | ||
| train_data: List[Dict[str, Any]], | ||
| train_kwargs: Optional[Dict[str, Any]]=None, | ||
| train_method: TrainingMethod = TrainingMethod.SFT, | ||
| provider: str = "openai", | ||
| cache_finetune: bool = True, | ||
| ) -> FinetuneJob: | ||
|
Comment on lines
-107
to
-113
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is just linter formatting |
||
| self, | ||
| train_data: List[Dict[str, Any]], | ||
| train_kwargs: Optional[Dict[str, Any]] = None, | ||
| train_method: TrainingMethod = TrainingMethod.SFT, | ||
| provider: str = "openai", | ||
| cache_finetune: bool = True, | ||
| ) -> FinetuneJob: | ||
| """Start model fine-tuning, if supported.""" | ||
| from dspy import settings as settings | ||
|
|
||
| err = "Fine-tuning is an experimental feature." | ||
| err += " Set `dspy.settings.experimental` to `True` to use it." | ||
| assert settings.experimental, err | ||
|
|
@@ -123,24 +147,20 @@ def finetune( | |
| train_data=train_data, | ||
| train_kwargs=train_kwargs, | ||
| train_method=train_method, | ||
| provider=provider | ||
| provider=provider, | ||
| ) | ||
|
|
||
| executor = ThreadPoolExecutor(max_workers=1) | ||
| executor.submit( | ||
| execute_finetune_job, | ||
| finetune_job, | ||
| lm=self, | ||
| cache_finetune=cache_finetune | ||
| ) | ||
| executor.submit(execute_finetune_job, finetune_job, lm=self, cache_finetune=cache_finetune) | ||
| executor.shutdown(wait=False) | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is just linter formatting |
||
|
|
||
| return finetune_job | ||
|
|
||
| def copy(self, **kwargs): | ||
| """Returns a copy of the language model with possibly updated parameters.""" | ||
|
|
||
| import copy | ||
|
|
||
| new_instance = copy.deepcopy(self) | ||
| new_instance.history = [] | ||
|
|
||
|
|
@@ -153,23 +173,35 @@ def copy(self, **kwargs): | |
| return new_instance | ||
|
|
||
|
|
||
|
|
||
| @functools.lru_cache(maxsize=None) | ||
| def cached_litellm_completion(request): | ||
| return litellm_completion(request, cache={"no-cache": False, "no-store": False}) | ||
| def cached_litellm_completion(request, num_retries: int): | ||
| return litellm_completion( | ||
| request, | ||
| cache={"no-cache": False, "no-store": False}, | ||
| num_retries=num_retries, | ||
| ) | ||
|
|
||
|
|
||
| def litellm_completion(request, cache={"no-cache": True, "no-store": True}): | ||
| def litellm_completion(request, num_retries: int, cache={"no-cache": True, "no-store": True}): | ||
| kwargs = ujson.loads(request) | ||
| return litellm.completion(cache=cache, **kwargs) | ||
| return litellm.completion( | ||
| num_retries=num_retries, | ||
| retry_strategy="exponential_backoff_retry", | ||
| cache=cache, | ||
| **kwargs, | ||
| ) | ||
|
|
||
|
|
||
| @functools.lru_cache(maxsize=None) | ||
| def cached_litellm_text_completion(request): | ||
| return litellm_text_completion(request, cache={"no-cache": False, "no-store": False}) | ||
| def cached_litellm_text_completion(request, num_retries: int): | ||
| return litellm_text_completion( | ||
| request, | ||
| num_retries=num_retries, | ||
| cache={"no-cache": False, "no-store": False}, | ||
| ) | ||
|
|
||
|
|
||
| def litellm_text_completion(request, cache={"no-cache": True, "no-store": True}): | ||
| def litellm_text_completion(request, num_retries: int, cache={"no-cache": True, "no-store": True}): | ||
| kwargs = ujson.loads(request) | ||
|
|
||
| # Extract the provider and model from the model string. | ||
|
|
@@ -190,6 +222,8 @@ def litellm_text_completion(request, cache={"no-cache": True, "no-store": True}) | |
| api_key=api_key, | ||
| api_base=api_base, | ||
| prompt=prompt, | ||
| num_retries=num_retries, | ||
| retry_strategy="exponential_backoff_retry", | ||
| **kwargs, | ||
| ) | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,25 @@ | ||
| from unittest import mock | ||
|
|
||
| from dspy.clients.lm import LM | ||
|
|
||
|
|
||
| def test_lm_chat_respects_max_retries(): | ||
| lm = LM(model="openai/gpt4o", model_type="chat", max_retries=17) | ||
|
|
||
| with mock.patch("dspy.clients.lm.litellm.completion") as litellm_completion_api: | ||
| lm(messages=[{"content": "Hello, world!", "role": "user"}]) | ||
|
|
||
| assert litellm_completion_api.call_count == 1 | ||
| assert litellm_completion_api.call_args[1]["max_retries"] == 17 | ||
| assert litellm_completion_api.call_args[1]["retry_strategy"] == "exponential_backoff_retry" | ||
|
|
||
|
|
||
| def test_lm_completions_respects_max_retries(): | ||
| lm = LM(model="openai/gpt-3.5-turbo", model_type="completions", max_retries=17) | ||
|
|
||
| with mock.patch("dspy.clients.lm.litellm.text_completion") as litellm_completion_api: | ||
| lm(prompt="Hello, world!") | ||
|
|
||
| assert litellm_completion_api.call_count == 1 | ||
| assert litellm_completion_api.call_args[1]["max_retries"] == 17 | ||
| assert litellm_completion_api.call_args[1]["retry_strategy"] == "exponential_backoff_retry" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is just linter formatting