Skip to content

Commit

Permalink
wandb.config.update() should not modify passed arg
Browse files Browse the repository at this point in the history
  • Loading branch information
raubitsj committed Jan 15, 2021
1 parent 7696137 commit 8ebcbc3
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 6 deletions.
11 changes: 8 additions & 3 deletions wandb/sdk/wandb_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,13 @@ def __contains__(self, key):

def _update(self, d, allow_val_change=None, ignore_locked=None):
parsed_dict = wandb_helper.parse_config(d)
locked_keys = set()
for key in list(parsed_dict):
if self._check_locked(key, ignore_locked=ignore_locked):
del parsed_dict[key]
sanitized = self._sanitize_dict(parsed_dict, allow_val_change)
locked_keys.add(key)
sanitized = self._sanitize_dict(
parsed_dict, allow_val_change, ignore_keys=locked_keys
)
self._items.update(sanitized)
return sanitized

Expand Down Expand Up @@ -205,9 +208,11 @@ def _load_defaults(self):
if conf_dict is not None:
self.update(conf_dict)

def _sanitize_dict(self, config_dict, allow_val_change=None):
def _sanitize_dict(self, config_dict, allow_val_change=None, ignore_keys=None):
sanitized = {}
for k, v in six.iteritems(config_dict):
if k in ignore_keys:
continue
k, v = self._sanitize(k, v, allow_val_change)
sanitized[k] = v
return sanitized
Expand Down
11 changes: 8 additions & 3 deletions wandb/sdk_py27/wandb_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,13 @@ def __contains__(self, key):

def _update(self, d, allow_val_change=None, ignore_locked=None):
parsed_dict = wandb_helper.parse_config(d)
locked_keys = set()
for key in list(parsed_dict):
if self._check_locked(key, ignore_locked=ignore_locked):
del parsed_dict[key]
sanitized = self._sanitize_dict(parsed_dict, allow_val_change)
locked_keys.add(key)
sanitized = self._sanitize_dict(
parsed_dict, allow_val_change, ignore_keys=locked_keys
)
self._items.update(sanitized)
return sanitized

Expand Down Expand Up @@ -205,9 +208,11 @@ def _load_defaults(self):
if conf_dict is not None:
self.update(conf_dict)

def _sanitize_dict(self, config_dict, allow_val_change=None):
def _sanitize_dict(self, config_dict, allow_val_change=None, ignore_keys=None):
sanitized = {}
for k, v in six.iteritems(config_dict):
if k in ignore_keys:
continue
k, v = self._sanitize(k, v, allow_val_change)
sanitized[k] = v
return sanitized
Expand Down

0 comments on commit 8ebcbc3

Please sign in to comment.