From aab10e4de1ecf28f76aaa1b762a28f9c314b2348 Mon Sep 17 00:00:00 2001 From: pritam Date: Wed, 11 Nov 2020 21:38:50 -0800 Subject: [PATCH 1/2] Modify Pipe to return an RRef. As per proposal in https://github.com/pytorch/pytorch/issues/44827, the API needs to return an RRef to support inter-host pipelining. For now, we just return a local RRef and only support pipeline on a single host. But having this change in the API upfront ensures we don't make any BC breaking changes later. Differential Revision: [D24914022](https://our.internmc.facebook.com/intern/diff/D24914022/) [ghstack-poisoned] --- .../_pipeline/sync/skip/test_gpipe.py | 13 ++-- .../_pipeline/sync/skip/test_leak.py | 7 ++- test/distributed/_pipeline/sync/test_bugs.py | 14 +++-- .../_pipeline/sync/test_inplace.py | 13 ++-- test/distributed/_pipeline/sync/test_pipe.py | 63 ++++++++++--------- .../_pipeline/sync/test_transparency.py | 5 +- torch/distributed/_pipeline/sync/pipe.py | 9 +-- .../_internal/distributed/pipeline/utils.py | 19 ++++++ 8 files changed, 85 insertions(+), 58 deletions(-) diff --git a/test/distributed/_pipeline/sync/skip/test_gpipe.py b/test/distributed/_pipeline/sync/skip/test_gpipe.py index 96ecd84e0d18..4ea75e2d199f 100644 --- a/test/distributed/_pipeline/sync/skip/test_gpipe.py +++ b/test/distributed/_pipeline/sync/skip/test_gpipe.py @@ -12,12 +12,13 @@ from torch.distributed._pipeline.sync.skip import pop, skippable, stash from torch.distributed._pipeline.sync.skip.portal import PortalBlue, PortalCopy, PortalOrange from torch.testing._internal.distributed.pipeline.utils import convert_to_balance +from torch.testing._internal.distributed.pipeline.utils import setup_rpc @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") @pytest.mark.parametrize("balance", [[3], [1, 2], [2, 1], [1, 1, 1]], ids=["3", "1:2", "2:1", "1:1:1"]) @pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"]) -def test_1to3(balance, checkpoint): +def test_1to3(balance, checkpoint, setup_rpc): if torch.cuda.device_count() < len(balance): pytest.skip("at least %d cuda devices required" % len(balance)) @@ -61,14 +62,14 @@ def forward(self, input): input = torch.rand(30, 3, 224, 224, device=in_device, requires_grad=True) output = model(input) - loss = output.mean() + loss = output.local_value().mean() loss.backward() - assert torch.allclose(output.norm(), torch.tensor(1039.0, device=out_device), atol=6e-1) + assert torch.allclose(output.local_value().norm(), torch.tensor(1039.0, device=out_device), atol=6e-1) assert torch.allclose(input.grad.norm(), torch.tensor(0.0004533053, device=in_device)) -def test_none_skip(): +def test_none_skip(setup_rpc): @skippable(stash=["none"]) class Stash(nn.Module): def forward(self, input): @@ -102,7 +103,7 @@ def assert_grad_fn_is_not_portal(grad_fn, visited=None): for next_grad_fn, _ in grad_fn.next_functions: assert_grad_fn_is_not_portal(next_grad_fn, visited) - assert_grad_fn_is_not_portal(output.grad_fn) + assert_grad_fn_is_not_portal(output.local_value().grad_fn) - output.sum().backward() + output.local_value().sum().backward() assert input.grad.mean().item() == 1 diff --git a/test/distributed/_pipeline/sync/skip/test_leak.py b/test/distributed/_pipeline/sync/skip/test_leak.py index 31c4ea13b9f1..4c3614c2d5b5 100644 --- a/test/distributed/_pipeline/sync/skip/test_leak.py +++ b/test/distributed/_pipeline/sync/skip/test_leak.py @@ -11,6 +11,7 @@ from torch.distributed._pipeline.sync import Pipe, is_checkpointing, is_recomputing from torch.distributed._pipeline.sync.skip import pop, skippable, stash from torch.distributed._pipeline.sync.skip.tracker import current_skip_tracker +from torch.testing._internal.distributed.pipeline.utils import setup_rpc @skippable(stash=["skip"]) @@ -29,7 +30,7 @@ def forward(self, input): @pytest.mark.parametrize("train", [True, False], ids=["train", "eval"]) @pytest.mark.parametrize("checkpoint", ["always", "except_last", "never"]) -def test_delete_portal_tensor(train, checkpoint): +def test_delete_portal_tensor(train, checkpoint, setup_rpc): # Without checkpointing: # +- Stash --+ +--- Pop ----+ - - - layers # | 2,blue,1 |--| 1,orange,0 | - - - tensor_life and portal function @@ -97,7 +98,7 @@ def forward(self, input): if train: model.train() - output = model(input) + output = model(input).local_value() output.norm().backward() else: model.eval() @@ -106,7 +107,7 @@ def forward(self, input): @pytest.mark.parametrize("train", [True, False], ids=["train", "eval"]) -def test_no_portal_without_pipe(train, monkeypatch): +def test_no_portal_without_pipe(train, monkeypatch, setup_rpc): def deny(*args, **kwargs): raise AssertionError("tried to create Portal without Pipe") diff --git a/test/distributed/_pipeline/sync/test_bugs.py b/test/distributed/_pipeline/sync/test_bugs.py index 4f5346a837b5..ba6b1e1a9ba9 100644 --- a/test/distributed/_pipeline/sync/test_bugs.py +++ b/test/distributed/_pipeline/sync/test_bugs.py @@ -10,9 +10,10 @@ import torch.nn.functional as F from torch.distributed._pipeline.sync import Pipe +from torch.testing._internal.distributed.pipeline.utils import setup_rpc -def test_python_autograd_function(): +def test_python_autograd_function(setup_rpc): # A Python autograd function might fail with this error: # # RuntimeError: Returning Variables sharing storage with other Variables @@ -41,10 +42,10 @@ def forward(self, input): x = torch.rand(42) y = model(x) - assert torch.allclose(x, y) + assert torch.allclose(x, y.local_value()) -def test_exception_no_hang(): +def test_exception_no_hang(setup_rpc): # In v0.0.2, once a failed partition receives a normal message # (non-closing) for the next micro-batch, a hang occured. The reason was # that a failed partition didn't call in_queue.task_done() on a normal @@ -69,7 +70,7 @@ def forward(self, x): @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="2 cuda devices required") -def test_tuple_wait(cuda_sleep): +def test_tuple_wait(cuda_sleep, setup_rpc): # In v0.0.3, Wait is applied to only the first tensor on a micro-batch. # Under this behavior, if checkpointing was disabled, there's a possibility # that gradient accumulations on other tensors are not synchronized @@ -113,7 +114,7 @@ def forward(self, triple): b = torch.rand(1024, 3, 32, 32, device=0, requires_grad=True) y = model((a, b)) - y.norm().backward() + y.local_value().norm().backward() torch.cuda.synchronize(0) torch.cuda.synchronize(1) @@ -121,7 +122,7 @@ def forward(self, triple): assert torch.isclose(b.grad.norm().cpu(), torch.tensor(5.000)) -def test_parallel_randoms(): +def test_parallel_randoms(setup_rpc): class Dropouts(nn.Module): def forward(self, x): for _ in range(100): @@ -133,6 +134,7 @@ def forward(self, x): x = torch.rand(10, 10, requires_grad=True) model = Pipe(model, chunks=10, checkpoint="always") y = model(x) + y = y.local_value() y.norm().backward() assert y.to(torch.bool).tolist() == x.grad.to(torch.bool).tolist() diff --git a/test/distributed/_pipeline/sync/test_inplace.py b/test/distributed/_pipeline/sync/test_inplace.py index 17b3dac4eca8..4b3d28db2c3e 100644 --- a/test/distributed/_pipeline/sync/test_inplace.py +++ b/test/distributed/_pipeline/sync/test_inplace.py @@ -9,14 +9,15 @@ from torch import nn from torch.distributed._pipeline.sync import Pipe +from torch.testing._internal.distributed.pipeline.utils import setup_rpc -def test_inplace_on_requires_grad(): +def test_inplace_on_requires_grad(setup_rpc): model = nn.Sequential(nn.Linear(1, 1), nn.ReLU(inplace=True)) model = Pipe(model, checkpoint="always") x = torch.rand(1) - y = model(x) + y = model(x).local_value() message = r"a leaf Variable that requires grad .* used in an in-place operation." with pytest.raises(RuntimeError, match=message): @@ -24,14 +25,14 @@ def test_inplace_on_requires_grad(): @pytest.mark.xfail(strict=True) -def test_inplace_on_not_requires_grad(): +def test_inplace_on_not_requires_grad(setup_rpc): # In-place operation on a tensor not requiring grad doesn't cause a # RuntimeError. Currently, we cannot detect this case. model = nn.Sequential(nn.ReLU(inplace=True)) model = Pipe(model, [1], devices=["cpu"], checkpoint="always") x = torch.rand(1) - y = model(x) + y = model(x).local_value() del model message = r"a leaf Variable that requires grad .* used in an in-place operation." @@ -40,7 +41,7 @@ def test_inplace_on_not_requires_grad(): @pytest.mark.xfail(strict=True) -def test_inplace_incorrect_grad(): +def test_inplace_incorrect_grad(setup_rpc): class M(nn.Module): def forward(self, foo_bar): # 'foo' requires grad but 'bar' does not. In-place operation on @@ -62,7 +63,7 @@ def forward(self, foo_bar): foo = torch.tensor([1.0], requires_grad=True) bar = torch.tensor([1.0]) - output = model((foo, bar)) + output = model((foo, bar)).local_value() del model output.backward() diff --git a/test/distributed/_pipeline/sync/test_pipe.py b/test/distributed/_pipeline/sync/test_pipe.py index 9c2964940576..60f4061a5a99 100644 --- a/test/distributed/_pipeline/sync/test_pipe.py +++ b/test/distributed/_pipeline/sync/test_pipe.py @@ -13,6 +13,7 @@ from torch import nn from torch.distributed._pipeline.sync import Pipe +from torch.testing._internal.distributed.pipeline.utils import setup_rpc skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") @@ -68,7 +69,7 @@ def test_chunks_less_than_1(): with pytest.raises(ValueError): Pipe(model, chunks=-1) -def test_batch_size_indivisible(): +def test_batch_size_indivisible(setup_rpc): model = nn.Sequential(nn.Linear(1, 1)) model = Pipe(model, chunks=4) @@ -79,7 +80,7 @@ def test_batch_size_indivisible(): assert not record -def test_batch_size_small(): +def test_batch_size_small(setup_rpc): model = nn.Sequential(nn.Linear(1, 1)) model = Pipe(model, chunks=4) @@ -90,7 +91,7 @@ def test_batch_size_small(): assert not record -def test_checkpoint_mode(): +def test_checkpoint_mode(setup_rpc): def count_grad_fn(grad_fn, name, visited=None): if visited is None: visited = set() @@ -119,9 +120,9 @@ def count_grad_fn(grad_fn, name, visited=None): except_last_output = except_last(input) never_output = never(input) - assert count_grad_fn(always_output.grad_fn, "CheckpointBackward") == 2 - assert count_grad_fn(except_last_output.grad_fn, "CheckpointBackward") == 1 - assert count_grad_fn(never_output.grad_fn, "CheckpointBackward") == 0 + assert count_grad_fn(always_output.local_value().grad_fn, "CheckpointBackward") == 2 + assert count_grad_fn(except_last_output.local_value().grad_fn, "CheckpointBackward") == 1 + assert count_grad_fn(never_output.local_value().grad_fn, "CheckpointBackward") == 0 def test_checkpoint_mode_invalid(): @@ -140,7 +141,7 @@ def test_checkpoint_mode_when_chunks_1(): Pipe(model, chunks=1, checkpoint="never") -def test_checkpoint_eval(): +def test_checkpoint_eval(setup_rpc): model = nn.Sequential(nn.Linear(1, 1)) model = Pipe(model, chunks=2) input = torch.rand(2, 1) @@ -157,16 +158,16 @@ def find_grad_fn(grad_fn, name): model.train() train_output = model(input) - assert find_grad_fn(train_output.grad_fn, "CheckpointBackward") - assert find_grad_fn(train_output.grad_fn, "RecomputeBackward") + assert find_grad_fn(train_output.local_value().grad_fn, "CheckpointBackward") + assert find_grad_fn(train_output.local_value().grad_fn, "RecomputeBackward") model.eval() eval_output = model(input) - assert not find_grad_fn(eval_output.grad_fn, "CheckpointBackward") - assert not find_grad_fn(eval_output.grad_fn, "RecomputeBackward") + assert not find_grad_fn(eval_output.local_value().grad_fn, "CheckpointBackward") + assert not find_grad_fn(eval_output.local_value().grad_fn, "RecomputeBackward") -def test_checkpoint_non_float_input(): +def test_checkpoint_non_float_input(setup_rpc): class ForkNonFloat(nn.Module): def forward(self, input): return (input * 2, torch.tensor([False])) @@ -183,7 +184,7 @@ def forward(self, input): output.backward() -def test_no_grad(): +def test_no_grad(setup_rpc): model = nn.Sequential(nn.Linear(1, 1)) model = Pipe(model, chunks=2) input = torch.rand(2, 1) @@ -206,7 +207,7 @@ def hook(module, input, output): assert latent.grad_fn is None -def test_exception(): +def test_exception(setup_rpc): class ExpectedException(Exception): pass @@ -221,7 +222,7 @@ def forward(self, *_): model(torch.rand(1)) -def test_exception_early_stop_asap(): +def test_exception_early_stop_asap(setup_rpc): """Even the first partitions have finished to process, the partition before the failed partition should be killed as soon as possible. """ @@ -258,7 +259,7 @@ def forward(self, x): assert counter == 2 -def test_input_pair(): +def test_input_pair(setup_rpc): class Two(nn.Module): def __init__(self): super().__init__() @@ -275,7 +276,7 @@ def forward(self, a_and_b): a = torch.rand(10, 1, requires_grad=True) b = torch.rand(10, 1, requires_grad=True) - a_out, b_out = model((a, b)) + a_out, b_out = model((a, b)).local_value() loss = (a_out + b_out).mean() loss.backward() @@ -283,7 +284,7 @@ def forward(self, a_and_b): assert b.grad is not None -def test_input_singleton(): +def test_input_singleton(setup_rpc): class One(nn.Module): def __init__(self): super().__init__() @@ -298,7 +299,7 @@ def forward(self, only_a): a = torch.rand(10, 1, requires_grad=True) - (a_out,) = model((a,)) + (a_out,) = model((a,)).local_value() loss = a_out.mean() loss.backward() @@ -306,7 +307,7 @@ def forward(self, only_a): assert a.grad is not None -def test_input_varargs(): +def test_input_varargs(setup_rpc): model = nn.Sequential(nn.Linear(1, 1)) model = Pipe(model) @@ -318,7 +319,7 @@ def test_input_varargs(): model(a, b) -def test_non_tensor(): +def test_non_tensor(setup_rpc): class NonTensor(nn.Module): def forward(self, _): return "hello" @@ -336,7 +337,7 @@ def forward(self, _): model("hello") -def test_non_tensor_tuple(): +def test_non_tensor_tuple(setup_rpc): class NonTensorTuple(nn.Module): def forward(self, x): return (x, "hello") @@ -355,7 +356,7 @@ def forward(self, x): @pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"]) -def test_deferred_batch_norm(checkpoint): +def test_deferred_batch_norm(checkpoint, setup_rpc): bn = nn.BatchNorm2d(3) pipe_bn = deepcopy(bn) pipe = Pipe( @@ -363,7 +364,7 @@ def test_deferred_batch_norm(checkpoint): ) x = torch.rand(4, 3, 10, 10) - pipe(x).mean().backward() + pipe(x).local_value().mean().backward() bn(x).mean().backward() assert torch.allclose(pipe[0].running_mean, bn.running_mean, atol=1e-4) @@ -371,7 +372,7 @@ def test_deferred_batch_norm(checkpoint): @pytest.mark.parametrize("checkpoint", ["never", "always"]) -def test_deferred_batch_norm_params(checkpoint): +def test_deferred_batch_norm_params(checkpoint, setup_rpc): bn = nn.BatchNorm2d(3) pipe_bn = deepcopy(bn) pipe = Pipe( @@ -379,7 +380,7 @@ def test_deferred_batch_norm_params(checkpoint): ) x = torch.rand(4, 3, 10, 10) - pipe(x).mean().backward() + pipe(x).local_value().mean().backward() bn(x).mean().backward() assert pipe[0].weight.grad is not None @@ -455,7 +456,7 @@ def test_deny_moving(): model.to(dtype=torch.float) -def test_empty_module(): +def test_empty_module(setup_rpc): # Empty sequential module is not illegal. model = nn.Sequential() model = Pipe(model) @@ -518,7 +519,7 @@ def __init__(self, param1, param2): @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Need atleast two GPUs") -def test_verify_nested_modules(): +def test_verify_nested_modules(setup_rpc): model = nn.Sequential( nn.Sequential( nn.Linear(32, 16).cuda(0), @@ -532,8 +533,8 @@ def test_verify_nested_modules(): pipe = Pipe(model) out = pipe(torch.rand(10, 32).cuda(0)) - assert out.device == torch.device("cuda:1") - assert out.size() == torch.Size([10, 2]) + assert out.local_value().device == torch.device("cuda:1") + assert out.local_value().size() == torch.Size([10, 2]) def test_verify_module_duplicate_parameters_on_same_device(): class Surrogate(nn.Module): @@ -547,7 +548,7 @@ def __init__(self, module): Pipe(model) -def test_forward_lockstep(): +def test_forward_lockstep(setup_rpc): timeline = [] class DelayedLog(nn.Module): diff --git a/test/distributed/_pipeline/sync/test_transparency.py b/test/distributed/_pipeline/sync/test_transparency.py index 3d2c77e8fef4..12ffe80e5015 100644 --- a/test/distributed/_pipeline/sync/test_transparency.py +++ b/test/distributed/_pipeline/sync/test_transparency.py @@ -8,9 +8,10 @@ from torch import nn from torch.distributed._pipeline.sync import Pipe +from torch.testing._internal.distributed.pipeline.utils import setup_rpc -def test_simple_linears(): +def test_simple_linears(setup_rpc): def sum_grad(parameters): return sum([p.grad.sum() for p in parameters if p.grad is not None]) @@ -33,7 +34,7 @@ def zero_grad(parameters): # With Pipe model = Pipe(model, chunks=4) - outputs = model(inputs) + outputs = model(inputs).local_value() loss = outputs.mean() loss.backward() diff --git a/torch/distributed/_pipeline/sync/pipe.py b/torch/distributed/_pipeline/sync/pipe.py index 68906958cc0e..a86d4bb7dcbe 100644 --- a/torch/distributed/_pipeline/sync/pipe.py +++ b/torch/distributed/_pipeline/sync/pipe.py @@ -10,6 +10,7 @@ import torch from torch import Tensor, nn +from torch.distributed.rpc import RRef import torch.autograd import torch.cuda @@ -304,7 +305,7 @@ def _ensure_copy_streams(self) -> List[List[AbstractStream]]: return self._copy_streams - def forward(self, input: TensorOrTensors) -> TensorOrTensors: # type: ignore + def forward(self, input: TensorOrTensors) -> RRef[TensorOrTensors]: # type: ignore """:class:`Pipe` is a fairly transparent module wrapper. It doesn't modify the input and output signature of the underlying module. But there's type restriction. Input and output have to be a @@ -312,10 +313,10 @@ def forward(self, input: TensorOrTensors) -> TensorOrTensors: # type: ignore applied at partition boundaries too. Args: - input (torch.Tensor or tensors): input mini-batch + input (torch.Tensor or Tuple[torch.Tensor, ...]): input mini-batch Returns: - tensor or tensors: output mini-batch + ``RRef`` to the output of the mini-batch Raises: TypeError: input is not a tensor or tensors. @@ -335,4 +336,4 @@ def forward(self, input: TensorOrTensors) -> TensorOrTensors: # type: ignore # Merge the micro-batches into one mini-batch. output = microbatch.gather(batches) - return output + return RRef(output) diff --git a/torch/testing/_internal/distributed/pipeline/utils.py b/torch/testing/_internal/distributed/pipeline/utils.py index 2bf4829b8223..3d19359b748a 100644 --- a/torch/testing/_internal/distributed/pipeline/utils.py +++ b/torch/testing/_internal/distributed/pipeline/utils.py @@ -4,8 +4,27 @@ # LICENSE file in the root directory of this source tree. from torch import nn +from torch.distributed import rpc from typing import List +import pytest +import tempfile + +@pytest.fixture +def setup_rpc(scope="session"): + file = tempfile.NamedTemporaryFile() + rpc.init_rpc( + name="worker0", + rank=0, + world_size=1, + rpc_backend_options=rpc.TensorPipeRpcBackendOptions( + init_method="file://{}".format(file.name), + ) + ) + yield + rpc.shutdown() + + def convert_to_balance(pipe: nn.Sequential, balance: List[int]): device_idx = 0 pipe_idx = 0 From 4ef9120b3bd62f4bfdea8b3e8170e005f6d57d60 Mon Sep 17 00:00:00 2001 From: pritam Date: Tue, 24 Nov 2020 16:13:10 -0800 Subject: [PATCH 2/2] Update on "Modify Pipe to return an RRef." As per proposal in https://github.com/pytorch/pytorch/issues/44827, the API needs to return an RRef to support inter-host pipelining. For now, we just return a local RRef and only support pipeline on a single host. But having this change in the API upfront ensures we don't make any BC breaking changes later. Differential Revision: [D24914022](https://our.internmc.facebook.com/intern/diff/D24914022/) [ghstack-poisoned] --- test/distributed/_pipeline/sync/conftest.py | 16 ++++++++++++++++ .../_pipeline/sync/skip/test_gpipe.py | 1 - .../_pipeline/sync/skip/test_leak.py | 1 - test/distributed/_pipeline/sync/test_bugs.py | 1 - .../_pipeline/sync/test_inplace.py | 1 - test/distributed/_pipeline/sync/test_pipe.py | 1 - .../_pipeline/sync/test_transparency.py | 1 - torch/distributed/_pipeline/sync/pipe.py | 2 +- .../_internal/distributed/pipeline/utils.py | 19 ------------------- 9 files changed, 17 insertions(+), 26 deletions(-) diff --git a/test/distributed/_pipeline/sync/conftest.py b/test/distributed/_pipeline/sync/conftest.py index 315431d0b644..561c41d11350 100644 --- a/test/distributed/_pipeline/sync/conftest.py +++ b/test/distributed/_pipeline/sync/conftest.py @@ -5,7 +5,9 @@ # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. import pytest +import tempfile import torch +from torch.distributed import rpc @pytest.fixture(autouse=True) @@ -35,3 +37,17 @@ def cuda_sleep(seconds): def pytest_report_header(): return f"torch: {torch.__version__}" + +@pytest.fixture +def setup_rpc(scope="session"): + file = tempfile.NamedTemporaryFile() + rpc.init_rpc( + name="worker0", + rank=0, + world_size=1, + rpc_backend_options=rpc.TensorPipeRpcBackendOptions( + init_method="file://{}".format(file.name), + ) + ) + yield + rpc.shutdown() diff --git a/test/distributed/_pipeline/sync/skip/test_gpipe.py b/test/distributed/_pipeline/sync/skip/test_gpipe.py index 4ea75e2d199f..90ecd7613d67 100644 --- a/test/distributed/_pipeline/sync/skip/test_gpipe.py +++ b/test/distributed/_pipeline/sync/skip/test_gpipe.py @@ -12,7 +12,6 @@ from torch.distributed._pipeline.sync.skip import pop, skippable, stash from torch.distributed._pipeline.sync.skip.portal import PortalBlue, PortalCopy, PortalOrange from torch.testing._internal.distributed.pipeline.utils import convert_to_balance -from torch.testing._internal.distributed.pipeline.utils import setup_rpc @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") diff --git a/test/distributed/_pipeline/sync/skip/test_leak.py b/test/distributed/_pipeline/sync/skip/test_leak.py index 4c3614c2d5b5..7d03a4e9db49 100644 --- a/test/distributed/_pipeline/sync/skip/test_leak.py +++ b/test/distributed/_pipeline/sync/skip/test_leak.py @@ -11,7 +11,6 @@ from torch.distributed._pipeline.sync import Pipe, is_checkpointing, is_recomputing from torch.distributed._pipeline.sync.skip import pop, skippable, stash from torch.distributed._pipeline.sync.skip.tracker import current_skip_tracker -from torch.testing._internal.distributed.pipeline.utils import setup_rpc @skippable(stash=["skip"]) diff --git a/test/distributed/_pipeline/sync/test_bugs.py b/test/distributed/_pipeline/sync/test_bugs.py index ba6b1e1a9ba9..a66b7d006ae1 100644 --- a/test/distributed/_pipeline/sync/test_bugs.py +++ b/test/distributed/_pipeline/sync/test_bugs.py @@ -10,7 +10,6 @@ import torch.nn.functional as F from torch.distributed._pipeline.sync import Pipe -from torch.testing._internal.distributed.pipeline.utils import setup_rpc def test_python_autograd_function(setup_rpc): diff --git a/test/distributed/_pipeline/sync/test_inplace.py b/test/distributed/_pipeline/sync/test_inplace.py index 4b3d28db2c3e..3b842dbfb9ab 100644 --- a/test/distributed/_pipeline/sync/test_inplace.py +++ b/test/distributed/_pipeline/sync/test_inplace.py @@ -9,7 +9,6 @@ from torch import nn from torch.distributed._pipeline.sync import Pipe -from torch.testing._internal.distributed.pipeline.utils import setup_rpc def test_inplace_on_requires_grad(setup_rpc): diff --git a/test/distributed/_pipeline/sync/test_pipe.py b/test/distributed/_pipeline/sync/test_pipe.py index 60f4061a5a99..877fa7dfffbc 100644 --- a/test/distributed/_pipeline/sync/test_pipe.py +++ b/test/distributed/_pipeline/sync/test_pipe.py @@ -13,7 +13,6 @@ from torch import nn from torch.distributed._pipeline.sync import Pipe -from torch.testing._internal.distributed.pipeline.utils import setup_rpc skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") diff --git a/test/distributed/_pipeline/sync/test_transparency.py b/test/distributed/_pipeline/sync/test_transparency.py index 12ffe80e5015..56ad86de081b 100644 --- a/test/distributed/_pipeline/sync/test_transparency.py +++ b/test/distributed/_pipeline/sync/test_transparency.py @@ -8,7 +8,6 @@ from torch import nn from torch.distributed._pipeline.sync import Pipe -from torch.testing._internal.distributed.pipeline.utils import setup_rpc def test_simple_linears(setup_rpc): diff --git a/torch/distributed/_pipeline/sync/pipe.py b/torch/distributed/_pipeline/sync/pipe.py index a86d4bb7dcbe..3433fbf573b0 100644 --- a/torch/distributed/_pipeline/sync/pipe.py +++ b/torch/distributed/_pipeline/sync/pipe.py @@ -316,7 +316,7 @@ def forward(self, input: TensorOrTensors) -> RRef[TensorOrTensors]: # type: ign input (torch.Tensor or Tuple[torch.Tensor, ...]): input mini-batch Returns: - ``RRef`` to the output of the mini-batch + :class:`~torch.distributed.rpc.RRef` to the output of the mini-batch Raises: TypeError: input is not a tensor or tensors. diff --git a/torch/testing/_internal/distributed/pipeline/utils.py b/torch/testing/_internal/distributed/pipeline/utils.py index 3d19359b748a..2bf4829b8223 100644 --- a/torch/testing/_internal/distributed/pipeline/utils.py +++ b/torch/testing/_internal/distributed/pipeline/utils.py @@ -4,27 +4,8 @@ # LICENSE file in the root directory of this source tree. from torch import nn -from torch.distributed import rpc from typing import List -import pytest -import tempfile - -@pytest.fixture -def setup_rpc(scope="session"): - file = tempfile.NamedTemporaryFile() - rpc.init_rpc( - name="worker0", - rank=0, - world_size=1, - rpc_backend_options=rpc.TensorPipeRpcBackendOptions( - init_method="file://{}".format(file.name), - ) - ) - yield - rpc.shutdown() - - def convert_to_balance(pipe: nn.Sequential, balance: List[int]): device_idx = 0 pipe_idx = 0