Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions compute_modules/sources_v2/_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ def get_source(source_api_name: str): # type: ignore[no-untyped-def]
from external_systems.sources import (
AwsCredentials,
ClientCertificate,
GcpOauthCredentials,
HttpsConnectionParameters,
OauthCredentials,
Source,
SourceCredentials,
SourceParameters,
Expand All @@ -57,6 +59,47 @@ def convert_resolved_source_credentials(
if credentials is None:
return None

cloud_credentials = _maybe_get_cloud_credentials(credentials)
if cloud_credentials is not None:
return cloud_credentials

gcp_oauth_credentials = _maybe_get_gcp_oauth_credentials(credentials)
if gcp_oauth_credentials is not None:
return gcp_oauth_credentials

oauth2_credentials = _maybe_get_oauth_credentials(credentials)
if oauth2_credentials is not None:
return oauth2_credentials

return None

def _maybe_get_oauth_credentials(
credentials: Any,
) -> Optional[SourceCredentials]:
oauth_credentials = credentials.get("oauth2Credentials")
if oauth_credentials is None:
return None

return OauthCredentials(
access_token=oauth_credentials.get("accessToken"),
expiration=datetime.strptime(oauth_credentials.get("expiration"), JAVA_OFFSET_DATETIME_FORMAT),
)

def _maybe_get_gcp_oauth_credentials(
credentials: Any,
) -> Optional[SourceCredentials]:
gcp_oauth_credentials = credentials.get("gcpOauthCredentials", None)
if gcp_oauth_credentials is None:
return None

return GcpOauthCredentials(
access_token=gcp_oauth_credentials.get("accessToken"),
expiration=datetime.strptime(gcp_oauth_credentials.get("expiration"), JAVA_OFFSET_DATETIME_FORMAT),
)

def _maybe_get_cloud_credentials(
credentials: Any,
) -> Optional[SourceCredentials]:
cloud_credentials = credentials.get("cloudCredentials", None)
if cloud_credentials is None:
return None
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ packages = [{ include = "compute_modules" }]
[tool.poetry.dependencies]
python = "^3.9"
requests = "^2.32.3"
external-systems = { version = "^0.100.0", optional = true }
external-systems = { version = "^0.107.0", optional = true }
pyyaml = { version = "^6.0.1", optional = true }


Expand Down
148 changes: 143 additions & 5 deletions tests/sources_v2/test_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from unittest import mock

import pytest
from external_systems.sources import AwsCredentials, Source
from external_systems.sources import AwsCredentials, GcpOauthCredentials, OauthCredentials, Source

from compute_modules.sources_v2 import get_source
from compute_modules.sources_v2._api import (
Expand Down Expand Up @@ -120,12 +120,13 @@ def test_get_source_without_http_connection(mock_source_config_file: Path, mock_
]
with pytest.raises(ValueError):
source.get_https_connection()
with pytest.raises(ValueError):
source.get_session_credentials()


def test_get_source_with_aws_session_credentials(
def test_get_source_get_aws_credentials_with_aws_session_credentials(
mock_source_config_file: Path, mock_service_discovery_file: Path
) -> None:
# Create a source config with session credentials directly in the JSON
source_config = {
"test_source": {
"secrets": {},
Expand Down Expand Up @@ -164,10 +165,9 @@ def test_get_source_with_aws_session_credentials(
)


def test_get_source_with_aws_basic_credentials(
def test_get_source_get_aws_credentials_with_aws_basic_credentials(
mock_source_config_file: Path, mock_service_discovery_file: Path
) -> None:
# Create a source config with basic credentials directly in the JSON
source_config = {
"test_source": {
"secrets": {},
Expand Down Expand Up @@ -199,6 +199,144 @@ def test_get_source_with_aws_basic_credentials(
)


def test_get_source_get_session_credentials_with_aws_session_credentials(
mock_source_config_file: Path, mock_service_discovery_file: Path
) -> None:
source_config = {
"test_source": {
"secrets": {},
"sourceConfiguration": {"type": "s3"},
"resolvedCredentials": {
"cloudCredentials": {
"awsCredentials": {
"sessionCredentials": {
"accessKeyId": "ACCESS_KEY",
"secretAccessKey": "SECRET_KEY",
"sessionToken": "SESSION_TOKEN",
"expiration": "2023-01-01T00:00:00Z",
}
}
}
},
}
}

mock_source_config_file.write_text(json.dumps(source_config))

with mock.patch.dict(
os.environ,
{
SOURCE_CONFIGURATIONS_PATH: str(mock_source_config_file),
SERVICE_DISCOVERY_PATH: str(mock_service_discovery_file),
},
):
source = get_source("test_source")
assert isinstance(source, Source)
assert source.get_session_credentials().get() == AwsCredentials(
access_key_id="ACCESS_KEY",
secret_access_key="SECRET_KEY",
session_token="SESSION_TOKEN",
expiration=datetime.strptime("2023-01-01T00:00:00Z", JAVA_OFFSET_DATETIME_FORMAT),
)


def test_get_source_get_session_credentials_with_aws_basic_credentials(
mock_source_config_file: Path, mock_service_discovery_file: Path
) -> None:
source_config = {
"test_source": {
"secrets": {},
"sourceConfiguration": {"type": "s3"},
"resolvedCredentials": {
"cloudCredentials": {
"awsCredentials": {
"basicCredentials": {"accessKeyId": "ACCESS_KEY", "secretAccessKey": "SECRET_KEY"}
}
}
},
}
}

mock_source_config_file.write_text(json.dumps(source_config))

with mock.patch.dict(
os.environ,
{
SOURCE_CONFIGURATIONS_PATH: str(mock_source_config_file),
SERVICE_DISCOVERY_PATH: str(mock_service_discovery_file),
},
):
source = get_source("test_source")
assert isinstance(source, Source)
assert source.get_session_credentials().get() == AwsCredentials(
access_key_id="ACCESS_KEY",
secret_access_key="SECRET_KEY",
)


def test_get_source_get_session_credentials_with_gcp_oauth_credentials(
mock_source_config_file: Path, mock_service_discovery_file: Path
) -> None:
source_config = {
"test_source": {
"secrets": {},
"sourceConfiguration": {"type": "bigquery"},
"resolvedCredentials": {
"gcpOauthCredentials": {
"accessToken": "ACCESS_TOKEN",
"expiration": "2023-01-01T00:00:00Z",
}
},
}
}

mock_source_config_file.write_text(json.dumps(source_config))

with mock.patch.dict(
os.environ,
{
SOURCE_CONFIGURATIONS_PATH: str(mock_source_config_file),
SERVICE_DISCOVERY_PATH: str(mock_service_discovery_file),
},
):
source = get_source("test_source")
assert isinstance(source, Source)
assert source.get_session_credentials().get() == GcpOauthCredentials(
access_token="ACCESS_TOKEN",
expiration=datetime.strptime("2023-01-01T00:00:00Z", JAVA_OFFSET_DATETIME_FORMAT),
)


def test_get_source_get_session_credentials_with_oauth2_credentials(
mock_source_config_file: Path, mock_service_discovery_file: Path
) -> None:
source_config = {
"test_source": {
"secrets": {},
"sourceConfiguration": {"type": "webhooks-rest"},
"resolvedCredentials": {
"oauth2Credentials": {"accessToken": "ACCESS_TOKEN", "expiration": "2023-01-01T00:00:00Z"}
},
}
}

mock_source_config_file.write_text(json.dumps(source_config))

with mock.patch.dict(
os.environ,
{
SOURCE_CONFIGURATIONS_PATH: str(mock_source_config_file),
SERVICE_DISCOVERY_PATH: str(mock_service_discovery_file),
},
):
source = get_source("test_source")
assert isinstance(source, Source)
assert source.get_session_credentials().get() == OauthCredentials(
access_token="ACCESS_TOKEN",
expiration=datetime.strptime("2023-01-01T00:00:00Z", JAVA_OFFSET_DATETIME_FORMAT),
)


def test_get_source_not_found(mock_source_config_file: Path, mock_service_discovery_file: Path) -> None:
source_config = {"existing_source": {"secrets": {}, "sourceConfiguration": {"type": "test_type"}}}

Expand Down