Skip to content

Commit

Permalink
Update pipeline API to accept arbitrary sequence of Tensors and not j…
Browse files Browse the repository at this point in the history
…ust Tuple

Pull Request resolved: #48467

The current API's forward method only accepted a Tensor or a Tuple of
Tensors, making this more generic by accepting any Sequence of Tensors.
ghstack-source-id: 118100143

Differential Revision: [D25181944](https://our.internmc.facebook.com/intern/diff/D25181944/)
  • Loading branch information
pritamdamania committed Dec 8, 2020
1 parent fcab424 commit 9fb4c3b
Show file tree
Hide file tree
Showing 9 changed files with 83 additions and 20 deletions.
54 changes: 53 additions & 1 deletion test/distributed/_pipeline/sync/test_pipe.py
Expand Up @@ -258,6 +258,31 @@ def forward(self, x):
assert counter == 2


def test_nested_input(setup_rpc):
class NestedInput(nn.Module):
def __init__(self):
super().__init__()
self.fc_a = nn.Linear(1, 1)
self.fc_b = nn.Linear(1, 1)

def forward(self, inp):
return inp

model = nn.Sequential(NestedInput())
model = Pipe(model, chunks=2)

a = torch.rand(10, 1, requires_grad=True)
b = torch.rand(10, 1, requires_grad=True)

# TypeError: expected Tensor, but got tuple
with pytest.raises(TypeError):
model((a, (a, b))).local_value()

# TypeError: expected Tensor, but got list
with pytest.raises(TypeError):
model((a, [a, b])).local_value()


def test_input_pair(setup_rpc):
class Two(nn.Module):
def __init__(self):
Expand All @@ -282,6 +307,17 @@ def forward(self, a_and_b):
assert a.grad is not None
assert b.grad is not None

# Test with list.
a.grad = None
b.grad = None
a_out, b_out = model([a, b]).local_value()
loss = (a_out + b_out).mean()
loss.backward()

assert a.grad is not None
assert b.grad is not None



def test_input_singleton(setup_rpc):
class One(nn.Module):
Expand All @@ -305,6 +341,18 @@ def forward(self, only_a):
assert all(p.grad is not None for p in model.parameters())
assert a.grad is not None

# Test with list
a.grad = None
for p in model.parameters():
p.grad = None

(a_out,) = model([a]).local_value()
loss = a_out.mean()
loss.backward()

assert all(p.grad is not None for p in model.parameters())
assert a.grad is not None


def test_input_varargs(setup_rpc):
model = nn.Sequential(nn.Linear(1, 1))
Expand Down Expand Up @@ -336,7 +384,7 @@ def forward(self, _):
model("hello")


def test_non_tensor_tuple(setup_rpc):
def test_non_tensor_sequence(setup_rpc):
class NonTensorTuple(nn.Module):
def forward(self, x):
return (x, "hello")
Expand All @@ -353,6 +401,10 @@ def forward(self, x):
with pytest.raises(TypeError):
model((x, "hello"))

# TypeError: expected Tensor to scatter, but got str
with pytest.raises(TypeError):
model([x, "hello"])


@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"])
def test_deferred_batch_norm(checkpoint, setup_rpc):
Expand Down
4 changes: 2 additions & 2 deletions torch/distributed/_pipeline/sync/balance/__init__.py
Expand Up @@ -18,7 +18,7 @@
pipe = Pipe(model, balance, chunks=8)
"""
from typing import List, Tuple, Union
from typing import List, Union, Sequence

import torch
from torch import Tensor
Expand All @@ -32,7 +32,7 @@

Device = Union[torch.device, int, str]

Tensors = Tuple[Tensor, ...]
Tensors = Sequence[Tensor]
TensorOrTensors = Union[Tensor, Tensors]


Expand Down
4 changes: 2 additions & 2 deletions torch/distributed/_pipeline/sync/balance/profile.py
Expand Up @@ -7,7 +7,7 @@
"""Per-layer profilers."""
import copy
import time
from typing import Generator, List, Tuple, Union
from typing import Generator, List, Union, Sequence

import torch
from torch import Tensor
Expand All @@ -20,7 +20,7 @@

Device = Union[torch.device, int, str]

Tensors = Tuple[Tensor, ...]
Tensors = Sequence[Tensor]
TensorOrTensors = Union[Tensor, Tensors]


Expand Down
13 changes: 11 additions & 2 deletions torch/distributed/_pipeline/sync/checkpoint.py
Expand Up @@ -27,7 +27,16 @@
from collections import deque
from contextlib import contextmanager
import threading
from typing import TYPE_CHECKING, Deque, Generator, List, Optional, Tuple, Union
from typing import (
TYPE_CHECKING,
Deque,
Generator,
List,
Optional,
Union,
Sequence,
Tuple
)

import torch
from torch import Tensor
Expand All @@ -40,7 +49,7 @@
__all__ = ["is_checkpointing", "is_recomputing"]


Tensors = Tuple[Tensor, ...]
Tensors = Sequence[Tensor]
TensorOrTensors = Union[Tensor, Tensors]

# Types for shared memory between Checkpoint and Recompute.
Expand Down
4 changes: 2 additions & 2 deletions torch/distributed/_pipeline/sync/copy.py
Expand Up @@ -8,7 +8,7 @@
and computation on the same GPU.
"""
from collections import deque
from typing import Deque, List, Optional, Tuple
from typing import Deque, List, Optional, Tuple, Sequence

import torch
from torch import Tensor
Expand All @@ -18,7 +18,7 @@
__all__: List[str] = []


Tensors = Tuple[Tensor, ...]
Tensors = Sequence[Tensor]


# Common interface between :class:`Copy` and :class:`Wait`.
Expand Down
9 changes: 5 additions & 4 deletions torch/distributed/_pipeline/sync/microbatch.py
Expand Up @@ -6,7 +6,7 @@
# LICENSE file in the root directory of this source tree.
"""Manipulation of micro-batches."""
import typing
from typing import Callable, Iterable, Iterator, List, Tuple, Union, cast
from typing import Callable, Iterable, Iterator, List, Union, cast, Sequence

import torch
from torch import Tensor
Expand All @@ -15,7 +15,7 @@
__all__: List[str] = []


Tensors = Tuple[Tensor, ...]
Tensors = Sequence[Tensor]
TensorOrTensors = Union[Tensor, Tensors]
Function = Callable[[TensorOrTensors], TensorOrTensors]

Expand Down Expand Up @@ -139,9 +139,10 @@ def check(input: TensorOrTensors) -> None:
TypeError: input is not a tensor or tensors.
"""
if isinstance(input, tuple):
if isinstance(input, Sequence):
for x in input:
check(x)
if not isinstance(x, Tensor):
raise TypeError(f"expected Tensor, but got {input.__class__.__name__}")
return

if not isinstance(input, Tensor):
Expand Down
8 changes: 4 additions & 4 deletions torch/distributed/_pipeline/sync/pipe.py
Expand Up @@ -6,7 +6,7 @@
# LICENSE file in the root directory of this source tree.
"""The Pipe interface."""
from collections import OrderedDict
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union, cast
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union, cast, Sequence

import torch
from torch import Tensor, nn
Expand All @@ -27,7 +27,7 @@
Device = Union[torch.device, int, str]
Devices = Union[Iterable[Device], List[Device]]

Tensors = Tuple[Tensor, ...]
Tensors = Sequence[Tensor]
TensorOrTensors = Union[Tensor, Tensors]

if TYPE_CHECKING:
Expand Down Expand Up @@ -310,11 +310,11 @@ def forward(self, input: TensorOrTensors) -> RRef[TensorOrTensors]: # type: ign
""":class:`Pipe` is a fairly transparent module wrapper. It doesn't
modify the input and output signature of the underlying module. But
there's type restriction. Input and output have to be a
:class:`~torch.Tensor` or a tuple of tensors. This restriction is
:class:`~torch.Tensor` or a sequence of tensors. This restriction is
applied at partition boundaries too.
Args:
input (torch.Tensor or Tuple[torch.Tensor, ...]): input mini-batch
input (torch.Tensor or Sequence[torch.Tensor]): input mini-batch
Returns:
:class:`~torch.distributed.rpc.RRef` to the output of the mini-batch
Expand Down
4 changes: 2 additions & 2 deletions torch/distributed/_pipeline/sync/pipeline.py
Expand Up @@ -7,7 +7,7 @@
"""The pipeline parallelism of Pipe."""
from queue import Queue
from types import TracebackType
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Type, Union, cast
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Type, Union, cast, Sequence

import torch
from torch import Tensor, nn
Expand All @@ -25,7 +25,7 @@
__all__: List[str] = []


Tensors = Tuple[Tensor, ...]
Tensors = Sequence[Tensor]
TensorOrTensors = Union[Tensor, Tensors]

ExcInfo = Tuple[Type[BaseException], BaseException, TracebackType]
Expand Down
3 changes: 2 additions & 1 deletion torch/distributed/_pipeline/sync/skip/skippable.py
Expand Up @@ -17,6 +17,7 @@
List,
Optional,
Set,
Sequence,
Tuple,
Type,
TypeVar,
Expand All @@ -33,7 +34,7 @@
__all__ = ["skippable", "stash", "pop", "verify_skippables"]


Tensors = Tuple[Tensor, ...]
Tensors = Sequence[Tensor]
TensorOrTensors = Union[Tensor, Tensors]

StashPop = Union["stash", "pop"]
Expand Down

0 comments on commit 9fb4c3b

Please sign in to comment.