Skip to content

Commit

Permalink
Updates & Bug Fixes (#10)
Browse files Browse the repository at this point in the history
* Only support int, ObjectId, timestamp, and datetime rep key types

* make sure replication key hasnt changed

* only discover databases that the user has access to

* add comment
  • Loading branch information
nick-mccoy committed Aug 30, 2019
1 parent 2a3ae9a commit 063c3fd
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 45 deletions.
60 changes: 53 additions & 7 deletions tap_mongodb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,57 @@

REQUIRED_CONFIG_KEYS = [
'host',
'port'
'port',
'user',
'password',
'database'
]

IGNORE_DBS = ['admin', 'system', 'local', 'config']

def get_databases(client, config):

# usersInfo Command returns object in shape:
# {
# <some_other_keys>
# 'users': [
# {
# '_id': <auth_db>.<user>,
# 'db': <auth_db>,
# 'mechanisms': ['SCRAM-SHA-1', 'SCRAM-SHA-256'],
# 'roles': [{'db': 'admin', 'role': 'readWriteAnyDatabase'},
# {'db': 'local', 'role': 'read'}],
# 'user': <user>,
# 'userId': <userId>
# }
# ]
# }
user_info = client[config['database']].command({'usersInfo': config['user']})

can_read_all = False
db_names = []

users = [u for u in user_info.get('users') if u.get('user')==config['user']]
if len(users)!=1:
LOGGER.warning('Could not find any users for {}'.format(config['user']))
return []

for role_info in users[0].get('roles', []):
if role_info.get('role') in ['read', 'readWrite']:
if role_info.get('db'):
db_names.append(role_info['db'])
elif role_info.get('role') in ['readAnyDatabase', 'readWriteAnyDatabase']:
# If user has either of these roles they can query all databases
can_read_all = True
break

if can_read_all:
db_names = client.list_database_names()

return [d for d in db_names if d not in IGNORE_DBS]



def produce_collection_schema(collection):
collection_name = collection.name
collection_db_name = collection.database.name
Expand Down Expand Up @@ -64,12 +110,10 @@ def produce_collection_schema(collection):
}


def do_discover(client):
def do_discover(client, config):
streams = []

database_names = client.list_database_names()
for db_name in [d for d in database_names
if d not in IGNORE_DBS]:
for db_name in get_databases(client, config):
db = client[db_name]

collection_names = db.list_collection_names()
Expand Down Expand Up @@ -156,7 +200,9 @@ def sync_stream(client, stream, state):
try:
stream_projection = json.loads(stream_projection)
except:
raise common.InvalidProjectionException("The projection provided is not valid JSON")
err_msg = "The projection: {} for stream {} is not valid json"
raise common.InvalidProjectionException(err_msg.format(stream_projection,
tap_stream_id))
else:
LOGGER.warning('There is no projection found for stream %s, all fields will be retrieved.', stream['tap_stream_id'])

Expand Down Expand Up @@ -228,7 +274,7 @@ def main_impl():
common.include_schemas_in_destination_stream_name = (config.get('include_schemas_in_destination_stream_name') == 'true')

if args.discover:
do_discover(client)
do_discover(client, config)
elif args.catalog:
state = args.state or {}
do_sync(client, args.catalog.to_dict(), state)
Expand Down
12 changes: 2 additions & 10 deletions tap_mongodb/sync_strategies/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
class InvalidProjectionException(Exception):
"""Raised if projection blacklists _id"""

class UnsupportedReplicationKeyType(Exception):
class UnsupportedReplicationKeyTypeException(Exception):
"""Raised if key type is unsupported"""

def calculate_destination_stream_name(stream):
Expand Down Expand Up @@ -53,10 +53,8 @@ def class_to_string(bookmark_value, bookmark_type):
return utils.strftime(utc_datetime)
if bookmark_type == 'Timestamp':
return '{}.{}'.format(bookmark_value.time, bookmark_value.inc)
if bookmark_type in ['int', 'ObjectId', 'Decimal', 'float']:
if bookmark_type in ['int', 'ObjectId']:
return str(bookmark_value)
if bookmark_type == 'str':
return bookmark_value
raise UnsupportedReplicationKeyTypeException("{} is not a supported replication key type".format(bookmark_type))


Expand All @@ -67,15 +65,9 @@ def string_to_class(str_value, type_value):
return int(str_value)
if type_value == 'ObjectId':
return objectid.ObjectId(str_value)
if type_value == 'Decimal':
return decimal.Decimal(str_value)
if type_value == 'float':
return float(str_value)
if type_value == 'Timestamp':
split_value = str_value.split('.')
return bson.timestamp.Timestamp(int(split_value[0]), int(split_value[1]))
if type_value == 'str':
return str_val
raise UnsupportedReplicationKeyTypeException("{} is not a supported replication key type".format(bookmark_type))

