Skip to content

Commit

Permalink
Support tracing native functional collective via python APIs (#119103)
Browse files Browse the repository at this point in the history
Summary:
- Inlined `torch.distributed.distributed_c10d._get_group_size_by_name`
- Updated all torch.compile tests in test_c10d_functional_native.py to use funcol python APIs (as opposed to the dispatcher ops)

Pull Request resolved: #119103
Approved by: https://github.com/wconstab, https://github.com/fegin, https://github.com/wanchaol
  • Loading branch information
yifuwang authored and pytorchmergebot committed Feb 15, 2024
1 parent 5f9b432 commit cd08dc3
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 19 deletions.
28 changes: 11 additions & 17 deletions test/distributed/test_c10d_functional_native.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,10 +499,8 @@ def test_inductor_all_gather_into_tensor_single(self):
self._init_process_group()

def func(arg: torch.Tensor) -> torch.Tensor:
ag0 = torch.ops._c10d_functional.all_gather_into_tensor(
arg, self.world_size, "default"
)
ag0 = torch.ops._c10d_functional.wait_tensor(ag0)
ag0 = funcol.all_gather_tensor(arg, 0, "default")
ag0 = funcol.wait_tensor(ag0)
return ag0

arg = torch.rand(4, 4, device=self.device)
Expand Down Expand Up @@ -533,10 +531,8 @@ def test_inductor_all_gather_into_tensor_coalesced(self):
self._init_process_group()

def func(args: List[torch.Tensor]) -> torch.Tensor:
ag0 = torch.ops._c10d_functional.all_gather_into_tensor_coalesced(
args, self.world_size, "default"
)
ag0 = [torch.ops._c10d_functional.wait_tensor(out) for out in ag0]
ag0 = funcol.all_gather_into_tensor_coalesced(args, "default")
ag0 = [funcol.wait_tensor(out) for out in ag0]
return ag0

args = [torch.rand(4, 4, device=self.device) for _ in range(4)]
Expand Down Expand Up @@ -575,10 +571,8 @@ def test_inductor_reduce_scatter_tensor_single(self):
self._init_process_group()

def func(arg: torch.Tensor) -> torch.Tensor:
rs0 = torch.ops._c10d_functional.reduce_scatter_tensor(
arg, "avg", self.world_size, "default"
)
rs0 = torch.ops._c10d_functional.wait_tensor(rs0)
rs0 = funcol.reduce_scatter_tensor(arg, "avg", 0, "default")
rs0 = funcol.wait_tensor(rs0)
return rs0

arg = torch.rand(4, 4, device=self.device)
Expand Down Expand Up @@ -609,10 +603,10 @@ def test_inductor_reduce_scatter_tensor_coalesced(self):
self._init_process_group()

def func(args: List[torch.Tensor]) -> torch.Tensor:
rs0 = torch.ops._c10d_functional.reduce_scatter_tensor_coalesced(
args, "avg", self.world_size, "default"
rs0 = funcol.reduce_scatter_tensor_coalesced(
args, "avg", [0] * len(args), "default"
)
rs0 = [torch.ops._c10d_functional.wait_tensor(out) for out in rs0]
rs0 = [funcol.wait_tensor(out) for out in rs0]
return rs0

args = [torch.rand(4, 4, device=self.device) for _ in range(4)]
Expand Down Expand Up @@ -662,13 +656,13 @@ def func(
output_split_sizes: torch.Tensor,
input_split_sizes: torch.Tensor,
) -> torch.Tensor:
output = torch.ops._c10d_functional.all_to_all_single(
output = funcol.all_to_all_single(
input,
_tolist_with_constrain_as_size(output_split_sizes),
_tolist_with_constrain_as_size(input_split_sizes),
"default",
)
return torch.ops._c10d_functional.wait_tensor(output)
return funcol.wait_tensor(output)

torch.manual_seed(42)
send_sz_matrix = torch.randint(0, 20, (self.world_size, self.world_size))
Expand Down
4 changes: 3 additions & 1 deletion torch/_dynamo/variables/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,15 @@ def is_constant_pg_functions(value):
return False

from torch.distributed.distributed_c10d import (
_get_group_size_by_name,
_get_group_tag,
get_process_group_ranks,
)

constant_processgroup_functions = [
get_process_group_ranks,
_get_group_size_by_name,
_get_group_tag,
get_process_group_ranks,
]

return inspect.isfunction(value) and value in constant_processgroup_functions
Expand Down
3 changes: 2 additions & 1 deletion torch/_dynamo/variables/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,8 @@ def call_function(
# We desugar it at trace-time into ranks by directly calling util
# bake the result into the trace
assert len(args) == 1, "Expected one arg (pg)"
assert isinstance(args[0], ProcessGroupVariable)
# Some constant pg functions address a pg via its name
assert isinstance(args[0], (ProcessGroupVariable, ConstantVariable))

invocation_result = self.value(args[0].as_python_constant())
# Note - while we *could* cook up sources around invocations, like a FunctionSource
Expand Down

0 comments on commit cd08dc3

Please sign in to comment.