Skip to content

Commit

Permalink
Traceable wrapper subclass support for deferred runtime asserts (#126198
Browse files Browse the repository at this point in the history
)

The padded dense -> jagged conversion op has the signature:
```
_fbgemm_dense_to_jagged_forward(Tensor dense, Tensor[] offsets, SymInt? total_L=None) -> Tensor
```

when `total_L` is not specified, the meta registration has a data-dependent output shape (based on `offsets[0][-1]`). Returning an unbacked SymInt here should work in theory, but traceable wrapper subclass support is missing in later code to handle deferred runtime asserts. This PR fixes this.
Pull Request resolved: #126198
Approved by: https://github.com/ezyang
  • Loading branch information
jbschlosser authored and pytorchmergebot committed May 21, 2024
1 parent 82b4528 commit 31ba6ee
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/source/fx.experimental.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ torch.fx.experimental.symbolic_shapes
CallMethodKey
PropagateUnbackedSymInts
DivideByKey
InnerTensorKey

hint_int
is_concrete_int
Expand Down
3 changes: 2 additions & 1 deletion test/allowlist_for_publicAPI.json
Original file line number Diff line number Diff line change
Expand Up @@ -2027,6 +2027,7 @@
"uninteresting_files",
"CallMethodKey",
"DivideByKey",
"InnerTensorKey",
"PropagateUnbackedSymInts",
"ShapeEnvSettings",
"log_lru_cache_stats",
Expand Down Expand Up @@ -2752,4 +2753,4 @@
"torch.utils.hipify.hipify_python": [
"TrieNode"
]
}
}
20 changes: 20 additions & 0 deletions torch/fx/experimental/symbolic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,18 @@ def get(self, o: Any) -> Any:
return getattr(o, self.name)()


@dataclass(frozen=True)
class InnerTensorKey:
inner_name: str

def __str__(self) -> str:
return f".{self.inner_name}"

def get(self, o: Any) -> Any:
"""Get the inner tensor attribute"""
return getattr(o, self.inner_name)


@dataclass(frozen=True)
class DivideByKey:
divisor: int
Expand Down Expand Up @@ -538,6 +550,14 @@ def free_unbacked_symbols_with_path(
real=real[i] if real is not None else None
)
)
elif is_traceable_wrapper_subclass(a):
# TODO: Determine if this is correct
attrs, _ = a.__tensor_flatten__()
for attr in attrs:
sub = getattr(a, attr)
r.update(
free_unbacked_symbols_with_path(sub, path + (InnerTensorKey(attr),))
)
elif isinstance(a, torch.Tensor):
r.update(
free_unbacked_symbols_with_path(
Expand Down
8 changes: 8 additions & 0 deletions torch/fx/passes/runtime_assert.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def insert_deferred_runtime_asserts(
ConvertIntKey,
DivideByKey,
free_symbols,
InnerTensorKey,
)
from torch.utils._sympy.interp import sympy_interp
from torch.utils._sympy.reference import PythonReferenceAnalysis
Expand Down Expand Up @@ -225,6 +226,13 @@ def go(node, keypath):
),
keypath[1:],
)
elif isinstance(keypath[0], InnerTensorKey):
return go(
graph.call_function(
getattr, (node, keypath[0].inner_name)
),
keypath[1:],
)
else:
raise AssertionError(f"unrecognized keypath {keypath}")

Expand Down

0 comments on commit 31ba6ee

Please sign in to comment.