Skip to content

Commit

Permalink
Modify Pipe to return an RRef.
Browse files Browse the repository at this point in the history
Pull Request resolved: #47829

As per proposal in #44827,
the API needs to return an RRef to support inter-host pipelining.

For now, we just return a local RRef and only support pipeline on a single
host. But having this change in the API upfront ensures we don't make any BC
breaking changes later.
ghstack-source-id: 118064022

Differential Revision: [D24914022](https://our.internmc.facebook.com/intern/diff/D24914022/)
  • Loading branch information
pritamdamania committed Dec 8, 2020
1 parent 88ebf6f commit fcab424
Show file tree
Hide file tree
Showing 8 changed files with 77 additions and 59 deletions.
16 changes: 16 additions & 0 deletions test/distributed/_pipeline/sync/conftest.py
Expand Up @@ -5,7 +5,9 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import pytest
import tempfile
import torch
from torch.distributed import rpc


@pytest.fixture(autouse=True)
Expand Down Expand Up @@ -35,3 +37,17 @@ def cuda_sleep(seconds):

def pytest_report_header():
return f"torch: {torch.__version__}"

@pytest.fixture
def setup_rpc(scope="session"):
file = tempfile.NamedTemporaryFile()
rpc.init_rpc(
name="worker0",
rank=0,
world_size=1,
rpc_backend_options=rpc.TensorPipeRpcBackendOptions(
init_method="file://{}".format(file.name),
)
)
yield
rpc.shutdown()
12 changes: 6 additions & 6 deletions test/distributed/_pipeline/sync/skip/test_gpipe.py
Expand Up @@ -17,7 +17,7 @@
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
@pytest.mark.parametrize("balance", [[3], [1, 2], [2, 1], [1, 1, 1]], ids=["3", "1:2", "2:1", "1:1:1"])
@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"])
def test_1to3(balance, checkpoint):
def test_1to3(balance, checkpoint, setup_rpc):
if torch.cuda.device_count() < len(balance):
pytest.skip("at least %d cuda devices required" % len(balance))

Expand Down Expand Up @@ -61,14 +61,14 @@ def forward(self, input):

input = torch.rand(30, 3, 224, 224, device=in_device, requires_grad=True)
output = model(input)
loss = output.mean()
loss = output.local_value().mean()
loss.backward()

assert torch.allclose(output.norm(), torch.tensor(1039.0, device=out_device), atol=6e-1)
assert torch.allclose(output.local_value().norm(), torch.tensor(1039.0, device=out_device), atol=6e-1)
assert torch.allclose(input.grad.norm(), torch.tensor(0.0004533053, device=in_device))


def test_none_skip():
def test_none_skip(setup_rpc):
@skippable(stash=["none"])
class Stash(nn.Module):
def forward(self, input):
Expand Down Expand Up @@ -102,7 +102,7 @@ def assert_grad_fn_is_not_portal(grad_fn, visited=None):
for next_grad_fn, _ in grad_fn.next_functions:
assert_grad_fn_is_not_portal(next_grad_fn, visited)

assert_grad_fn_is_not_portal(output.grad_fn)
assert_grad_fn_is_not_portal(output.local_value().grad_fn)

output.sum().backward()
output.local_value().sum().backward()
assert input.grad.mean().item() == 1
6 changes: 3 additions & 3 deletions test/distributed/_pipeline/sync/skip/test_leak.py
Expand Up @@ -29,7 +29,7 @@ def forward(self, input):

@pytest.mark.parametrize("train", [True, False], ids=["train", "eval"])
@pytest.mark.parametrize("checkpoint", ["always", "except_last", "never"])
def test_delete_portal_tensor(train, checkpoint):
def test_delete_portal_tensor(train, checkpoint, setup_rpc):
# Without checkpointing:
# +- Stash --+ +--- Pop ----+ - - - layers
# | 2,blue,1 |--| 1,orange,0 | - - - tensor_life and portal function
Expand Down Expand Up @@ -97,7 +97,7 @@ def forward(self, input):

if train:
model.train()
output = model(input)
output = model(input).local_value()
output.norm().backward()
else:
model.eval()
Expand All @@ -106,7 +106,7 @@ def forward(self, input):


@pytest.mark.parametrize("train", [True, False], ids=["train", "eval"])
def test_no_portal_without_pipe(train, monkeypatch):
def test_no_portal_without_pipe(train, monkeypatch, setup_rpc):
def deny(*args, **kwargs):
raise AssertionError("tried to create Portal without Pipe")

Expand Down
13 changes: 7 additions & 6 deletions test/distributed/_pipeline/sync/test_bugs.py
Expand Up @@ -12,7 +12,7 @@
from torch.distributed._pipeline.sync import Pipe


def test_python_autograd_function():
def test_python_autograd_function(setup_rpc):
# A Python autograd function might fail with this error:
#
# RuntimeError: Returning Variables sharing storage with other Variables
Expand Down Expand Up @@ -41,10 +41,10 @@ def forward(self, input):

x = torch.rand(42)
y = model(x)
assert torch.allclose(x, y)
assert torch.allclose(x, y.local_value())


def test_exception_no_hang():
def test_exception_no_hang(setup_rpc):
# In v0.0.2, once a failed partition receives a normal message
# (non-closing) for the next micro-batch, a hang occured. The reason was
# that a failed partition didn't call in_queue.task_done() on a normal
Expand All @@ -69,7 +69,7 @@ def forward(self, x):


@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="2 cuda devices required")
def test_tuple_wait(cuda_sleep):
def test_tuple_wait(cuda_sleep, setup_rpc):
# In v0.0.3, Wait is applied to only the first tensor on a micro-batch.
# Under this behavior, if checkpointing was disabled, there's a possibility
# that gradient accumulations on other tensors are not synchronized
Expand Down Expand Up @@ -113,15 +113,15 @@ def forward(self, triple):
b = torch.rand(1024, 3, 32, 32, device=0, requires_grad=True)

