Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion skyflow/error/_skyflow_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,4 @@ def __init__(self,
self.http_status = http_status if http_status else SkyflowMessages.HttpStatus.BAD_REQUEST.value
self.details = details
self.request_id = request_id
log_error(message, http_code, request_id, grpc_code, http_status, details)
super().__init__()
2 changes: 1 addition & 1 deletion skyflow/utils/_skyflow_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class Error(Enum):
EMPTY_CREDENTIALS_TOKEN = f"{error_prefix} Initialization failed. Invalid token.Specify a valid credentials token."
INVALID_CREDENTIALS_TOKEN_IN_CONFIG = f"{error_prefix} Initialization failed. Invalid credentials token for {{}} with id {{}}. Expected token to be a string."
INVALID_CREDENTIALS_TOKEN = f"{error_prefix} Initialization failed. Invalid credentials token. Expected token to be a string."
EXPIRED_TOKEN = f"${error_prefix} Initialization failed. Given token is expired. Specify a valid credentials token."
EXPIRED_TOKEN = f"{error_prefix} Initialization failed. Given token is expired. Specify a valid credentials token."
EMPTY_API_KEY_IN_CONFIG = f"{error_prefix} Initialization failed. Invalid api key for {{}} with id {{}}.Specify a valid api key."
EMPTY_API_KEY= f"{error_prefix} Initialization failed. Invalid api key.Specify a valid api key."
INVALID_API_KEY_IN_CONFIG = f"{error_prefix} Initialization failed. Invalid api key for {{}} with id {{}}. Expected api key to be a string."
Expand Down
22 changes: 7 additions & 15 deletions skyflow/utils/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,26 +30,18 @@
invalid_input_error_code = SkyflowMessages.ErrorCodes.INVALID_INPUT.value

def get_credentials(config_level_creds = None, common_skyflow_creds = None, logger = None):
dotenv.load_dotenv()
dotenv_path = dotenv.find_dotenv(usecwd=True)
if dotenv_path:
load_dotenv(dotenv_path)
env_skyflow_credentials = os.getenv("SKYFLOW_CREDENTIALS")
if config_level_creds:
return config_level_creds
if common_skyflow_creds:
return common_skyflow_creds
dotenv_path = dotenv.find_dotenv(usecwd=True)
if dotenv_path:
load_dotenv(dotenv_path)
env_skyflow_credentials = os.getenv("SKYFLOW_CREDENTIALS")
if env_skyflow_credentials:
env_skyflow_credentials.strip()
try:
env_creds = env_skyflow_credentials.replace('\n', '\\n')
return {
'credentials_string': env_creds
}
except json.JSONDecodeError:
raise SkyflowError(SkyflowMessages.Error.INVALID_JSON_FORMAT_IN_CREDENTIALS_ENV.value, invalid_input_error_code)
else:
raise SkyflowError(SkyflowMessages.Error.INVALID_CREDENTIALS.value, invalid_input_error_code)
env_creds = env_skyflow_credentials.strip().replace('\n', '\\n')
return {'credentials_string': env_creds}
raise SkyflowError(SkyflowMessages.Error.INVALID_CREDENTIALS.value, invalid_input_error_code)

def validate_api_key(api_key: str, logger = None) -> bool:
if len(api_key) != 42:
Expand Down
6 changes: 3 additions & 3 deletions skyflow/utils/validations/_validations.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,8 @@ def validate_credentials(logger, credentials, config_id_type=None, config_id=Non
)
if is_expired(credentials.get("token"), logger):
raise SkyflowError(
SkyflowMessages.Error.INVALID_CREDENTIALS_TOKEN.value.format(config_id_type, config_id)
if config_id_type and config_id else SkyflowMessages.Error.INVALID_CREDENTIALS_TOKEN.value,
SkyflowMessages.Error.EXPIRED_TOKEN.value
if config_id_type and config_id else SkyflowMessages.Error.EXPIRED_TOKEN.value,
invalid_input_error_code
)
elif "api_key" in credentials:
Expand Down Expand Up @@ -389,7 +389,7 @@ def validate_deidentify_file_request(logger, request: DeidentifyFileRequest):
if hasattr(request, 'wait_time') and request.wait_time is not None:
if not isinstance(request.wait_time, (int, float)):
raise SkyflowError(SkyflowMessages.Error.INVALID_WAIT_TIME.value, invalid_input_error_code)
if request.wait_time < 0 and request.wait_time > 64:
if request.wait_time < 0 or request.wait_time > 64:
raise SkyflowError(SkyflowMessages.Error.WAIT_TIME_GREATER_THEN_64.value, invalid_input_error_code)

def validate_insert_request(logger, request):
Expand Down
43 changes: 26 additions & 17 deletions skyflow/vault/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
self.__logger = None
self.__is_config_updated = False
self.__bearer_token = None
self.__credentials = None
self.__vault_url = None
self.__is_static_token = None

def set_common_skyflow_credentials(self, credentials):
self.__common_skyflow_credentials = credentials
Expand All @@ -23,16 +26,27 @@
self.__logger = logger

def initialize_client_configuration(self):
credentials = get_credentials(self.__config.get("credentials"), self.__common_skyflow_credentials, logger = self.__logger)
token = self.get_bearer_token(credentials)
vault_url = get_vault_url(self.__config.get("cluster_id"),
self.__config.get("env"),
self.__config.get("vault_id"),
logger = self.__logger)
self.initialize_api_client(vault_url, token)

def initialize_api_client(self, vault_url, token):
self.__api_client = Skyflow(base_url=vault_url, token=token)
if self.__api_client is not None and not self.__is_config_updated:
if self.__is_static_token:
return
if self.__bearer_token is not None and not is_expired(self.__bearer_token):
return

needs_reinit = self.__api_client is None or self.__is_config_updated
if needs_reinit:
self.__credentials = get_credentials(self.__config.get("credentials"), self.__common_skyflow_credentials, logger=self.__logger)
self.__vault_url = get_vault_url(self.__config.get("cluster_id"),
self.__config.get("env"),
self.__config.get("vault_id"),
logger=self.__logger)
self.__is_static_token = 'token' in self.__credentials or 'api_key' in self.__credentials
bearer_token = self.get_bearer_token(self.__credentials)
if needs_reinit:
self.initialize_api_client(self.__vault_url, bearer_token)

def initialize_api_client(self, vault_url, bearer_token):
token_provider = lambda: self.__bearer_token if self.__bearer_token else bearer_token # noqa: E731
self.__api_client = Skyflow(base_url=vault_url, token=token_provider)

Check failure

Code scanning / Semgrep OSS

Semgrep Finding: semgreprules.check-sensitive-info Error

Potential sensitive information found: token
Comment thread
saileshwar-skyflow marked this conversation as resolved.
Dismissed

def get_records_api(self):
return self.__api_client.records
Expand Down Expand Up @@ -63,11 +77,10 @@
"ctx": self.__config.get("ctx")
}

