Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OAI api refactor #391

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
more comments
  • Loading branch information
you-n-g committed Sep 29, 2024
commit 8408dbd4a5794ca2ab3e6139cad323e1bd84a738
69 changes: 69 additions & 0 deletions rdagent/oai/backends/az.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
"""
TODO:
It is not complete now.

Please refer to rdagent/oai/llm_utils.py:APIBackend for the future design
"""

from azure.identity import DefaultAzureCredential, get_bearer_token_provider
import openai
from pydantic_settings import BaseSettings


class AzureConf(BaseSettings):
"""
TODO: move more settings here
"""
use_azure_token_provider: bool = False
managed_identity_client_id: str | None = None
chat_model: str = "gpt-4-turbo"

chat_azure_api_base: str = ""
chat_azure_api_version: str = ""


class BaseAPI:
"""
TOOD: there may be some more shared methods in the BaseAPI
"""
pass


class AzureAPI(BaseAPI):

def _get_credential(self):
dac_kwargs = {}
if AZURE_CONF.managed_identity_client_id is not None:
dac_kwargs["managed_identity_client_id"] = self.managed_identity_client_id
credential = DefaultAzureCredential(**dac_kwargs)
return credential

def _get_client(self):
kwargs = {}
if AZURE_CONF.use_azure_token_provider:
kwargs["azure_ad_token_provider"]= get_bearer_token_provider(
self._get_credential(),
"https://cognitiveservices.azure.com/.default",
)
return openai.AzureOpenAI(
api_version=AZURE_CONF.chat_azure_api_version,
azure_endpoint=AZURE_CONF.chat_azure_api_base,
**kwargs,
)

# def list_deployments(self):
# client = self._get_client()
# try:
# deployments = client.deployments.list()
# return [deployment for deployment in deployments]
# except Exception as e:
# print(f"An error occurred while listing deployments: {e}")
# return []

AZURE_CONF = AzureConf()


# if __name__ == "__main__":
# api = AzureAPI()
# deployments = api.list_deployments()
# print(deployments)
10 changes: 6 additions & 4 deletions rdagent/oai/llm_utils.py
Original file line number Diff line number Diff line change
@@ -235,6 +235,12 @@ def display_history(self) -> None:


class APIBackend:
"""
This is a unified interface for different backends.

(xiao) thinks integerate all kinds of API in a single class is not a good design.
So we should split them into different classes in `oai/backends/` in the future.
"""
# FIXME: (xiao) I think we should skip using self.xxxx
# We can use self.cfg directly. If it is hard to 兼容 different settings of backends. We can split it into multiple BaseSettings.
def __init__( # noqa: C901, PLR0912, PLR0915
@@ -384,10 +390,6 @@ def __init__( # noqa: C901, PLR0912, PLR0915
self.use_gcr_endpoint = self.cfg.use_gcr_endpoint
self.retry_wait_seconds = self.cfg.retry_wait_seconds

def list_available_deployments(self):
if self.use_azure:
# TODO:

def build_chat_session(
self,
conversation_id: str | None = None,
Loading
Oops, something went wrong.