Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 76 additions & 42 deletions dspy/clients/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,15 @@
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Literal, Optional

import litellm
import ujson
from litellm.caching import Cache

from dspy.clients.finetune import FinetuneJob, TrainingMethod
from dspy.clients.lm_finetune_utils import (
execute_finetune_job,
get_provider_finetune_job_class,
)
Comment on lines -14 to -17
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just linter formatting

from dspy.utils.callback import with_callbacks
from dspy.clients.lm_finetune_utils import execute_finetune_job, get_provider_finetune_job_class
from dspy.utils.callback import BaseCallback, with_callbacks
from dspy.utils.logging import logger

DISK_CACHE_DIR = os.environ.get("DSPY_CACHEDIR") or os.path.join(Path.home(), ".dspy_cache")
Expand All @@ -27,18 +24,40 @@

GLOBAL_HISTORY = []


class LM:
"""
A language model supporting chat or text completion requests for use with DSPy modules.
"""

def __init__(
self,
model,
model_type='chat',
temperature=0.0,
max_tokens=1000,
cache=True,
launch_kwargs=None,
callbacks=None,
**kwargs
):
self,
model: str,
model_type: Literal["chat", "text"] = "chat",
temperature: float = 0.0,
max_tokens: int = 1000,
cache: bool = True,
launch_kwargs: Optional[Dict[str, Any]] = None,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't document this parameter because it's not actually used

callbacks: Optional[List[BaseCallback]] = None,
num_retries: int = 8,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Empirically, 8 retries with exponential backoff in LiteLLM takes just over 1 minute with LiteLLM's builtin / default retry strategy. I defer to their exponential backoff retry strategy defaults, rather than defining our own (presumably, they know best)

**kwargs,
):
"""
Create a new language model instance for use with DSPy modules and programs.

Args:
model: The model to use. This should be a string of the form ``"llm_provider/llm_name"``
supported by LiteLLM. For example, ``"openai/gpt-4o"``.
model_type: The type of the model, either ``"chat"`` or ``"text"``.
temperature: The sampling temperature to use when generating responses.
max_tokens: The maximum number of tokens to generate per response.
cache: Whether to cache the model responses for reuse to improve performance
and reduce costs.
callbacks: A list of callback functions to run before and after each request.
num_retries: The number of times to retry a request if it fails transiently due to
network error, rate limiting, etc. Requests are retried with exponential
backoff.
"""
# Remember to update LM.copy() if you modify the constructor!
self.model = model
self.model_type = model_type
Expand All @@ -47,6 +66,7 @@ def __init__(
self.kwargs = dict(temperature=temperature, max_tokens=max_tokens, **kwargs)
self.history = []
self.callbacks = callbacks or []
self.num_retries = num_retries

# TODO: Arbitrary model strings could include the substring "o1-". We
# should find a more robust way to check for the "o1-" family models.
Expand All @@ -68,7 +88,10 @@ def __call__(self, prompt=None, messages=None, **kwargs):
else:
completion = cached_litellm_text_completion if cache else litellm_text_completion

response = completion(ujson.dumps(dict(model=self.model, messages=messages, **kwargs)))
response = completion(
request=ujson.dumps(dict(model=self.model, messages=messages, **kwargs)),
num_retries=self.num_retries,
)
outputs = [c.message.content if hasattr(c, "message") else c["text"] for c in response["choices"]]

# Logging, with removed api key & where `cost` is None on cache hit.
Expand All @@ -85,9 +108,9 @@ def __call__(self, prompt=None, messages=None, **kwargs):
)
self.history.append(entry)
GLOBAL_HISTORY.append(entry)

return outputs

def inspect_history(self, n: int = 1):
_inspect_history(self.history, n)

Expand All @@ -104,15 +127,16 @@ def kill(self):
logger.info(msg)

def finetune(
self,
train_data: List[Dict[str, Any]],
train_kwargs: Optional[Dict[str, Any]]=None,
train_method: TrainingMethod = TrainingMethod.SFT,
provider: str = "openai",
cache_finetune: bool = True,
) -> FinetuneJob:
Comment on lines -107 to -113
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just linter formatting

