Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add type annotations to torch.nn.parallel._functions #49687

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 0 additions & 3 deletions mypy.ini
Expand Up @@ -86,9 +86,6 @@ ignore_errors = True
[mypy-torch.nn.modules.pooling]
ignore_errors = True

[mypy-torch.nn.parallel._functions]
ignore_errors = True

[mypy-torch.nn.qat.modules.activations]
ignore_errors = True

Expand Down
11 changes: 6 additions & 5 deletions torch/nn/parallel/_functions.py
Expand Up @@ -4,6 +4,7 @@
from . import comm
from torch.autograd import Function
from torch._utils import _get_device_index
from typing import List, Optional


class Broadcast(Function):
Expand Down Expand Up @@ -39,9 +40,9 @@ class ReduceAddCoalesced(Function):
def forward(ctx, destination, num_inputs, *grads):
ctx.target_gpus = [grads[i].get_device() for i in range(0, len(grads), num_inputs)]

grads = [grads[i:i + num_inputs]
for i in range(0, len(grads), num_inputs)]
return comm.reduce_add_coalesced(grads, destination)
grads_ = [grads[i:i + num_inputs]
for i in range(0, len(grads), num_inputs)]
return comm.reduce_add_coalesced(grads_, destination)

@staticmethod
def backward(ctx, *grad_outputs):
Expand Down Expand Up @@ -105,10 +106,10 @@ def backward(ctx, *grad_output):


# background streams used for copying
_streams = None
_streams: Optional[List[Optional[torch.cuda.Stream]]] = None


def _get_stream(device):
def _get_stream(device: int):
"""Gets a background stream for copying between CPU and GPU"""
global _streams
if device == -1:
Expand Down