Skip to content

Commit

Permalink
Fix mypy errors
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan committed May 16, 2023
1 parent 22424b0 commit fad32d0
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion torch/distributed/_tensor/ops/math_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def mean_rule(op_schema: OpSchema) -> OutputSharding:
for placement in output_sharding.output_spec.placements:
if placement.is_partial():
partial_placement = cast(_Partial, placement)
partial_placement.reduce_op = c10d.ReduceOp.AVG # type: ignore[attr-defined]
partial_placement.reduce_op = c10d.ReduceOp.AVG

return output_sharding

Expand Down
4 changes: 2 additions & 2 deletions torch/distributed/_tensor/placement_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,8 +328,8 @@ class _Partial(Placement):
# We can implement custom reductions as needed by subclassing this
# class and override those contracts.

def __init__(self, reduce_op: c10d.ReduceOp = c10d.ReduceOp.SUM): # type: ignore[assignment]
self.reduce_op: c10d.ReduceOp = reduce_op
def __init__(self, reduce_op: c10d.ReduceOp.RedOpType = c10d.ReduceOp.SUM):
self.reduce_op: c10d.ReduceOp.RedOpType = reduce_op

def _to_replicate(
self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
Expand Down

0 comments on commit fad32d0

Please sign in to comment.