Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 83 additions & 27 deletions dspy/clients/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import uuid
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Dict, List, Literal, Optional

Expand Down Expand Up @@ -164,6 +165,55 @@ def copy(self, **kwargs):
return new_instance


@dataclass(frozen=True)
class _ProviderAPIConfig:
"""
API configurations for a provider (e.g. OpenAI, Azure OpenAI)
"""

api_key: Optional[str]
api_base: Optional[str]
api_version: Optional[str]
# Azure OpenAI with Azure AD auth requires an Azure AD token for authentication.
# For all other providers, this field is empty
azure_ad_token: Optional[str]


def _extract_provider_api_config(model: str, llm_kwargs: Dict[str, Any]) -> _ProviderAPIConfig:
"""
Extract the API configurations from the specified LLM keyword arguments (`llm_kwargs`) for the
provider corresponding to the given model.

Note: The API configurations are removed from the specified `llm_kwargs`, if present, mutating
the input dictionary.
Comment on lines +187 to +188
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This avoids duplicating code in litellm_text_completion, which pops the api key and base from the user-specified kwargs

"""
provider = _get_provider(model)
api_key = llm_kwargs.pop("api_key", None) or os.getenv(f"{provider.upper()}_API_KEY")
api_base = llm_kwargs.pop("api_base", None) or os.getenv(f"{provider.upper()}_API_BASE")
api_version = llm_kwargs.pop("api_version", None) or os.getenv(f"{provider.upper()}_API_VERSION")
if "azure" in provider:
azure_ad_token = llm_kwargs.pop("azure_ad_token", None) or os.getenv("AZURE_AD_TOKEN")
else:
azure_ad_token = None
return _ProviderAPIConfig(
api_key=api_key,
api_base=api_base,
api_version=api_version,
azure_ad_token=azure_ad_token,
)


def _get_provider(model: str) -> str:
"""
Extract the provider name from the model string of the format "<provider_name>/<model_name>",
e.g. "openai/gpt-4".

TODO: Not all the models are in the format of "provider/model"
"""
model = model.split("/", 1)
return model[0] if len(model) > 1 else "openai"


@functools.lru_cache(maxsize=None)
def cached_litellm_completion(request, num_retries: int):
return litellm_completion(
Expand All @@ -175,15 +225,16 @@ def cached_litellm_completion(request, num_retries: int):

def litellm_completion(request, num_retries: int, cache={"no-cache": True, "no-store": True}):
kwargs = ujson.loads(request)
router = _get_litellm_router(model=kwargs["model"], num_retries=num_retries)
api_config = _extract_provider_api_config(model=kwargs["model"], llm_kwargs=kwargs)
router = _get_litellm_router(model=kwargs["model"], num_retries=num_retries, api_config=api_config)
return router.completion(
cache=cache,
**kwargs,
)


@functools.lru_cache(maxsize=None)
def _get_litellm_router(model: str, num_retries: int) -> Router:
def _get_litellm_router(model: str, num_retries: int, api_config: _ProviderAPIConfig) -> Router:
"""
Get a LiteLLM router for the given model with the specified number of retries
for transient errors.
Expand All @@ -193,6 +244,9 @@ def _get_litellm_router(model: str, num_retries: int) -> Router:
num_retries: The number of times to retry a request if it fails transiently due to
network error, rate limiting, etc. Requests are retried with exponential
backoff.
api_config: The API configurations (keys, base URL, etc.) for the provider
(OpenAI, Azure OpenAI, etc.) corresponding to the given model.

Returns:
A LiteLLM router instance that can be used to query the given model.
"""
Expand All @@ -207,19 +261,29 @@ def _get_litellm_router(model: str, num_retries: int) -> Router:
ContentPolicyViolationErrorRetries=0,
)

# LiteLLM routers must specify a `model_list`, which maps model names passed
# to `completions()` into actual LiteLLM model names. For our purposes, the
# model name is the same as the LiteLLM model name, so we add a single
# entry to the `model_list` that maps the model name to itself
litellm_params = {
"model": model,
}
if api_config.api_key is not None:
litellm_params["api_key"] = api_config.api_key
if api_config.api_base is not None:
litellm_params["api_base"] = api_config.api_base
if api_config.api_version is not None:
litellm_params["api_version"] = api_config.api_version
if api_config.azure_ad_token is not None:
litellm_params["azure_ad_token"] = api_config.azure_ad_token
Comment on lines +271 to +278
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@okhat For OpenAI and Azure OpenAI, API keys must be defined in the Router's model_list directly, due to initialization logic in https://github.com/BerriAI/litellm/blob/0fe8cde7c78d6b975b936f7802a4124b58bd253a/litellm/router.py#L3999-L4001

model_list = [
{
"model_name": model,
"litellm_params": litellm_params,
}
]
return Router(
# LiteLLM routers must specify a `model_list`, which maps model names passed
# to `completions()` into actual LiteLLM model names. For our purposes, the
# model name is the same as the LiteLLM model name, so we add a single
# entry to the `model_list` that maps the model name to itself
model_list=[
{
"model_name": model,
"litellm_params": {
"model": model,
},
}
],
model_list=model_list,
retry_policy=retry_policy,
)

Expand All @@ -235,26 +299,18 @@ def cached_litellm_text_completion(request, num_retries: int):

def litellm_text_completion(request, num_retries: int, cache={"no-cache": True, "no-store": True}):
kwargs = ujson.loads(request)

