Skip to content

Commit

Permalink
Better config management for overwriting settings
Browse files Browse the repository at this point in the history
  • Loading branch information
vierja committed Oct 6, 2017
1 parent 83c5eb3 commit dcc874d
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 22 deletions.
19 changes: 10 additions & 9 deletions luminoth/models/fasterrcnn/base_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,16 @@ train:
# Learning rate config.
learning_rate:
# Learning rate decay method ((empty), "none", piecewise_constant, exponential_decay, polynomial_decay)
decay_method: piecewise_constant
boundaries: [50000, 100000, 150000]
values: [0.001, 0.0005, 0.0003, 0.0001]
# You can define different decay methods using `decay_method` and defining all the necessary arguments.
decay_method:
learning_rate: 0.001

# Optimizer config
optimizer:
# Type of optimizer to use (momentum, adam, gradient_descent, rmsprop)
type: adam
type: momentum
# any options are passed directly to the optimizer as kwarg.
# momentum: 0.9
momentum: 0.9

# Number of epochs (complete dataset batches) to run
num_epochs: 1000
Expand Down Expand Up @@ -106,11 +106,12 @@ base_network:
# Should we download weights if not available
download: True
# Which endpoint layer to use as feature map for network
endpoint: conv5/conv5_1
endpoint:
# Is trainable, how many layers from the endpoint are we training
finetune_num_layers: 3
# Regularization
weight_decay: 0.0005
finetune_num_layers:
arg_scope:
# Regularization
weight_decay: 0.0005

loss:
# Loss weights for calculating the total loss
Expand Down
40 changes: 27 additions & 13 deletions luminoth/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,43 +21,57 @@ def kwargs_to_config(kwargs):
))


def is_basestring(value):
"""
Checks if value is string in both Python2.7 and Python3+
"""
return isinstance(value, (type(u''), str))


def types_compatible(new_config_value, base_config_value):
"""
Checks that config value types are compatible.
"""
# Allow to overwrite None values (explicit or just missing)
if base_config_value is None:
return True
# Allow all None and False values.
# Allow overwrite all None and False values.
# TODO: reconsider this.
if new_config_value is None or new_config_value is False:
return True
# For Python2 compatibility. We want to allow the case when both are
# basestrings (e.g. unicode and str).
if (isinstance(new_config_value, (type(u''), str))
and isinstance(base_config_value, (type(u''), str))):

# Checking strings is different because in Python2 we could get different
# types str vs unicode.
if is_basestring(new_config_value) and is_basestring(base_config_value):
return True

return isinstance(new_config_value, type(base_config_value))


def merge_into(new_config, base_config):
def merge_into(new_config, base_config, overwrite=False):
if type(new_config) is not easydict.EasyDict:
return

for key, value in new_config.items():
if value is None:
continue

# All keys in new_config must be overwriting values in base_config
if key not in base_config:
raise KeyError('Key "{}" is not a valid config key.'.format(key))
# Since we already have the values of base_config we check against them
if (not types_compatible(value, base_config[key])):
if not types_compatible(value, base_config.get(key)):
raise ValueError(
'Incorrect type "{}" for key "{}". Must be "{}"'.format(
type(value), key, type(base_config[key])))
type(value), key, type(base_config.get(key))))

# Recursively merge dicts
if type(value) is easydict.EasyDict:
if (isinstance(value, dict) and
base_config.get(key) is not None and
not overwrite):
# Something
try:
merge_into(new_config[key], base_config[key])
merge_into(
new_config[key], base_config.get(key, {}),
overwrite=key == 'train' # Overwrite train config.
)
except (KeyError, ValueError) as e:
raise e
else:
Expand Down

0 comments on commit dcc874d

Please sign in to comment.