Skip to content

Commit

Permalink
check type of attribute keys (#1066)
Browse files Browse the repository at this point in the history
* check type of attribute keys

* introduce deprecation cycle

* fix typo

* stringify keys

* cleanup

* do not cover except
  • Loading branch information
malmans2 committed Sep 8, 2022
1 parent ea7bb11 commit f6e3d61
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 0 deletions.
22 changes: 22 additions & 0 deletions zarr/attrs.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from collections.abc import MutableMapping

from zarr._storage.store import Store, StoreV3
Expand Down Expand Up @@ -128,6 +129,27 @@ def put(self, d):
self._write_op(self._put_nosync, dict(attributes=d))

def _put_nosync(self, d):

d_to_check = d if self._version == 2 else d["attributes"]
if not all(isinstance(item, str) for item in d_to_check):
# TODO: Raise an error for non-string keys
# raise TypeError("attribute keys must be strings")
warnings.warn(
"only attribute keys of type 'string' will be allowed in the future",
DeprecationWarning,
stacklevel=2
)

try:
d_to_check = {str(k): v for k, v in d_to_check.items()}
except TypeError as ex: # pragma: no cover
raise TypeError("attribute keys can not be stringified") from ex

if self._version == 2:
d = d_to_check
else:
d["attributes"] = d_to_check

if self._version == 2:
self.store[self.key] = json_dumps(d)
if self.cache:
Expand Down
15 changes: 15 additions & 0 deletions zarr/tests/test_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,3 +268,18 @@ def test_caching_off(self, zarr_version):
get_cnt = 10 if zarr_version == 2 else 12
assert get_cnt == store.counter['__getitem__', attrs_key]
assert 3 == store.counter['__setitem__', attrs_key]

def test_wrong_keys(self, zarr_version):
store = _init_store(zarr_version)
a = self.init_attributes(store, zarr_version=zarr_version)

warning_msg = "only attribute keys of type 'string' will be allowed in the future"

with pytest.warns(DeprecationWarning, match=warning_msg):
a[1] = "foo"

with pytest.warns(DeprecationWarning, match=warning_msg):
a.put({1: "foo"})

with pytest.warns(DeprecationWarning, match=warning_msg):
a.update({1: "foo"})

0 comments on commit f6e3d61

Please sign in to comment.