Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify Pipe to return an RRef. #47829

Closed
13 changes: 7 additions & 6 deletions test/distributed/_pipeline/sync/skip/test_gpipe.py
Expand Up @@ -12,12 +12,13 @@
from torch.distributed._pipeline.sync.skip import pop, skippable, stash
from torch.distributed._pipeline.sync.skip.portal import PortalBlue, PortalCopy, PortalOrange
from torch.testing._internal.distributed.pipeline.utils import convert_to_balance
from torch.testing._internal.distributed.pipeline.utils import setup_rpc


@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
@pytest.mark.parametrize("balance", [[3], [1, 2], [2, 1], [1, 1, 1]], ids=["3", "1:2", "2:1", "1:1:1"])
@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"])
def test_1to3(balance, checkpoint):
def test_1to3(balance, checkpoint, setup_rpc):
if torch.cuda.device_count() < len(balance):
pytest.skip("at least %d cuda devices required" % len(balance))

Expand Down Expand Up @@ -61,14 +62,14 @@ def forward(self, input):

input = torch.rand(30, 3, 224, 224, device=in_device, requires_grad=True)
output = model(input)
loss = output.mean()
loss = output.local_value().mean()
loss.backward()

assert torch.allclose(output.norm(), torch.tensor(1039.0, device=out_device), atol=6e-1)
assert torch.allclose(output.local_value().norm(), torch.tensor(1039.0, device=out_device), atol=6e-1)
assert torch.allclose(input.grad.norm(), torch.tensor(0.0004533053, device=in_device))


def test_none_skip():
def test_none_skip(setup_rpc):
@skippable(stash=["none"])
class Stash(nn.Module):
def forward(self, input):
Expand Down Expand Up @@ -102,7 +103,7 @@ def assert_grad_fn_is_not_portal(grad_fn, visited=None):
for next_grad_fn, _ in grad_fn.next_functions:
assert_grad_fn_is_not_portal(next_grad_fn, visited)

assert_grad_fn_is_not_portal(output.grad_fn)
assert_grad_fn_is_not_portal(output.local_value().grad_fn)

output.sum().backward()
output.local_value().sum().backward()
assert input.grad.mean().item() == 1
7 changes: 4 additions & 3 deletions test/distributed/_pipeline/sync/skip/test_leak.py
Expand Up @@ -11,6 +11,7 @@
from torch.distributed._pipeline.sync import Pipe, is_checkpointing, is_recomputing
from torch.distributed._pipeline.sync.skip import pop, skippable, stash
from torch.distributed._pipeline.sync.skip.tracker import current_skip_tracker
from torch.testing._internal.distributed.pipeline.utils import setup_rpc


@skippable(stash=["skip"])
Expand All @@ -29,7 +30,7 @@ def forward(self, input):

@pytest.mark.parametrize("train", [True, False], ids=["train", "eval"])
@pytest.mark.parametrize("checkpoint", ["always", "except_last", "never"])
def test_delete_portal_tensor(train, checkpoint):
def test_delete_portal_tensor(train, checkpoint, setup_rpc):
# Without checkpointing:
# +- Stash --+ +--- Pop ----+ - - - layers
# | 2,blue,1 |--| 1,orange,0 | - - - tensor_life and portal function
Expand Down Expand Up @@ -97,7 +98,7 @@ def forward(self, input):

if train:
model.train()
output = model(input)
output = model(input).local_value()
output.norm().backward()
else:
model.eval()
Expand All @@ -106,7 +107,7 @@ def forward(self, input):


@pytest.mark.parametrize("train", [True, False], ids=["train", "eval"])
def test_no_portal_without_pipe(train, monkeypatch):
def test_no_portal_without_pipe(train, monkeypatch, setup_rpc):
def deny(*args, **kwargs):
raise AssertionError("tried to create Portal without Pipe")

Expand Down
14 changes: 8 additions & 6 deletions test/distributed/_pipeline/sync/test_bugs.py
Expand Up @@ -10,9 +10,10 @@
import torch.nn.functional as F

