Skip to content
This repository was archived by the owner on Mar 13, 2020. It is now read-only.
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 87 additions & 14 deletions rdl/data_sources/AWSLambdaDataSource.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
import boto3
import time
import datetime

from rdl.data_sources.ChangeTrackingInfo import ChangeTrackingInfo
from rdl.data_sources.SourceTableInfo import SourceTableInfo
Expand All @@ -11,15 +12,24 @@


class AWSLambdaDataSource(object):
# 'aws-lambda://tenant=543_dc2;function=123456789012:function:my-function;'
# 'aws-lambda://tenant=543_dc2;function=123456789012:function:my-function;role=arn:aws:iam::123456789012:role/RoleName;'
CONNECTION_STRING_PREFIX = "aws-lambda://"
CONNECTION_STRING_GROUP_SEPARATOR = ";"
CONNECTION_STRING_KEY_VALUE_SEPARATOR = "="

CONNECTION_DATA_ROLE_KEY = "role"
CONNECTION_DATA_FUNCTION_KEY = "function"
CONNECTION_DATA_TENANT_KEY = "tenant"

AWS_SERVICE_LAMBDA = "lambda"
AWS_SERVICE_S3 = "s3"

def __init__(self, connection_string, logger=None):
self.logger = logger or logging.getLogger(__name__)

if not AWSLambdaDataSource.can_handle_connection_string(connection_string):
raise ValueError(connection_string)

self.connection_string = connection_string
self.connection_data = dict(
kv.split(AWSLambdaDataSource.CONNECTION_STRING_KEY_VALUE_SEPARATOR)
Expand All @@ -29,8 +39,19 @@ def __init__(self, connection_string, logger=None):
.rstrip(AWSLambdaDataSource.CONNECTION_STRING_GROUP_SEPARATOR)
.split(AWSLambdaDataSource.CONNECTION_STRING_GROUP_SEPARATOR)
)
self.aws_lambda_client = boto3.client("lambda")
self.aws_s3_client = boto3.client("s3")

self.aws_sts_client = boto3.client("sts")
role_credentials = self.__assume_role(
self.connection_data[self.CONNECTION_DATA_ROLE_KEY],
f"dwp_{self.connection_data[self.CONNECTION_DATA_TENANT_KEY]}",
)

self.aws_lambda_client = self.__get_aws_client(
self.AWS_SERVICE_LAMBDA, role_credentials
)
self.aws_s3_client = self.__get_aws_client(
self.AWS_SERVICE_S3, role_credentials
)