# Extract the provider and model from the model string.
# TODO: Not all the models are in the format of "provider/model"
model = kwargs.pop("model").split("/", 1)
provider, model = model[0] if len(model) > 1 else "openai", model[-1]
text_completion_model_name = f"text-completion-openai/{model}"

# Use the API key and base from the kwargs, or from the environment.
api_key = kwargs.pop("api_key", None) or os.getenv(f"{provider}_API_KEY")
api_base = kwargs.pop("api_base", None) or os.getenv(f"{provider}_API_BASE")
Comment on lines -245 to -247
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are now part of the router definition, so they don't need to be passed to the text_completion() API

model = kwargs.pop("model")
api_config = _extract_provider_api_config(model=model, llm_kwargs=kwargs)
model_name = model.split("/", 1)[-1]
text_completion_model_name = f"text-completion-openai/{model_name}"

# Build the prompt from the messages.
prompt = "\n\n".join([x["content"] for x in kwargs.pop("messages")] + ["BEGIN RESPONSE:"])

router = _get_litellm_router(model=text_completion_model_name, num_retries=num_retries)
router = _get_litellm_router(model=text_completion_model_name, num_retries=num_retries, api_config=api_config)
return router.text_completion(
cache=cache,
model=text_completion_model_name,
api_key=api_key,
api_base=api_base,
prompt=prompt,
**kwargs,
)
77 changes: 58 additions & 19 deletions tests/clients/test_lm.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,40 @@
from unittest import mock

import pytest
from litellm.router import RetryPolicy

from dspy.clients.lm import LM, _get_litellm_router


def test_lm_chat_respects_max_retries():
@pytest.mark.parametrize("keys_in_env_vars", [True, False])
def test_lm_chat_respects_max_retries(keys_in_env_vars, monkeypatch):
model_name = "openai/gpt4o"
num_retries = 17
temperature = 0.5
max_tokens = 100
prompt = "Hello, world!"
api_version = "2024-02-01"
api_key = "apikey"

lm_kwargs = {
"model": model_name,
"model_type": "chat",
"num_retries": num_retries,
"temperature": temperature,
"max_tokens": max_tokens,
}
if keys_in_env_vars:
api_base = "http://testfromenv.com"
monkeypatch.setenv("OPENAI_API_KEY", api_key)
monkeypatch.setenv("OPENAI_API_BASE", api_base)
monkeypatch.setenv("OPENAI_API_VERSION", api_version)
else:
api_base = "http://test.com"
lm_kwargs["api_key"] = api_key
lm_kwargs["api_base"] = api_base
lm_kwargs["api_version"] = api_version

lm = LM(
model=model_name, model_type="chat", num_retries=num_retries, temperature=temperature, max_tokens=max_tokens
)
lm = LM(**lm_kwargs)

MockRouter = mock.MagicMock()
mock_completion = mock.MagicMock()
Expand All @@ -29,6 +49,9 @@ def test_lm_chat_respects_max_retries():
"model_name": model_name,
"litellm_params": {
"model": model_name,
"api_key": api_key,
"api_base": api_base,
"api_version": api_version,
},
}
],
Expand All @@ -50,25 +73,39 @@ def test_lm_chat_respects_max_retries():
)


def test_lm_completions_respects_max_retries():
model_name = "openai/gpt-3.5-turbo"
expected_model = "text-completion-" + model_name
@pytest.mark.parametrize("keys_in_env_vars", [True, False])
def test_lm_completions_respects_max_retries(keys_in_env_vars, monkeypatch):
model_name = "azure/gpt-3.5-turbo"
expected_model = "text-completion-openai/" + model_name.split("/")[-1]
num_retries = 17
temperature = 0.5
max_tokens = 100
prompt = "Hello, world!"
api_base = "http://test.com"
api_version = "2024-02-01"
api_key = "apikey"
azure_ad_token = "adtoken"

lm_kwargs = {
"model": model_name,
"model_type": "text",
"num_retries": num_retries,
"temperature": temperature,
"max_tokens": max_tokens,
}
if keys_in_env_vars:
api_base = "http://testfromenv.com"
monkeypatch.setenv("AZURE_API_KEY", api_key)
monkeypatch.setenv("AZURE_API_BASE", api_base)
monkeypatch.setenv("AZURE_API_VERSION", api_version)
monkeypatch.setenv("AZURE_AD_TOKEN", azure_ad_token)
else:
api_base = "http://test.com"
lm_kwargs["api_key"] = api_key
lm_kwargs["api_base"] = api_base
lm_kwargs["api_version"] = api_version
lm_kwargs["azure_ad_token"] = azure_ad_token

lm = LM(
model=model_name,
model_type="text",
num_retries=num_retries,
temperature=temperature,
max_tokens=max_tokens,
api_base=api_base,
api_key=api_key,
)
lm = LM(**lm_kwargs)

MockRouter = mock.MagicMock()
mock_text_completion = mock.MagicMock()
Expand All @@ -83,6 +120,10 @@ def test_lm_completions_respects_max_retries():
"model_name": expected_model,
"litellm_params": {
"model": expected_model,
"api_key": api_key,
"api_base": api_base,
"api_version": api_version,
"azure_ad_token": azure_ad_token,
},
}
],
Expand All @@ -100,7 +141,5 @@ def test_lm_completions_respects_max_retries():
prompt=prompt + "\n\nBEGIN RESPONSE:",
temperature=temperature,
max_tokens=max_tokens,
api_key=api_key,
api_base=api_base,
cache=mock.ANY,
)
Loading