Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CLI-641] wandb.config.update() should not modify passed arg #1706

Merged
merged 6 commits into from
Jan 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
107 changes: 85 additions & 22 deletions tests/wandb_config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,39 +7,102 @@
from wandb import wandb_sdk


def callback_func(key=None, val=None, data=None):
print(key, val, data)
def get_callback(d):
def callback_func(key=None, val=None, data=None):
print("CONFIG", key, val, data)
if data:
d.update(data)
if key:
d[key] = val

return callback_func

def test_attrib_get():
s = wandb_sdk.Config()
s._set_callback(callback_func)
s.this = 2
assert s.this == 2

@pytest.fixture()
def consolidated():
return {}

def test_locked_set():
s = wandb_sdk.Config()
s.update_locked(dict(this=2, that=4), "sweep")
s.this = 8
assert s.this == 2
assert s.that == 4

@pytest.fixture()
def callback(consolidated):
return get_callback(consolidated)


def test_locked_update():
@pytest.fixture()
def config(callback):
s = wandb_sdk.Config()
s.update_locked(dict(this=2, that=4), "sweep")
s.update(dict(this=8))
assert s.this == 2
assert s.that == 4
s._set_callback(callback)
return s


def test_attrib_set(consolidated, config):
config.this = 2
assert dict(config) == dict(this=2)
assert consolidated == dict(config)


def test_locked_set_attr(consolidated, config):
config.update_locked(dict(this=2, that=4), "sweep")
config.this = 8
assert config.this == 2
assert config.that == 4
assert dict(config) == dict(this=2, that=4)
assert consolidated == dict(config)


def test_locked_set_key(consolidated, config):
config.update_locked(dict(this=2, that=4), "sweep")
config["this"] = 8
assert config["this"] == 2
assert config["that"] == 4
assert dict(config) == dict(this=2, that=4)
assert consolidated == dict(config)


def test_update(consolidated, config):
config.update(dict(this=8))
assert dict(config) == dict(this=8)
config.update(dict(that=4))
assert dict(config) == dict(this=8, that=4)
assert consolidated == dict(config)


def test_setdefaults(consolidated, config):
config.update(dict(this=8))
assert dict(config) == dict(this=8)
config.setdefaults(dict(extra=2, another=4))
assert dict(config) == dict(this=8, extra=2, another=4)
assert consolidated == dict(config)


def test_setdefaults_existing(consolidated, config):
config.update(dict(this=8))
assert dict(config) == dict(this=8)
config.setdefaults(dict(extra=2, this=4))
assert dict(config) == dict(this=8, extra=2)
assert consolidated == dict(config)


def test_locked_update(consolidated, config):
config.update_locked(dict(this=2, that=4), "sweep")
config.update(dict(this=8))
assert dict(config) == dict(this=2, that=4)
assert consolidated == dict(config)


def test_locked_no_sideeffect(consolidated, config):
config.update_locked(dict(this=2, that=4), "sweep")
update_arg = dict(this=8)
config.update(update_arg)
assert update_arg == dict(this=8)
assert dict(config) == dict(this=2, that=4)
assert consolidated == dict(config)


def test_load_config_default():
test_path = "config-defaults.yaml"
yaml_dict = {"epochs": {"value": 32}, "size_batch": {"value": 32}}
with open(test_path, "w") as f:
yaml.dump(yaml_dict, f, default_flow_style=False)
s = wandb_sdk.Config()
expected = sorted([("epochs", 32), ("size_batch", 32)], key=lambda x: x[0])
actual = sorted(s.items(), key=lambda x: x[0])
assert actual == expected
config = wandb_sdk.Config()
assert dict(config) == dict(epochs=32, size_batch=32)
16 changes: 12 additions & 4 deletions wandb/sdk/wandb_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,13 @@ def __contains__(self, key):

def _update(self, d, allow_val_change=None, ignore_locked=None):
parsed_dict = wandb_helper.parse_config(d)
locked_keys = set()
for key in list(parsed_dict):
if self._check_locked(key, ignore_locked=ignore_locked):
del parsed_dict[key]
sanitized = self._sanitize_dict(parsed_dict, allow_val_change)
locked_keys.add(key)
sanitized = self._sanitize_dict(
parsed_dict, allow_val_change, ignore_keys=locked_keys
)
self._items.update(sanitized)
return sanitized

Expand All @@ -178,9 +181,10 @@ def persist(self):

def setdefaults(self, d):
d = wandb_helper.parse_config(d)
d = self._sanitize_dict(d)
# strip out keys already configured
d = {k: v for k, v in six.iteritems(d) if k not in self._items}
d = self._sanitize_dict(d)
self._items.update(d)
if self._callback:
self._callback(data=d)

Expand All @@ -205,9 +209,13 @@ def _load_defaults(self):
if conf_dict is not None:
self.update(conf_dict)

def _sanitize_dict(self, config_dict, allow_val_change=None):
def _sanitize_dict(
self, config_dict, allow_val_change=None, ignore_keys: set = None
):
sanitized = {}
for k, v in six.iteritems(config_dict):
if ignore_keys and k in ignore_keys:
continue
k, v = self._sanitize(k, v, allow_val_change)
sanitized[k] = v
return sanitized
Expand Down
16 changes: 12 additions & 4 deletions wandb/sdk_py27/wandb_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,13 @@ def __contains__(self, key):

def _update(self, d, allow_val_change=None, ignore_locked=None):
parsed_dict = wandb_helper.parse_config(d)
locked_keys = set()
for key in list(parsed_dict):
if self._check_locked(key, ignore_locked=ignore_locked):
del parsed_dict[key]
sanitized = self._sanitize_dict(parsed_dict, allow_val_change)
locked_keys.add(key)
sanitized = self._sanitize_dict(
parsed_dict, allow_val_change, ignore_keys=locked_keys
)
self._items.update(sanitized)
return sanitized

Expand All @@ -178,9 +181,10 @@ def persist(self):

def setdefaults(self, d):
d = wandb_helper.parse_config(d)
d = self._sanitize_dict(d)
# strip out keys already configured
d = {k: v for k, v in six.iteritems(d) if k not in self._items}
d = self._sanitize_dict(d)
self._items.update(d)
if self._callback:
self._callback(data=d)

Expand All @@ -205,9 +209,13 @@ def _load_defaults(self):
if conf_dict is not None:
self.update(conf_dict)

def _sanitize_dict(self, config_dict, allow_val_change=None):
def _sanitize_dict(
self, config_dict, allow_val_change=None, ignore_keys = None
):
sanitized = {}
for k, v in six.iteritems(config_dict):
if ignore_keys and k in ignore_keys:
continue
k, v = self._sanitize(k, v, allow_val_change)
sanitized[k] = v
return sanitized
Expand Down