diff --git a/dspy/clients/lm.py b/dspy/clients/lm.py index 6a55e4cfd4..3880d35f84 100644 --- a/dspy/clients/lm.py +++ b/dspy/clients/lm.py @@ -3,6 +3,7 @@ import os import uuid from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass from datetime import datetime from typing import Any, Dict, List, Literal, Optional @@ -164,6 +165,55 @@ def copy(self, **kwargs): return new_instance +@dataclass(frozen=True) +class _ProviderAPIConfig: + """ + API configurations for a provider (e.g. OpenAI, Azure OpenAI) + """ + + api_key: Optional[str] + api_base: Optional[str] + api_version: Optional[str] + # Azure OpenAI with Azure AD auth requires an Azure AD token for authentication. + # For all other providers, this field is empty + azure_ad_token: Optional[str] + + +def _extract_provider_api_config(model: str, llm_kwargs: Dict[str, Any]) -> _ProviderAPIConfig: + """ + Extract the API configurations from the specified LLM keyword arguments (`llm_kwargs`) for the + provider corresponding to the given model. + + Note: The API configurations are removed from the specified `llm_kwargs`, if present, mutating + the input dictionary. + """ + provider = _get_provider(model) + api_key = llm_kwargs.pop("api_key", None) or os.getenv(f"{provider.upper()}_API_KEY") + api_base = llm_kwargs.pop("api_base", None) or os.getenv(f"{provider.upper()}_API_BASE") + api_version = llm_kwargs.pop("api_version", None) or os.getenv(f"{provider.upper()}_API_VERSION") + if "azure" in provider: + azure_ad_token = llm_kwargs.pop("azure_ad_token", None) or os.getenv("AZURE_AD_TOKEN") + else: + azure_ad_token = None + return _ProviderAPIConfig( + api_key=api_key, + api_base=api_base, + api_version=api_version, + azure_ad_token=azure_ad_token, + ) + + +def _get_provider(model: str) -> str: + """ + Extract the provider name from the model string of the format "/", + e.g. "openai/gpt-4". + + TODO: Not all the models are in the format of "provider/model" + """ + model = model.split("/", 1) + return model[0] if len(model) > 1 else "openai" + + @functools.lru_cache(maxsize=None) def cached_litellm_completion(request, num_retries: int): return litellm_completion( @@ -175,7 +225,8 @@ def cached_litellm_completion(request, num_retries: int): def litellm_completion(request, num_retries: int, cache={"no-cache": True, "no-store": True}): kwargs = ujson.loads(request) - router = _get_litellm_router(model=kwargs["model"], num_retries=num_retries) + api_config = _extract_provider_api_config(model=kwargs["model"], llm_kwargs=kwargs) + router = _get_litellm_router(model=kwargs["model"], num_retries=num_retries, api_config=api_config) return router.completion( cache=cache, **kwargs, @@ -183,7 +234,7 @@ def litellm_completion(request, num_retries: int, cache={"no-cache": True, "no-s @functools.lru_cache(maxsize=None) -def _get_litellm_router(model: str, num_retries: int) -> Router: +def _get_litellm_router(model: str, num_retries: int, api_config: _ProviderAPIConfig) -> Router: """ Get a LiteLLM router for the given model with the specified number of retries for transient errors. @@ -193,6 +244,9 @@ def _get_litellm_router(model: str, num_retries: int) -> Router: num_retries: The number of times to retry a request if it fails transiently due to network error, rate limiting, etc. Requests are retried with exponential backoff. + api_config: The API configurations (keys, base URL, etc.) for the provider + (OpenAI, Azure OpenAI, etc.) corresponding to the given model. + Returns: A LiteLLM router instance that can be used to query the given model. """ @@ -207,19 +261,29 @@ def _get_litellm_router(model: str, num_retries: int) -> Router: ContentPolicyViolationErrorRetries=0, ) + # LiteLLM routers must specify a `model_list`, which maps model names passed + # to `completions()` into actual LiteLLM model names. For our purposes, the + # model name is the same as the LiteLLM model name, so we add a single + # entry to the `model_list` that maps the model name to itself + litellm_params = { + "model": model, + } + if api_config.api_key is not None: + litellm_params["api_key"] = api_config.api_key + if api_config.api_base is not None: + litellm_params["api_base"] = api_config.api_base + if api_config.api_version is not None: + litellm_params["api_version"] = api_config.api_version + if api_config.azure_ad_token is not None: + litellm_params["azure_ad_token"] = api_config.azure_ad_token + model_list = [ + { + "model_name": model, + "litellm_params": litellm_params, + } + ] return Router( - # LiteLLM routers must specify a `model_list`, which maps model names passed - # to `completions()` into actual LiteLLM model names. For our purposes, the - # model name is the same as the LiteLLM model name, so we add a single - # entry to the `model_list` that maps the model name to itself - model_list=[ - { - "model_name": model, - "litellm_params": { - "model": model, - }, - } - ], + model_list=model_list, retry_policy=retry_policy, ) @@ -235,26 +299,18 @@ def cached_litellm_text_completion(request, num_retries: int): def litellm_text_completion(request, num_retries: int, cache={"no-cache": True, "no-store": True}): kwargs = ujson.loads(request) - - # Extract the provider and model from the model string. - # TODO: Not all the models are in the format of "provider/model" - model = kwargs.pop("model").split("/", 1) - provider, model = model[0] if len(model) > 1 else "openai", model[-1] - text_completion_model_name = f"text-completion-openai/{model}" - - # Use the API key and base from the kwargs, or from the environment. - api_key = kwargs.pop("api_key", None) or os.getenv(f"{provider}_API_KEY") - api_base = kwargs.pop("api_base", None) or os.getenv(f"{provider}_API_BASE") + model = kwargs.pop("model") + api_config = _extract_provider_api_config(model=model, llm_kwargs=kwargs) + model_name = model.split("/", 1)[-1] + text_completion_model_name = f"text-completion-openai/{model_name}" # Build the prompt from the messages. prompt = "\n\n".join([x["content"] for x in kwargs.pop("messages")] + ["BEGIN RESPONSE:"]) - router = _get_litellm_router(model=text_completion_model_name, num_retries=num_retries) + router = _get_litellm_router(model=text_completion_model_name, num_retries=num_retries, api_config=api_config) return router.text_completion( cache=cache, model=text_completion_model_name, - api_key=api_key, - api_base=api_base, prompt=prompt, **kwargs, ) diff --git a/tests/clients/test_lm.py b/tests/clients/test_lm.py index 61e5828aee..7b3a02481a 100644 --- a/tests/clients/test_lm.py +++ b/tests/clients/test_lm.py @@ -1,20 +1,40 @@ from unittest import mock +import pytest from litellm.router import RetryPolicy from dspy.clients.lm import LM, _get_litellm_router -def test_lm_chat_respects_max_retries(): +@pytest.mark.parametrize("keys_in_env_vars", [True, False]) +def test_lm_chat_respects_max_retries(keys_in_env_vars, monkeypatch): model_name = "openai/gpt4o" num_retries = 17 temperature = 0.5 max_tokens = 100 prompt = "Hello, world!" + api_version = "2024-02-01" + api_key = "apikey" + + lm_kwargs = { + "model": model_name, + "model_type": "chat", + "num_retries": num_retries, + "temperature": temperature, + "max_tokens": max_tokens, + } + if keys_in_env_vars: + api_base = "http://testfromenv.com" + monkeypatch.setenv("OPENAI_API_KEY", api_key) + monkeypatch.setenv("OPENAI_API_BASE", api_base) + monkeypatch.setenv("OPENAI_API_VERSION", api_version) + else: + api_base = "http://test.com" + lm_kwargs["api_key"] = api_key + lm_kwargs["api_base"] = api_base + lm_kwargs["api_version"] = api_version - lm = LM( - model=model_name, model_type="chat", num_retries=num_retries, temperature=temperature, max_tokens=max_tokens - ) + lm = LM(**lm_kwargs) MockRouter = mock.MagicMock() mock_completion = mock.MagicMock() @@ -29,6 +49,9 @@ def test_lm_chat_respects_max_retries(): "model_name": model_name, "litellm_params": { "model": model_name, + "api_key": api_key, + "api_base": api_base, + "api_version": api_version, }, } ], @@ -50,25 +73,39 @@ def test_lm_chat_respects_max_retries(): ) -def test_lm_completions_respects_max_retries(): - model_name = "openai/gpt-3.5-turbo" - expected_model = "text-completion-" + model_name +@pytest.mark.parametrize("keys_in_env_vars", [True, False]) +def test_lm_completions_respects_max_retries(keys_in_env_vars, monkeypatch): + model_name = "azure/gpt-3.5-turbo" + expected_model = "text-completion-openai/" + model_name.split("/")[-1] num_retries = 17 temperature = 0.5 max_tokens = 100 prompt = "Hello, world!" - api_base = "http://test.com" + api_version = "2024-02-01" api_key = "apikey" + azure_ad_token = "adtoken" + + lm_kwargs = { + "model": model_name, + "model_type": "text", + "num_retries": num_retries, + "temperature": temperature, + "max_tokens": max_tokens, + } + if keys_in_env_vars: + api_base = "http://testfromenv.com" + monkeypatch.setenv("AZURE_API_KEY", api_key) + monkeypatch.setenv("AZURE_API_BASE", api_base) + monkeypatch.setenv("AZURE_API_VERSION", api_version) + monkeypatch.setenv("AZURE_AD_TOKEN", azure_ad_token) + else: + api_base = "http://test.com" + lm_kwargs["api_key"] = api_key + lm_kwargs["api_base"] = api_base + lm_kwargs["api_version"] = api_version + lm_kwargs["azure_ad_token"] = azure_ad_token - lm = LM( - model=model_name, - model_type="text", - num_retries=num_retries, - temperature=temperature, - max_tokens=max_tokens, - api_base=api_base, - api_key=api_key, - ) + lm = LM(**lm_kwargs) MockRouter = mock.MagicMock() mock_text_completion = mock.MagicMock() @@ -83,6 +120,10 @@ def test_lm_completions_respects_max_retries(): "model_name": expected_model, "litellm_params": { "model": expected_model, + "api_key": api_key, + "api_base": api_base, + "api_version": api_version, + "azure_ad_token": azure_ad_token, }, } ], @@ -100,7 +141,5 @@ def test_lm_completions_respects_max_retries(): prompt=prompt + "\n\nBEGIN RESPONSE:", temperature=temperature, max_tokens=max_tokens, - api_key=api_key, - api_base=api_base, cache=mock.ANY, )