diff --git a/compute_modules/sources_v2/_sources.py b/compute_modules/sources_v2/_sources.py index 5be779f..9a52dec 100644 --- a/compute_modules/sources_v2/_sources.py +++ b/compute_modules/sources_v2/_sources.py @@ -41,7 +41,9 @@ def get_source(source_api_name: str): # type: ignore[no-untyped-def] from external_systems.sources import ( AwsCredentials, ClientCertificate, + GcpOauthCredentials, HttpsConnectionParameters, + OauthCredentials, Source, SourceCredentials, SourceParameters, @@ -57,6 +59,47 @@ def convert_resolved_source_credentials( if credentials is None: return None + cloud_credentials = _maybe_get_cloud_credentials(credentials) + if cloud_credentials is not None: + return cloud_credentials + + gcp_oauth_credentials = _maybe_get_gcp_oauth_credentials(credentials) + if gcp_oauth_credentials is not None: + return gcp_oauth_credentials + + oauth2_credentials = _maybe_get_oauth_credentials(credentials) + if oauth2_credentials is not None: + return oauth2_credentials + + return None + + def _maybe_get_oauth_credentials( + credentials: Any, + ) -> Optional[SourceCredentials]: + oauth_credentials = credentials.get("oauth2Credentials") + if oauth_credentials is None: + return None + + return OauthCredentials( + access_token=oauth_credentials.get("accessToken"), + expiration=datetime.strptime(oauth_credentials.get("expiration"), JAVA_OFFSET_DATETIME_FORMAT), + ) + + def _maybe_get_gcp_oauth_credentials( + credentials: Any, + ) -> Optional[SourceCredentials]: + gcp_oauth_credentials = credentials.get("gcpOauthCredentials", None) + if gcp_oauth_credentials is None: + return None + + return GcpOauthCredentials( + access_token=gcp_oauth_credentials.get("accessToken"), + expiration=datetime.strptime(gcp_oauth_credentials.get("expiration"), JAVA_OFFSET_DATETIME_FORMAT), + ) + + def _maybe_get_cloud_credentials( + credentials: Any, + ) -> Optional[SourceCredentials]: cloud_credentials = credentials.get("cloudCredentials", None) if cloud_credentials is None: return None diff --git a/pyproject.toml b/pyproject.toml index 429fdb6..029dcf5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,7 @@ packages = [{ include = "compute_modules" }] [tool.poetry.dependencies] python = "^3.9" requests = "^2.32.3" -external-systems = { version = "^0.100.0", optional = true } +external-systems = { version = "^0.107.0", optional = true } pyyaml = { version = "^6.0.1", optional = true } diff --git a/tests/sources_v2/test_sources.py b/tests/sources_v2/test_sources.py index 7b2c364..c607f14 100644 --- a/tests/sources_v2/test_sources.py +++ b/tests/sources_v2/test_sources.py @@ -20,7 +20,7 @@ from unittest import mock import pytest -from external_systems.sources import AwsCredentials, Source +from external_systems.sources import AwsCredentials, GcpOauthCredentials, OauthCredentials, Source from compute_modules.sources_v2 import get_source from compute_modules.sources_v2._api import ( @@ -120,12 +120,13 @@ def test_get_source_without_http_connection(mock_source_config_file: Path, mock_ ] with pytest.raises(ValueError): source.get_https_connection() + with pytest.raises(ValueError): + source.get_session_credentials() -def test_get_source_with_aws_session_credentials( +def test_get_source_get_aws_credentials_with_aws_session_credentials( mock_source_config_file: Path, mock_service_discovery_file: Path ) -> None: - # Create a source config with session credentials directly in the JSON source_config = { "test_source": { "secrets": {}, @@ -164,10 +165,9 @@ def test_get_source_with_aws_session_credentials( ) -def test_get_source_with_aws_basic_credentials( +def test_get_source_get_aws_credentials_with_aws_basic_credentials( mock_source_config_file: Path, mock_service_discovery_file: Path ) -> None: - # Create a source config with basic credentials directly in the JSON source_config = { "test_source": { "secrets": {}, @@ -199,6 +199,144 @@ def test_get_source_with_aws_basic_credentials( ) +def test_get_source_get_session_credentials_with_aws_session_credentials( + mock_source_config_file: Path, mock_service_discovery_file: Path +) -> None: + source_config = { + "test_source": { + "secrets": {}, + "sourceConfiguration": {"type": "s3"}, + "resolvedCredentials": { + "cloudCredentials": { + "awsCredentials": { + "sessionCredentials": { + "accessKeyId": "ACCESS_KEY", + "secretAccessKey": "SECRET_KEY", + "sessionToken": "SESSION_TOKEN", + "expiration": "2023-01-01T00:00:00Z", + } + } + } + }, + } + } + + mock_source_config_file.write_text(json.dumps(source_config)) + + with mock.patch.dict( + os.environ, + { + SOURCE_CONFIGURATIONS_PATH: str(mock_source_config_file), + SERVICE_DISCOVERY_PATH: str(mock_service_discovery_file), + }, + ): + source = get_source("test_source") + assert isinstance(source, Source) + assert source.get_session_credentials().get() == AwsCredentials( + access_key_id="ACCESS_KEY", + secret_access_key="SECRET_KEY", + session_token="SESSION_TOKEN", + expiration=datetime.strptime("2023-01-01T00:00:00Z", JAVA_OFFSET_DATETIME_FORMAT), + ) + + +def test_get_source_get_session_credentials_with_aws_basic_credentials( + mock_source_config_file: Path, mock_service_discovery_file: Path +) -> None: + source_config = { + "test_source": { + "secrets": {}, + "sourceConfiguration": {"type": "s3"}, + "resolvedCredentials": { + "cloudCredentials": { + "awsCredentials": { + "basicCredentials": {"accessKeyId": "ACCESS_KEY", "secretAccessKey": "SECRET_KEY"} + } + } + }, + } + } + + mock_source_config_file.write_text(json.dumps(source_config)) + + with mock.patch.dict( + os.environ, + { + SOURCE_CONFIGURATIONS_PATH: str(mock_source_config_file), + SERVICE_DISCOVERY_PATH: str(mock_service_discovery_file), + }, + ): + source = get_source("test_source") + assert isinstance(source, Source) + assert source.get_session_credentials().get() == AwsCredentials( + access_key_id="ACCESS_KEY", + secret_access_key="SECRET_KEY", + ) + + +def test_get_source_get_session_credentials_with_gcp_oauth_credentials( + mock_source_config_file: Path, mock_service_discovery_file: Path +) -> None: + source_config = { + "test_source": { + "secrets": {}, + "sourceConfiguration": {"type": "bigquery"}, + "resolvedCredentials": { + "gcpOauthCredentials": { + "accessToken": "ACCESS_TOKEN", + "expiration": "2023-01-01T00:00:00Z", + } + }, + } + } + + mock_source_config_file.write_text(json.dumps(source_config)) + + with mock.patch.dict( + os.environ, + { + SOURCE_CONFIGURATIONS_PATH: str(mock_source_config_file), + SERVICE_DISCOVERY_PATH: str(mock_service_discovery_file), + }, + ): + source = get_source("test_source") + assert isinstance(source, Source) + assert source.get_session_credentials().get() == GcpOauthCredentials( + access_token="ACCESS_TOKEN", + expiration=datetime.strptime("2023-01-01T00:00:00Z", JAVA_OFFSET_DATETIME_FORMAT), + ) + + +def test_get_source_get_session_credentials_with_oauth2_credentials( + mock_source_config_file: Path, mock_service_discovery_file: Path +) -> None: + source_config = { + "test_source": { + "secrets": {}, + "sourceConfiguration": {"type": "webhooks-rest"}, + "resolvedCredentials": { + "oauth2Credentials": {"accessToken": "ACCESS_TOKEN", "expiration": "2023-01-01T00:00:00Z"} + }, + } + } + + mock_source_config_file.write_text(json.dumps(source_config)) + + with mock.patch.dict( + os.environ, + { + SOURCE_CONFIGURATIONS_PATH: str(mock_source_config_file), + SERVICE_DISCOVERY_PATH: str(mock_service_discovery_file), + }, + ): + source = get_source("test_source") + assert isinstance(source, Source) + assert source.get_session_credentials().get() == OauthCredentials( + access_token="ACCESS_TOKEN", + expiration=datetime.strptime("2023-01-01T00:00:00Z", JAVA_OFFSET_DATETIME_FORMAT), + ) + + def test_get_source_not_found(mock_source_config_file: Path, mock_service_discovery_file: Path) -> None: source_config = {"existing_source": {"secrets": {}, "sourceConfiguration": {"type": "test_type"}}}