y = model((a, b))
y.norm().backward()
y.local_value().norm().backward()

torch.cuda.synchronize(0)
torch.cuda.synchronize(1)

assert torch.isclose(b.grad.norm().cpu(), torch.tensor(5.000))


def test_parallel_randoms():
def test_parallel_randoms(setup_rpc):
class Dropouts(nn.Module):
def forward(self, x):
for _ in range(100):
Expand All @@ -133,6 +133,7 @@ def forward(self, x):
x = torch.rand(10, 10, requires_grad=True)
model = Pipe(model, chunks=10, checkpoint="always")
y = model(x)
y = y.local_value()
y.norm().backward()

assert y.to(torch.bool).tolist() == x.grad.to(torch.bool).tolist()
12 changes: 6 additions & 6 deletions test/distributed/_pipeline/sync/test_inplace.py
Expand Up @@ -11,27 +11,27 @@
from torch.distributed._pipeline.sync import Pipe


def test_inplace_on_requires_grad():
def test_inplace_on_requires_grad(setup_rpc):
model = nn.Sequential(nn.Linear(1, 1), nn.ReLU(inplace=True))
model = Pipe(model, checkpoint="always")

x = torch.rand(1)
y = model(x)
y = model(x).local_value()

message = r"a leaf Variable that requires grad .* used in an in-place operation."
with pytest.raises(RuntimeError, match=message):
y.backward()


@pytest.mark.xfail(strict=True)
def test_inplace_on_not_requires_grad():
def test_inplace_on_not_requires_grad(setup_rpc):
# In-place operation on a tensor not requiring grad doesn't cause a
# RuntimeError. Currently, we cannot detect this case.
model = nn.Sequential(nn.ReLU(inplace=True))
model = Pipe(model, [1], devices=["cpu"], checkpoint="always")

x = torch.rand(1)
y = model(x)
y = model(x).local_value()
del model

message = r"a leaf Variable that requires grad .* used in an in-place operation."
Expand All @@ -40,7 +40,7 @@ def test_inplace_on_not_requires_grad():


@pytest.mark.xfail(strict=True)
def test_inplace_incorrect_grad():
def test_inplace_incorrect_grad(setup_rpc):
class M(nn.Module):
def forward(self, foo_bar):
# 'foo' requires grad but 'bar' does not. In-place operation on
Expand All @@ -62,7 +62,7 @@ def forward(self, foo_bar):
foo = torch.tensor([1.0], requires_grad=True)
bar = torch.tensor([1.0])

output = model((foo, bar))
output = model((foo, bar)).local_value()
del model
output.backward()

Expand Down

0 comments on commit fcab424

Please sign in to comment.