From f82de3865b6382545e8c248d88a960c9f2cfa59c Mon Sep 17 00:00:00 2001 From: dbczumar Date: Tue, 29 Oct 2024 15:56:13 -0700 Subject: [PATCH 1/6] fix Signed-off-by: dbczumar --- dspy/clients/lm.py | 115 ++++++++++++++++++++++++++++----------------- 1 file changed, 73 insertions(+), 42 deletions(-) diff --git a/dspy/clients/lm.py b/dspy/clients/lm.py index f7350ecc18..c26e12cc19 100644 --- a/dspy/clients/lm.py +++ b/dspy/clients/lm.py @@ -4,18 +4,15 @@ from concurrent.futures import ThreadPoolExecutor from datetime import datetime from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Literal, Optional import litellm import ujson from litellm.caching import Cache from dspy.clients.finetune import FinetuneJob, TrainingMethod -from dspy.clients.lm_finetune_utils import ( - execute_finetune_job, - get_provider_finetune_job_class, -) -from dspy.utils.callback import with_callbacks +from dspy.clients.lm_finetune_utils import execute_finetune_job, get_provider_finetune_job_class +from dspy.utils.callback import BaseCallback, with_callbacks from dspy.utils.logging import logger DISK_CACHE_DIR = os.environ.get("DSPY_CACHEDIR") or os.path.join(Path.home(), ".dspy_cache") @@ -27,18 +24,37 @@ GLOBAL_HISTORY = [] + class LM: def __init__( - self, - model, - model_type='chat', - temperature=0.0, - max_tokens=1000, - cache=True, - launch_kwargs=None, - callbacks=None, - **kwargs - ): + self, + model: str, + model_type: Literal["chat", "completions"] = "chat", + temperature: float = 0.0, + max_tokens: int = 1000, + cache: bool = True, + launch_kwargs=None, + callbacks: Optional[List[BaseCallback]] = None, + num_retries: int = 7, + **kwargs, + ): + """ + Create a new language model instance for use with DSPy modules and programs. + + Args: + model: The model to use. This should be a string of the form "llm_provider/llm_name" + supported by LiteLLM. For example, ``"openai/gpt-4o"``. + model_type: The type of the model, either ``"chat"`` or ``"completions"``. + temperature: The sampling temperature to use when generating responses. + max_tokens: The maximum number of tokens to generate per response. + cache: Whether to cache the model responses for reuse to improve performance + and reduce costs. + launch_kwargs: Additional keyword arguments to pass to the model launch request. + callbacks: A list of callback functions to run before and after each request. + num_retries: The number of times to retry a request if it fails transiently due to + network error, rate limiting, etc. Requests are retried with exponential + backoff. + """ # Remember to update LM.copy() if you modify the constructor! self.model = model self.model_type = model_type @@ -47,6 +63,7 @@ def __init__( self.kwargs = dict(temperature=temperature, max_tokens=max_tokens, **kwargs) self.history = [] self.callbacks = callbacks or [] + self.num_retries = num_retries # TODO: Arbitrary model strings could include the substring "o1-". We # should find a more robust way to check for the "o1-" family models. @@ -68,7 +85,10 @@ def __call__(self, prompt=None, messages=None, **kwargs): else: completion = cached_litellm_text_completion if cache else litellm_text_completion - response = completion(ujson.dumps(dict(model=self.model, messages=messages, **kwargs))) + response = completion( + request=ujson.dumps(dict(model=self.model, messages=messages, **kwargs)), + num_retries=self.num_retries, + ) outputs = [c.message.content if hasattr(c, "message") else c["text"] for c in response["choices"]] # Logging, with removed api key & where `cost` is None on cache hit. @@ -85,9 +105,9 @@ def __call__(self, prompt=None, messages=None, **kwargs): ) self.history.append(entry) GLOBAL_HISTORY.append(entry) - + return outputs - + def inspect_history(self, n: int = 1): _inspect_history(self.history, n) @@ -104,15 +124,16 @@ def kill(self): logger.info(msg) def finetune( - self, - train_data: List[Dict[str, Any]], - train_kwargs: Optional[Dict[str, Any]]=None, - train_method: TrainingMethod = TrainingMethod.SFT, - provider: str = "openai", - cache_finetune: bool = True, - ) -> FinetuneJob: + self, + train_data: List[Dict[str, Any]], + train_kwargs: Optional[Dict[str, Any]] = None, + train_method: TrainingMethod = TrainingMethod.SFT, + provider: str = "openai", + cache_finetune: bool = True, + ) -> FinetuneJob: """Start model fine-tuning, if supported.""" from dspy import settings as settings + err = "Fine-tuning is an experimental feature." err += " Set `dspy.settings.experimental` to `True` to use it." assert settings.experimental, err @@ -123,24 +144,20 @@ def finetune( train_data=train_data, train_kwargs=train_kwargs, train_method=train_method, - provider=provider + provider=provider, ) executor = ThreadPoolExecutor(max_workers=1) - executor.submit( - execute_finetune_job, - finetune_job, - lm=self, - cache_finetune=cache_finetune - ) + executor.submit(execute_finetune_job, finetune_job, lm=self, cache_finetune=cache_finetune) executor.shutdown(wait=False) return finetune_job - + def copy(self, **kwargs): """Returns a copy of the language model with possibly updated parameters.""" import copy + new_instance = copy.deepcopy(self) new_instance.history = [] @@ -153,23 +170,35 @@ def copy(self, **kwargs): return new_instance - @functools.lru_cache(maxsize=None) -def cached_litellm_completion(request): - return litellm_completion(request, cache={"no-cache": False, "no-store": False}) +def cached_litellm_completion(request, num_retries: int): + return litellm_completion( + request, + cache={"no-cache": False, "no-store": False}, + num_retries=num_retries, + ) -def litellm_completion(request, cache={"no-cache": True, "no-store": True}): +def litellm_completion(request, num_retries: int, cache={"no-cache": True, "no-store": True}): kwargs = ujson.loads(request) - return litellm.completion(cache=cache, **kwargs) + return litellm.completion( + num_retries=num_retries, + retry_strategy="exponential_backoff_retry", + cache=cache, + **kwargs, + ) @functools.lru_cache(maxsize=None) -def cached_litellm_text_completion(request): - return litellm_text_completion(request, cache={"no-cache": False, "no-store": False}) +def cached_litellm_text_completion(request, num_retries: int): + return litellm_text_completion( + request, + num_retries=num_retries, + cache={"no-cache": False, "no-store": False}, + ) -def litellm_text_completion(request, cache={"no-cache": True, "no-store": True}): +def litellm_text_completion(request, num_retries: int, cache={"no-cache": True, "no-store": True}): kwargs = ujson.loads(request) # Extract the provider and model from the model string. @@ -190,6 +219,8 @@ def litellm_text_completion(request, cache={"no-cache": True, "no-store": True}) api_key=api_key, api_base=api_base, prompt=prompt, + num_retries=num_retries, + retry_strategy="exponential_backoff_retry", **kwargs, ) From c4d462ef08127078e824d7d28927b90c51dcb281 Mon Sep 17 00:00:00 2001 From: dbczumar Date: Tue, 29 Oct 2024 15:58:34 -0700 Subject: [PATCH 2/6] type hint Signed-off-by: dbczumar --- dspy/clients/lm.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/dspy/clients/lm.py b/dspy/clients/lm.py index c26e12cc19..984ed39b64 100644 --- a/dspy/clients/lm.py +++ b/dspy/clients/lm.py @@ -33,7 +33,7 @@ def __init__( temperature: float = 0.0, max_tokens: int = 1000, cache: bool = True, - launch_kwargs=None, + launch_kwargs: Optional[Dict[str, Any]] = None, callbacks: Optional[List[BaseCallback]] = None, num_retries: int = 7, **kwargs, @@ -49,7 +49,6 @@ def __init__( max_tokens: The maximum number of tokens to generate per response. cache: Whether to cache the model responses for reuse to improve performance and reduce costs. - launch_kwargs: Additional keyword arguments to pass to the model launch request. callbacks: A list of callback functions to run before and after each request. num_retries: The number of times to retry a request if it fails transiently due to network error, rate limiting, etc. Requests are retried with exponential From 898028ad22699d2c9b67e959b26dc6df59b313c2 Mon Sep 17 00:00:00 2001 From: dbczumar Date: Tue, 29 Oct 2024 16:01:21 -0700 Subject: [PATCH 3/6] fix Signed-off-by: dbczumar --- dspy/clients/lm.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/dspy/clients/lm.py b/dspy/clients/lm.py index 984ed39b64..d96c75f482 100644 --- a/dspy/clients/lm.py +++ b/dspy/clients/lm.py @@ -26,6 +26,11 @@ class LM: + """ + A language model supporting chat and completions requests for use with DSPy + modules and programs. + """ + def __init__( self, model: str, @@ -35,7 +40,7 @@ def __init__( cache: bool = True, launch_kwargs: Optional[Dict[str, Any]] = None, callbacks: Optional[List[BaseCallback]] = None, - num_retries: int = 7, + num_retries: int = 8, **kwargs, ): """ From d2413af1abb2f237263b4fc63ef81c59a0208bf2 Mon Sep 17 00:00:00 2001 From: dbczumar Date: Tue, 29 Oct 2024 16:04:14 -0700 Subject: [PATCH 4/6] fix Signed-off-by: dbczumar --- dspy/clients/lm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dspy/clients/lm.py b/dspy/clients/lm.py index d96c75f482..88900242aa 100644 --- a/dspy/clients/lm.py +++ b/dspy/clients/lm.py @@ -47,7 +47,7 @@ def __init__( Create a new language model instance for use with DSPy modules and programs. Args: - model: The model to use. This should be a string of the form "llm_provider/llm_name" + model: The model to use. This should be a string of the form ``"llm_provider/llm_name"`` supported by LiteLLM. For example, ``"openai/gpt-4o"``. model_type: The type of the model, either ``"chat"`` or ``"completions"``. temperature: The sampling temperature to use when generating responses. From 445b84ab5998161989ace8c76861c5c48bc7efef Mon Sep 17 00:00:00 2001 From: dbczumar Date: Tue, 29 Oct 2024 18:30:24 -0700 Subject: [PATCH 5/6] fix Signed-off-by: dbczumar --- tests/clients/test_lm.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) create mode 100644 tests/clients/test_lm.py diff --git a/tests/clients/test_lm.py b/tests/clients/test_lm.py new file mode 100644 index 0000000000..8825dab95d --- /dev/null +++ b/tests/clients/test_lm.py @@ -0,0 +1,25 @@ +from unittest import mock + +from dspy.clients.lm import LM + + +def test_lm_chat_respects_max_retries(): + lm = LM(model="openai/gpt4o", model_type="chat", max_retries=17) + + with mock.patch("dspy.clients.lm.litellm.completion") as litellm_completion_api: + lm(messages=[{"content": "Hello, world!", "role": "user"}]) + + assert litellm_completion_api.call_count == 1 + assert litellm_completion_api.call_args[1]["max_retries"] == 17 + assert litellm_completion_api.call_args[1]["retry_strategy"] == "exponential_backoff_retry" + + +def test_lm_completions_respects_max_retries(): + lm = LM(model="openai/gpt-3.5-turbo", model_type="completions", max_retries=17) + + with mock.patch("dspy.clients.lm.litellm.text_completion") as litellm_completion_api: + lm(prompt="Hello, world!") + + assert litellm_completion_api.call_count == 1 + assert litellm_completion_api.call_args[1]["max_retries"] == 17 + assert litellm_completion_api.call_args[1]["retry_strategy"] == "exponential_backoff_retry" From e34e82176a075ce426717c5325ef9127dcfdc0f6 Mon Sep 17 00:00:00 2001 From: Omar Khattab Date: Tue, 29 Oct 2024 18:39:23 -0700 Subject: [PATCH 6/6] Update lm.py --- dspy/clients/lm.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/dspy/clients/lm.py b/dspy/clients/lm.py index 88900242aa..7ba73c600f 100644 --- a/dspy/clients/lm.py +++ b/dspy/clients/lm.py @@ -27,14 +27,13 @@ class LM: """ - A language model supporting chat and completions requests for use with DSPy - modules and programs. + A language model supporting chat or text completion requests for use with DSPy modules. """ def __init__( self, model: str, - model_type: Literal["chat", "completions"] = "chat", + model_type: Literal["chat", "text"] = "chat", temperature: float = 0.0, max_tokens: int = 1000, cache: bool = True, @@ -49,7 +48,7 @@ def __init__( Args: model: The model to use. This should be a string of the form ``"llm_provider/llm_name"`` supported by LiteLLM. For example, ``"openai/gpt-4o"``. - model_type: The type of the model, either ``"chat"`` or ``"completions"``. + model_type: The type of the model, either ``"chat"`` or ``"text"``. temperature: The sampling temperature to use when generating responses. max_tokens: The maximum number of tokens to generate per response. cache: Whether to cache the model responses for reuse to improve performance