Skip to content

Commit

Permalink
Update pipeline API to accept arbitrary sequence of Tensors and not j…
Browse files Browse the repository at this point in the history
…ust Tuple (#48467)

Summary:
Pull Request resolved: #48467

The current API's forward method only accepted a Tensor or a Tuple of
Tensors, making this more generic by accepting any Sequence of Tensors.
ghstack-source-id: 118436340

Test Plan: waitforbuildbot

Reviewed By: rohan-varma

Differential Revision: D25181944

fbshipit-source-id: 4db251dad52c01abc69f3d327788f2e4289e6c9d
  • Loading branch information
pritamdamania authored and facebook-github-bot committed Dec 13, 2020
1 parent 33b7970 commit dc4db95
Show file tree
Hide file tree
Showing 9 changed files with 84 additions and 21 deletions.
54 changes: 53 additions & 1 deletion test/distributed/_pipeline/sync/test_pipe.py
Expand Up @@ -258,6 +258,31 @@ def forward(self, x):
assert counter == 2


def test_nested_input(setup_rpc):
class NestedInput(nn.Module):
def __init__(self):
super().__init__()
self.fc_a = nn.Linear(1, 1)
self.fc_b = nn.Linear(1, 1)

def forward(self, inp):
return inp

model = nn.Sequential(NestedInput())
model = Pipe(model, chunks=2)

a = torch.rand(10, 1, requires_grad=True)
b = torch.rand(10, 1, requires_grad=True)

# TypeError: expected Tensor, but got tuple
with pytest.raises(TypeError):
model((a, (a, b))).local_value()

# TypeError: expected Tensor, but got list
with pytest.raises(TypeError):
model((a, [a, b])).local_value()


def test_input_pair(setup_rpc):
class Two(nn.Module):
def __init__(self):
Expand All @@ -282,6 +307,17 @@ def forward(self, a_and_b):
assert a.grad is not None
assert b.grad is not None

# Test with list.
a.grad = None
b.grad = None
a_out, b_out = model([a, b]).local_value()
loss = (a_out + b_out).mean()
loss.backward()

assert a.grad is not None
assert b.grad is not None



def test_input_singleton(setup_rpc):
class One(nn.Module):
Expand All @@ -305,6 +341,18 @@ def forward(self, only_a):
assert all(p.grad is not None for p in model.parameters())
assert a.grad is not None

# Test with list
a.grad = None
for p in model.parameters():
p.grad = None

(a_out,) = model([a]).local_value()
loss = a_out.mean()
loss.backward()

assert all(p.grad is not None for p in model.parameters())
assert a.grad is not None


def test_input_varargs(setup_rpc):
model = nn.Sequential(nn.Linear(1, 1))
Expand Down Expand Up @@ -336,7 +384,7 @@ def forward(self, _):
model("hello")


def test_non_tensor_tuple(setup_rpc):
def test_non_tensor_sequence(setup_rpc):
class NonTensorTuple(nn.Module):
def forward(self, x):
return (x, "hello")
Expand All @@ -353,6 +401,10 @@ def forward(self, x):
with pytest.raises(TypeError):
model((x, "hello"))

# TypeError: expected Tensor to scatter, but got str
with pytest.raises(TypeError):
model([x, "hello"])


@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"])
def test_deferred_batch_norm(checkpoint, setup_rpc):
Expand Down
4 changes: 2 additions & 2 deletions torch/distributed/_pipeline/sync/balance/__init__.py
Expand Up @@ -18,7 +18,7 @@
pipe = Pipe(model, balance, chunks=8)
"""
from typing import List, Tuple, Union
from typing import List, Union, Sequence

import torch
from torch import Tensor
Expand All @@ -32,7 +32,7 @@

Device = Union[torch.device, int, str]

Tensors = Tuple[Tensor, ...]
Tensors = Sequence[Tensor]
TensorOrTensors = Union[Tensor, Tensors]


Expand Down
4 changes: 2 additions & 2 deletions torch/distributed/_pipeline/sync/balance/profile.py
Expand Up @@ -7,7 +7,7 @@
"""Per-layer profilers."""
import copy
import time
from typing import Generator, List, Tuple, Union
from typing import Generator, List, Union, Sequence

import torch
from torch import Tensor
Expand All @@ -20,7 +20,7 @@

Device = Union[torch.device, int, str]

Tensors = Tuple[Tensor, ...]
Tensors = Sequence[Tensor]
TensorOrTensors = Union[Tensor, Tensors]


Expand Down
13 changes: 11 additions & 2 deletions torch/distributed/_pipeline/sync/checkpoint.py
Expand Up @@ -27,7 +27,16 @@
from collections import deque
from contextlib import contextmanager
import threading
from typing import TYPE_CHECKING, Deque, Generator, List, Optional, Tuple, Union
from typing import (
TYPE_CHECKING,
Deque,
Generator,
List,
Optional,
Union,
Sequence,
Tuple
)

import torch
from torch import Tensor
Expand All @@ -40,7 +49,7 @@
__all__ = ["is_checkpointing", "is_recomputing"]


Tensors = Tuple[Tensor, ...]
Tensors = Sequence[Tensor]
TensorOrTensors = Union[Tensor, Tensors]

# Types for shared memory between Checkpoint and Recompute.
Expand Down
4 changes: 2 additions & 2 deletions torch/distributed/_pipeline/sync/copy.py
Expand Up @@ -8,7 +8,7 @@
and computation on the same GPU.
"""
from collections import deque
from typing import Deque, List, Optional, Tuple
from typing import Deque, List, Optional, Tuple, Sequence

import torch
from torch import Tensor
Expand All @@ -18,7 +18,7 @@
__all__: List[str] = []


Tensors = Tuple[Tensor, ...]
Tensors = Sequence[Tensor]


# Common interface between :class:`Copy` and :class:`Wait`.
Expand Down
11 changes: 6 additions & 5 deletions torch/distributed/_pipeline/sync/microbatch.py
Expand Up @@ -6,7 +6,7 @@
# LICENSE file in the root directory of this source tree.
"""Manipulation of micro-batches."""
import typing
from typing import Callable, Iterable, Iterator, List, Tuple, Union, cast
from typing import Callable, Iterable, Iterator, List, Union, cast, Sequence

import torch
from torch import Tensor
Expand All @@ -15,7 +15,7 @@
__all__: List[str] = []


Tensors = Tuple[Tensor, ...]
Tensors = Sequence[Tensor]
TensorOrTensors = Union[Tensor, Tensors]
Function = Callable[[TensorOrTensors], TensorOrTensors]

Expand Down Expand Up @@ -110,7 +110,7 @@ def __setitem__(self, index: Union[int, slice], value: TensorOrTensors) -> None:
def _setitem_by_index(self, index: int, value: Tensor) -> None:
if not self.atomic:
i = index
self.value = self.value[:i] + (value,) + self.value[i + 1 :]
self.value = self.value[:i] + (value,) + self.value[i + 1 :] # type: ignore
return

if index != 0:
Expand Down Expand Up @@ -139,9 +139,10 @@ def check(input: TensorOrTensors) -> None:
TypeError: input is not a tensor or tensors.
"""
if isinstance(input, tuple):
if isinstance(input, Sequence):
for x in input:
check(x)
if not isinstance(x, Tensor):
raise TypeError(f"expected Tensor, but got {input.__class__.__name__}")
return

if not isinstance(input, Tensor):
Expand Down
8 changes: 4 additions & 4 deletions torch/distributed/_pipeline/sync/pipe.py
Expand Up @@ -6,7 +6,7 @@
# LICENSE file in the root directory of this source tree.
"""The Pipe interface."""
from collections import OrderedDict
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union, cast
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union, cast, Sequence

import torch
from torch import Tensor, nn
Expand All @@ -27,7 +27,7 @@
Device = Union[torch.device, int, str]
Devices = Union[Iterable[Device], List[Device]]

Tensors = Tuple[Tensor, ...]
Tensors = Sequence[Tensor]
TensorOrTensors = Union[Tensor, Tensors]

if TYPE_CHECKING:
Expand Down Expand Up @@ -310,11 +310,11 @@ def forward(self, input: TensorOrTensors) -> RRef[TensorOrTensors]: # type: ign
""":class:`Pipe` is a fairly transparent module wrapper. It doesn't
modify the input and output signature of the underlying module. But
there's type restriction. Input and output have to be a
:class:`~torch.Tensor` or a tuple of tensors. This restriction is
:class:`~torch.Tensor` or a sequence of tensors. This restriction is
applied at partition boundaries too.
Args:
input (torch.Tensor or Tuple[torch.Tensor, ...]): input mini-batch
input (torch.Tensor or Sequence[torch.Tensor]): input mini-batch
Returns:
:class:`~torch.distributed.rpc.RRef` to the output of the mini-batch
Expand Down
4 changes: 2 additions & 2 deletions torch/distributed/_pipeline/sync/pipeline.py
Expand Up @@ -7,7 +7,7 @@
"""The pipeline parallelism of Pipe."""
from queue import Queue
from types import TracebackType
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Type, Union, cast
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Type, Union, cast, Sequence

import torch
from torch import Tensor, nn
Expand All @@ -25,7 +25,7 @@
__all__: List[str] = []


Tensors = Tuple[Tensor, ...]
Tensors = Sequence[Tensor]
TensorOrTensors = Union[Tensor, Tensors]

ExcInfo = Tuple[Type[BaseException], BaseException, TracebackType]
Expand Down
3 changes: 2 additions & 1 deletion torch/distributed/_pipeline/sync/skip/skippable.py
Expand Up @@ -17,6 +17,7 @@
List,
Optional,
Set,
Sequence,
Tuple,
Type,
TypeVar,
Expand All @@ -33,7 +34,7 @@
__all__ = ["skippable", "stash", "pop", "verify_skippables"]


Tensors = Tuple[Tensor, ...]
Tensors = Sequence[Tensor]
TensorOrTensors = Union[Tensor, Tensors]

StashPop = Union["stash", "pop"]
Expand Down

0 comments on commit dc4db95

Please sign in to comment.