Skip to content
This repository was archived by the owner on Mar 13, 2020. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@ usage: py -m rdl process [-h] [-f [FORCE_FULL_REFRESH_MODELS]]

positional arguments:
source-connection-string
The source connections string as a 64bit ODBC system
dsn. Eg: mssql+pyodbc://dwsource
The source connection string as either:
- (a) 64bit ODBC system dsn.
eg: `mssql+pyodbc://dwsource`.
- (b) AWS Lambda Function.
eg: `aws-lambda://tenant={databaseIdentifier};function={awsAccountNumber}:function:{functionName}`
destination-connection-string
The destination database connection string. Provide in
PostgreSQL + Psycopg format. Eg: 'postgresql+psycopg2:
Expand Down
2 changes: 1 addition & 1 deletion rdl/BatchDataLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def load_batch(self, batch_key_tracker):
batch_tracker.load_completed_successfully()

for primary_key in batch_key_tracker.primary_keys:
batch_key_tracker.set_bookmark(primary_key, data_frame.iloc[-1][primary_key])
batch_key_tracker.set_bookmark(primary_key, int(data_frame.iloc[-1][primary_key]))

self.logger.info(f"Batch keys '{batch_key_tracker.bookmarks}' completed. {batch_tracker.get_statistics()}")

Expand Down
6 changes: 5 additions & 1 deletion rdl/RelationalDataLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,11 @@ def get_arguments(self):
'source_connection_string',
metavar='source-connection-string',
type=self.raw_connection_string_to_valid_source_connection_string,
help='The source connections string as a 64bit ODBC system dsn. Eg: mssql+pyodbc://dwsource')
help='The source connection string as either '
'(a) 64bit ODBC system dsn. '
'Eg: mssql+pyodbc://dwsource. '
'(b) AWS Lambda Function. '
'Eg: aws-lambda://tenant={databaseIdentifier};function={awsAccountNumber}:function:{functionName}')

