Skip to content

Commit

Permalink
Add missing torch.distributed.ReduceOp.AVG in type stubs (#101534)
Browse files Browse the repository at this point in the history
Add missing `AVG` to `torch.distributed.ReduceOp` enum for type annotation.

Ref:

https://github.com/pytorch/pytorch/blob/88b6a4577bad670c608424018caeba8698ad5f97/torch/csrc/distributed/c10d/Types.hpp#L35-L47
Pull Request resolved: #101534
Approved by: https://github.com/Skylion007
  • Loading branch information
XuehaiPan authored and jcaip committed May 22, 2023
1 parent ab09162 commit ddbc1f4
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 23 deletions.
19 changes: 10 additions & 9 deletions torch/_C/_distributed_c10d.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -96,15 +96,16 @@ class DebugLevel(Enum):
class ReduceOp:
def __init__(self, op: "RedOpType"): ...

SUM = ...
PRODUCT = ...
MIN = ...
MAX = ...
BAND = ...
BOR = ...
BXOR = ...
PREMUL_SUM = ...
UNUSED = ...
SUM: "RedOpType" = ...
AVG: "RedOpType" = ...
PRODUCT: "RedOpType" = ...
MIN: "RedOpType" = ...
MAX: "RedOpType" = ...
BAND: "RedOpType" = ...
BOR: "RedOpType" = ...
BXOR: "RedOpType" = ...
PREMUL_SUM: "RedOpType" = ...
UNUSED: "RedOpType" = ...

class RedOpType(Enum): ...

Expand Down
8 changes: 4 additions & 4 deletions torch/distributed/_tensor/device_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ def all_gather(
def all_reduce(
self,
tensor: torch.Tensor,
op: ReduceOp = ReduceOp.SUM, # type: ignore[assignment]
op: ReduceOp.RedOpType = ReduceOp.SUM,
mesh_dim: int = 0,
async_op: bool = False,
) -> torch.Tensor:
Expand All @@ -409,7 +409,7 @@ def all_reduce(
A :class:`torch.Tensor` object
"""
dim_group = self._dim_groups[mesh_dim]
op_name: str = op.name # type: ignore[attr-defined]
op_name: str = op.name
return funcol.all_reduce(
tensor,
reduceOp=op_name,
Expand All @@ -422,7 +422,7 @@ def all_reduce(
def reduce_scatter(
self,
input: torch.Tensor,
op: ReduceOp = ReduceOp.SUM, # type: ignore[assignment]
op: ReduceOp.RedOpType = ReduceOp.SUM,
mesh_dim: int = 0,
scatter_dim: int = 0,
) -> torch.Tensor:
Expand All @@ -441,7 +441,7 @@ def reduce_scatter(
Returns:
A :class:`torch.Tensor` object
"""
op_name: str = op.name # type: ignore[attr-defined]
op_name: str = op.name
if self._backend == "nccl" or self._backend == "threaded":
dim_group = self._dim_groups[mesh_dim]
scatter_tensor = funcol.reduce_scatter_tensor(
Expand Down
2 changes: 1 addition & 1 deletion torch/distributed/_tensor/ops/math_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def mean_rule(op_schema: OpSchema) -> OutputSharding:
for placement in output_sharding.output_spec.placements:
if placement.is_partial():
partial_placement = cast(_Partial, placement)
partial_placement.reduce_op = c10d.ReduceOp.AVG # type: ignore[attr-defined]
partial_placement.reduce_op = c10d.ReduceOp.AVG

return output_sharding

Expand Down
14 changes: 5 additions & 9 deletions torch/distributed/_tensor/placement_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def _reduce_shard_tensor(
self,
tensor: torch.Tensor,
mesh: DeviceMesh,
reduce_op: c10d.ReduceOp,
reduce_op: c10d.ReduceOp.RedOpType,
mesh_dim: int,
) -> torch.Tensor:
"""
Expand Down Expand Up @@ -328,15 +328,13 @@ class _Partial(Placement):
# We can implement custom reductions as needed by subclassing this
# class and override those contracts.

def __init__(self, reduce_op: c10d.ReduceOp = c10d.ReduceOp.SUM): # type: ignore[assignment]
self.reduce_op: c10d.ReduceOp = reduce_op
def __init__(self, reduce_op: c10d.ReduceOp.RedOpType = c10d.ReduceOp.SUM):
self.reduce_op: c10d.ReduceOp.RedOpType = reduce_op

def _to_replicate(
self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
) -> torch.Tensor:
return mesh.all_reduce(
tensor, self.reduce_op, mesh_dim=mesh_dim # type: ignore[call-arg]
)
return mesh.all_reduce(tensor, self.reduce_op, mesh_dim=mesh_dim)

def _to_shard(
self,
Expand All @@ -347,9 +345,7 @@ def _to_shard(
) -> torch.Tensor:
# by default call reduce_shard_tensor of the shard_spec.
shard_spec = cast(Shard, shard_spec)
return shard_spec._reduce_shard_tensor(
tensor, mesh, self.reduce_op, mesh_dim # type: ignore[call-arg]
)
return shard_spec._reduce_shard_tensor(tensor, mesh, self.reduce_op, mesh_dim)

def __eq__(self, other: object) -> bool:
if not isinstance(other, _Partial):
Expand Down

0 comments on commit ddbc1f4

Please sign in to comment.