Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updates & Bug Fixes #10

Merged
merged 4 commits into from
Aug 30, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 53 additions & 7 deletions tap_mongodb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,57 @@

REQUIRED_CONFIG_KEYS = [
'host',
'port'
'port',
'user',
'password',
'database'
]

IGNORE_DBS = ['admin', 'system', 'local', 'config']

def get_databases(client, config):

# usersInfo Command returns object in shape:
# {
# <some_other_keys>
# 'users': [
# {
# '_id': <auth_db>.<user>,
# 'db': <auth_db>,
# 'mechanisms': ['SCRAM-SHA-1', 'SCRAM-SHA-256'],
# 'roles': [{'db': 'admin', 'role': 'readWriteAnyDatabase'},
# {'db': 'local', 'role': 'read'}],
# 'user': <user>,
# 'userId': <userId>
# }
# ]
# }
user_info = client[config['database']].command({'usersInfo': config['user']})

can_read_all = False
db_names = []

users = [u for u in user_info.get('users') if u.get('user')==config['user']]
if len(users)!=1:
LOGGER.warning('Could not find any users for {}'.format(config['user']))
return []

for role_info in users[0].get('roles', []):
if role_info.get('role') in ['read', 'readWrite']:
if role_info.get('db'):
db_names.append(role_info['db'])
elif role_info.get('role') in ['readAnyDatabase', 'readWriteAnyDatabase']:
# If user has either of these roles they can query all databases
can_read_all = True
break

if can_read_all:
db_names = client.list_database_names()

return [d for d in db_names if d not in IGNORE_DBS]



def produce_collection_schema(collection):
collection_name = collection.name
collection_db_name = collection.database.name
Expand Down Expand Up @@ -64,12 +110,10 @@ def produce_collection_schema(collection):
}


def do_discover(client):
def do_discover(client, config):
streams = []

database_names = client.list_database_names()
for db_name in [d for d in database_names
if d not in IGNORE_DBS]:
for db_name in get_databases(client, config):
db = client[db_name]

collection_names = db.list_collection_names()
Expand Down Expand Up @@ -156,7 +200,9 @@ def sync_stream(client, stream, state):
try:
stream_projection = json.loads(stream_projection)
except:
raise common.InvalidProjectionException("The projection provided is not valid JSON")
err_msg = "The projection: {} for stream {} is not valid json"
raise common.InvalidProjectionException(err_msg.format(stream_projection,
tap_stream_id))
else:
LOGGER.warning('There is no projection found for stream %s, all fields will be retrieved.', stream['tap_stream_id'])

Expand Down Expand Up @@ -228,7 +274,7 @@ def main_impl():
common.include_schemas_in_destination_stream_name = (config.get('include_schemas_in_destination_stream_name') == 'true')

if args.discover:
do_discover(client)
do_discover(client, config)
elif args.catalog:
state = args.state or {}
do_sync(client, args.catalog.to_dict(), state)
Expand Down
12 changes: 2 additions & 10 deletions tap_mongodb/sync_strategies/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
class InvalidProjectionException(Exception):
"""Raised if projection blacklists _id"""

class UnsupportedReplicationKeyType(Exception):
class UnsupportedReplicationKeyTypeException(Exception):
"""Raised if key type is unsupported"""

def calculate_destination_stream_name(stream):
Expand Down Expand Up @@ -53,10 +53,8 @@ def class_to_string(bookmark_value, bookmark_type):
return utils.strftime(utc_datetime)
if bookmark_type == 'Timestamp':
return '{}.{}'.format(bookmark_value.time, bookmark_value.inc)
if bookmark_type in ['int', 'ObjectId', 'Decimal', 'float']:
if bookmark_type in ['int', 'ObjectId']:
return str(bookmark_value)
if bookmark_type == 'str':
return bookmark_value
raise UnsupportedReplicationKeyTypeException("{} is not a supported replication key type".format(bookmark_type))


Expand All @@ -67,15 +65,9 @@ def string_to_class(str_value, type_value):
return int(str_value)
if type_value == 'ObjectId':
return objectid.ObjectId(str_value)
if type_value == 'Decimal':
return decimal.Decimal(str_value)
if type_value == 'float':
return float(str_value)
if type_value == 'Timestamp':
split_value = str_value.split('.')
return bson.timestamp.Timestamp(int(split_value[0]), int(split_value[1]))
if type_value == 'str':
return str_val
raise UnsupportedReplicationKeyTypeException("{} is not a supported replication key type".format(bookmark_type))