from torch.distributed._pipeline.sync import Pipe
from torch.testing._internal.distributed.pipeline.utils import setup_rpc


def test_python_autograd_function():
def test_python_autograd_function(setup_rpc):
# A Python autograd function might fail with this error:
#
# RuntimeError: Returning Variables sharing storage with other Variables
Expand Down Expand Up @@ -41,10 +42,10 @@ def forward(self, input):

x = torch.rand(42)
y = model(x)
assert torch.allclose(x, y)
assert torch.allclose(x, y.local_value())


def test_exception_no_hang():
def test_exception_no_hang(setup_rpc):
# In v0.0.2, once a failed partition receives a normal message
# (non-closing) for the next micro-batch, a hang occured. The reason was
# that a failed partition didn't call in_queue.task_done() on a normal
Expand All @@ -69,7 +70,7 @@ def forward(self, x):


@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="2 cuda devices required")
def test_tuple_wait(cuda_sleep):
def test_tuple_wait(cuda_sleep, setup_rpc):
# In v0.0.3, Wait is applied to only the first tensor on a micro-batch.
# Under this behavior, if checkpointing was disabled, there's a possibility
# that gradient accumulations on other tensors are not synchronized
Expand Down Expand Up @@ -113,15 +114,15 @@ def forward(self, triple):
b = torch.rand(1024, 3, 32, 32, device=0, requires_grad=True)

y = model((a, b))
y.norm().backward()
y.local_value().norm().backward()

torch.cuda.synchronize(0)
torch.cuda.synchronize(1)

assert torch.isclose(b.grad.norm().cpu(), torch.tensor(5.000))


def test_parallel_randoms():
def test_parallel_randoms(setup_rpc):
class Dropouts(nn.Module):
def forward(self, x):
for _ in range(100):
Expand All @@ -133,6 +134,7 @@ def forward(self, x):
x = torch.rand(10, 10, requires_grad=True)
model = Pipe(model, chunks=10, checkpoint="always")
y = model(x)
y = y.local_value()
y.norm().backward()

assert y.to(torch.bool).tolist() == x.grad.to(torch.bool).tolist()
13 changes: 7 additions & 6 deletions test/distributed/_pipeline/sync/test_inplace.py
Expand Up @@ -9,29 +9,30 @@
from torch import nn

from torch.distributed._pipeline.sync import Pipe
from torch.testing._internal.distributed.pipeline.utils import setup_rpc


def test_inplace_on_requires_grad():
def test_inplace_on_requires_grad(setup_rpc):
model = nn.Sequential(nn.Linear(1, 1), nn.ReLU(inplace=True))
model = Pipe(model, checkpoint="always")

x = torch.rand(1)
y = model(x)
y = model(x).local_value()

message = r"a leaf Variable that requires grad .* used in an in-place operation."
with pytest.raises(RuntimeError, match=message):
y.backward()


@pytest.mark.xfail(strict=True)
def test_inplace_on_not_requires_grad():
def test_inplace_on_not_requires_grad(setup_rpc):
# In-place operation on a tensor not requiring grad doesn't cause a
# RuntimeError. Currently, we cannot detect this case.
model = nn.Sequential(nn.ReLU(inplace=True))
model = Pipe(model, [1], devices=["cpu"], checkpoint="always")

x = torch.rand(1)
y = model(x)
y = model(x).local_value()
del model

message = r"a leaf Variable that requires grad .* used in an in-place operation."
Expand All @@ -40,7 +41,7 @@ def test_inplace_on_not_requires_grad():


@pytest.mark.xfail(strict=True)
def test_inplace_incorrect_grad():
def test_inplace_incorrect_grad(setup_rpc):
class M(nn.Module):
def forward(self, foo_bar):
# 'foo' requires grad but 'bar' does not. In-place operation on
Expand All @@ -62,7 +63,7 @@ def forward(self, foo_bar):
foo = torch.tensor([1.0], requires_grad=True)
bar = torch.tensor([1.0])

output = model((foo, bar))
output = model((foo, bar)).local_value()
del model
output.backward()

Expand Down