Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable checking of untyped definitions #1612

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/release.rst
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ Maintenance

.. _release_2.17.1:

* Enabled checking of untyped definitions by mypy in all modules apart from
``zarr.tests`` and ``zarr._storage``.
By :user:`David Stansby <dstansby>` :issue:`1612`.

2.17.1
------

Expand Down
10 changes: 10 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,16 @@ ignore_missing_imports = true
warn_unused_configs = true
warn_redundant_casts = true
warn_unused_ignores = true
strict_equality = true
strict_concatenate = true
check_untyped_defs = true

[[tool.mypy.overrides]]
module = [
"zarr.tests.*",
"zarr._storage.*"
]
check_untyped_defs = false

[tool.pytest.ini_options]
doctest_optionflags = [
Expand Down
2 changes: 1 addition & 1 deletion zarr/_storage/v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,7 @@ def __init__(self, store, max_size: int):
self._keys_cache = None
self._contains_cache = {}
self._listdir_cache: Dict[Path, Any] = dict()
self._values_cache: Dict[Path, Any] = OrderedDict()
self._values_cache: OrderedDict[Path, Any] = OrderedDict()
self._mutex = Lock()
self.hits = self.misses = 0

Expand Down
3 changes: 2 additions & 1 deletion zarr/attrs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import warnings
from collections.abc import MutableMapping
from typing import Any

from zarr._storage.store import Store, StoreV3
from zarr.util import json_dumps
Expand Down Expand Up @@ -41,7 +42,7 @@ def _get_nosync(self):
try:
data = self.store[self.key]
except KeyError:
d = dict()
d: dict[str, Any] = dict()
if self._version > 2:
d["attributes"] = {}
else:
Expand Down
20 changes: 8 additions & 12 deletions zarr/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1981,11 +1981,11 @@ def _set_selection(self, indexer, value, fields=None):
chunk_value = value[out_selection]
# handle missing singleton dimensions
if indexer.drop_axes:
item: list[slice | None]
item = [slice(None)] * self.ndim
for a in indexer.drop_axes:
item[a] = np.newaxis
item = tuple(item)
chunk_value = chunk_value[item]
chunk_value = chunk_value[tuple(item)]

# put data
self._chunk_setitem(chunk_coords, chunk_selection, chunk_value, fields=fields)
Expand All @@ -2004,8 +2004,7 @@ def _set_selection(self, indexer, value, fields=None):
item = [slice(None)] * self.ndim
for a in indexer.drop_axes:
item[a] = np.newaxis
item = tuple(item)
cv = chunk_value[item]
cv = chunk_value[tuple(item)]
chunk_values.append(cv)

self._chunk_setitems(lchunk_coords, lchunk_selection, chunk_values, fields=fields)
Expand Down Expand Up @@ -2133,6 +2132,7 @@ def _chunk_getitems(
# Keys to retrieve
ckeys = [self._chunk_key(ch) for ch in lchunk_coords]

cdatas: dict[str, PartialReadBuffer | UncompressedPartialReadBufferV3]
# Check if we can do a partial read
if (
self._partial_decompress
Expand Down Expand Up @@ -2171,6 +2171,7 @@ def _chunk_getitems(
cdatas = {key: value for key, value in zip(ckeys, values) if value is not None}
else:
partial_read_decode = False
contexts: dict[str, Context] | ConstantMap
contexts = {}
if not isinstance(self._meta_array, np.ndarray):
contexts = ConstantMap(ckeys, constant=Context(meta_array=self._meta_array))
Expand Down Expand Up @@ -2323,7 +2324,7 @@ def _process_for_setitem(self, ckey, chunk_selection, value, fields=None):

return chunk

def _chunk_key(self, chunk_coords):
def _chunk_key(self, chunk_coords) -> str:
if self._version == 3:
# _chunk_key() corresponds to data_key(P, i, j, ...) example in the spec
# where P = self._key_prefix, i, j, ... = chunk_coords
Expand Down Expand Up @@ -2542,12 +2543,7 @@ def hexdigest(self, hashname="sha1"):
"""

checksum = binascii.hexlify(self.digest(hashname=hashname))

# This is a bytes object on Python 3 and we want a str.
if not isinstance(checksum, str):
checksum = checksum.decode("utf8")

return checksum
return checksum.decode("utf8")

def __getstate__(self):
return {
Expand All @@ -2565,7 +2561,7 @@ def __getstate__(self):
}

def __setstate__(self, state):
self.__init__(**state)
self.__init__(**state) # type: ignore[misc]

def _synchronized_op(self, f, *args, **kwargs):
if self._synchronizer is None:
Expand Down
13 changes: 9 additions & 4 deletions zarr/hierarchy.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,15 +148,19 @@ def __init__(
*,
meta_array=None,
):
store: BaseStore = _normalize_store_arg(store, zarr_version=zarr_version)
store: BaseStore = _normalize_store_arg( # type: ignore[no-redef]
store, zarr_version=zarr_version
)
if zarr_version is None:
zarr_version = getattr(store, "_store_version", DEFAULT_ZARR_VERSION)

if zarr_version != 2:
assert_zarr_v3_api_available()

if chunk_store is not None:
chunk_store: BaseStore = _normalize_store_arg(chunk_store, zarr_version=zarr_version)
chunk_store: BaseStore = _normalize_store_arg( # type: ignore[no-redef]
chunk_store, zarr_version=zarr_version
)
self._store = store
self._chunk_store = chunk_store
self._path = normalize_storage_path(path)
Expand Down Expand Up @@ -201,6 +205,7 @@ def __init__(
self._meta = self._store._metadata_class.decode_group_metadata(meta_bytes)

# setup attributes
akey: str | None
if self._version == 2:
akey = self._key_prefix + attrs_key
else:
Expand Down Expand Up @@ -410,7 +415,7 @@ def __getstate__(self):
}

def __setstate__(self, state):
self.__init__(**state)
self.__init__(**state) # type: ignore[misc]

def _item_path(self, item):
absolute = isinstance(item, str) and item and item[0] == "/"
Expand Down Expand Up @@ -541,7 +546,7 @@ def __getattr__(self, item):

def __dir__(self):
# noinspection PyUnresolvedReferences
base = super().__dir__()
base = list(super().__dir__())
keys = sorted(set(base + list(self)))
keys = [k for k in keys if is_valid_python_name(k)]
return keys
Expand Down
16 changes: 13 additions & 3 deletions zarr/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

import numpy as np

from typing import Union, Optional, Tuple, List


from zarr.errors import (
ArrayIndexError,
Expand Down Expand Up @@ -330,6 +332,7 @@ def __init__(self, selection, array):
# setup per-dimension indexers
dim_indexers = []
for dim_sel, dim_len, dim_chunk_len in zip(selection, array._shape, array._chunks):
dim_indexer: Union[IntDimIndexer, SliceDimIndexer]
if is_integer(dim_sel):
dim_indexer = IntDimIndexer(dim_sel, dim_len, dim_chunk_len)

Expand Down Expand Up @@ -520,9 +523,11 @@ def __iter__(self):
else:
start = self.chunk_nitems_cumsum[dim_chunk_ix - 1]
stop = self.chunk_nitems_cumsum[dim_chunk_ix]
dim_out_sel: Union[slice, np.ndarray]
if self.order == Order.INCREASING:
dim_out_sel = slice(start, stop)
else:
assert self.dim_out_sel is not None
dim_out_sel = self.dim_out_sel[start:stop]

# find region in chunk
Expand Down Expand Up @@ -576,11 +581,11 @@ def oindex_set(a, selection, value):
selection = ix_(selection, a.shape)
if not np.isscalar(value) and drop_axes:
value = np.asanyarray(value)
value_selection: List[Union[slice, None]]
value_selection = [slice(None)] * len(a.shape)
for i in drop_axes:
value_selection[i] = np.newaxis
value_selection = tuple(value_selection)
value = value[value_selection]
value = value[tuple(value_selection)]
a[selection] = value


Expand All @@ -595,6 +600,8 @@ def __init__(self, selection, array):

# setup per-dimension indexers
dim_indexers = []
dim_indexer: Union[IntDimIndexer, SliceDimIndexer, IntArrayDimIndexer, BoolArrayDimIndexer]

for dim_sel, dim_len, dim_chunk_len in zip(selection, array._shape, array._chunks):
if is_integer(dim_sel):
dim_indexer = IntDimIndexer(dim_sel, dim_len, dim_chunk_len)
Expand All @@ -621,6 +628,7 @@ def __init__(self, selection, array):
self.dim_indexers = dim_indexers
self.shape = tuple(s.nitems for s in self.dim_indexers if not isinstance(s, IntDimIndexer))
self.is_advanced = not is_basic_selection(selection)
self.drop_axes: Optional[Tuple[int, ...]]
if self.is_advanced:
self.drop_axes = tuple(
i
Expand Down Expand Up @@ -796,7 +804,7 @@ def __init__(self, selection, array):
boundscheck_indices(dim_sel, dim_len)

# compute chunk index for each point in the selection
chunks_multi_index = tuple(
chunks_multi_index = list(
dim_sel // dim_chunk_len for (dim_sel, dim_chunk_len) in zip(selection, array._chunks)
)

Expand Down Expand Up @@ -847,6 +855,7 @@ def __iter__(self):
else:
start = self.chunk_nitems_cumsum[chunk_rix - 1]
stop = self.chunk_nitems_cumsum[chunk_rix]
out_selection: Union[slice, np.ndarray]
if self.sel_sort is None:
out_selection = slice(start, stop)
else:
Expand Down Expand Up @@ -949,6 +958,7 @@ def check_no_multi_fields(fields):


def pop_fields(selection):
fields: Union[str, List[str], None]
if isinstance(selection, str):
# single field selection
fields = selection
Expand Down
4 changes: 2 additions & 2 deletions zarr/n5.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,8 +328,8 @@ def __init__(self, *args, **kwargs):
if "dimension_separator" in kwargs:
kwargs.pop("dimension_separator")
warnings.warn("Keyword argument `dimension_separator` will be ignored")
dimension_separator = "."
super().__init__(*args, dimension_separator=dimension_separator, **kwargs)
kwargs["dimension_separator"] = "."
super().__init__(*args, **kwargs)

@staticmethod
def _swap_separator(key: str):
Expand Down
29 changes: 16 additions & 13 deletions zarr/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
from numcodecs.registry import codec_registry
from zarr.context import Context
from zarr.types import PathLike as Path, DIMENSION_SEPARATOR
from zarr.util import NoLock

from zarr.errors import (
MetadataError,
Expand All @@ -57,6 +56,7 @@
buffer_size,
json_loads,
nolock,
NoLock,
normalize_chunks,
normalize_dimension_separator,
normalize_dtype,
Expand Down Expand Up @@ -835,7 +835,7 @@ def __getstate__(self):

def __setstate__(self, state):
root, cls = state
self.__init__(root=root, cls=cls)
self.__init__(root=root, cls=cls) # type: ignore[misc]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These igore statements are all for the following warning. I'm not sure if there's a good way to fix instead of ignoring the warning?

Accessing "init" on an instance is unsound, since instance.init could be from an incompatible subclass

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pretty sure the right thing here is to define__reduce__, which I started over in #1089. It's stalled, but very simple to pick up I think

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 - do you want #1089 to be finished before merging this PR, or would it be okay to merge this with the # type: ignore comments and then get rid of them when #1089 is done?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think blinding mypy to our sins with #type: ignore is fine for now


def _get_parent(self, item: str):
parent = self.root
Expand Down Expand Up @@ -1815,7 +1815,9 @@ def __setstate__(self, state):
# get clobbered
if mode in "wx":
mode = "a"
self.__init__(path=path, compression=compression, allowZip64=allowZip64, mode=mode)
self.__init__( # type: ignore[misc]
path=path, compression=compression, allowZip64=allowZip64, mode=mode
)

def close(self):
"""Closes the underlying zip file, ensuring all records are written."""
Expand Down Expand Up @@ -2095,7 +2097,7 @@ def __init__(
self.mode = mode
self.open = open
self.write_lock = write_lock
self.write_mutex: Union[Lock, NoLock]
self.write_mutex: Lock | NoLock
if write_lock:
# This may not be required as some dbm implementations manage their own
# locks, but err on the side of caution.
Expand All @@ -2117,7 +2119,9 @@ def __setstate__(self, state):
path, flag, mode, open, write_lock, open_kws = state
if flag[0] == "n":
flag = "c" + flag[1:] # don't clobber an existing database
self.__init__(path=path, flag=flag, mode=mode, open=open, write_lock=write_lock, **open_kws)
self.__init__( # type: ignore[misc]
path=path, flag=flag, mode=mode, open=open, write_lock=write_lock, **open_kws
)

def close(self):
"""Closes the underlying database file."""
Expand Down Expand Up @@ -2308,7 +2312,7 @@ def __getstate__(self):

def __setstate__(self, state):
path, buffers, kwargs = state
self.__init__(path=path, buffers=buffers, **kwargs)
self.__init__(path=path, buffers=buffers, **kwargs) # type: ignore[misc]

def close(self):
"""Closes the underlying database."""
Expand Down Expand Up @@ -2419,10 +2423,10 @@ def __init__(self, store: StoreLike, max_size: int):
self._store: BaseStore = BaseStore._ensure_store(store)
self._max_size = max_size
self._current_size = 0
self._keys_cache = None
self._keys_cache: None | list = None
self._contains_cache: Dict[Any, Any] = {}
self._listdir_cache: Dict[Path, Any] = dict()
self._values_cache: Dict[Path, Any] = OrderedDict()
self._values_cache: OrderedDict[Path, Any] = OrderedDict()
self._mutex = Lock()
self.hits = self.misses = 0

Expand Down Expand Up @@ -2659,7 +2663,7 @@ def __getstate__(self):

def __setstate__(self, state):
path, kwargs = state
self.__init__(path=path, **kwargs)
self.__init__(path=path, **kwargs) # type: ignore[misc]

def close(self):
"""Closes the underlying database."""
Expand Down Expand Up @@ -2737,8 +2741,7 @@ def listdir(self, path=None):
""",
(path, path),
)
keys = list(map(operator.itemgetter(0), keys))
return keys
return list(map(operator.itemgetter(0), keys))

def getsize(self, path=None):
path = normalize_storage_path(path)
Expand Down Expand Up @@ -2850,7 +2853,7 @@ def __getstate__(self):

def __setstate__(self, state):
database, collection, kwargs = state
self.__init__(database=database, collection=collection, **kwargs)
self.__init__(database=database, collection=collection, **kwargs) # type: ignore[misc]

def close(self):
"""Cleanup client resources and disconnect from MongoDB."""
Expand Down Expand Up @@ -2924,7 +2927,7 @@ def __getstate__(self):

def __setstate__(self, state):
prefix, kwargs = state
self.__init__(prefix=prefix, **kwargs)
self.__init__(prefix=prefix, **kwargs) # type: ignore[misc]

def clear(self):
for key in self.keys():
Expand Down
2 changes: 1 addition & 1 deletion zarr/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __getstate__(self):

def __setstate__(self, *args):
# reinitialize from scratch
self.__init__()
self.__init__() # type: ignore[misc]


class ProcessSynchronizer(Synchronizer):
Expand Down