-
Notifications
You must be signed in to change notification settings - Fork 2.5k
Fixes for OpenAI / Azure OpenAI compatibility with LiteLLM router #1760
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,6 +3,7 @@ | |
| import os | ||
| import uuid | ||
| from concurrent.futures import ThreadPoolExecutor | ||
| from dataclasses import dataclass | ||
| from datetime import datetime | ||
| from typing import Any, Dict, List, Literal, Optional | ||
|
|
||
|
|
@@ -164,6 +165,55 @@ def copy(self, **kwargs): | |
| return new_instance | ||
|
|
||
|
|
||
| @dataclass(frozen=True) | ||
| class _ProviderAPIConfig: | ||
| """ | ||
| API configurations for a provider (e.g. OpenAI, Azure OpenAI) | ||
| """ | ||
|
|
||
| api_key: Optional[str] | ||
| api_base: Optional[str] | ||
| api_version: Optional[str] | ||
| # Azure OpenAI with Azure AD auth requires an Azure AD token for authentication. | ||
| # For all other providers, this field is empty | ||
| azure_ad_token: Optional[str] | ||
|
|
||
|
|
||
| def _extract_provider_api_config(model: str, llm_kwargs: Dict[str, Any]) -> _ProviderAPIConfig: | ||
| """ | ||
| Extract the API configurations from the specified LLM keyword arguments (`llm_kwargs`) for the | ||
| provider corresponding to the given model. | ||
|
|
||
| Note: The API configurations are removed from the specified `llm_kwargs`, if present, mutating | ||
| the input dictionary. | ||
| """ | ||
| provider = _get_provider(model) | ||
| api_key = llm_kwargs.pop("api_key", None) or os.getenv(f"{provider.upper()}_API_KEY") | ||
| api_base = llm_kwargs.pop("api_base", None) or os.getenv(f"{provider.upper()}_API_BASE") | ||
| api_version = llm_kwargs.pop("api_version", None) or os.getenv(f"{provider.upper()}_API_VERSION") | ||
| if "azure" in provider: | ||
| azure_ad_token = llm_kwargs.pop("azure_ad_token", None) or os.getenv("AZURE_AD_TOKEN") | ||
| else: | ||
| azure_ad_token = None | ||
| return _ProviderAPIConfig( | ||
| api_key=api_key, | ||
| api_base=api_base, | ||
| api_version=api_version, | ||
| azure_ad_token=azure_ad_token, | ||
| ) | ||
|
|
||
|
|
||
| def _get_provider(model: str) -> str: | ||
| """ | ||
| Extract the provider name from the model string of the format "<provider_name>/<model_name>", | ||
| e.g. "openai/gpt-4". | ||
|
|
||
| TODO: Not all the models are in the format of "provider/model" | ||
| """ | ||
| model = model.split("/", 1) | ||
| return model[0] if len(model) > 1 else "openai" | ||
|
|
||
|
|
||
| @functools.lru_cache(maxsize=None) | ||
| def cached_litellm_completion(request, num_retries: int): | ||
| return litellm_completion( | ||
|
|
@@ -175,15 +225,16 @@ def cached_litellm_completion(request, num_retries: int): | |
|
|
||
| def litellm_completion(request, num_retries: int, cache={"no-cache": True, "no-store": True}): | ||
| kwargs = ujson.loads(request) | ||
| router = _get_litellm_router(model=kwargs["model"], num_retries=num_retries) | ||
| api_config = _extract_provider_api_config(model=kwargs["model"], llm_kwargs=kwargs) | ||
| router = _get_litellm_router(model=kwargs["model"], num_retries=num_retries, api_config=api_config) | ||
| return router.completion( | ||
| cache=cache, | ||
| **kwargs, | ||
| ) | ||
|
|
||
|
|
||
| @functools.lru_cache(maxsize=None) | ||
| def _get_litellm_router(model: str, num_retries: int) -> Router: | ||
| def _get_litellm_router(model: str, num_retries: int, api_config: _ProviderAPIConfig) -> Router: | ||
| """ | ||
| Get a LiteLLM router for the given model with the specified number of retries | ||
| for transient errors. | ||
|
|
@@ -193,6 +244,9 @@ def _get_litellm_router(model: str, num_retries: int) -> Router: | |
| num_retries: The number of times to retry a request if it fails transiently due to | ||
| network error, rate limiting, etc. Requests are retried with exponential | ||
| backoff. | ||
| api_config: The API configurations (keys, base URL, etc.) for the provider | ||
| (OpenAI, Azure OpenAI, etc.) corresponding to the given model. | ||
|
|
||
| Returns: | ||
| A LiteLLM router instance that can be used to query the given model. | ||
| """ | ||
|
|
@@ -207,19 +261,29 @@ def _get_litellm_router(model: str, num_retries: int) -> Router: | |
| ContentPolicyViolationErrorRetries=0, | ||
| ) | ||
|
|
||
| # LiteLLM routers must specify a `model_list`, which maps model names passed | ||
| # to `completions()` into actual LiteLLM model names. For our purposes, the | ||
| # model name is the same as the LiteLLM model name, so we add a single | ||
| # entry to the `model_list` that maps the model name to itself | ||
| litellm_params = { | ||
| "model": model, | ||
| } | ||
| if api_config.api_key is not None: | ||
| litellm_params["api_key"] = api_config.api_key | ||
| if api_config.api_base is not None: | ||
| litellm_params["api_base"] = api_config.api_base | ||
| if api_config.api_version is not None: | ||
| litellm_params["api_version"] = api_config.api_version | ||
| if api_config.azure_ad_token is not None: | ||
| litellm_params["azure_ad_token"] = api_config.azure_ad_token | ||
|
Comment on lines
+271
to
+278
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @okhat For OpenAI and Azure OpenAI, API keys must be defined in the Router's |
||
| model_list = [ | ||
| { | ||
| "model_name": model, | ||
| "litellm_params": litellm_params, | ||
| } | ||
| ] | ||
| return Router( | ||
| # LiteLLM routers must specify a `model_list`, which maps model names passed | ||
| # to `completions()` into actual LiteLLM model names. For our purposes, the | ||
| # model name is the same as the LiteLLM model name, so we add a single | ||
| # entry to the `model_list` that maps the model name to itself | ||
| model_list=[ | ||
| { | ||
| "model_name": model, | ||
| "litellm_params": { | ||
| "model": model, | ||
| }, | ||
| } | ||
| ], | ||
| model_list=model_list, | ||
| retry_policy=retry_policy, | ||
| ) | ||
|
|
||
|
|
@@ -235,26 +299,18 @@ def cached_litellm_text_completion(request, num_retries: int): | |
|
|
||
| def litellm_text_completion(request, num_retries: int, cache={"no-cache": True, "no-store": True}): | ||
| kwargs = ujson.loads(request) | ||
|
|
||
| # Extract the provider and model from the model string. | ||
| # TODO: Not all the models are in the format of "provider/model" | ||
| model = kwargs.pop("model").split("/", 1) | ||
| provider, model = model[0] if len(model) > 1 else "openai", model[-1] | ||
| text_completion_model_name = f"text-completion-openai/{model}" | ||
|
|
||
| # Use the API key and base from the kwargs, or from the environment. | ||
| api_key = kwargs.pop("api_key", None) or os.getenv(f"{provider}_API_KEY") | ||
| api_base = kwargs.pop("api_base", None) or os.getenv(f"{provider}_API_BASE") | ||
|
Comment on lines
-245
to
-247
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These are now part of the router definition, so they don't need to be passed to the |
||
| model = kwargs.pop("model") | ||
| api_config = _extract_provider_api_config(model=model, llm_kwargs=kwargs) | ||
| model_name = model.split("/", 1)[-1] | ||
| text_completion_model_name = f"text-completion-openai/{model_name}" | ||
|
|
||
| # Build the prompt from the messages. | ||
| prompt = "\n\n".join([x["content"] for x in kwargs.pop("messages")] + ["BEGIN RESPONSE:"]) | ||
|
|
||
| router = _get_litellm_router(model=text_completion_model_name, num_retries=num_retries) | ||
| router = _get_litellm_router(model=text_completion_model_name, num_retries=num_retries, api_config=api_config) | ||
| return router.text_completion( | ||
| cache=cache, | ||
| model=text_completion_model_name, | ||
| api_key=api_key, | ||
| api_base=api_base, | ||
| prompt=prompt, | ||
| **kwargs, | ||
| ) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This avoids duplicating code in
litellm_text_completion, which pops the api key and base from the user-specified kwargs