From b5f496d481354b32ebe5719f82bf330c188fa838 Mon Sep 17 00:00:00 2001 From: Walter Ngo Date: Mon, 30 Sep 2019 13:26:55 +1000 Subject: [PATCH 1/3] [SP-333] Assume role when invoking lambda function --- rdl/data_sources/AWSLambdaDataSource.py | 89 ++++++++++++++++++++++--- 1 file changed, 79 insertions(+), 10 deletions(-) diff --git a/rdl/data_sources/AWSLambdaDataSource.py b/rdl/data_sources/AWSLambdaDataSource.py index b47acea..e676611 100644 --- a/rdl/data_sources/AWSLambdaDataSource.py +++ b/rdl/data_sources/AWSLambdaDataSource.py @@ -3,6 +3,7 @@ import json import boto3 import time +import datetime from rdl.data_sources.ChangeTrackingInfo import ChangeTrackingInfo from rdl.data_sources.SourceTableInfo import SourceTableInfo @@ -11,15 +12,21 @@ class AWSLambdaDataSource(object): - # 'aws-lambda://tenant=543_dc2;function=123456789012:function:my-function;' + # 'aws-lambda://tenant=543_dc2;function=123456789012:function:my-function;role=arn:aws:iam::123456789012:role/RoleName;' CONNECTION_STRING_PREFIX = "aws-lambda://" CONNECTION_STRING_GROUP_SEPARATOR = ";" CONNECTION_STRING_KEY_VALUE_SEPARATOR = "=" + CONNECTION_STRING_ROLE_KEY = "role" + + AWS_SERVICE_LAMBDA = "lambda" + AWS_SERVICE_S3 = "s3" def __init__(self, connection_string, logger=None): self.logger = logger or logging.getLogger(__name__) + if not AWSLambdaDataSource.can_handle_connection_string(connection_string): raise ValueError(connection_string) + self.connection_string = connection_string self.connection_data = dict( kv.split(AWSLambdaDataSource.CONNECTION_STRING_KEY_VALUE_SEPARATOR) @@ -29,8 +36,19 @@ def __init__(self, connection_string, logger=None): .rstrip(AWSLambdaDataSource.CONNECTION_STRING_GROUP_SEPARATOR) .split(AWSLambdaDataSource.CONNECTION_STRING_GROUP_SEPARATOR) ) - self.aws_lambda_client = boto3.client("lambda") - self.aws_s3_client = boto3.client("s3") + + self.aws_sts_client = boto3.client("sts") + role_credentials = self.__assume_role( + self.connection_data[self.CONNECTION_STRING_ROLE_KEY], + f'dwp_{self.connection_data["tenant"]}', + ) + + self.aws_lambda_client = self.__get_aws_client( + self.AWS_SERVICE_LAMBDA, role_credentials + ) + self.aws_s3_client = self.__get_aws_client( + self.AWS_SERVICE_S3, role_credentials + ) @staticmethod def can_handle_connection_string(connection_string): @@ -125,7 +143,7 @@ def __get_table_data( { "Name": col["source_name"], "DataType": col["destination"]["type"], - "IsPrimaryKey": col["destination"]["primary_key"] + "IsPrimaryKey": col["destination"]["primary_key"], } for col in columns_config ], @@ -148,17 +166,64 @@ def __get_table_data( def __get_data_frame(self, data: [[]], column_names: []): return pandas.DataFrame(data=data, columns=column_names) + def __assume_role(self, role_arn, session_name): + self.logger.debug(f"\nAssuming role with ARN: {role_arn}") + + assume_role_response = self.aws_sts_client.assume_role( + RoleArn=role_arn, RoleSessionName=session_name + ) + + role_credentials = assume_role_response["Credentials"] + + self.role_session_expiry = role_credentials["Expiration"] + + return role_credentials + + def __get_aws_client(self, service, credentials): + return boto3.client( + service_name=service, + aws_access_key_id=credentials["AccessKeyId"], + aws_secret_access_key=credentials["SecretAccessKey"], + aws_session_token=credentials["SessionToken"], + ) + + def __refresh_aws_clients_if_expired(self): + current_datetime = datetime.datetime.now() + + if ( + current_datetime > self.role_session_expiry - datetime.timedelta(minutes=5) + and current_datetime < self.role_session_expiry + ): + role_credentials = self.__assume_role( + self.connection_data[self.CONNECTION_STRING_ROLE_KEY], + f'dwp_{self.connection_data["tenant"]}', + ) + + self.aws_lambda_client = self.__get_aws_client( + self.AWS_SERVICE_LAMBDA, role_credentials + ) + self.aws_s3_client = self.__get_aws_client( + self.AWS_SERVICE_S3, role_credentials + ) + def __invoke_lambda(self, pay_load): max_attempts = Constants.MAX_AWS_LAMBDA_INVOKATION_ATTEMPTS retry_delay = Constants.AWS_LAMBDA_RETRY_DELAY_SECONDS response_payload = None - for current_attempt in list(range(1, max_attempts+1, 1)): + for current_attempt in list(range(1, max_attempts + 1, 1)): + + self.__refresh_aws_clients_if_expired() + if current_attempt > 1: - self.logger.debug(f"\nDelaying retry for {(current_attempt - 1) ^ retry_delay} seconds") + self.logger.debug( + f"\nDelaying retry for {(current_attempt - 1) ^ retry_delay} seconds" + ) time.sleep((current_attempt - 1) ^ retry_delay) - self.logger.debug(f"\nRequest being sent to Lambda, attempt {current_attempt} of {max_attempts}:") + self.logger.debug( + f"\nRequest being sent to Lambda, attempt {current_attempt} of {max_attempts}:" + ) self.logger.debug(pay_load) lambda_response = self.aws_lambda_client.invoke( @@ -170,7 +235,9 @@ def __invoke_lambda(self, pay_load): response_status_code = int(lambda_response["StatusCode"]) response_function_error = lambda_response.get("FunctionError") - self.logger.debug(f"\nResponse received from Lambda, attempt {current_attempt} of {max_attempts}:") + self.logger.debug( + f"\nResponse received from Lambda, attempt {current_attempt} of {max_attempts}:" + ) self.logger.debug(f'Response - StatusCode = "{response_status_code}"') self.logger.debug(f'Response - FunctionError = "{response_function_error}"') @@ -179,10 +246,12 @@ def __invoke_lambda(self, pay_load): if response_status_code != 200 or response_function_error: self.logger.error( f'Error in response from aws lambda \'{self.connection_data["function"]}\', ' - f'attempt {current_attempt} of {max_attempts}' + f"attempt {current_attempt} of {max_attempts}" ) self.logger.error(f"Response - Status Code = {response_status_code}") - self.logger.error(f"Response - Error Function = {response_function_error}") + self.logger.error( + f"Response - Error Function = {response_function_error}" + ) self.logger.error(f"Response - Error Details:") # the below is risky as it may contain actual data if this line is reached in case of success # however, the same Payload field is used to return actual error details in case of failure From 9e7cf1fcd47a4d0277fe6136d4b8f4d159780501 Mon Sep 17 00:00:00 2001 From: Walter Ngo Date: Mon, 30 Sep 2019 13:30:09 +1000 Subject: [PATCH 2/3] [SP-333] Use connection data keys instead of magic strings --- rdl/data_sources/AWSLambdaDataSource.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/rdl/data_sources/AWSLambdaDataSource.py b/rdl/data_sources/AWSLambdaDataSource.py index e676611..67d714d 100644 --- a/rdl/data_sources/AWSLambdaDataSource.py +++ b/rdl/data_sources/AWSLambdaDataSource.py @@ -16,7 +16,10 @@ class AWSLambdaDataSource(object): CONNECTION_STRING_PREFIX = "aws-lambda://" CONNECTION_STRING_GROUP_SEPARATOR = ";" CONNECTION_STRING_KEY_VALUE_SEPARATOR = "=" - CONNECTION_STRING_ROLE_KEY = "role" + + CONNECTION_DATA_ROLE_KEY = "role" + CONNECTION_DATA_FUNCTION_KEY = "function" + CONNECTION_DATA_TENANT_KEY = "tenant" AWS_SERVICE_LAMBDA = "lambda" AWS_SERVICE_S3 = "s3" @@ -39,8 +42,8 @@ def __init__(self, connection_string, logger=None): self.aws_sts_client = boto3.client("sts") role_credentials = self.__assume_role( - self.connection_data[self.CONNECTION_STRING_ROLE_KEY], - f'dwp_{self.connection_data["tenant"]}', + self.connection_data[self.CONNECTION_DATA_ROLE_KEY], + f"dwp_{self.connection_data[self.CONNECTION_DATA_TENANT_KEY]}", ) self.aws_lambda_client = self.__get_aws_client( @@ -105,7 +108,7 @@ def get_table_data_frame( def __get_table_info(self, table_config, last_known_sync_version): pay_load = { "Command": "GetTableInfo", - "TenantId": int(self.connection_data["tenant"]), + "TenantId": int(self.connection_data[self.CONNECTION_DATA_TENANT_KEY]), "Table": {"Schema": table_config["schema"], "Name": table_config["name"]}, "CommandPayload": {"LastSyncVersion": last_known_sync_version}, } @@ -131,7 +134,7 @@ def __get_table_data( ): pay_load = { "Command": "GetTableData", - "TenantId": int(self.connection_data["tenant"]), + "TenantId": int(self.connection_data[self.CONNECTION_DATA_TENANT_KEY]), "Table": {"Schema": table_config["schema"], "Name": table_config["name"]}, "CommandPayload": { "AuditColumnNameForChangeVersion": Providers.AuditColumnsNames.CHANGE_VERSION, @@ -195,8 +198,8 @@ def __refresh_aws_clients_if_expired(self): and current_datetime < self.role_session_expiry ): role_credentials = self.__assume_role( - self.connection_data[self.CONNECTION_STRING_ROLE_KEY], - f'dwp_{self.connection_data["tenant"]}', + self.connection_data[self.CONNECTION_DATA_ROLE_KEY], + f"dwp_{self.connection_data[self.CONNECTION_DATA_TENANT_KEY]}", ) self.aws_lambda_client = self.__get_aws_client( @@ -227,7 +230,7 @@ def __invoke_lambda(self, pay_load): self.logger.debug(pay_load) lambda_response = self.aws_lambda_client.invoke( - FunctionName=self.connection_data["function"], + FunctionName=self.connection_data[self.CONNECTION_DATA_FUNCTION_KEY], InvocationType="RequestResponse", LogType="None", # |'Tail', Set to Tail to include the execution log in the response Payload=json.dumps(pay_load).encode(), @@ -245,7 +248,7 @@ def __invoke_lambda(self, pay_load): if response_status_code != 200 or response_function_error: self.logger.error( - f'Error in response from aws lambda \'{self.connection_data["function"]}\', ' + f"Error in response from aws lambda '{self.connection_data[self.CONNECTION_DATA_FUNCTION_KEY]}', " f"attempt {current_attempt} of {max_attempts}" ) self.logger.error(f"Response - Status Code = {response_status_code}") From ba42aed7a33cbd847dd869ef225f24dfda905c68 Mon Sep 17 00:00:00 2001 From: Walter Ngo Date: Mon, 30 Sep 2019 13:59:45 +1000 Subject: [PATCH 3/3] [SP-333] Fix error when comparing datetimes --- rdl/data_sources/AWSLambdaDataSource.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/rdl/data_sources/AWSLambdaDataSource.py b/rdl/data_sources/AWSLambdaDataSource.py index 67d714d..0913175 100644 --- a/rdl/data_sources/AWSLambdaDataSource.py +++ b/rdl/data_sources/AWSLambdaDataSource.py @@ -191,7 +191,8 @@ def __get_aws_client(self, service, credentials): ) def __refresh_aws_clients_if_expired(self): - current_datetime = datetime.datetime.now() + # this is due to AWS returning their expiry date in UTC + current_datetime = datetime.datetime.now(datetime.timezone.utc) if ( current_datetime > self.role_session_expiry - datetime.timedelta(minutes=5)