diff --git a/test/distributed/_pipeline/sync/test_pipe.py b/test/distributed/_pipeline/sync/test_pipe.py index f45281f5796d..5bc62399114b 100644 --- a/test/distributed/_pipeline/sync/test_pipe.py +++ b/test/distributed/_pipeline/sync/test_pipe.py @@ -258,6 +258,31 @@ def forward(self, x): assert counter == 2 +def test_nested_input(setup_rpc): + class NestedInput(nn.Module): + def __init__(self): + super().__init__() + self.fc_a = nn.Linear(1, 1) + self.fc_b = nn.Linear(1, 1) + + def forward(self, inp): + return inp + + model = nn.Sequential(NestedInput()) + model = Pipe(model, chunks=2) + + a = torch.rand(10, 1, requires_grad=True) + b = torch.rand(10, 1, requires_grad=True) + + # TypeError: expected Tensor, but got tuple + with pytest.raises(TypeError): + model((a, (a, b))).local_value() + + # TypeError: expected Tensor, but got list + with pytest.raises(TypeError): + model((a, [a, b])).local_value() + + def test_input_pair(setup_rpc): class Two(nn.Module): def __init__(self): @@ -282,6 +307,17 @@ def forward(self, a_and_b): assert a.grad is not None assert b.grad is not None + # Test with list. + a.grad = None + b.grad = None + a_out, b_out = model([a, b]).local_value() + loss = (a_out + b_out).mean() + loss.backward() + + assert a.grad is not None + assert b.grad is not None + + def test_input_singleton(setup_rpc): class One(nn.Module): @@ -305,6 +341,18 @@ def forward(self, only_a): assert all(p.grad is not None for p in model.parameters()) assert a.grad is not None + # Test with list + a.grad = None + for p in model.parameters(): + p.grad = None + + (a_out,) = model([a]).local_value() + loss = a_out.mean() + loss.backward() + + assert all(p.grad is not None for p in model.parameters()) + assert a.grad is not None + def test_input_varargs(setup_rpc): model = nn.Sequential(nn.Linear(1, 1)) @@ -336,7 +384,7 @@ def forward(self, _): model("hello") -def test_non_tensor_tuple(setup_rpc): +def test_non_tensor_sequence(setup_rpc): class NonTensorTuple(nn.Module): def forward(self, x): return (x, "hello") @@ -353,6 +401,10 @@ def forward(self, x): with pytest.raises(TypeError): model((x, "hello")) + # TypeError: expected Tensor to scatter, but got str + with pytest.raises(TypeError): + model([x, "hello"]) + @pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"]) def test_deferred_batch_norm(checkpoint, setup_rpc): diff --git a/torch/distributed/_pipeline/sync/balance/__init__.py b/torch/distributed/_pipeline/sync/balance/__init__.py index 15aa53bc1a2c..8c6da586657f 100644 --- a/torch/distributed/_pipeline/sync/balance/__init__.py +++ b/torch/distributed/_pipeline/sync/balance/__init__.py @@ -18,7 +18,7 @@ pipe = Pipe(model, balance, chunks=8) """ -from typing import List, Tuple, Union +from typing import List, Union, Sequence import torch from torch import Tensor @@ -32,7 +32,7 @@ Device = Union[torch.device, int, str] -Tensors = Tuple[Tensor, ...] +Tensors = Sequence[Tensor] TensorOrTensors = Union[Tensor, Tensors] diff --git a/torch/distributed/_pipeline/sync/balance/profile.py b/torch/distributed/_pipeline/sync/balance/profile.py index 737dda60f6fa..382da988e808 100644 --- a/torch/distributed/_pipeline/sync/balance/profile.py +++ b/torch/distributed/_pipeline/sync/balance/profile.py @@ -7,7 +7,7 @@ """Per-layer profilers.""" import copy import time -from typing import Generator, List, Tuple, Union +from typing import Generator, List, Union, Sequence import torch from torch import Tensor @@ -20,7 +20,7 @@ Device = Union[torch.device, int, str] -Tensors = Tuple[Tensor, ...] +Tensors = Sequence[Tensor] TensorOrTensors = Union[Tensor, Tensors] diff --git a/torch/distributed/_pipeline/sync/checkpoint.py b/torch/distributed/_pipeline/sync/checkpoint.py index bad5eec19469..3f9240793183 100644 --- a/torch/distributed/_pipeline/sync/checkpoint.py +++ b/torch/distributed/_pipeline/sync/checkpoint.py @@ -27,7 +27,16 @@ from collections import deque from contextlib import contextmanager import threading -from typing import TYPE_CHECKING, Deque, Generator, List, Optional, Tuple, Union +from typing import ( + TYPE_CHECKING, + Deque, + Generator, + List, + Optional, + Union, + Sequence, + Tuple +) import torch from torch import Tensor @@ -40,7 +49,7 @@ __all__ = ["is_checkpointing", "is_recomputing"] -Tensors = Tuple[Tensor, ...] +Tensors = Sequence[Tensor] TensorOrTensors = Union[Tensor, Tensors] # Types for shared memory between Checkpoint and Recompute. diff --git a/torch/distributed/_pipeline/sync/copy.py b/torch/distributed/_pipeline/sync/copy.py index 3d330f59eeee..07e71a87ce08 100644 --- a/torch/distributed/_pipeline/sync/copy.py +++ b/torch/distributed/_pipeline/sync/copy.py @@ -8,7 +8,7 @@ and computation on the same GPU. """ from collections import deque -from typing import Deque, List, Optional, Tuple +from typing import Deque, List, Optional, Tuple, Sequence import torch from torch import Tensor @@ -18,7 +18,7 @@ __all__: List[str] = [] -Tensors = Tuple[Tensor, ...] +Tensors = Sequence[Tensor] # Common interface between :class:`Copy` and :class:`Wait`. diff --git a/torch/distributed/_pipeline/sync/microbatch.py b/torch/distributed/_pipeline/sync/microbatch.py index d38cb6d3b85c..190e1e1f987b 100644 --- a/torch/distributed/_pipeline/sync/microbatch.py +++ b/torch/distributed/_pipeline/sync/microbatch.py @@ -6,7 +6,7 @@ # LICENSE file in the root directory of this source tree. """Manipulation of micro-batches.""" import typing -from typing import Callable, Iterable, Iterator, List, Tuple, Union, cast +from typing import Callable, Iterable, Iterator, List, Union, cast, Sequence import torch from torch import Tensor @@ -15,7 +15,7 @@ __all__: List[str] = [] -Tensors = Tuple[Tensor, ...] +Tensors = Sequence[Tensor] TensorOrTensors = Union[Tensor, Tensors] Function = Callable[[TensorOrTensors], TensorOrTensors] @@ -139,9 +139,10 @@ def check(input: TensorOrTensors) -> None: TypeError: input is not a tensor or tensors. """ - if isinstance(input, tuple): + if isinstance(input, Sequence): for x in input: - check(x) + if not isinstance(x, Tensor): + raise TypeError(f"expected Tensor, but got {input.__class__.__name__}") return if not isinstance(input, Tensor): diff --git a/torch/distributed/_pipeline/sync/pipe.py b/torch/distributed/_pipeline/sync/pipe.py index a097e8aa1a9e..82db93060d91 100644 --- a/torch/distributed/_pipeline/sync/pipe.py +++ b/torch/distributed/_pipeline/sync/pipe.py @@ -6,7 +6,7 @@ # LICENSE file in the root directory of this source tree. """The Pipe interface.""" from collections import OrderedDict -from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union, cast +from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union, cast, Sequence import torch from torch import Tensor, nn @@ -27,7 +27,7 @@ Device = Union[torch.device, int, str] Devices = Union[Iterable[Device], List[Device]] -Tensors = Tuple[Tensor, ...] +Tensors = Sequence[Tensor] TensorOrTensors = Union[Tensor, Tensors] if TYPE_CHECKING: @@ -310,11 +310,11 @@ def forward(self, input: TensorOrTensors) -> RRef[TensorOrTensors]: # type: ign """:class:`Pipe` is a fairly transparent module wrapper. It doesn't modify the input and output signature of the underlying module. But there's type restriction. Input and output have to be a - :class:`~torch.Tensor` or a tuple of tensors. This restriction is + :class:`~torch.Tensor` or a sequence of tensors. This restriction is applied at partition boundaries too. Args: - input (torch.Tensor or Tuple[torch.Tensor, ...]): input mini-batch + input (torch.Tensor or Sequence[torch.Tensor]): input mini-batch Returns: :class:`~torch.distributed.rpc.RRef` to the output of the mini-batch diff --git a/torch/distributed/_pipeline/sync/pipeline.py b/torch/distributed/_pipeline/sync/pipeline.py index 86c8dfddebeb..72c04c6f28d0 100644 --- a/torch/distributed/_pipeline/sync/pipeline.py +++ b/torch/distributed/_pipeline/sync/pipeline.py @@ -7,7 +7,7 @@ """The pipeline parallelism of Pipe.""" from queue import Queue from types import TracebackType -from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Type, Union, cast +from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Type, Union, cast, Sequence import torch from torch import Tensor, nn @@ -25,7 +25,7 @@ __all__: List[str] = [] -Tensors = Tuple[Tensor, ...] +Tensors = Sequence[Tensor] TensorOrTensors = Union[Tensor, Tensors] ExcInfo = Tuple[Type[BaseException], BaseException, TracebackType] diff --git a/torch/distributed/_pipeline/sync/skip/skippable.py b/torch/distributed/_pipeline/sync/skip/skippable.py index 9bb258382b9b..e0b0dae584a2 100644 --- a/torch/distributed/_pipeline/sync/skip/skippable.py +++ b/torch/distributed/_pipeline/sync/skip/skippable.py @@ -17,6 +17,7 @@ List, Optional, Set, + Sequence, Tuple, Type, TypeVar, @@ -33,7 +34,7 @@ __all__ = ["skippable", "stash", "pop", "verify_skippables"] -Tensors = Tuple[Tensor, ...] +Tensors = Sequence[Tensor] TensorOrTensors = Union[Tensor, Tensors] StashPop = Union["stash", "pop"]