From 203721fc719b73aae5dd418d9d0745463e7a4741 Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Fri, 16 Oct 2020 20:37:42 +0000 Subject: [PATCH 1/3] annotate torch.nn.cpp --- mypy.ini | 3 --- torch/nn/cpp.py | 9 +++++---- torch/testing/_internal/common_utils.py | 2 +- torch/utils/cpp_extension.py | 2 +- 4 files changed, 7 insertions(+), 9 deletions(-) diff --git a/mypy.ini b/mypy.ini index 14f3504f6cce..535310411720 100644 --- a/mypy.ini +++ b/mypy.ini @@ -161,9 +161,6 @@ ignore_errors = True [mypy-torch.nn.utils.prune] ignore_errors = True -[mypy-torch.nn.cpp] -ignore_errors = True - [mypy-torch.utils.show_pickle] ignore_errors = True diff --git a/torch/nn/cpp.py b/torch/nn/cpp.py index 194c17bd6b5a..25a5bcc446aa 100644 --- a/torch/nn/cpp.py +++ b/torch/nn/cpp.py @@ -57,9 +57,9 @@ def __init__(self, cpp_module): # assigned to in the super class constructor. self.cpp_module = cpp_module super(ModuleWrapper, self).__init__() - self._parameters = OrderedDictWrapper(cpp_module, "_parameters") - self._buffers = OrderedDictWrapper(cpp_module, "_buffers") - self._modules = OrderedDictWrapper(cpp_module, "_modules") + self._parameters = OrderedDictWrapper(cpp_module, "_parameters") # type: ignore[assignment] + self._buffers: OrderedDictWrapper = OrderedDictWrapper(cpp_module, "_buffers") # type: ignore[assignment] + self._modules: OrderedDictWrapper = OrderedDictWrapper(cpp_module, "_modules") # type: ignore[assignment] for attr in dir(cpp_module): # Skip magic methods and the three attributes above. if not attr.startswith("_"): @@ -78,7 +78,8 @@ def _apply(self, fn): return self - @property + # nn.Module defines training as a boolean + @property # type: ignore[override] def training(self): return self.cpp_module.training diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 65912eba60bc..ebb55d54dd6a 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -1044,7 +1044,7 @@ def _compareScalars(self, a, b, *, rtol, atol = 0, 0 rtol = cast(float, rtol) atol = cast(float, atol) - atol = max(atol, self.precision) + atol = max(atol, self.precision) if atol is not None else self.precision return _compare_scalars_internal(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) diff --git a/torch/utils/cpp_extension.py b/torch/utils/cpp_extension.py index 1358ca611901..4948e6e33099 100644 --- a/torch/utils/cpp_extension.py +++ b/torch/utils/cpp_extension.py @@ -1560,7 +1560,7 @@ def _import_module_from_library(module_name, path, is_python_module): # Close the .so file after load. with file: if is_python_module: - return imp.load_module(module_name, file, path, description) + return imp.load_module(module_name, file, path, description) # type: ignore else: torch.ops.load_library(path) From 923c08f9defd89dee9982891cf76b2f41f419195 Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Wed, 21 Oct 2020 21:07:01 +0000 Subject: [PATCH 2/3] add type: ignore to kl.py --- torch/distributions/kl.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torch/distributions/kl.py b/torch/distributions/kl.py index c7e079b1f57a..fe64ccc56009 100644 --- a/torch/distributions/kl.py +++ b/torch/distributions/kl.py @@ -104,8 +104,10 @@ def _dispatch_kl(type_p, type_q): if not matches: return NotImplemented # Check that the left- and right- lexicographic orders agree. - left_p, left_q = min(_Match(*m) for m in matches).types - right_q, right_p = min(_Match(*reversed(m)) for m in matches).types + # mypy isn't smart enough to know that _Match implements __lt__ + # see: https://github.com/python/typing/issues/760#issuecomment-710670503 + left_p, left_q = min(_Match(*m) for m in matches).types # type: ignore + right_q, right_p = min(_Match(*reversed(m)) for m in matches).types # type: ignore left_fun = _KL_REGISTRY[left_p, left_q] right_fun = _KL_REGISTRY[right_p, right_q] if left_fun is not right_fun: From 540fcd76f613ab6b8175b15e7fb7f975a1c68e93 Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Thu, 22 Oct 2020 16:02:18 +0000 Subject: [PATCH 3/3] move assertion --- torch/testing/_internal/common_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index ebb55d54dd6a..bfb61c0a6981 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -1044,7 +1044,8 @@ def _compareScalars(self, a, b, *, rtol, atol = 0, 0 rtol = cast(float, rtol) atol = cast(float, atol) - atol = max(atol, self.precision) if atol is not None else self.precision + assert atol is not None + atol = max(atol, self.precision) return _compare_scalars_internal(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)