Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Traceable wrapper subclass support for deferred runtime asserts #126198

Closed
wants to merge 6 commits into from
20 changes: 20 additions & 0 deletions torch/fx/experimental/symbolic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,18 @@ def get(self, o: Any) -> Any:
return getattr(o, self.name)()


@dataclass(frozen=True)
class InnerTensorKey:
inner_name: str

def __str__(self) -> str:
return f".{self.inner_name}"

def get(self, o: Any) -> Any:
"""Call the method on object"""
jbschlosser marked this conversation as resolved.
Show resolved Hide resolved
return getattr(o, self.inner_name)


@dataclass(frozen=True)
class DivideByKey:
divisor: int
Expand Down Expand Up @@ -534,6 +546,14 @@ def free_unbacked_symbols_with_path(
real=real[i] if real is not None else None
)
)
elif is_traceable_wrapper_subclass(a):
# TODO: Determine if this is correct
attrs, _ = a.__tensor_flatten__()
for attr in attrs:
sub = getattr(a, attr)
r.update(
free_unbacked_symbols_with_path(sub, path + (InnerTensorKey(attr),))
)
elif isinstance(a, torch.Tensor):
r.update(
free_unbacked_symbols_with_path(
Expand Down
8 changes: 8 additions & 0 deletions torch/fx/passes/runtime_assert.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def insert_deferred_runtime_asserts(
cast_symbool_to_symint_guardless,
ConvertIntKey,
DivideByKey,
InnerTensorKey,
free_symbols,
)
from torch.utils._sympy.interp import sympy_interp
Expand Down Expand Up @@ -225,6 +226,13 @@ def go(node, keypath):
),
keypath[1:],
)
elif isinstance(keypath[0], InnerTensorKey):
return go(
graph.call_function(
getattr, (node, keypath[0].inner_name)
),
keypath[1:],
)
else:
raise AssertionError(f"unrecognized keypath {keypath}")

Expand Down
Loading