Skip to content

Commit

Permalink
Sync flow for Step Functions (aws#380)
Browse files Browse the repository at this point in the history
* Sync flow for Step Functions

* Added generator mapping for step functions resources

* Improve documentation and logging

* Improve logging

* Updated loading mechanism and tests
  • Loading branch information
qingchm committed Aug 3, 2021
1 parent 8c0ec73 commit 0705772
Show file tree
Hide file tree
Showing 8 changed files with 247 additions and 16 deletions.
8 changes: 5 additions & 3 deletions samcli/lib/sync/flows/http_api_sync_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,11 @@ def set_up(self) -> None:
def sync(self) -> None:
api_physical_id = self.get_physical_id(self._api_identifier)
if self._definition_uri is None:
msg = "fails since no DefinitionUri defined in the template, \
if you are using DefinitionBody please run sam sync --infra"
LOG.error("%sImport HttpApi %s", self.log_prefix, msg)
LOG.error(
"%sImport HttpApi fails since no DefinitionUri defined in the template, \
if you are using DefinitionBody please run sam sync --infra",
self.log_prefix,
)
raise UriNotFoundException(self._api_identifier, "DefinitionUri")
LOG.debug("%sTrying to import HttpAPI through client", self.log_prefix)
response = self._api_client.reimport_api(ApiId=api_physical_id, Body=self._swagger_body)
Expand Down
8 changes: 5 additions & 3 deletions samcli/lib/sync/flows/rest_api_sync_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,11 @@ def set_up(self) -> None:
def sync(self) -> None:
api_physical_id = self.get_physical_id(self._api_identifier)
if self._definition_uri is None:
msg = "fails since no DefinitionUri defined in the template, \
if you are using DefinitionBody please run sam sync --infra"
LOG.error("%sPut RestApi %s", self.log_prefix, msg)
LOG.error(
"%sImport HttpApi fails since no DefinitionUri defined in the template, \
if you are using DefinitionBody please run sam sync --infra",
self.log_prefix,
)
raise UriNotFoundException(self._api_identifier, "DefinitionUri")
LOG.debug("%sTrying to put RestAPI through client", self.log_prefix)
response = self._api_client.put_rest_api(restApiId=api_physical_id, mode="overwrite", body=self._swagger_body)
Expand Down
109 changes: 109 additions & 0 deletions samcli/lib/sync/flows/stepfunctions_sync_flow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
"""Base SyncFlow for StepFunctions"""
import logging
from typing import Any, Dict, List, TYPE_CHECKING, cast, Optional


from boto3.session import Session

from samcli.lib.providers.provider import Stack, get_resource_by_id, ResourceIdentifier
from samcli.lib.sync.exceptions import UriNotFoundException
from samcli.lib.sync.sync_flow import SyncFlow, ResourceAPICall

if TYPE_CHECKING:
from samcli.commands.deploy.deploy_context import DeployContext
from samcli.commands.build.build_context import BuildContext

LOG = logging.getLogger(__name__)


class StepFunctionsSyncFlow(SyncFlow):
_state_machine_identifier: str
_stepfunctions_client: Any
_definition_uri: Optional[str]
_stacks: List[Stack]
_states_definition: Optional[str]

def __init__(
self,
state_machine_identifier: str,
build_context: "BuildContext",
deploy_context: "DeployContext",
physical_id_mapping: Dict[str, str],
stacks: List[Stack],
):
"""
Parameters
----------
state_machine_identifier : str
State Machine resource identifier that need to be synced.
build_context : BuildContext
BuildContext used for build related parameters
deploy_context : BuildContext
DeployContext used for this deploy related parameters
physical_id_mapping : Dict[str, str]
Mapping between resource logical identifier and physical identifier
stacks : List[Stack], optional
List of stacks containing a root stack and optional nested stacks
"""
super().__init__(
build_context,
deploy_context,
physical_id_mapping,
log_name="StepFunctions " + state_machine_identifier,
stacks=stacks,
)
self._state_machine_identifier = state_machine_identifier
self._stepfunctions_client = None

def set_up(self) -> None:
super().set_up()
self._stepfunctions_client = cast(Session, self._session).client("stepfunctions")

def gather_resources(self) -> None:
self._definition_uri = self._get_definition_file(self._state_machine_identifier)
self._states_definition = self._process_definition_file()

def _process_definition_file(self) -> Optional[str]:
if self._definition_uri is None:
return None
with open(self._definition_uri, "r", encoding="utf-8") as states_file:
states_data = states_file.read()
return states_data

def _get_definition_file(self, state_machine_identifier: str) -> Optional[str]:
state_machine_resource = get_resource_by_id(self._stacks, ResourceIdentifier(state_machine_identifier))
if state_machine_resource is None:
return None
properties = state_machine_resource.get("Properties", {})
definition_file = properties.get("DefinitionUri")
return cast(Optional[str], definition_file)

def compare_remote(self) -> bool:
# Not comparing with remote right now, instead only making update api calls
# Note: describe state machine has a better rate limit then update state machine
# So if we face any throttling issues, comparing should be desired
return False

def gather_dependencies(self) -> List[SyncFlow]:
return []

def _get_resource_api_calls(self) -> List[ResourceAPICall]:
return []

def _equality_keys(self):
return self._state_machine_identifier

def sync(self) -> None:
state_machine_arn = self.get_physical_id(self._state_machine_identifier)
if self._definition_uri is None:
LOG.error(
"%sUpdate State Machine fails since no DefinitionUri defined in the template, \
if you are using inline Definition please run sam sync --infra",
self.log_prefix,
)
raise UriNotFoundException(self._state_machine_identifier, "DefinitionUri")
LOG.debug("%sTrying to update State Machine definition", self.log_prefix)
response = self._stepfunctions_client.update_state_machine(
stateMachineArn=state_machine_arn, definition=self._states_definition
)
LOG.debug("%sUpdate State Machine: %s", self.log_prefix, response)
24 changes: 24 additions & 0 deletions samcli/lib/sync/sync_flow_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from samcli.lib.sync.flows.image_function_sync_flow import ImageFunctionSyncFlow
from samcli.lib.sync.flows.rest_api_sync_flow import RestApiSyncFlow
from samcli.lib.sync.flows.http_api_sync_flow import HttpApiSyncFlow
from samcli.lib.sync.flows.stepfunctions_sync_flow import StepFunctionsSyncFlow
from samcli.lib.utils.boto_utils import get_boto_resource_provider_with_config
from samcli.lib.utils.cloudformation import get_physical_id_mapping

Expand Down Expand Up @@ -110,6 +111,25 @@ def _create_api_flow(self, resource_identifier: ResourceIdentifier, resource: Di
self._stacks,
)

def _create_stepfunctions_flow(
self, resource_identifier: ResourceIdentifier, resource: Dict[str, Any]
) -> Optional[SyncFlow]:
definition_substitutions = resource.get("Properties", dict()).get("DefinitionSubstitutions", None)
if definition_substitutions:
LOG.warning(
"DefinitionSubstitutions property is specified in resource %s. Skipping this resource. "
"Code sync for StepFunctions does not go through CFN, please run sam sync --infra to update.",
resource_identifier,
)
return None
return StepFunctionsSyncFlow(
str(resource_identifier),
self._build_context,
self._deploy_context,
self._physical_id_mapping,
self._stacks,
)

GeneratorFunction = Callable[["SyncFlowFactory", ResourceIdentifier, Dict[str, Any]], Optional[SyncFlow]]
GENERATOR_MAPPING: Dict[str, GeneratorFunction] = {
SamBaseProvider.LAMBDA_FUNCTION: _create_lambda_flow,
Expand All @@ -120,6 +140,10 @@ def _create_api_flow(self, resource_identifier: ResourceIdentifier, resource: Di
CfnApiProvider.APIGATEWAY_RESTAPI: _create_rest_api_flow,
SamApiProvider.SERVERLESS_HTTP_API: _create_api_flow,
CfnApiProvider.APIGATEWAY_V2_API: _create_api_flow,
# Using strings for resource names for now, looking for a solution to
# have a place that stores all resource names like command/_utils/resources.py
"AWS::Serverless::StateMachine": _create_stepfunctions_flow,
"AWS::StepFunctions::StateMachine": _create_stepfunctions_flow,
}
# SyncFlow mapping between resource type and creation function
# Ignoring no-self-use as PyLint has a bug with Generic Abstract Classes
Expand Down
12 changes: 7 additions & 5 deletions tests/unit/lib/sync/flows/test_http_api_sync_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,16 @@ def test_sync_direct(self, session_mock):
sync_flow._get_definition_file.return_value = "file.yaml"

sync_flow.set_up()
with patch("builtins.open", mock_open(read_data="data")) as mock_file:
with patch("builtins.open", mock_open(read_data='{"key": "value"}'.encode("utf-8"))) as mock_file:
sync_flow.gather_resources()

sync_flow._api_client.reimport_api.return_value = {"Response": "success"}

sync_flow.sync()

sync_flow._api_client.reimport_api.assert_called_once_with(ApiId="PhysicalApi1", Body=ANY)
sync_flow._api_client.reimport_api.assert_called_once_with(
ApiId="PhysicalApi1", Body='{"key": "value"}'.encode("utf-8")
)

@patch("samcli.lib.sync.flows.generic_api_sync_flow.get_resource_by_id")
def test_get_definition_file(self, get_resource_mock):
Expand All @@ -60,9 +62,9 @@ def test_get_definition_file(self, get_resource_mock):
def test_process_definition_file(self):
sync_flow = self.create_sync_flow()
sync_flow._definition_uri = "path"
with patch("builtins.open", mock_open(read_data="data")) as mock_file:
with patch("builtins.open", mock_open(read_data='{"key": "value"}'.encode("utf-8"))) as mock_file:
data = sync_flow._process_definition_file()
self.assertEqual(data, "data")
self.assertEqual(data, '{"key": "value"}'.encode("utf-8"))

@patch("samcli.lib.sync.sync_flow.Session")
def test_failed_gather_resources(self, session_mock):
Expand All @@ -77,6 +79,6 @@ def test_failed_gather_resources(self, session_mock):
sync_flow.set_up()
sync_flow._definition_uri = None

with patch("builtins.open", mock_open(read_data="data")) as mock_file:
with patch("builtins.open", mock_open(read_data='{"key": "value"}'.encode("utf-8"))) as mock_file:
with self.assertRaises(UriNotFoundException):
sync_flow.sync()
12 changes: 7 additions & 5 deletions tests/unit/lib/sync/flows/test_rest_api_sync_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,16 @@ def test_sync_direct(self, session_mock):
sync_flow._get_definition_file.return_value = "file.yaml"

sync_flow.set_up()
with patch("builtins.open", mock_open(read_data="data")) as mock_file:
with patch("builtins.open", mock_open(read_data='{"key": "value"}'.encode("utf-8"))) as mock_file:
sync_flow.gather_resources()

sync_flow._api_client.put_rest_api.return_value = {"Response": "success"}

sync_flow.sync()

sync_flow._api_client.put_rest_api.assert_called_once_with(restApiId="PhysicalApi1", mode="overwrite", body=ANY)
sync_flow._api_client.put_rest_api.assert_called_once_with(
restApiId="PhysicalApi1", mode="overwrite", body='{"key": "value"}'.encode("utf-8")
)

@patch("samcli.lib.sync.flows.generic_api_sync_flow.get_resource_by_id")
def test_get_definition_file(self, get_resource_mock):
Expand All @@ -60,9 +62,9 @@ def test_get_definition_file(self, get_resource_mock):
def test_process_definition_file(self):
sync_flow = self.create_sync_flow()
sync_flow._definition_uri = "path"
with patch("builtins.open", mock_open(read_data="data")) as mock_file:
with patch("builtins.open", mock_open(read_data='{"key": "value"}'.encode("utf-8"))) as mock_file:
data = sync_flow._process_definition_file()
self.assertEqual(data, "data")
self.assertEqual(data, '{"key": "value"}'.encode("utf-8"))

@patch("samcli.lib.sync.sync_flow.Session")
def test_failed_gather_resources(self, session_mock):
Expand All @@ -77,6 +79,6 @@ def test_failed_gather_resources(self, session_mock):
sync_flow.set_up()
sync_flow._definition_uri = None

with patch("builtins.open", mock_open(read_data="data")) as mock_file:
with patch("builtins.open", mock_open(read_data='{"key": "value"}'.encode("utf-8"))) as mock_file:
with self.assertRaises(UriNotFoundException):
sync_flow.sync()
84 changes: 84 additions & 0 deletions tests/unit/lib/sync/flows/test_stepfunctions_sync_flow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from unittest import TestCase
from unittest.mock import ANY, MagicMock, mock_open, patch

from samcli.lib.sync.flows.stepfunctions_sync_flow import StepFunctionsSyncFlow
from samcli.lib.sync.exceptions import UriNotFoundException


class TestStepFunctionsSyncFlow(TestCase):
def create_sync_flow(self):
sync_flow = StepFunctionsSyncFlow(
"StateMachine1",
build_context=MagicMock(),
deploy_context=MagicMock(),
physical_id_mapping={},
stacks=[MagicMock()],
)
sync_flow._get_resource_api_calls = MagicMock()
return sync_flow

@patch("samcli.lib.sync.sync_flow.Session")
def test_set_up(self, session_mock):
sync_flow = self.create_sync_flow()
sync_flow.set_up()
session_mock.return_value.client.assert_any_call("stepfunctions")

@patch("samcli.lib.sync.sync_flow.Session")
def test_sync_direct(self, session_mock):
sync_flow = self.create_sync_flow()

sync_flow.get_physical_id = MagicMock()
sync_flow.get_physical_id.return_value = "PhysicalId1"

sync_flow._get_definition_file = MagicMock()
sync_flow._get_definition_file.return_value = "file.yaml"

sync_flow.set_up()
with patch("builtins.open", mock_open(read_data='{"key": "value"}')) as mock_file:
sync_flow.gather_resources()

sync_flow._stepfunctions_client.update_state_machine.return_value = {"Response": "success"}

sync_flow.sync()

sync_flow._stepfunctions_client.update_state_machine.assert_called_once_with(
stateMachineArn="PhysicalId1", definition='{"key": "value"}'
)

@patch("samcli.lib.sync.flows.stepfunctions_sync_flow.get_resource_by_id")
def test_get_definition_file(self, get_resource_mock):
sync_flow = self.create_sync_flow()

get_resource_mock.return_value = {"Properties": {"DefinitionUri": "test_uri"}}
result_uri = sync_flow._get_definition_file("test")

self.assertEqual(result_uri, "test_uri")

get_resource_mock.return_value = {"Properties": {}}
result_uri = sync_flow._get_definition_file("test")

self.assertEqual(result_uri, None)

def test_process_definition_file(self):
sync_flow = self.create_sync_flow()
sync_flow._definition_uri = "path"
with patch("builtins.open", mock_open(read_data='{"key": "value"}')) as mock_file:
data = sync_flow._process_definition_file()
self.assertEqual(data, '{"key": "value"}')

@patch("samcli.lib.sync.sync_flow.Session")
def test_failed_gather_resources(self, session_mock):
sync_flow = self.create_sync_flow()

sync_flow.get_physical_id = MagicMock()
sync_flow.get_physical_id.return_value = "PhysicalApi1"

sync_flow._get_definition_file = MagicMock()
sync_flow._get_definition_file.return_value = "file.yaml"

sync_flow.set_up()
sync_flow._definition_uri = None

with patch("builtins.open", mock_open(read_data='{"key": "value"}')) as mock_file:
with self.assertRaises(UriNotFoundException):
sync_flow.sync()
6 changes: 6 additions & 0 deletions tests/unit/lib/sync/test_sync_flow_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@ def test_create_api_flow(self, http_api_sync_mock):
result = factory._create_api_flow("API1", {})
self.assertEqual(result, http_api_sync_mock.return_value)

@patch("samcli.lib.sync.sync_flow_factory.StepFunctionsSyncFlow")
def test_create_stepfunctions_flow(self, stepfunctions_sync_mock):
factory = self.create_factory()
result = factory._create_stepfunctions_flow("StateMachine1", {})
self.assertEqual(result, stepfunctions_sync_mock.return_value)

@patch("samcli.lib.sync.sync_flow_factory.get_resource_by_id")
def test_create_sync_flow(self, get_resource_by_id_mock):
factory = self.create_factory()
Expand Down

0 comments on commit 0705772

Please sign in to comment.