diff --git a/rdl/data_sources/AWSLambdaDataSource.py b/rdl/data_sources/AWSLambdaDataSource.py index b47acea..0913175 100644 --- a/rdl/data_sources/AWSLambdaDataSource.py +++ b/rdl/data_sources/AWSLambdaDataSource.py @@ -3,6 +3,7 @@ import json import boto3 import time +import datetime from rdl.data_sources.ChangeTrackingInfo import ChangeTrackingInfo from rdl.data_sources.SourceTableInfo import SourceTableInfo @@ -11,15 +12,24 @@ class AWSLambdaDataSource(object): - # 'aws-lambda://tenant=543_dc2;function=123456789012:function:my-function;' + # 'aws-lambda://tenant=543_dc2;function=123456789012:function:my-function;role=arn:aws:iam::123456789012:role/RoleName;' CONNECTION_STRING_PREFIX = "aws-lambda://" CONNECTION_STRING_GROUP_SEPARATOR = ";" CONNECTION_STRING_KEY_VALUE_SEPARATOR = "=" + CONNECTION_DATA_ROLE_KEY = "role" + CONNECTION_DATA_FUNCTION_KEY = "function" + CONNECTION_DATA_TENANT_KEY = "tenant" + + AWS_SERVICE_LAMBDA = "lambda" + AWS_SERVICE_S3 = "s3" + def __init__(self, connection_string, logger=None): self.logger = logger or logging.getLogger(__name__) + if not AWSLambdaDataSource.can_handle_connection_string(connection_string): raise ValueError(connection_string) + self.connection_string = connection_string self.connection_data = dict( kv.split(AWSLambdaDataSource.CONNECTION_STRING_KEY_VALUE_SEPARATOR) @@ -29,8 +39,19 @@ def __init__(self, connection_string, logger=None): .rstrip(AWSLambdaDataSource.CONNECTION_STRING_GROUP_SEPARATOR) .split(AWSLambdaDataSource.CONNECTION_STRING_GROUP_SEPARATOR) ) - self.aws_lambda_client = boto3.client("lambda") - self.aws_s3_client = boto3.client("s3") + + self.aws_sts_client = boto3.client("sts") + role_credentials = self.__assume_role( + self.connection_data[self.CONNECTION_DATA_ROLE_KEY], + f"dwp_{self.connection_data[self.CONNECTION_DATA_TENANT_KEY]}", + ) + + self.aws_lambda_client = self.__get_aws_client( + self.AWS_SERVICE_LAMBDA, role_credentials + ) + self.aws_s3_client = self.__get_aws_client( + self.AWS_SERVICE_S3, role_credentials + ) @staticmethod def can_handle_connection_string(connection_string): @@ -87,7 +108,7 @@ def get_table_data_frame( def __get_table_info(self, table_config, last_known_sync_version): pay_load = { "Command": "GetTableInfo", - "TenantId": int(self.connection_data["tenant"]), + "TenantId": int(self.connection_data[self.CONNECTION_DATA_TENANT_KEY]), "Table": {"Schema": table_config["schema"], "Name": table_config["name"]}, "CommandPayload": {"LastSyncVersion": last_known_sync_version}, } @@ -113,7 +134,7 @@ def __get_table_data( ): pay_load = { "Command": "GetTableData", - "TenantId": int(self.connection_data["tenant"]), + "TenantId": int(self.connection_data[self.CONNECTION_DATA_TENANT_KEY]), "Table": {"Schema": table_config["schema"], "Name": table_config["name"]}, "CommandPayload": { "AuditColumnNameForChangeVersion": Providers.AuditColumnsNames.CHANGE_VERSION, @@ -125,7 +146,7 @@ def __get_table_data( { "Name": col["source_name"], "DataType": col["destination"]["type"], - "IsPrimaryKey": col["destination"]["primary_key"] + "IsPrimaryKey": col["destination"]["primary_key"], } for col in columns_config ], @@ -148,21 +169,69 @@ def __get_table_data( def __get_data_frame(self, data: [[]], column_names: []): return pandas.DataFrame(data=data, columns=column_names) + def __assume_role(self, role_arn, session_name): + self.logger.debug(f"\nAssuming role with ARN: {role_arn}") + + assume_role_response = self.aws_sts_client.assume_role( + RoleArn=role_arn, RoleSessionName=session_name + ) + + role_credentials = assume_role_response["Credentials"] + + self.role_session_expiry = role_credentials["Expiration"] + + return role_credentials + + def __get_aws_client(self, service, credentials): + return boto3.client( + service_name=service, + aws_access_key_id=credentials["AccessKeyId"], + aws_secret_access_key=credentials["SecretAccessKey"], + aws_session_token=credentials["SessionToken"], + ) + + def __refresh_aws_clients_if_expired(self): + # this is due to AWS returning their expiry date in UTC + current_datetime = datetime.datetime.now(datetime.timezone.utc) + + if ( + current_datetime > self.role_session_expiry - datetime.timedelta(minutes=5) + and current_datetime < self.role_session_expiry + ): + role_credentials = self.__assume_role( + self.connection_data[self.CONNECTION_DATA_ROLE_KEY], + f"dwp_{self.connection_data[self.CONNECTION_DATA_TENANT_KEY]}", + ) + + self.aws_lambda_client = self.__get_aws_client( + self.AWS_SERVICE_LAMBDA, role_credentials + ) + self.aws_s3_client = self.__get_aws_client( + self.AWS_SERVICE_S3, role_credentials + ) + def __invoke_lambda(self, pay_load): max_attempts = Constants.MAX_AWS_LAMBDA_INVOKATION_ATTEMPTS retry_delay = Constants.AWS_LAMBDA_RETRY_DELAY_SECONDS response_payload = None - for current_attempt in list(range(1, max_attempts+1, 1)): + for current_attempt in list(range(1, max_attempts + 1, 1)): + + self.__refresh_aws_clients_if_expired() + if current_attempt > 1: - self.logger.debug(f"\nDelaying retry for {(current_attempt - 1) ^ retry_delay} seconds") + self.logger.debug( + f"\nDelaying retry for {(current_attempt - 1) ^ retry_delay} seconds" + ) time.sleep((current_attempt - 1) ^ retry_delay) - self.logger.debug(f"\nRequest being sent to Lambda, attempt {current_attempt} of {max_attempts}:") + self.logger.debug( + f"\nRequest being sent to Lambda, attempt {current_attempt} of {max_attempts}:" + ) self.logger.debug(pay_load) lambda_response = self.aws_lambda_client.invoke( - FunctionName=self.connection_data["function"], + FunctionName=self.connection_data[self.CONNECTION_DATA_FUNCTION_KEY], InvocationType="RequestResponse", LogType="None", # |'Tail', Set to Tail to include the execution log in the response Payload=json.dumps(pay_load).encode(), @@ -170,7 +239,9 @@ def __invoke_lambda(self, pay_load): response_status_code = int(lambda_response["StatusCode"]) response_function_error = lambda_response.get("FunctionError") - self.logger.debug(f"\nResponse received from Lambda, attempt {current_attempt} of {max_attempts}:") + self.logger.debug( + f"\nResponse received from Lambda, attempt {current_attempt} of {max_attempts}:" + ) self.logger.debug(f'Response - StatusCode = "{response_status_code}"') self.logger.debug(f'Response - FunctionError = "{response_function_error}"') @@ -178,11 +249,13 @@ def __invoke_lambda(self, pay_load): if response_status_code != 200 or response_function_error: self.logger.error( - f'Error in response from aws lambda \'{self.connection_data["function"]}\', ' - f'attempt {current_attempt} of {max_attempts}' + f"Error in response from aws lambda '{self.connection_data[self.CONNECTION_DATA_FUNCTION_KEY]}', " + f"attempt {current_attempt} of {max_attempts}" ) self.logger.error(f"Response - Status Code = {response_status_code}") - self.logger.error(f"Response - Error Function = {response_function_error}") + self.logger.error( + f"Response - Error Function = {response_function_error}" + ) self.logger.error(f"Response - Error Details:") # the below is risky as it may contain actual data if this line is reached in case of success # however, the same Payload field is used to return actual error details in case of failure