Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make FakeProcessGroup traceable #113314

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
15 changes: 15 additions & 0 deletions test/distributed/test_functional_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from functorch import make_fx
from torch.testing import FileCheck
from torch.testing._internal.distributed.fake_pg import FakeStore
from torch.utils._triton import has_triton

if not dist.is_available():
Expand Down Expand Up @@ -575,6 +576,20 @@ def allreduce(t, pg):
compiled_allreduce = torch.compile(allreduce, fullgraph=True)
compiled_allreduce(torch.randn(8, device=self.device), self.process_group)

@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
def test_tracing_with_fakepg(self):
def allreduce(t, pg):
return ft_c.all_reduce(t, "sum", pg)

compiled_allreduce = torch.compile(allreduce, fullgraph=True)
dist.init_process_group(
backend="fake",
rank=0,
world_size=8,
store=FakeStore(),
)
allreduce(torch.randn(8, device=self.device), pg=dist.group.WORLD)


class TestOpWaitiness(MultiThreadedTestCase):
@property
Expand Down
7 changes: 7 additions & 0 deletions torch/_dynamo/variables/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
)
from .distributed import (
DeviceMeshVariable,
FakeProcessGroupVariable,
PlacementClassVariable,
PlacementVariable,
ProcessGroupVariable,
Expand Down Expand Up @@ -660,6 +661,12 @@ def index_source(key):
source=self.source,
guards=self.make_guards(GuardBuilder.ID_MATCH),
)
elif FakeProcessGroupVariable.is_process_group(value):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we make this check directly inside ProcessGroupVariable, I think FakePG is just another type of PG so we should try to make that works directly

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I guess we won't actually do something special for FakePG. Let me change it.

return FakeProcessGroupVariable(
value,
source=self.source,
guards=self.make_guards(GuardBuilder.ID_MATCH),
)
elif DeviceMeshVariable.is_device_mesh(value):
# TODO: see if we need to add custom guard instead
# of a simple ID_MATCH
Expand Down
11 changes: 11 additions & 0 deletions torch/_dynamo/variables/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,3 +213,14 @@ def is_process_group(value):
from torch._C._distributed_c10d import ProcessGroup

return istype(value, ProcessGroup)


class FakeProcessGroupVariable(ProcessGroupVariable):
@staticmethod
def is_process_group(value):
if not DistributedVariable.is_available():
return False

from torch.testing._internal.distributed.fake_pg import FakeProcessGroup

return istype(value, FakeProcessGroup)
9 changes: 7 additions & 2 deletions torch/_dynamo/variables/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,12 @@
NullContextVariable,
TorchFunctionDisableVariable,
)
from .distributed import is_constant_pg_functions, is_from_local, ProcessGroupVariable
from .distributed import (
FakeProcessGroupVariable,
is_constant_pg_functions,
is_from_local,
ProcessGroupVariable,
)
from .higher_order_ops import TorchHigherOrderOperatorVariable
from .lists import ListVariable, TupleVariable
from .torch_function import can_dispatch_torch_function, dispatch_torch_function
Expand Down Expand Up @@ -584,7 +589,7 @@ def call_function(
# We desugar it at trace-time into ranks by directly calling util
# bake the result into the trace
assert len(args) == 1, "Expected one arg (pg)"
assert isinstance(args[0], ProcessGroupVariable)
assert isinstance(args[0], (ProcessGroupVariable, FakeProcessGroupVariable))

invocation_result = self.value(args[0].as_python_constant())
# Note - while we *could* cook up sources around invocations, like a FunctionSource
Expand Down
3 changes: 3 additions & 0 deletions torch/testing/_internal/distributed/fake_pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ def __init__(self, rank, world_size):
def allreduce(self, tensor_list, opts=AllreduceOptions()):
return ret_work(tensor_list)

def allreduce_coalesced(self, tensor_list, opts=AllreduceOptions()):
return ret_work(tensor_list)

def allgather(self, output_tensors, input_tensor, opts=AllgatherOptions()):
# NOTE: in general it's not good form to try to make FakePG work with 'real data',
# but the reasoning here is that we want FakePG to work with DeviceMesh's init
Expand Down