Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Annotate torch.nn.cpp #46490

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 0 additions & 3 deletions mypy.ini
Expand Up @@ -161,9 +161,6 @@ ignore_errors = True
[mypy-torch.nn.utils.prune]
ignore_errors = True

[mypy-torch.nn.cpp]
ignore_errors = True

[mypy-torch.utils.show_pickle]
ignore_errors = True

Expand Down
6 changes: 4 additions & 2 deletions torch/distributions/kl.py
Expand Up @@ -104,8 +104,10 @@ def _dispatch_kl(type_p, type_q):
if not matches:
return NotImplemented
# Check that the left- and right- lexicographic orders agree.
left_p, left_q = min(_Match(*m) for m in matches).types
right_q, right_p = min(_Match(*reversed(m)) for m in matches).types
# mypy isn't smart enough to know that _Match implements __lt__
# see: https://github.com/python/typing/issues/760#issuecomment-710670503
left_p, left_q = min(_Match(*m) for m in matches).types # type: ignore
right_q, right_p = min(_Match(*reversed(m)) for m in matches).types # type: ignore
left_fun = _KL_REGISTRY[left_p, left_q]
right_fun = _KL_REGISTRY[right_p, right_q]
if left_fun is not right_fun:
Expand Down
9 changes: 5 additions & 4 deletions torch/nn/cpp.py
Expand Up @@ -57,9 +57,9 @@ def __init__(self, cpp_module):
# assigned to in the super class constructor.
self.cpp_module = cpp_module
super(ModuleWrapper, self).__init__()
self._parameters = OrderedDictWrapper(cpp_module, "_parameters")
self._buffers = OrderedDictWrapper(cpp_module, "_buffers")
self._modules = OrderedDictWrapper(cpp_module, "_modules")
self._parameters = OrderedDictWrapper(cpp_module, "_parameters") # type: ignore[assignment]
self._buffers: OrderedDictWrapper = OrderedDictWrapper(cpp_module, "_buffers") # type: ignore[assignment]
self._modules: OrderedDictWrapper = OrderedDictWrapper(cpp_module, "_modules") # type: ignore[assignment]
for attr in dir(cpp_module):
# Skip magic methods and the three attributes above.
if not attr.startswith("_"):
Expand All @@ -78,7 +78,8 @@ def _apply(self, fn):

return self

@property
# nn.Module defines training as a boolean
@property # type: ignore[override]
def training(self):
return self.cpp_module.training

Expand Down
1 change: 1 addition & 0 deletions torch/testing/_internal/common_utils.py
Expand Up @@ -1044,6 +1044,7 @@ def _compareScalars(self, a, b, *,
rtol, atol = 0, 0
rtol = cast(float, rtol)
atol = cast(float, atol)
assert atol is not None
atol = max(atol, self.precision)

return _compare_scalars_internal(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
Expand Down
2 changes: 1 addition & 1 deletion torch/utils/cpp_extension.py
Expand Up @@ -1560,7 +1560,7 @@ def _import_module_from_library(module_name, path, is_python_module):
# Close the .so file after load.
with file:
if is_python_module:
return imp.load_module(module_name, file, path, description)
return imp.load_module(module_name, file, path, description) # type: ignore
else:
torch.ops.load_library(path)

Expand Down