if self.__bearer_token is None or self.__is_config_updated:
if self.__bearer_token is None or self.__is_config_updated or is_expired(self.__bearer_token):
if 'path' in credentials:
path = credentials.get("path")
self.__bearer_token, _ = generate_bearer_token(
path,
credentials.get("path"),
options,
self.__logger
)
Expand All @@ -83,10 +96,6 @@
else:
log_info(SkyflowMessages.Info.REUSE_BEARER_TOKEN.value, self.__logger)

if is_expired(self.__bearer_token):
self.__is_config_updated = True
raise SyntaxError(SkyflowMessages.Error.EXPIRED_TOKEN.value, SkyflowMessages.ErrorCodes.INVALID_INPUT.value)

return self.__bearer_token

def update_config(self, config):
Expand Down
2 changes: 1 addition & 1 deletion tests/utils/validations/test__validations.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def test_validate_credentials_with_expired_token(self):
with patch('skyflow.service_account.is_expired', return_value=True):
with self.assertRaises(SkyflowError) as context:
validate_credentials(self.logger, credentials)
self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CREDENTIALS_TOKEN.value)
self.assertEqual(context.exception.message, SkyflowMessages.Error.EXPIRED_TOKEN.value)

def test_validate_credentials_empty_credentials(self):
credentials = {}
Expand Down
Loading
Loading