Skip to content

Commit

Permalink
[DTensor] Computed DTensorSpec hash lazily (#114322)
Browse files Browse the repository at this point in the history
This is a forward fix for #113781.

We lazily compute the hash so that we do not try to compute the hash on `SymInt`s (for the stride) during Dynamo tracing.

Tested via:
```
python test/distributed/_tensor/test_dtensor_compile.py -k test_2d_fsdp_tp_ac_compile
```
Pull Request resolved: #114322
Approved by: https://github.com/wanchaol
ghstack dependencies: #113919, #113924, #114134, #113925, #113930, #114141, #113915, #114140
  • Loading branch information
awgu authored and pytorchmergebot committed Nov 22, 2023
1 parent c5ddfa7 commit e7326ec
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions torch/distributed/_tensor/placement_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ class DTensorSpec:
def __post_init__(self):
if not isinstance(self.placements, tuple):
self.placements = tuple(self.placements)
self._hash = self._hash_impl()
self._hash: Optional[int] = None

def __setattr__(self, attr: str, value: Any):
super().__setattr__(attr, value)
Expand All @@ -397,7 +397,7 @@ def __setattr__(self, attr: str, value: Any):
if hasattr(self, "_hash") and attr in ("mesh", "placements", "tensor_meta"):
self._hash = self._hash_impl()

def _hash_impl(self):
def _hash_impl(self) -> int:
# hashing and equality check for DTensorSpec are used to cache the sharding
# propagation results. We only need to consider the mesh, placements, shape
# dtype and stride.
Expand All @@ -416,9 +416,12 @@ def _hash_impl(self):
return hash((self.mesh, self.placements))

def __hash__(self) -> int:
# We eagerly cache the spec to avoid recomputing the hash upon each
# We lazily cache the spec to avoid recomputing the hash upon each
# use, where we make sure to update the hash when the `tensor_meta`
# changes by overriding `__setattr__`.
# changes by overriding `__setattr__`. This must be lazy so that Dynamo
# does not try to hash non-singleton `SymInt`s for the stride.
if self._hash is None:
self._hash = self._hash_impl()
return self._hash

def __eq__(self, __o: object) -> bool:
Expand Down

0 comments on commit e7326ec

Please sign in to comment.