From 599bcf583bd25afc69c6af629205e334c2a0302f Mon Sep 17 00:00:00 2001 From: saileshwar-skyflow Date: Mon, 27 Apr 2026 17:22:18 +0530 Subject: [PATCH 1/3] SK-2777: update client re-init logic --- skyflow/error/_skyflow_error.py | 1 - skyflow/utils/_utils.py | 22 ++++-------- skyflow/utils/validations/_validations.py | 6 ++-- skyflow/vault/client/client.py | 41 ++++++++++++++--------- 4 files changed, 36 insertions(+), 34 deletions(-) diff --git a/skyflow/error/_skyflow_error.py b/skyflow/error/_skyflow_error.py index 7b917fae..fca43935 100644 --- a/skyflow/error/_skyflow_error.py +++ b/skyflow/error/_skyflow_error.py @@ -15,5 +15,4 @@ def __init__(self, self.http_status = http_status if http_status else SkyflowMessages.HttpStatus.BAD_REQUEST.value self.details = details self.request_id = request_id - log_error(message, http_code, request_id, grpc_code, http_status, details) super().__init__() \ No newline at end of file diff --git a/skyflow/utils/_utils.py b/skyflow/utils/_utils.py index 4278357e..7caa0c84 100644 --- a/skyflow/utils/_utils.py +++ b/skyflow/utils/_utils.py @@ -30,26 +30,18 @@ invalid_input_error_code = SkyflowMessages.ErrorCodes.INVALID_INPUT.value def get_credentials(config_level_creds = None, common_skyflow_creds = None, logger = None): - dotenv.load_dotenv() - dotenv_path = dotenv.find_dotenv(usecwd=True) - if dotenv_path: - load_dotenv(dotenv_path) - env_skyflow_credentials = os.getenv("SKYFLOW_CREDENTIALS") if config_level_creds: return config_level_creds if common_skyflow_creds: return common_skyflow_creds + dotenv_path = dotenv.find_dotenv(usecwd=True) + if dotenv_path: + load_dotenv(dotenv_path) + env_skyflow_credentials = os.getenv("SKYFLOW_CREDENTIALS") if env_skyflow_credentials: - env_skyflow_credentials.strip() - try: - env_creds = env_skyflow_credentials.replace('\n', '\\n') - return { - 'credentials_string': env_creds - } - except json.JSONDecodeError: - raise SkyflowError(SkyflowMessages.Error.INVALID_JSON_FORMAT_IN_CREDENTIALS_ENV.value, invalid_input_error_code) - else: - raise SkyflowError(SkyflowMessages.Error.INVALID_CREDENTIALS.value, invalid_input_error_code) + env_creds = env_skyflow_credentials.strip().replace('\n', '\\n') + return {'credentials_string': env_creds} + raise SkyflowError(SkyflowMessages.Error.INVALID_CREDENTIALS.value, invalid_input_error_code) def validate_api_key(api_key: str, logger = None) -> bool: if len(api_key) != 42: diff --git a/skyflow/utils/validations/_validations.py b/skyflow/utils/validations/_validations.py index f3428f45..acca531f 100644 --- a/skyflow/utils/validations/_validations.py +++ b/skyflow/utils/validations/_validations.py @@ -122,8 +122,8 @@ def validate_credentials(logger, credentials, config_id_type=None, config_id=Non ) if is_expired(credentials.get("token"), logger): raise SkyflowError( - SkyflowMessages.Error.INVALID_CREDENTIALS_TOKEN.value.format(config_id_type, config_id) - if config_id_type and config_id else SkyflowMessages.Error.INVALID_CREDENTIALS_TOKEN.value, + SkyflowMessages.Error.EXPIRED_TOKEN.value + if config_id_type and config_id else SkyflowMessages.Error.EXPIRED_TOKEN.value, invalid_input_error_code ) elif "api_key" in credentials: @@ -389,7 +389,7 @@ def validate_deidentify_file_request(logger, request: DeidentifyFileRequest): if hasattr(request, 'wait_time') and request.wait_time is not None: if not isinstance(request.wait_time, (int, float)): raise SkyflowError(SkyflowMessages.Error.INVALID_WAIT_TIME.value, invalid_input_error_code) - if request.wait_time < 0 and request.wait_time > 64: + if request.wait_time < 0 or request.wait_time > 64: raise SkyflowError(SkyflowMessages.Error.WAIT_TIME_GREATER_THEN_64.value, invalid_input_error_code) def validate_insert_request(logger, request): diff --git a/skyflow/vault/client/client.py b/skyflow/vault/client/client.py index f47a525c..38d17011 100644 --- a/skyflow/vault/client/client.py +++ b/skyflow/vault/client/client.py @@ -14,6 +14,9 @@ def __init__(self, config): self.__logger = None self.__is_config_updated = False self.__bearer_token = None + self.__credentials = None + self.__vault_url = None + self.__is_static_token = None def set_common_skyflow_credentials(self, credentials): self.__common_skyflow_credentials = credentials @@ -23,16 +26,29 @@ def set_logger(self, log_level, logger): self.__logger = logger def initialize_client_configuration(self): - credentials = get_credentials(self.__config.get("credentials"), self.__common_skyflow_credentials, logger = self.__logger) - token = self.get_bearer_token(credentials) - vault_url = get_vault_url(self.__config.get("cluster_id"), - self.__config.get("env"), - self.__config.get("vault_id"), - logger = self.__logger) - self.initialize_api_client(vault_url, token) + if self.__api_client is not None and not self.__is_config_updated: + if self.__is_static_token: + return + if self.__bearer_token is not None and not is_expired(self.__bearer_token): + return + + needs_reinit = self.__api_client is None or self.__is_config_updated + if needs_reinit: + self.__credentials = get_credentials(self.__config.get("credentials"), self.__common_skyflow_credentials, logger=self.__logger) + self.__vault_url = get_vault_url(self.__config.get("cluster_id"), + self.__config.get("env"), + self.__config.get("vault_id"), + logger=self.__logger) + self.__is_static_token = 'token' in self.__credentials or 'api_key' in self.__credentials + token = self.get_bearer_token(self.__credentials) + if needs_reinit: + self.initialize_api_client(self.__vault_url, token) def initialize_api_client(self, vault_url, token): - self.__api_client = Skyflow(base_url=vault_url, token=token) + self.__api_client = Skyflow( + base_url=vault_url, + token=lambda: self.__bearer_token if self.__bearer_token else token, + ) def get_records_api(self): return self.__api_client.records @@ -63,11 +79,10 @@ def get_bearer_token(self, credentials): "ctx": self.__config.get("ctx") } - if self.__bearer_token is None or self.__is_config_updated: + if self.__bearer_token is None or self.__is_config_updated or is_expired(self.__bearer_token): if 'path' in credentials: - path = credentials.get("path") self.__bearer_token, _ = generate_bearer_token( - path, + credentials.get("path"), options, self.__logger ) @@ -83,10 +98,6 @@ def get_bearer_token(self, credentials): else: log_info(SkyflowMessages.Info.REUSE_BEARER_TOKEN.value, self.__logger) - if is_expired(self.__bearer_token): - self.__is_config_updated = True - raise SyntaxError(SkyflowMessages.Error.EXPIRED_TOKEN.value, SkyflowMessages.ErrorCodes.INVALID_INPUT.value) - return self.__bearer_token def update_config(self, config): From 763af7294cd603c51205d3115ff7f217f816ff7d Mon Sep 17 00:00:00 2001 From: saileshwar-skyflow Date: Mon, 27 Apr 2026 17:39:55 +0530 Subject: [PATCH 2/3] SK-2777: update unit tests --- skyflow/vault/client/client.py | 12 +- tests/vault/client/test__client.py | 316 ++++++++++++++++++++++++----- 2 files changed, 275 insertions(+), 53 deletions(-) diff --git a/skyflow/vault/client/client.py b/skyflow/vault/client/client.py index 38d17011..0304c11a 100644 --- a/skyflow/vault/client/client.py +++ b/skyflow/vault/client/client.py @@ -40,15 +40,13 @@ def initialize_client_configuration(self): self.__config.get("vault_id"), logger=self.__logger) self.__is_static_token = 'token' in self.__credentials or 'api_key' in self.__credentials - token = self.get_bearer_token(self.__credentials) + bearer_token = self.get_bearer_token(self.__credentials) if needs_reinit: - self.initialize_api_client(self.__vault_url, token) + self.initialize_api_client(self.__vault_url, bearer_token) - def initialize_api_client(self, vault_url, token): - self.__api_client = Skyflow( - base_url=vault_url, - token=lambda: self.__bearer_token if self.__bearer_token else token, - ) + def initialize_api_client(self, vault_url, bearer_token): + token_provider = lambda: self.__bearer_token if self.__bearer_token else bearer_token # noqa: E731 + self.__api_client = Skyflow(base_url=vault_url, token=token_provider) def get_records_api(self): return self.__api_client.records diff --git a/tests/vault/client/test__client.py b/tests/vault/client/test__client.py index 565b1e6f..9d0d2520 100644 --- a/tests/vault/client/test__client.py +++ b/tests/vault/client/test__client.py @@ -1,5 +1,5 @@ import unittest -from unittest.mock import patch, MagicMock +from unittest.mock import patch, MagicMock, call from skyflow.vault.client.client import VaultClient CONFIG = { @@ -12,11 +12,19 @@ } CREDENTIALS_WITH_API_KEY = {"api_key": "dummy_api_key"} +CREDENTIALS_WITH_TOKEN = {"token": "dummy_static_token"} +CREDENTIALS_WITH_PATH = {"path": "/some/path/credentials.json"} +CREDENTIALS_WITH_STRING = {"credentials_string": '{"clientID": "x"}'} + class TestVaultClient(unittest.TestCase): def setUp(self): self.vault_client = VaultClient(CONFIG) + # ------------------------------------------------------------------ # + # Basic setters / getters # + # ------------------------------------------------------------------ # + def test_set_common_skyflow_credentials(self): credentials = {"api_key": "dummy_api_key"} self.vault_client.set_common_skyflow_credentials(credentials) @@ -28,73 +36,289 @@ def test_set_logger(self): self.assertEqual(self.vault_client.get_log_level(), "INFO") self.assertEqual(self.vault_client.get_logger(), mock_logger) + def test_get_vault_id(self): + self.assertEqual(self.vault_client.get_vault_id(), CONFIG["vault_id"]) + + def test_get_config(self): + self.assertEqual(self.vault_client.get_config(), CONFIG) + + def test_get_common_skyflow_credentials(self): + credentials = {"api_key": "dummy_api_key"} + self.vault_client.set_common_skyflow_credentials(credentials) + self.assertEqual(self.vault_client.get_common_skyflow_credentials(), credentials) + + def test_get_log_level(self): + self.vault_client.set_logger("DEBUG", MagicMock()) + self.assertEqual(self.vault_client.get_log_level(), "DEBUG") + + def test_get_logger(self): + mock_logger = MagicMock() + self.vault_client.set_logger("INFO", mock_logger) + self.assertEqual(self.vault_client.get_logger(), mock_logger) + + # ------------------------------------------------------------------ # + # initialize_client_configuration — first call (slow path) # + # ------------------------------------------------------------------ # + @patch("skyflow.vault.client.client.get_credentials") @patch("skyflow.vault.client.client.get_vault_url") @patch("skyflow.vault.client.client.VaultClient.initialize_api_client") - def test_initialize_client_configuration(self, mock_init_api_client, mock_get_vault_url, mock_get_credentials): - mock_get_credentials.return_value = (CREDENTIALS_WITH_API_KEY) + def test_initialize_client_configuration_first_call( + self, mock_init_api_client, mock_get_vault_url, mock_get_credentials + ): + mock_get_credentials.return_value = CREDENTIALS_WITH_API_KEY mock_get_vault_url.return_value = "https://test-vault-url.com" self.vault_client.initialize_client_configuration() - mock_get_credentials.assert_called_once_with(CONFIG["credentials"], None, logger=None) - mock_get_vault_url.assert_called_once_with(CONFIG["cluster_id"], CONFIG["env"], CONFIG["vault_id"], logger=None) + mock_get_credentials.assert_called_once_with( + CONFIG["credentials"], None, logger=None + ) + mock_get_vault_url.assert_called_once_with( + CONFIG["cluster_id"], CONFIG["env"], CONFIG["vault_id"], logger=None + ) mock_init_api_client.assert_called_once() - @patch("skyflow.vault.client.client.Skyflow") - def test_initialize_api_client(self, mock_api_client): - self.vault_client.initialize_api_client("https://test-vault-url.com", "dummy_token") - mock_api_client.assert_called_once_with(base_url="https://test-vault-url.com", token="dummy_token") + # ------------------------------------------------------------------ # + # initialize_client_configuration — fast path (static token) # + # ------------------------------------------------------------------ # - def test_get_records_api(self): + @patch("skyflow.vault.client.client.get_credentials") + @patch("skyflow.vault.client.client.get_vault_url") + @patch("skyflow.vault.client.client.VaultClient.initialize_api_client") + def test_initialize_client_configuration_fast_path_api_key( + self, mock_init_api_client, mock_get_vault_url, mock_get_credentials + ): + """Once initialized with api_key, subsequent calls skip all work.""" + mock_get_credentials.return_value = CREDENTIALS_WITH_API_KEY + mock_get_vault_url.return_value = "https://test-vault-url.com" + # Side-effect simulates initialize_api_client actually setting __api_client + mock_init_api_client.side_effect = lambda *_: setattr( + self.vault_client, "_VaultClient__api_client", MagicMock() + ) + + self.vault_client.initialize_client_configuration() # first call — slow path + mock_get_credentials.reset_mock() + mock_get_vault_url.reset_mock() + mock_init_api_client.reset_mock() + + self.vault_client.initialize_client_configuration() # second call — fast path + + mock_get_credentials.assert_not_called() + mock_get_vault_url.assert_not_called() + mock_init_api_client.assert_not_called() + + @patch("skyflow.vault.client.client.get_credentials") + @patch("skyflow.vault.client.client.get_vault_url") + @patch("skyflow.vault.client.client.VaultClient.initialize_api_client") + def test_initialize_client_configuration_fast_path_static_token( + self, mock_init_api_client, mock_get_vault_url, mock_get_credentials + ): + """Once initialized with a static token, subsequent calls skip all work.""" + mock_get_credentials.return_value = CREDENTIALS_WITH_TOKEN + mock_get_vault_url.return_value = "https://test-vault-url.com" + mock_init_api_client.side_effect = lambda *_: setattr( + self.vault_client, "_VaultClient__api_client", MagicMock() + ) + + self.vault_client.initialize_client_configuration() + mock_get_credentials.reset_mock() + mock_get_vault_url.reset_mock() + mock_init_api_client.reset_mock() + + self.vault_client.initialize_client_configuration() + + mock_get_credentials.assert_not_called() + mock_get_vault_url.assert_not_called() + mock_init_api_client.assert_not_called() + + # ------------------------------------------------------------------ # + # initialize_client_configuration — fast path (service account) # + # ------------------------------------------------------------------ # + + @patch("skyflow.vault.client.client.is_expired", return_value=False) + @patch("skyflow.vault.client.client.get_credentials") + @patch("skyflow.vault.client.client.get_vault_url") + @patch("skyflow.vault.client.client.VaultClient.initialize_api_client") + def test_initialize_client_configuration_fast_path_valid_sa_token( + self, mock_init_api_client, mock_get_vault_url, mock_get_credentials, mock_is_expired + ): + """Service account with a still-valid token skips get_bearer_token entirely.""" + mock_get_credentials.return_value = CREDENTIALS_WITH_PATH + mock_get_vault_url.return_value = "https://test-vault-url.com" + + # Seed the cached bearer token as if first call already ran self.vault_client._VaultClient__api_client = MagicMock() - self.vault_client._VaultClient__api_client.records = MagicMock() - records_api = self.vault_client.get_records_api() - self.assertIsNotNone(records_api) + self.vault_client._VaultClient__is_static_token = False + self.vault_client._VaultClient__bearer_token = "cached_sa_token" + self.vault_client._VaultClient__credentials = CREDENTIALS_WITH_PATH - def test_get_tokens_api(self): + self.vault_client.initialize_client_configuration() + + mock_get_credentials.assert_not_called() + mock_get_vault_url.assert_not_called() + mock_init_api_client.assert_not_called() + + # ------------------------------------------------------------------ # + # initialize_client_configuration — token expiry (no client reinit) # + # ------------------------------------------------------------------ # + + @patch("skyflow.vault.client.client.generate_bearer_token", return_value=("new_sa_token", None)) + @patch("skyflow.vault.client.client.is_expired", return_value=True) + @patch("skyflow.vault.client.client.get_credentials") + @patch("skyflow.vault.client.client.get_vault_url") + @patch("skyflow.vault.client.client.VaultClient.initialize_api_client") + def test_initialize_client_configuration_expired_token_no_reinit( + self, mock_init_api_client, mock_get_vault_url, mock_get_credentials, + mock_is_expired, mock_generate_bearer_token + ): + """Expired service account token is regenerated in-place; httpx client is NOT recreated.""" + mock_get_credentials.return_value = CREDENTIALS_WITH_PATH + mock_get_vault_url.return_value = "https://test-vault-url.com" + + # Client already initialized — simulate warm state with an expired token self.vault_client._VaultClient__api_client = MagicMock() - self.vault_client._VaultClient__api_client.tokens = MagicMock() - tokens_api = self.vault_client.get_tokens_api() - self.assertIsNotNone(tokens_api) + self.vault_client._VaultClient__is_static_token = False + self.vault_client._VaultClient__bearer_token = "expired_sa_token" + self.vault_client._VaultClient__credentials = CREDENTIALS_WITH_PATH - def test_get_query_api(self): + self.vault_client.initialize_client_configuration() + + # Token was regenerated + mock_generate_bearer_token.assert_called_once() + self.assertEqual( + self.vault_client._VaultClient__bearer_token, "new_sa_token" + ) + # httpx client was NOT recreated + mock_init_api_client.assert_not_called() + + # ------------------------------------------------------------------ # + # initialize_client_configuration — config update forces reinit # + # ------------------------------------------------------------------ # + + @patch("skyflow.vault.client.client.get_credentials") + @patch("skyflow.vault.client.client.get_vault_url") + @patch("skyflow.vault.client.client.VaultClient.initialize_api_client") + def test_initialize_client_configuration_reinit_after_update_config( + self, mock_init_api_client, mock_get_vault_url, mock_get_credentials + ): + """update_config() marks the client stale; next call must recreate it.""" + mock_get_credentials.return_value = CREDENTIALS_WITH_API_KEY + mock_get_vault_url.return_value = "https://test-vault-url.com" + + # Simulate already-initialized client self.vault_client._VaultClient__api_client = MagicMock() - self.vault_client._VaultClient__api_client.query = MagicMock() - query_api = self.vault_client.get_query_api() - self.assertIsNotNone(query_api) + self.vault_client._VaultClient__is_static_token = True - def test_get_vault_id(self): - self.assertEqual(self.vault_client.get_vault_id(), CONFIG["vault_id"]) + self.vault_client.update_config({"cluster_id": "new_cluster"}) + self.vault_client.initialize_client_configuration() + + mock_get_credentials.assert_called_once() + mock_get_vault_url.assert_called_once() + mock_init_api_client.assert_called_once() + + # ------------------------------------------------------------------ # + # initialize_api_client — lambda token provider # + # ------------------------------------------------------------------ # + + @patch("skyflow.vault.client.client.Skyflow") + def test_initialize_api_client_passes_callable_token(self, mock_skyflow): + """initialize_api_client must pass a callable (lambda) as token, not a string.""" + self.vault_client.initialize_api_client("https://test-vault-url.com", "initial_token") + + args, kwargs = mock_skyflow.call_args + self.assertEqual(kwargs["base_url"], "https://test-vault-url.com") + self.assertTrue(callable(kwargs["token"]), "token must be a callable (lambda)") + + @patch("skyflow.vault.client.client.Skyflow") + def test_initialize_api_client_lambda_returns_cached_bearer_token(self, mock_skyflow): + """Lambda returns __bearer_token when it is set (interceptor behaviour).""" + self.vault_client._VaultClient__bearer_token = "refreshed_token" + self.vault_client.initialize_api_client("https://test-vault-url.com", "initial_token") + + _, kwargs = mock_skyflow.call_args + self.assertEqual(kwargs["token"](), "refreshed_token") + + @patch("skyflow.vault.client.client.Skyflow") + def test_initialize_api_client_lambda_falls_back_to_initial_token(self, mock_skyflow): + """Lambda falls back to the initial token when __bearer_token is None.""" + self.vault_client._VaultClient__bearer_token = None + self.vault_client.initialize_api_client("https://test-vault-url.com", "initial_token") + + _, kwargs = mock_skyflow.call_args + self.assertEqual(kwargs["token"](), "initial_token") + + # ------------------------------------------------------------------ # + # get_bearer_token # + # ------------------------------------------------------------------ # + + def test_get_bearer_token_with_api_key(self): + result = self.vault_client.get_bearer_token(CREDENTIALS_WITH_API_KEY) + self.assertEqual(result, "dummy_api_key") + + def test_get_bearer_token_with_static_token(self): + result = self.vault_client.get_bearer_token(CREDENTIALS_WITH_TOKEN) + self.assertEqual(result, "dummy_static_token") + + @patch("skyflow.vault.client.client.generate_bearer_token", return_value=("sa_token", None)) + def test_get_bearer_token_generates_from_path_on_first_call(self, mock_generate): + result = self.vault_client.get_bearer_token(CREDENTIALS_WITH_PATH) + mock_generate.assert_called_once() + self.assertEqual(result, "sa_token") + self.assertEqual(self.vault_client._VaultClient__bearer_token, "sa_token") + + @patch("skyflow.vault.client.client.generate_bearer_token_from_creds", return_value=("sa_token_str", None)) + @patch("skyflow.vault.client.client.log_info") + def test_get_bearer_token_generates_from_credentials_string(self, mock_log, mock_generate): + result = self.vault_client.get_bearer_token(CREDENTIALS_WITH_STRING) + mock_generate.assert_called_once() + self.assertEqual(result, "sa_token_str") + + @patch("skyflow.vault.client.client.generate_bearer_token", return_value=("new_token", None)) + @patch("skyflow.vault.client.client.is_expired", return_value=True) + @patch("skyflow.vault.client.client.log_info") + def test_get_bearer_token_regenerates_on_expiry(self, mock_log, mock_is_expired, mock_generate): + """Expired token is regenerated silently — no exception raised.""" + self.vault_client._VaultClient__bearer_token = "expired_token" + result = self.vault_client.get_bearer_token(CREDENTIALS_WITH_PATH) + mock_generate.assert_called_once() + self.assertEqual(result, "new_token") @patch("skyflow.vault.client.client.generate_bearer_token") - @patch("skyflow.vault.client.client.generate_bearer_token_from_creds") + @patch("skyflow.vault.client.client.is_expired", return_value=False) @patch("skyflow.vault.client.client.log_info") - def test_get_bearer_token_with_api_key(self, mock_log_info, mock_generate_bearer_token, - mock_generate_bearer_token_from_creds): - token = self.vault_client.get_bearer_token(CREDENTIALS_WITH_API_KEY) - self.assertEqual(token, CREDENTIALS_WITH_API_KEY["api_key"]) - - def test_update_config(self): - new_config = {"credentials": "new_credentials"} - self.vault_client.update_config(new_config) + def test_get_bearer_token_reuses_valid_cached_token(self, mock_log, mock_is_expired, mock_generate): + """Valid cached token is reused without calling generate_bearer_token.""" + self.vault_client._VaultClient__bearer_token = "valid_token" + result = self.vault_client.get_bearer_token(CREDENTIALS_WITH_PATH) + mock_generate.assert_not_called() + self.assertEqual(result, "valid_token") + + # ------------------------------------------------------------------ # + # update_config # + # ------------------------------------------------------------------ # + + def test_update_config_sets_flag(self): + self.vault_client.update_config({"credentials": "new_credentials"}) self.assertTrue(self.vault_client._VaultClient__is_config_updated) self.assertEqual(self.vault_client.get_config()["credentials"], "new_credentials") - def test_get_config(self): - self.assertEqual(self.vault_client.get_config(), CONFIG) + # ------------------------------------------------------------------ # + # API accessor stubs # + # ------------------------------------------------------------------ # - def test_get_common_skyflow_credentials(self): - credentials = {"api_key": "dummy_api_key"} - self.vault_client.set_common_skyflow_credentials(credentials) - self.assertEqual(self.vault_client.get_common_skyflow_credentials(), credentials) + def test_get_records_api(self): + self.vault_client._VaultClient__api_client = MagicMock() + self.assertIsNotNone(self.vault_client.get_records_api()) - def test_get_log_level(self): - log_level = "DEBUG" - self.vault_client.set_logger(log_level, MagicMock()) - self.assertEqual(self.vault_client.get_log_level(), log_level) + def test_get_tokens_api(self): + self.vault_client._VaultClient__api_client = MagicMock() + self.assertIsNotNone(self.vault_client.get_tokens_api()) - def test_get_logger(self): - mock_logger = MagicMock() - self.vault_client.set_logger("INFO", mock_logger) - self.assertEqual(self.vault_client.get_logger(), mock_logger) \ No newline at end of file + def test_get_query_api(self): + self.vault_client._VaultClient__api_client = MagicMock() + self.assertIsNotNone(self.vault_client.get_query_api()) + + +if __name__ == "__main__": + unittest.main() From 934fc657400d8e63868b4f002508f26c3f073150 Mon Sep 17 00:00:00 2001 From: saileshwar-skyflow Date: Mon, 27 Apr 2026 17:44:26 +0530 Subject: [PATCH 3/3] SK-2777: update unit tests --- skyflow/utils/_skyflow_messages.py | 2 +- tests/utils/validations/test__validations.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/skyflow/utils/_skyflow_messages.py b/skyflow/utils/_skyflow_messages.py index 3672cfa8..8aea3b8b 100644 --- a/skyflow/utils/_skyflow_messages.py +++ b/skyflow/utils/_skyflow_messages.py @@ -47,7 +47,7 @@ class Error(Enum): EMPTY_CREDENTIALS_TOKEN = f"{error_prefix} Initialization failed. Invalid token.Specify a valid credentials token." INVALID_CREDENTIALS_TOKEN_IN_CONFIG = f"{error_prefix} Initialization failed. Invalid credentials token for {{}} with id {{}}. Expected token to be a string." INVALID_CREDENTIALS_TOKEN = f"{error_prefix} Initialization failed. Invalid credentials token. Expected token to be a string." - EXPIRED_TOKEN = f"${error_prefix} Initialization failed. Given token is expired. Specify a valid credentials token." + EXPIRED_TOKEN = f"{error_prefix} Initialization failed. Given token is expired. Specify a valid credentials token." EMPTY_API_KEY_IN_CONFIG = f"{error_prefix} Initialization failed. Invalid api key for {{}} with id {{}}.Specify a valid api key." EMPTY_API_KEY= f"{error_prefix} Initialization failed. Invalid api key.Specify a valid api key." INVALID_API_KEY_IN_CONFIG = f"{error_prefix} Initialization failed. Invalid api key for {{}} with id {{}}. Expected api key to be a string." diff --git a/tests/utils/validations/test__validations.py b/tests/utils/validations/test__validations.py index 48332a55..b1247ebc 100644 --- a/tests/utils/validations/test__validations.py +++ b/tests/utils/validations/test__validations.py @@ -116,7 +116,7 @@ def test_validate_credentials_with_expired_token(self): with patch('skyflow.service_account.is_expired', return_value=True): with self.assertRaises(SkyflowError) as context: validate_credentials(self.logger, credentials) - self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CREDENTIALS_TOKEN.value) + self.assertEqual(context.exception.message, SkyflowMessages.Error.EXPIRED_TOKEN.value) def test_validate_credentials_empty_credentials(self): credentials = {}