self,
train_data: List[Dict[str, Any]],
train_kwargs: Optional[Dict[str, Any]] = None,
train_method: TrainingMethod = TrainingMethod.SFT,
provider: str = "openai",
cache_finetune: bool = True,
) -> FinetuneJob:
"""Start model fine-tuning, if supported."""
from dspy import settings as settings

err = "Fine-tuning is an experimental feature."
err += " Set `dspy.settings.experimental` to `True` to use it."
assert settings.experimental, err
Expand All @@ -123,24 +147,20 @@ def finetune(
train_data=train_data,
train_kwargs=train_kwargs,
train_method=train_method,
provider=provider
provider=provider,
)

executor = ThreadPoolExecutor(max_workers=1)
executor.submit(
execute_finetune_job,
finetune_job,
lm=self,
cache_finetune=cache_finetune
)
executor.submit(execute_finetune_job, finetune_job, lm=self, cache_finetune=cache_finetune)
executor.shutdown(wait=False)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just linter formatting


return finetune_job

def copy(self, **kwargs):
"""Returns a copy of the language model with possibly updated parameters."""

import copy

new_instance = copy.deepcopy(self)
new_instance.history = []

Expand All @@ -153,23 +173,35 @@ def copy(self, **kwargs):
return new_instance



@functools.lru_cache(maxsize=None)
def cached_litellm_completion(request):
return litellm_completion(request, cache={"no-cache": False, "no-store": False})
def cached_litellm_completion(request, num_retries: int):
return litellm_completion(
request,
cache={"no-cache": False, "no-store": False},
num_retries=num_retries,
)


def litellm_completion(request, cache={"no-cache": True, "no-store": True}):
def litellm_completion(request, num_retries: int, cache={"no-cache": True, "no-store": True}):
kwargs = ujson.loads(request)
return litellm.completion(cache=cache, **kwargs)
return litellm.completion(
num_retries=num_retries,
retry_strategy="exponential_backoff_retry",
cache=cache,
**kwargs,
)


@functools.lru_cache(maxsize=None)
def cached_litellm_text_completion(request):
return litellm_text_completion(request, cache={"no-cache": False, "no-store": False})
def cached_litellm_text_completion(request, num_retries: int):
return litellm_text_completion(
request,
num_retries=num_retries,
cache={"no-cache": False, "no-store": False},
)


def litellm_text_completion(request, cache={"no-cache": True, "no-store": True}):
def litellm_text_completion(request, num_retries: int, cache={"no-cache": True, "no-store": True}):
kwargs = ujson.loads(request)

# Extract the provider and model from the model string.
Expand All @@ -190,6 +222,8 @@ def litellm_text_completion(request, cache={"no-cache": True, "no-store": True})
api_key=api_key,
api_base=api_base,
prompt=prompt,
num_retries=num_retries,
retry_strategy="exponential_backoff_retry",
**kwargs,
)

Expand Down
25 changes: 25 additions & 0 deletions tests/clients/test_lm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from unittest import mock

from dspy.clients.lm import LM


def test_lm_chat_respects_max_retries():
lm = LM(model="openai/gpt4o", model_type="chat", max_retries=17)

with mock.patch("dspy.clients.lm.litellm.completion") as litellm_completion_api:
lm(messages=[{"content": "Hello, world!", "role": "user"}])

assert litellm_completion_api.call_count == 1
assert litellm_completion_api.call_args[1]["max_retries"] == 17
assert litellm_completion_api.call_args[1]["retry_strategy"] == "exponential_backoff_retry"


def test_lm_completions_respects_max_retries():
lm = LM(model="openai/gpt-3.5-turbo", model_type="completions", max_retries=17)

with mock.patch("dspy.clients.lm.litellm.text_completion") as litellm_completion_api:
lm(prompt="Hello, world!")

assert litellm_completion_api.call_count == 1
assert litellm_completion_api.call_args[1]["max_retries"] == 17
assert litellm_completion_api.call_args[1]["retry_strategy"] == "exponential_backoff_retry"
Loading