New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add type annotations to torch._C._distributed_rpc module. #46624
Conversation
[ghstack-poisoned]
💊 CI failures summary and remediationsAs of commit 980619c (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group. This comment has been revised 145 times. |
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
ghstack-source-id: 990d10a2ae28ffccde4256f6ae207d9d494f3e34 Pull Request resolved: #46624
[ghstack-poisoned]
[ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me, but can you please modify mypy.ini
to add torch.distributed.rpc
to the set of modules where type annotation is enforced.
[ghstack-poisoned]
@@ -224,8 +227,8 @@ def _start_record_function(exec_type, func_name, current_worker_name, dest_worke | |||
profile_key = "rpc_{}#{}({} -> {})".format( | |||
exec_type.value, str(func_name), current_worker_name, dest_worker_name | |||
) | |||
rf = torch.autograd._RecordFunction() | |||
torch.autograd._run_before_callbacks(rf, profile_key) | |||
rf = torch.autograd._RecordFunction() # type: ignore[attr-defined] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ignoring this error because I don't find anywhere "_RecordFunction()" and "_run_before_callbacks()" are defined.
pass | ||
except TypeError as exc: | ||
# TypeError: metaclass conflict: the metaclass of a derived class | ||
# must be a (non-strict) subclass of the metaclasses of all its bases | ||
class RRefMeta(PyRRef.__class__, GenericWithOneTypeVar.__class__): | ||
# Mypy doesn't understand __class__ (mypy bug #4177) | ||
class RRefMeta(PyRRef.__class__, GenericWithOneTypeVar.__class__): # type: ignore |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ignore all type errors because this line (and line 393) will cause multiple type errors.
[ghstack-poisoned]
[ghstack-poisoned]
ghstack-source-id: 653f2dc89bb18de3a01ec64f1b5a18eef9151148 Pull Request resolved: #46624
@@ -379,16 +379,18 @@ def _rref_typeof_on_user(rref): | |||
|
|||
try: | |||
# Combine the implementation class and the type class. | |||
class RRef(PyRRef, GenericWithOneTypeVar): | |||
class RRef(PyRRef, Generic[T]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mypy complains " Variable "torch.distributed.rpc.api.GenericWithOneTypeVar" is not valid as a type [valid-type]" if use GenericWithOneTypeVar here.
[ghstack-poisoned]
@@ -114,6 +120,20 @@ PyObject* rpc_init(PyObject* _unused, PyObject* noargs) { | |||
"join", &RpcAgent::join, py::call_guard<py::gil_scoped_release>()) | |||
.def( | |||
"sync", &RpcAgent::sync, py::call_guard<py::gil_scoped_release>()) | |||
.def( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding three definitions because they are used in internal.py and api.py.
Thanks for the review. I have modified |
[ghstack-poisoned]
ghstack-source-id: 5243889c76c5cde8c2c221697080be23d7605446 Pull Request resolved: #46624
[ghstack-poisoned]
ghstack-source-id: cd8c93b49d1c02174f3bc31f3ef6f2a07308a0fa Pull Request resolved: #46624
Differential Revision: [D24761656](https://our.internmc.facebook.com/intern/diff/D24761656) [ghstack-poisoned]
Differential Revision: [D24761656](https://our.internmc.facebook.com/intern/diff/D24761656) [ghstack-poisoned]
ghstack-source-id: 2c9ba154b8c7e2c15a3cd47420f4d3a7e884afe9 Pull Request resolved: #46624
Stack from ghstack:
Differential Revision: D24761656