Skip to content

Commit

Permalink
[DTensor] Made _Partial, Replicate frozen dataclasses (#113919)
Browse files Browse the repository at this point in the history
This is part of the larger stack to work toward being able to cache hashes for `DTensorSpec`.
Pull Request resolved: #113919
Approved by: https://github.com/wanchaol
  • Loading branch information
awgu authored and pytorchmergebot committed Nov 20, 2023
1 parent 97d2b43 commit 77e058f
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions torch/distributed/_tensor/placement_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ def __str__(self) -> str:
return f"S({self.dim})"


@dataclass(frozen=True)
class Replicate(Placement):
# replicate placement
def __eq__(self, other: object) -> bool:
Expand Down Expand Up @@ -315,6 +316,7 @@ def _replicate_tensor(
return tensor


@dataclass(frozen=True)
class _Partial(Placement):
# This is a default partial placement with element-wise reduce op
# when doing reduction it follows the contract of `_to_replicate`
Expand All @@ -323,9 +325,7 @@ class _Partial(Placement):
#
# We can implement custom reductions as needed by subclassing this
# class and override those contracts.

def __init__(self, reduce_op: c10d.ReduceOp.RedOpType = c10d.ReduceOp.SUM):
self.reduce_op: c10d.ReduceOp.RedOpType = reduce_op
reduce_op: c10d.ReduceOp.RedOpType = c10d.ReduceOp.SUM

def _to_replicate(
self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
Expand Down

0 comments on commit 77e058f

Please sign in to comment.