def transform_value(value):
Expand Down
65 changes: 37 additions & 28 deletions tap_mongodb/sync_strategies/incremental.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,14 @@ def update_bookmark(row, state, tap_stream_id, replication_key_name):
state = singer.write_bookmark(state,
tap_stream_id,
'replication_key_type',
replication_key_type)
replication_key_type)

def sync_collection(client, stream, state, projection):
tap_stream_id = stream['tap_stream_id']
LOGGER.info('Starting incremental sync for {}'.format(tap_stream_id))

mdata = metadata.to_map(stream['metadata'])
stream_metadata = mdata.get(())
database_name = stream_metadata['database-name']

db = client[database_name]
collection = db[stream['stream']]
stream_metadata = metadata.to_map(stream['metadata']).get(())
collection = client[stream_metadata['database-name']][stream['stream']]

#before writing the table version to state, check if we had one to begin with
first_run = singer.get_bookmark(state, stream['tap_stream_id'], 'version') is None
Expand All @@ -50,7 +46,6 @@ def sync_collection(client, stream, state, projection):
stream['tap_stream_id'],
'version',
stream_version)
singer.write_message(singer.StateMessage(value=copy.deepcopy(state)))

activate_version_message = singer.ActivateVersionMessage(
stream=common.calculate_destination_stream_name(stream),
Expand All @@ -63,56 +58,70 @@ def sync_collection(client, stream, state, projection):
if first_run:
singer.write_message(activate_version_message)

# get bookmarks if they exist
# get replication key, and bookmarked value/type
stream_state = state.get('bookmarks', {}).get(tap_stream_id, {})
replication_key_bookmark = stream_state.get('replication_key_name')
replication_key_value_bookmark = stream_state.get('replication_key_value')
replication_key_type_bookmark = stream_state.get('replication_key_type')
replication_key_name_bookmark = stream_state.get('replication_key_name')
replication_key_name_md = stream_metadata.get('replication-key')


if not replication_key_bookmark:
replication_key_bookmark = stream_metadata.get('replication-key')
replication_key_value_bookmark = None
if replication_key_name_bookmark == replication_key_name_md:
replication_key_value_bookmark = stream_state.get('replication_key_value')
else:
if replication_key_name_bookmark is not None:
log_msg = "Replication Key changed from {} to {}, will re-replicate entire collection {}"
LOGGER.warning(log_msg.format(replication_key_name_bookmark,
replication_key_name_md,
tap_stream_id))
replication_key_name_bookmark = replication_key_name_md
state = singer.write_bookmark(state,
tap_stream_id,
'replication_key_name',
replication_key_bookmark)
replication_key_name_bookmark)
state = singer.clear_bookmark(state,
tap_stream_id,
'replication_key_value')
state = singer.clear_bookmark(state,
tap_stream_id,
'replication_key_type')

# write state message
singer.write_message(singer.StateMessage(value=copy.deepcopy(state)))

# create query
find_filter = {}
if replication_key_value_bookmark:
find_filter[replication_key_bookmark] = {}
find_filter[replication_key_bookmark]['$gte'] = common.string_to_class(replication_key_value_bookmark,
replication_key_type_bookmark = stream_state.get('replication_key_type')
find_filter[replication_key_name_bookmark] = {}
find_filter[replication_key_name_bookmark]['$gte'] = common.string_to_class(replication_key_value_bookmark,
replication_key_type_bookmark)


query_message = 'Querying {} with:\n\tFind Parameters: {}'.format(
stream['tap_stream_id'],
find_filter)
# log query
query_message = 'Querying {} with:\n\tFind Parameters: {}'.format(tap_stream_id, find_filter)
if projection:
query_message += '\n\tProjection: {}'.format(projection)
LOGGER.info(query_message)


# query collection
with collection.find(find_filter,
projection,
sort=[(replication_key_bookmark, pymongo.ASCENDING)]) as cursor:
sort=[(replication_key_name_bookmark, pymongo.ASCENDING)]) as cursor:
rows_saved = 0

time_extracted = utils.now()

start_time = time.time()

for row in cursor:
rows_saved += 1

record_message = common.row_to_singer_record(stream,
row,
stream_version,
time_extracted)

singer.write_message(record_message)
rows_saved += 1

update_bookmark(row, state, tap_stream_id, replication_key_name_bookmark)

update_bookmark(row, state, tap_stream_id, replication_key_bookmark)

if rows_saved % common.UPDATE_BOOKMARK_PERIOD == 0:
singer.write_message(singer.StateMessage(value=copy.deepcopy(state)))

Expand Down

0 comments on commit 063c3fd

Please sign in to comment.