process_command_parser.add_argument(
'destination_connection_string',
Expand Down
165 changes: 165 additions & 0 deletions rdl/data_sources/AWSLambdaDataSource.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
import logging
import pandas
import json
import boto3

from rdl.data_sources.ChangeTrackingInfo import ChangeTrackingInfo
from rdl.data_sources.SourceTableInfo import SourceTableInfo
from rdl.shared import Providers
from rdl.shared.Utils import prevent_senstive_data_logging


class AWSLambdaDataSource(object):
# 'aws-lambda://tenant=543_dc2;function=123456789012:function:my-function;'
CONNECTION_STRING_PREFIX = "aws-lambda://"
CONNECTION_STRING_GROUP_SEPARATOR = ";"
CONNECTION_STRING_KEY_VALUE_SEPARATOR = "="

def __init__(self, connection_string, logger=None):
self.logger = logger or logging.getLogger(__name__)
if not AWSLambdaDataSource.can_handle_connection_string(connection_string):
raise ValueError(connection_string)
self.connection_string = connection_string
self.connection_data = dict(
kv.split(AWSLambdaDataSource.CONNECTION_STRING_KEY_VALUE_SEPARATOR)
for kv in self.connection_string.lstrip(
AWSLambdaDataSource.CONNECTION_STRING_PREFIX
)
.rstrip(AWSLambdaDataSource.CONNECTION_STRING_GROUP_SEPARATOR)
.split(AWSLambdaDataSource.CONNECTION_STRING_GROUP_SEPARATOR)
)
self.aws_lambda_client = boto3.client("lambda")

@staticmethod
def can_handle_connection_string(connection_string):
return connection_string.startswith(
AWSLambdaDataSource.CONNECTION_STRING_PREFIX
) and len(connection_string) != len(
AWSLambdaDataSource.CONNECTION_STRING_PREFIX
)

@staticmethod
def get_connection_string_prefix():
return AWSLambdaDataSource.CONNECTION_STRING_PREFIX

def get_table_info(self, table_config, last_known_sync_version):
column_names, last_sync_version, sync_version, full_refresh_required, data_changed_since_last_sync \
= self.__get_table_info(table_config, last_known_sync_version)
columns_in_database = column_names
change_tracking_info = ChangeTrackingInfo(
last_sync_version=last_sync_version,
sync_version=sync_version,
force_full_load=full_refresh_required,
data_changed_since_last_sync=data_changed_since_last_sync,
)
source_table_info = SourceTableInfo(columns_in_database, change_tracking_info)
return source_table_info

@prevent_senstive_data_logging
def get_table_data_frame(
self,
table_config,
columns_config,
batch_config,
batch_tracker,
batch_key_tracker,
full_refresh,
change_tracking_info,
):
self.logger.debug(f"Starting read data from lambda.. : \n{None}")
column_names, data = self.__get_table_data(
table_config,
batch_config,
change_tracking_info,
full_refresh,
columns_config,
batch_key_tracker,
)
self.logger.debug(f"Finished read data from lambda.. : \n{None}")
# should we log size of data extracted?
data_frame = self.__get_data_frame(data, column_names)
batch_tracker.extract_completed_successfully(len(data_frame))
return data_frame

def __get_table_info(self, table_config, last_known_sync_version):
pay_load = {
"Command": "GetTableInfo",
"TenantId": int(self.connection_data["tenant"]),
"Table": {"Schema": table_config["schema"], "Name": table_config["name"]},
"CommandPayload": {"LastSyncVersion": last_known_sync_version},
}

result = self.__invoke_lambda(pay_load)

return result["ColumnNames"], \
result["LastSyncVersion"], \
result["CurrentSyncVersion"], \
result["FullRefreshRequired"], \
result["DataChangedSinceLastSync"]

def __get_table_data(
self,
table_config,
batch_config,
change_tracking_info,
full_refresh,
columns_config,
batch_key_tracker,
):
pay_load = {
"Command": "GetTableData",
"TenantId": int(self.connection_data["tenant"]),
"Table": {"Schema": table_config["schema"], "Name": table_config["name"]},
"CommandPayload": {
"AuditColumnNameForChangeVersion": Providers.AuditColumnsNames.CHANGE_VERSION,
"AuditColumnNameForDeletionFlag": Providers.AuditColumnsNames.IS_DELETED,
"BatchSize": batch_config["size"],
"LastSyncVersion": change_tracking_info.last_sync_version,
"FullRefresh": full_refresh,
"ColumnNames": list(map(lambda cfg: cfg['source_name'], columns_config)),
"PrimaryKeyColumnNames": table_config["primary_keys"],
"LastBatchPrimaryKeys": [
{"Key": k, "Value": v} for k, v in batch_key_tracker.bookmarks.items()
],
},
}

result = self.__invoke_lambda(pay_load)

return result["ColumnNames"], result["Data"]

def __get_data_frame(self, data: [[]], column_names: []):
return pandas.DataFrame(data=data, columns=column_names)

def __invoke_lambda(self, pay_load):
self.logger.debug('\nRequest being sent to Lambda:')
self.logger.debug(pay_load)

lambda_response = self.aws_lambda_client.invoke(
FunctionName=self.connection_data["function"],
InvocationType="RequestResponse",
LogType="None", # |'Tail', Set to Tail to include the execution log in the response
Payload=json.dumps(pay_load).encode(),
)

response_status_code = int(lambda_response['StatusCode'])
response_function_error = lambda_response.get("FunctionError")
self.logger.debug('\nResponse received from Lambda:')
self.logger.debug(f'Response - StatusCode = "{response_status_code}"')
self.logger.debug(f'Response - FunctionError = "{response_function_error}"')

response_payload = json.loads(lambda_response['Payload'].read())

if response_status_code != 200 or response_function_error:
self.logger.error(F'Error in response from aws lambda {self.connection_data["function"]}')
self.logger.error(f'Response - Status Code = {response_status_code}')
self.logger.error(f'Response - Error Function = {response_function_error}')
self.logger.error(f'Response - Error Details:')
# the below is risky as it may contain actual data if this line is reached in case of a successful result
# however, the same Payload field is used to return actual error details in case of real errors
# i.e. StatusCode is 200 (since AWS could invoke the lambda)
# BUT the lambda barfed with an error and therefore the FunctionError would not be None
self.logger.error(response_payload)
raise Exception('Error received when invoking AWS Lambda. See logs for further details.')

return response_payload
6 changes: 3 additions & 3 deletions rdl/data_sources/DataSourceFactory.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import logging
from rdl.data_sources.MsSqlDataSource import MsSqlDataSource

from rdl.data_sources.AWSLambdaDataSource import AWSLambdaDataSource

class DataSourceFactory(object):

def __init__(self, logger=None):
self.logger = logger or logging.getLogger(__name__)
self.sources = [MsSqlDataSource]
self.sources = [MsSqlDataSource, AWSLambdaDataSource]

def create_source(self, connection_string):
for source in self.sources:
Expand All @@ -23,4 +23,4 @@ def is_prefix_supported(self, connection_string):
return False

def get_supported_source_prefixes(self):
return list(map(lambda source: source.connection_string_prefix(), self.sources))
return list(map(lambda source: source.get_connection_string_prefix(), self.sources))
2 changes: 1 addition & 1 deletion rdl/data_sources/MsSqlDataSource.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def can_handle_connection_string(connection_string):
return MsSqlDataSource.__connection_string_regex_match(connection_string) is not None

@staticmethod
def connection_string_prefix():
def get_connection_string_prefix():
return 'mssql+pyodbc://'

def get_table_info(self, table_config, last_known_sync_version):
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
'SQLAlchemy==1.3.3',
'sqlalchemy-citext==1.3.post0',
'alembic==1.0.9',
'boto3==1.9.187',
],
package_data={
'': ['alembic.ini', 'alembic/*.py', 'alembic/**/*.py'],
Expand Down
36 changes: 36 additions & 0 deletions tests/unit_tests/test_AWSLambdaDataSource.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import unittest

from rdl.data_sources.AWSLambdaDataSource import AWSLambdaDataSource


class TestAWSLambdaDataSource(unittest.TestCase):
data_source = None
table_configs = []

@classmethod
def setUpClass(cls):
TestAWSLambdaDataSource.data_source = AWSLambdaDataSource(
"aws-lambda://tenant=543_dc2;function=123456789012:function:my-function;"
)

@classmethod
def tearDownClass(cls):
TestAWSLambdaDataSource.data_source = None

def test_can_handle_valid_connection_string(self):
self.assertTrue(
self.data_source.can_handle_connection_string(
"aws-lambda://tenant=543_dc2;function=123456789012:function:my-function;"
)
)

def test_can_handle_invalid_connection_string(self):
self.assertFalse(
self.data_source.can_handle_connection_string(
"lambda-aws://tenant=543_dc2;function=123456789012:function:my-function;"
)
)


if __name__ == "__main__":
unittest.main()