@staticmethod
def can_handle_connection_string(connection_string):
Expand Down Expand Up @@ -87,7 +108,7 @@ def get_table_data_frame(
def __get_table_info(self, table_config, last_known_sync_version):
pay_load = {
"Command": "GetTableInfo",
"TenantId": int(self.connection_data["tenant"]),
"TenantId": int(self.connection_data[self.CONNECTION_DATA_TENANT_KEY]),
"Table": {"Schema": table_config["schema"], "Name": table_config["name"]},
"CommandPayload": {"LastSyncVersion": last_known_sync_version},
}
Expand All @@ -113,7 +134,7 @@ def __get_table_data(
):
pay_load = {
"Command": "GetTableData",
"TenantId": int(self.connection_data["tenant"]),
"TenantId": int(self.connection_data[self.CONNECTION_DATA_TENANT_KEY]),
"Table": {"Schema": table_config["schema"], "Name": table_config["name"]},
"CommandPayload": {
"AuditColumnNameForChangeVersion": Providers.AuditColumnsNames.CHANGE_VERSION,
Expand All @@ -125,7 +146,7 @@ def __get_table_data(
{
"Name": col["source_name"],
"DataType": col["destination"]["type"],
"IsPrimaryKey": col["destination"]["primary_key"]
"IsPrimaryKey": col["destination"]["primary_key"],
}
for col in columns_config
],
Expand All @@ -148,41 +169,93 @@ def __get_table_data(
def __get_data_frame(self, data: [[]], column_names: []):
return pandas.DataFrame(data=data, columns=column_names)

def __assume_role(self, role_arn, session_name):
self.logger.debug(f"\nAssuming role with ARN: {role_arn}")

assume_role_response = self.aws_sts_client.assume_role(
RoleArn=role_arn, RoleSessionName=session_name
)

role_credentials = assume_role_response["Credentials"]

self.role_session_expiry = role_credentials["Expiration"]

return role_credentials

def __get_aws_client(self, service, credentials):
return boto3.client(
service_name=service,
aws_access_key_id=credentials["AccessKeyId"],
aws_secret_access_key=credentials["SecretAccessKey"],
aws_session_token=credentials["SessionToken"],
)

def __refresh_aws_clients_if_expired(self):
# this is due to AWS returning their expiry date in UTC
current_datetime = datetime.datetime.now(datetime.timezone.utc)

if (
current_datetime > self.role_session_expiry - datetime.timedelta(minutes=5)
and current_datetime < self.role_session_expiry
):
role_credentials = self.__assume_role(
self.connection_data[self.CONNECTION_DATA_ROLE_KEY],
f"dwp_{self.connection_data[self.CONNECTION_DATA_TENANT_KEY]}",
)

self.aws_lambda_client = self.__get_aws_client(
self.AWS_SERVICE_LAMBDA, role_credentials
)
self.aws_s3_client = self.__get_aws_client(
self.AWS_SERVICE_S3, role_credentials
)

def __invoke_lambda(self, pay_load):
max_attempts = Constants.MAX_AWS_LAMBDA_INVOKATION_ATTEMPTS
retry_delay = Constants.AWS_LAMBDA_RETRY_DELAY_SECONDS
response_payload = None

for current_attempt in list(range(1, max_attempts+1, 1)):
for current_attempt in list(range(1, max_attempts + 1, 1)):

self.__refresh_aws_clients_if_expired()

if current_attempt > 1:
self.logger.debug(f"\nDelaying retry for {(current_attempt - 1) ^ retry_delay} seconds")
self.logger.debug(
f"\nDelaying retry for {(current_attempt - 1) ^ retry_delay} seconds"
)
time.sleep((current_attempt - 1) ^ retry_delay)

self.logger.debug(f"\nRequest being sent to Lambda, attempt {current_attempt} of {max_attempts}:")
self.logger.debug(
f"\nRequest being sent to Lambda, attempt {current_attempt} of {max_attempts}:"
)
self.logger.debug(pay_load)

lambda_response = self.aws_lambda_client.invoke(
FunctionName=self.connection_data["function"],
FunctionName=self.connection_data[self.CONNECTION_DATA_FUNCTION_KEY],
InvocationType="RequestResponse",
LogType="None", # |'Tail', Set to Tail to include the execution log in the response
Payload=json.dumps(pay_load).encode(),
)

response_status_code = int(lambda_response["StatusCode"])
response_function_error = lambda_response.get("FunctionError")
self.logger.debug(f"\nResponse received from Lambda, attempt {current_attempt} of {max_attempts}:")
self.logger.debug(
f"\nResponse received from Lambda, attempt {current_attempt} of {max_attempts}:"
)
self.logger.debug(f'Response - StatusCode = "{response_status_code}"')
self.logger.debug(f'Response - FunctionError = "{response_function_error}"')

response_payload = json.loads(lambda_response["Payload"].read())

if response_status_code != 200 or response_function_error:
self.logger.error(
f'Error in response from aws lambda \'{self.connection_data["function"]}\', '
f'attempt {current_attempt} of {max_attempts}'
f"Error in response from aws lambda '{self.connection_data[self.CONNECTION_DATA_FUNCTION_KEY]}', "
f"attempt {current_attempt} of {max_attempts}"
)
self.logger.error(f"Response - Status Code = {response_status_code}")
self.logger.error(f"Response - Error Function = {response_function_error}")
self.logger.error(
f"Response - Error Function = {response_function_error}"
)
self.logger.error(f"Response - Error Details:")
# the below is risky as it may contain actual data if this line is reached in case of success
# however, the same Payload field is used to return actual error details in case of failure
Expand Down