Skip to content

Commit

Permalink
Annotate torch.nn.cpp (#46490)
Browse files Browse the repository at this point in the history
Summary:
Fixes #46489

Pull Request resolved: #46490

Reviewed By: zhangguanheng66

Differential Revision: D24509519

Pulled By: ezyang

fbshipit-source-id: edffd32ab2ac17ae4bbd44826b71f5cb9f1da1c5
  • Loading branch information
guilhermeleobas authored and facebook-github-bot committed Oct 24, 2020
1 parent c4892c8 commit 789e935
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 10 deletions.
3 changes: 0 additions & 3 deletions mypy.ini
Expand Up @@ -161,9 +161,6 @@ ignore_errors = True
[mypy-torch.nn.utils.prune]
ignore_errors = True

[mypy-torch.nn.cpp]
ignore_errors = True

[mypy-torch.utils.show_pickle]
ignore_errors = True

Expand Down
6 changes: 4 additions & 2 deletions torch/distributions/kl.py
Expand Up @@ -104,8 +104,10 @@ def _dispatch_kl(type_p, type_q):
if not matches:
return NotImplemented
# Check that the left- and right- lexicographic orders agree.
left_p, left_q = min(_Match(*m) for m in matches).types
right_q, right_p = min(_Match(*reversed(m)) for m in matches).types
# mypy isn't smart enough to know that _Match implements __lt__
# see: https://github.com/python/typing/issues/760#issuecomment-710670503
left_p, left_q = min(_Match(*m) for m in matches).types # type: ignore
right_q, right_p = min(_Match(*reversed(m)) for m in matches).types # type: ignore
left_fun = _KL_REGISTRY[left_p, left_q]
right_fun = _KL_REGISTRY[right_p, right_q]
if left_fun is not right_fun:
Expand Down
9 changes: 5 additions & 4 deletions torch/nn/cpp.py
Expand Up @@ -57,9 +57,9 @@ def __init__(self, cpp_module):
# assigned to in the super class constructor.
self.cpp_module = cpp_module
super(ModuleWrapper, self).__init__()
self._parameters = OrderedDictWrapper(cpp_module, "_parameters")
self._buffers = OrderedDictWrapper(cpp_module, "_buffers")
self._modules = OrderedDictWrapper(cpp_module, "_modules")
self._parameters = OrderedDictWrapper(cpp_module, "_parameters") # type: ignore[assignment]
self._buffers: OrderedDictWrapper = OrderedDictWrapper(cpp_module, "_buffers") # type: ignore[assignment]
self._modules: OrderedDictWrapper = OrderedDictWrapper(cpp_module, "_modules") # type: ignore[assignment]
for attr in dir(cpp_module):
# Skip magic methods and the three attributes above.
if not attr.startswith("_"):
Expand All @@ -78,7 +78,8 @@ def _apply(self, fn):

return self

@property
# nn.Module defines training as a boolean
@property # type: ignore[override]
def training(self):
return self.cpp_module.training

Expand Down
1 change: 1 addition & 0 deletions torch/testing/_internal/common_utils.py
Expand Up @@ -1044,6 +1044,7 @@ def _compareScalars(self, a, b, *,
rtol, atol = 0, 0
rtol = cast(float, rtol)
atol = cast(float, atol)
assert atol is not None
atol = max(atol, self.precision)

return _compare_scalars_internal(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
Expand Down
2 changes: 1 addition & 1 deletion torch/utils/cpp_extension.py
Expand Up @@ -1560,7 +1560,7 @@ def _import_module_from_library(module_name, path, is_python_module):
# Close the .so file after load.
with file:
if is_python_module:
return imp.load_module(module_name, file, path, description)
return imp.load_module(module_name, file, path, description) # type: ignore
else:
torch.ops.load_library(path)

Expand Down

0 comments on commit 789e935

Please sign in to comment.