def transform_value(value):
Expand Down
65 changes: 37 additions & 28 deletions tap_mongodb/sync_strategies/incremental.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,14 @@ def update_bookmark(row, state, tap_stream_id, replication_key_name):
state = singer.write_bookmark(state,
tap_stream_id,
'replication_key_type',
replication_key_type)
replication_key_type)

def sync_collection(client, stream, state, projection):
tap_stream_id = stream['tap_stream_id']
LOGGER.info('Starting incremental sync for {}'.format(tap_stream_id))

mdata = metadata.to_map(stream['metadata'])
stream_metadata = mdata.get(())
database_name = stream_metadata['database-name']

db = client[database_name]
collection = db[stream['stream']]
stream_metadata = metadata.to_map(stream['metadata']).get(())
collection = client[stream_metadata['database-name']][stream['stream']]

#before writing the table version to state, check if we had one to begin with
first_run = singer.get_bookmark(state, stream['tap_stream_id'], 'version') is None
Expand All @@ -50,7 +46,6 @@ def sync_collection(client, stream, state, projection):
stream['tap_stream_id'],
'version',
stream_version)
singer.write_message(singer.StateMessage(value=copy.deepcopy(state)))

activate_version_message = singer.ActivateVersionMessage(
stream=common.calculate_destination_stream_name(stream),
Expand All @@ -63,56 +58,70 @@ def sync_collection(client, stream, state, projection):
if first_run:
singer.write_message(activate_version_message)

# get bookmarks if they exist
# get replication key, and bookmarked value/type
stream_state = state.get('bookmarks', {}).get(tap_stream_id, {})
replication_key_bookmark = stream_state.get('replication_key_name')
replication_key_value_bookmark = stream_state.get('replication_key_value')
replication_key_type_bookmark = stream_state.get('replication_key_type')
replication_key_name_bookmark = stream_state.get('replication_key_name')
replication_key_name_md = stream_metadata.get('replication-key')


if not replication_key_bookmark:
replication_key_bookmark = stream_metadata.get('replication-key')
replication_key_value_bookmark = None
if replication_key_name_bookmark == replication_key_name_md:
replication_key_value_bookmark = stream_state.get('replication_key_value')
else:
if replication_key_name_bookmark is not None:
log_msg = "Replication Key changed from {} to {}, will re-replicate entire collection {}"
LOGGER.warning(log_msg.format(replication_key_name_bookmark,
replication_key_name_md,
tap_stream_id))
replication_key_name_bookmark = replication_key_name_md
state = singer.write_bookmark(state,
tap_stream_id,
'replication_key_name',
replication_key_bookmark)
replication_key_name_bookmark)
state = singer.clear_bookmark(state,
tap_stream_id,
'replication_key_value')
state = singer.clear_bookmark(state,
tap_stream_id,
'replication_key_type')

# write state message
singer.write_message(singer.StateMessage(value=copy.deepcopy(state)))

# create query
find_filter = {}
if replication_key_value_bookmark:
find_filter[replication_key_bookmark] = {}
find_filter[replication_key_bookmark]['$gte'] = common.string_to_class(replication_key_value_bookmark,
replication_key_type_bookmark = stream_state.get('replication_key_type')
find_filter[replication_key_name_bookmark] = {}
find_filter[replication_key_name_bookmark]['$gte'] = common.string_to_class(replication_key_value_bookmark,
replication_key_type_bookmark)


query_message = 'Querying {} with:\n\tFind Parameters: {}'.format(
stream['tap_stream_id'],
find_filter)
# log query
query_message = 'Querying {} with:\n\tFind Parameters: {}'.format(tap_stream_id, find_filter)
if projection:
query_message += '\n\tProjection: {}'.format(projection)
LOGGER.info(query_message)


# query collection
with collection.find(find_filter,
projection,
sort=[(replication_key_bookmark, pymongo.ASCENDING)]) as cursor:
sort=[(replication_key_name_bookmark, pymongo.ASCENDING)]) as cursor:
rows_saved = 0

time_extracted = utils.now()

start_time = time.time()

for row in cursor:
rows_saved += 1

record_message = common.row_to_singer_record(stream,
row,
stream_version,
time_extracted)

singer.write_message(record_message)
rows_saved += 1

update_bookmark(row, state, tap_stream_id, replication_key_name_bookmark)

update_bookmark(row, state, tap_stream_id, replication_key_bookmark)

if rows_saved % common.UPDATE_BOOKMARK_PERIOD == 0:
singer.write_message(singer.StateMessage(value=copy.deepcopy(state)))

Expand Down