Skip to content

Commit

Permalink
Fix mypy errors
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan committed May 16, 2023
1 parent 22424b0 commit 303a6d1
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 14 deletions.
8 changes: 4 additions & 4 deletions torch/distributed/_tensor/device_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ def all_gather(
def all_reduce(
self,
tensor: torch.Tensor,
op: ReduceOp = ReduceOp.SUM, # type: ignore[assignment]
op: ReduceOp.RedOpType = ReduceOp.SUM,
mesh_dim: int = 0,
async_op: bool = False,
) -> torch.Tensor:
Expand All @@ -409,7 +409,7 @@ def all_reduce(
A :class:`torch.Tensor` object
"""
dim_group = self._dim_groups[mesh_dim]
op_name: str = op.name # type: ignore[attr-defined]
op_name: str = op.name
return funcol.all_reduce(
tensor,
reduceOp=op_name,
Expand All @@ -422,7 +422,7 @@ def all_reduce(
def reduce_scatter(
self,
input: torch.Tensor,
op: ReduceOp = ReduceOp.SUM, # type: ignore[assignment]
op: ReduceOp.RedOpType = ReduceOp.SUM,
mesh_dim: int = 0,
scatter_dim: int = 0,
) -> torch.Tensor:
Expand All @@ -441,7 +441,7 @@ def reduce_scatter(
Returns:
A :class:`torch.Tensor` object
"""
op_name: str = op.name # type: ignore[attr-defined]
op_name: str = op.name
if self._backend == "nccl" or self._backend == "threaded":
dim_group = self._dim_groups[mesh_dim]
scatter_tensor = funcol.reduce_scatter_tensor(
Expand Down
2 changes: 1 addition & 1 deletion torch/distributed/_tensor/ops/math_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def mean_rule(op_schema: OpSchema) -> OutputSharding:
for placement in output_sharding.output_spec.placements:
if placement.is_partial():
partial_placement = cast(_Partial, placement)
partial_placement.reduce_op = c10d.ReduceOp.AVG # type: ignore[attr-defined]
partial_placement.reduce_op = c10d.ReduceOp.AVG

return output_sharding

Expand Down
14 changes: 5 additions & 9 deletions torch/distributed/_tensor/placement_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def _reduce_shard_tensor(
self,
tensor: torch.Tensor,
mesh: DeviceMesh,
reduce_op: c10d.ReduceOp,
reduce_op: c10d.ReduceOp.RedOpType,
mesh_dim: int,
) -> torch.Tensor:
"""
Expand Down Expand Up @@ -328,15 +328,13 @@ class _Partial(Placement):
# We can implement custom reductions as needed by subclassing this
# class and override those contracts.

def __init__(self, reduce_op: c10d.ReduceOp = c10d.ReduceOp.SUM): # type: ignore[assignment]
self.reduce_op: c10d.ReduceOp = reduce_op
def __init__(self, reduce_op: c10d.ReduceOp.RedOpType = c10d.ReduceOp.SUM):
self.reduce_op: c10d.ReduceOp.RedOpType = reduce_op

def _to_replicate(
self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
) -> torch.Tensor:
return mesh.all_reduce(
tensor, self.reduce_op, mesh_dim=mesh_dim # type: ignore[call-arg]
)
return mesh.all_reduce(tensor, self.reduce_op, mesh_dim=mesh_dim)

def _to_shard(
self,
Expand All @@ -347,9 +345,7 @@ def _to_shard(
) -> torch.Tensor:
# by default call reduce_shard_tensor of the shard_spec.
shard_spec = cast(Shard, shard_spec)
return shard_spec._reduce_shard_tensor(
tensor, mesh, self.reduce_op, mesh_dim # type: ignore[call-arg]
)
return shard_spec._reduce_shard_tensor(tensor, mesh, self.reduce_op, mesh_dim)

def __eq__(self, other: object) -> bool:
if not isinstance(other, _Partial):
Expand Down

0 comments on commit 303a6d1

Please sign in to comment.