Skip to content

Commit

Permalink
Fix mypy errors
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan committed May 16, 2023
1 parent 22424b0 commit b574250
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 7 deletions.
8 changes: 4 additions & 4 deletions torch/distributed/_tensor/device_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ def all_gather(
def all_reduce(
self,
tensor: torch.Tensor,
op: ReduceOp = ReduceOp.SUM, # type: ignore[assignment]
op: ReduceOp.RedOpType = ReduceOp.SUM,
mesh_dim: int = 0,
async_op: bool = False,
) -> torch.Tensor:
Expand All @@ -409,7 +409,7 @@ def all_reduce(
A :class:`torch.Tensor` object
"""
dim_group = self._dim_groups[mesh_dim]
op_name: str = op.name # type: ignore[attr-defined]
op_name: str = op.name
return funcol.all_reduce(
tensor,
reduceOp=op_name,
Expand All @@ -422,7 +422,7 @@ def all_reduce(
def reduce_scatter(
self,
input: torch.Tensor,
op: ReduceOp = ReduceOp.SUM, # type: ignore[assignment]
op: ReduceOp.RedOpType = ReduceOp.SUM,
mesh_dim: int = 0,
scatter_dim: int = 0,
) -> torch.Tensor:
Expand All @@ -441,7 +441,7 @@ def reduce_scatter(
Returns:
A :class:`torch.Tensor` object
"""
op_name: str = op.name # type: ignore[attr-defined]
op_name: str = op.name
if self._backend == "nccl" or self._backend == "threaded":
dim_group = self._dim_groups[mesh_dim]
scatter_tensor = funcol.reduce_scatter_tensor(
Expand Down
2 changes: 1 addition & 1 deletion torch/distributed/_tensor/ops/math_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def mean_rule(op_schema: OpSchema) -> OutputSharding:
for placement in output_sharding.output_spec.placements:
if placement.is_partial():
partial_placement = cast(_Partial, placement)
partial_placement.reduce_op = c10d.ReduceOp.AVG # type: ignore[attr-defined]
partial_placement.reduce_op = c10d.ReduceOp.AVG

return output_sharding

Expand Down
4 changes: 2 additions & 2 deletions torch/distributed/_tensor/placement_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,8 +328,8 @@ class _Partial(Placement):
# We can implement custom reductions as needed by subclassing this
# class and override those contracts.

def __init__(self, reduce_op: c10d.ReduceOp = c10d.ReduceOp.SUM): # type: ignore[assignment]
self.reduce_op: c10d.ReduceOp = reduce_op
def __init__(self, reduce_op: c10d.ReduceOp.RedOpType = c10d.ReduceOp.SUM):
self.reduce_op: c10d.ReduceOp.RedOpType = reduce_op

def _to_replicate(
self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
Expand Down

0 comments on commit b574250

Please sign in to comment.