diff --git a/CHANGELOG.md b/CHANGELOG.md index 83320ed91c915..3b68072555958 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ ### Bug Fixes / Nits - tune prompt to get rid of KeyError in SubQ engine (#7039) +- Fix validation of Azure OpenAI keys (#7042) ## [0.7.12] - 2023-07-25 diff --git a/llama_index/embeddings/openai.py b/llama_index/embeddings/openai.py index c033b496894b2..2c7a11bbe052d 100644 --- a/llama_index/embeddings/openai.py +++ b/llama_index/embeddings/openai.py @@ -245,7 +245,9 @@ def __init__( # Validate that either the openai.api_key property # or OPENAI_API_KEY env variable are set to a valid key # Raises ValueError if missing or doesn't match valid format - validate_openai_api_key(kwargs.get("api_key", None)) + validate_openai_api_key( + kwargs.get("api_key", None), kwargs.get("api_type", None) + ) """Init params.""" super().__init__(embed_batch_size, tokenizer, callback_manager) diff --git a/llama_index/llms/openai.py b/llama_index/llms/openai.py index de6e20d3e69a2..30b47048272a6 100644 --- a/llama_index/llms/openai.py +++ b/llama_index/llms/openai.py @@ -43,7 +43,9 @@ class OpenAI(LLM, BaseModel): max_retries: int = 10 def __init__(self, *args: Any, **kwargs: Any) -> None: - validate_openai_api_key(kwargs.get("api_key", None)) + validate_openai_api_key( + kwargs.get("api_key", None), kwargs.get("api_type", None) + ) super().__init__(*args, **kwargs) @property diff --git a/llama_index/llms/openai_utils.py b/llama_index/llms/openai_utils.py index 2594be2591b9f..fbced5e1f1935 100644 --- a/llama_index/llms/openai_utils.py +++ b/llama_index/llms/openai_utils.py @@ -256,9 +256,17 @@ def to_openai_function(pydantic_class: Type[BaseModel]) -> Dict[str, Any]: } -def validate_openai_api_key(api_key: Optional[str] = None) -> None: +def validate_openai_api_key( + api_key: Optional[str] = None, api_type: Optional[str] = None +) -> None: openai_api_key = api_key or os.environ.get("OPENAI_API_KEY", "") or openai.api_key + openai_api_type = ( + api_type or os.environ.get("OPENAI_API_TYPE", "") or openai.api_type + ) + if not openai_api_key: raise ValueError(MISSING_API_KEY_ERROR_MESSAGE) - elif not OPENAI_API_KEY_FORMAT.search(openai_api_key): + elif openai_api_type == "open_ai" and not OPENAI_API_KEY_FORMAT.search( + openai_api_key + ): raise ValueError(INVALID_API_KEY_ERROR_MESSAGE) diff --git a/tests/conftest.py b/tests/conftest.py index 2c4d292c680f3..24e0158631eb9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -93,8 +93,8 @@ def mock_openai_credentials() -> None: class CachedOpenAIApiKeys: """ - Saves the users' OpenAI API key either in the environment variable - or set to the library itself. + Saves the users' OpenAI API key and OpenAI API type either in + the environment variable or set to the library itself. This allows us to run tests by setting it without plowing over the local environment. """ @@ -104,17 +104,25 @@ def __init__( set_env_key_to: Optional[str] = "", set_library_key_to: Optional[str] = None, set_fake_key: bool = False, + set_env_type_to: Optional[str] = "", + set_library_type_to: str = "open_ai", # default value in openai package ): self.set_env_key_to = set_env_key_to self.set_library_key_to = set_library_key_to self.set_fake_key = set_fake_key + self.set_env_type_to = set_env_type_to + self.set_library_type_to = set_library_type_to def __enter__(self) -> None: self.api_env_variable_was = os.environ.get("OPENAI_API_KEY", "") + self.api_env_type_was = os.environ.get("OPENAI_API_TYPE", "") self.openai_api_key_was = openai.api_key + self.openai_api_type_was = openai.api_type os.environ["OPENAI_API_KEY"] = str(self.set_env_key_to) + os.environ["OPENAI_API_TYPE"] = str(self.set_env_type_to) openai.api_key = self.set_library_key_to + openai.api_type = self.set_library_type_to if self.set_fake_key: openai.api_key = "sk-" + "a" * 48 @@ -122,4 +130,6 @@ def __enter__(self) -> None: # No matter what, set the environment variable back to what it was def __exit__(self, *exc: Any) -> None: os.environ["OPENAI_API_KEY"] = str(self.api_env_variable_was) + os.environ["OPENAI_API_TYPE"] = str(self.api_env_type_was) openai.api_key = self.openai_api_key_was + openai.api_type = self.openai_api_type_was diff --git a/tests/embeddings/test_base.py b/tests/embeddings/test_base.py index 2fe279edd411d..099e6ac33faa0 100644 --- a/tests/embeddings/test_base.py +++ b/tests/embeddings/test_base.py @@ -128,8 +128,18 @@ def test_validates_api_key_format_from_env() -> None: with pytest.raises(ValueError, match="Invalid OpenAI API key."): OpenAIEmbedding() + with CachedOpenAIApiKeys( + set_env_key_to="api-hf47930g732gf372", set_env_type_to="azure" + ): + assert OpenAIEmbedding() + def test_validates_api_key_format_in_library() -> None: with CachedOpenAIApiKeys(set_library_key_to="api-hf47930g732gf372"): with pytest.raises(ValueError, match="Invalid OpenAI API key."): OpenAIEmbedding() + + with CachedOpenAIApiKeys( + set_library_key_to="api-hf47930g732gf372", set_library_type_to="azure" + ): + assert OpenAIEmbedding()