Skip to content

Commit

Permalink
[CLI-641] wandb.config.update() should not modify passed arg (#1706)
Browse files Browse the repository at this point in the history
* wandb.config.update() should not modify passed arg
* add unittest for sideeffect
* add missing tests and a type
* more tests, fix config.setdefaults()
  • Loading branch information
raubitsj committed Jan 16, 2021
1 parent 649c262 commit 98e363d
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 30 deletions.
107 changes: 85 additions & 22 deletions tests/wandb_config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,39 +7,102 @@
from wandb import wandb_sdk


def callback_func(key=None, val=None, data=None):
print(key, val, data)
def get_callback(d):
def callback_func(key=None, val=None, data=None):
print("CONFIG", key, val, data)
if data:
d.update(data)
if key:
d[key] = val

return callback_func

def test_attrib_get():
s = wandb_sdk.Config()
s._set_callback(callback_func)
s.this = 2
assert s.this == 2

@pytest.fixture()
def consolidated():
return {}

def test_locked_set():
s = wandb_sdk.Config()
s.update_locked(dict(this=2, that=4), "sweep")
s.this = 8
assert s.this == 2
assert s.that == 4

@pytest.fixture()
def callback(consolidated):
return get_callback(consolidated)


def test_locked_update():
@pytest.fixture()
def config(callback):
s = wandb_sdk.Config()
s.update_locked(dict(this=2, that=4), "sweep")
s.update(dict(this=8))
assert s.this == 2
assert s.that == 4
s._set_callback(callback)
return s


def test_attrib_set(consolidated, config):
config.this = 2
assert dict(config) == dict(this=2)
assert consolidated == dict(config)


def test_locked_set_attr(consolidated, config):
config.update_locked(dict(this=2, that=4), "sweep")
config.this = 8
assert config.this == 2
assert config.that == 4
assert dict(config) == dict(this=2, that=4)
assert consolidated == dict(config)


def test_locked_set_key(consolidated, config):
config.update_locked(dict(this=2, that=4), "sweep")
config["this"] = 8
assert config["this"] == 2
assert config["that"] == 4
assert dict(config) == dict(this=2, that=4)
assert consolidated == dict(config)


def test_update(consolidated, config):
config.update(dict(this=8))
assert dict(config) == dict(this=8)
config.update(dict(that=4))
assert dict(config) == dict(this=8, that=4)
assert consolidated == dict(config)


def test_setdefaults(consolidated, config):
config.update(dict(this=8))
assert dict(config) == dict(this=8)
config.setdefaults(dict(extra=2, another=4))
assert dict(config) == dict(this=8, extra=2, another=4)
assert consolidated == dict(config)


def test_setdefaults_existing(consolidated, config):
config.update(dict(this=8))
assert dict(config) == dict(this=8)
config.setdefaults(dict(extra=2, this=4))
assert dict(config) == dict(this=8, extra=2)
assert consolidated == dict(config)


def test_locked_update(consolidated, config):
config.update_locked(dict(this=2, that=4), "sweep")
config.update(dict(this=8))
assert dict(config) == dict(this=2, that=4)
assert consolidated == dict(config)


def test_locked_no_sideeffect(consolidated, config):
config.update_locked(dict(this=2, that=4), "sweep")
update_arg = dict(this=8)
config.update(update_arg)
assert update_arg == dict(this=8)
assert dict(config) == dict(this=2, that=4)
assert consolidated == dict(config)


def test_load_config_default():
test_path = "config-defaults.yaml"
yaml_dict = {"epochs": {"value": 32}, "size_batch": {"value": 32}}
with open(test_path, "w") as f:
yaml.dump(yaml_dict, f, default_flow_style=False)
s = wandb_sdk.Config()
expected = sorted([("epochs", 32), ("size_batch", 32)], key=lambda x: x[0])
actual = sorted(s.items(), key=lambda x: x[0])
assert actual == expected
config = wandb_sdk.Config()
assert dict(config) == dict(epochs=32, size_batch=32)
16 changes: 12 additions & 4 deletions wandb/sdk/wandb_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,13 @@ def __contains__(self, key):

def _update(self, d, allow_val_change=None, ignore_locked=None):
parsed_dict = wandb_helper.parse_config(d)
locked_keys = set()
for key in list(parsed_dict):
if self._check_locked(key, ignore_locked=ignore_locked):
del parsed_dict[key]
sanitized = self._sanitize_dict(parsed_dict, allow_val_change)
locked_keys.add(key)
sanitized = self._sanitize_dict(
parsed_dict, allow_val_change, ignore_keys=locked_keys
)
self._items.update(sanitized)
return sanitized

Expand All @@ -178,9 +181,10 @@ def persist(self):

def setdefaults(self, d):
d = wandb_helper.parse_config(d)
d = self._sanitize_dict(d)
# strip out keys already configured
d = {k: v for k, v in six.iteritems(d) if k not in self._items}
d = self._sanitize_dict(d)
self._items.update(d)
if self._callback:
self._callback(data=d)

Expand All @@ -205,9 +209,13 @@ def _load_defaults(self):
if conf_dict is not None:
self.update(conf_dict)

def _sanitize_dict(self, config_dict, allow_val_change=None):
def _sanitize_dict(
self, config_dict, allow_val_change=None, ignore_keys: set = None
):
sanitized = {}
for k, v in six.iteritems(config_dict):
if ignore_keys and k in ignore_keys:
continue
k, v = self._sanitize(k, v, allow_val_change)
sanitized[k] = v
return sanitized
Expand Down
16 changes: 12 additions & 4 deletions wandb/sdk_py27/wandb_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,13 @@ def __contains__(self, key):

def _update(self, d, allow_val_change=None, ignore_locked=None):
parsed_dict = wandb_helper.parse_config(d)
locked_keys = set()
for key in list(parsed_dict):
if self._check_locked(key, ignore_locked=ignore_locked):
del parsed_dict[key]
sanitized = self._sanitize_dict(parsed_dict, allow_val_change)
locked_keys.add(key)
sanitized = self._sanitize_dict(
parsed_dict, allow_val_change, ignore_keys=locked_keys
)
self._items.update(sanitized)
return sanitized

Expand All @@ -178,9 +181,10 @@ def persist(self):

def setdefaults(self, d):
d = wandb_helper.parse_config(d)
d = self._sanitize_dict(d)
# strip out keys already configured
d = {k: v for k, v in six.iteritems(d) if k not in self._items}
d = self._sanitize_dict(d)
self._items.update(d)
if self._callback:
self._callback(data=d)

Expand All @@ -205,9 +209,13 @@ def _load_defaults(self):
if conf_dict is not None:
self.update(conf_dict)

def _sanitize_dict(self, config_dict, allow_val_change=None):
def _sanitize_dict(
self, config_dict, allow_val_change=None, ignore_keys = None
):
sanitized = {}
for k, v in six.iteritems(config_dict):
if ignore_keys and k in ignore_keys:
continue
k, v = self._sanitize(k, v, allow_val_change)
sanitized[k] = v
return sanitized
Expand Down

0 comments on commit 98e363d

Please sign in to comment.