Skip to content

Commit

Permalink
Refactor all top level usages of record_shapeenv_event to ShapeEnv cl…
Browse files Browse the repository at this point in the history
…ass (#123735)

This ensures that first argument to record_shapeenv_event is a ShapeEnv
so we can appropriately short circuit when recording is not in progress.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: #123735
Approved by: https://github.com/ysiraichi, https://github.com/zou3519, https://github.com/albanD
ghstack dependencies: #124310, #124314, #124316, #124394, #124739, #124782, #124785
  • Loading branch information
ezyang authored and pytorchmergebot committed Apr 25, 2024
1 parent 61e05f2 commit 87bec7d
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 103 deletions.
6 changes: 2 additions & 4 deletions torch/_export/serde/serialize.py
Expand Up @@ -1421,8 +1421,7 @@ def deserialize_sym_int(self, s: SymInt) -> Union[int, torch.SymInt]:
self.shape_env.add_var_to_val(sym, hint)

if vr := self.symbol_name_to_range.get(val.expr_str):
symbolic_shapes._constrain_symbol_range(
self.shape_env,
self.shape_env.constrain_symbol_range(
sym,
compiler_min=vr.lower, # type: ignore[arg-type]
compiler_max=vr.upper, # type: ignore[arg-type]
Expand All @@ -1437,8 +1436,7 @@ def deserialize_sym_int(self, s: SymInt) -> Union[int, torch.SymInt]:
if s.name not in self.symbol_name_to_symbol:
self.symbol_name_to_symbol[s.name] = s
if vr := self.symbol_name_to_range.get(s.name):
symbolic_shapes._constrain_symbol_range(
self.shape_env,
self.shape_env.constrain_symbol_range(
s,
compiler_min=vr.lower, # type: ignore[arg-type]
compiler_max=vr.upper, # type: ignore[arg-type]
Expand Down
6 changes: 5 additions & 1 deletion torch/_logging/_registrations.py
@@ -1,7 +1,11 @@
# flake8: noqa: B950
from ._internal import register_artifact, register_log

DYNAMIC = ["torch.fx.experimental.symbolic_shapes", "torch.fx.experimental.sym_node"]
DYNAMIC = [
"torch.fx.experimental.symbolic_shapes",
"torch.fx.experimental.sym_node",
"torch.fx.experimental.recording",
]
DISTRIBUTED = [
"torch.distributed",
"torch._dynamo.backends.distributed",
Expand Down
93 changes: 53 additions & 40 deletions torch/fx/experimental/recording.py
@@ -1,4 +1,5 @@
import functools
import inspect
import itertools
import logging
from dataclasses import dataclass
Expand Down Expand Up @@ -220,52 +221,64 @@ def assert_equal(old: Optional[ShapeEnv], new: ShapeEnv) -> ShapeEnv:
def record_shapeenv_event(*, save_tracked_fakes: bool = False) -> Callable:
def decorator(fn: Callable) -> Callable:
assert callable(fn)
args = inspect.getfullargspec(fn).args
assert args and args[0] == "self", (
"record_shapeenv_event should only wrap methods on ShapeEnv; refactor your "
"code so that it calls into a method on ShapeEnv"
)
name = fn.__name__

@functools.wraps(fn)
def wrapper(*args, **kwargs):
from torch.fx.experimental.symbolic_shapes import ShapeEnv

if isinstance(args[0], ShapeEnv) and args[0].is_recording: # type: ignore[has-type]
# If ShapeEnv is already recording an event, call the wrapped
# function directly.
#
# NB: here, we skip the check of whether all ShapeEnv instances
# are equal, in favor of a faster dispatch.
return fn(*args, **kwargs)

# Retrieve an instance of ShapeEnv.
# Assumption: the collection of args and kwargs may not reference
# different ShapeEnv instances.
self = _extract_shape_env_and_assert_equal(args, kwargs)

# If we are calling this function without any ShapeEnv instance
# alive in its arguments, we don't record and call the original.
if self is None:
return fn(*args, **kwargs)

# Otherwise, start recording and call the function.
with self._recording():
# Take a snapshot of the current tracked_fakes.
tracked_fakes = (
self._snapshot_tracked_fakes() if save_tracked_fakes else None
)
# Record the event for 'fn'.
event = ShapeEnvEvent(
fn, list(args), kwargs, tracked_fakes, name=fn.__name__
)
# Play the event on this ShapeEnv.
# NB: It's important to put the event first, because running
# the event can trigger internal events that must be ordered
# after this event. However, if an exception happens, we do
# NOT want to have the event in the list, so pop it off from
# the record if an error happened
self.events.append(event)
try:
return event.run(self)
except Exception:
self.events.pop()
raise
assert isinstance(args[0], ShapeEnv)

try:
if args[0].is_recording: # type: ignore[has-type]
# If ShapeEnv is already recording an event, call the wrapped
# function directly.
#
# NB: here, we skip the check of whether all ShapeEnv instances
# are equal, in favor of a faster dispatch.
return fn(*args, **kwargs)

# Retrieve an instance of ShapeEnv.
# Assumption: the collection of args and kwargs may not reference
# different ShapeEnv instances.
self = _extract_shape_env_and_assert_equal(args, kwargs)

# If we are calling this function without any ShapeEnv instance
# alive in its arguments, we don't record and call the original.
if self is None:
return fn(*args, **kwargs)

# Otherwise, start recording and call the function.
with self._recording():
# Take a snapshot of the current tracked_fakes.
tracked_fakes = (
self._snapshot_tracked_fakes() if save_tracked_fakes else None
)
# Record the event for 'fn'.
event = ShapeEnvEvent(
fn, list(args), kwargs, tracked_fakes, name=fn.__name__
)
# Play the event on this ShapeEnv.
# NB: It's important to put the event first, because running
# the event can trigger internal events that must be ordered
# after this event. However, if an exception happens, we do
# NOT want to have the event in the list, so pop it off from
# the record if an error happened
self.events.append(event)
try:
return event.run(self)
except Exception:
self.events.pop()
raise

except Exception:
log.error("failed while running %s(*%s, **%s)", name, args[1:], kwargs)
raise

return wrapper

Expand Down
136 changes: 78 additions & 58 deletions torch/fx/experimental/symbolic_shapes.py
Expand Up @@ -725,10 +725,6 @@ def guard_scalar(a):
raise AssertionError(f"unrecognized scalar {a}")


def _constrain_symbol_range(shape_env, s: sympy.Symbol, compiler_min: int, compiler_max: int):
shape_env.constrain_symbol_range(s, compiler_min, compiler_max)


def _advise_is_size(a):
"""
Don't use this directly; use torch._check_is_size instead.
Expand Down Expand Up @@ -770,7 +766,6 @@ def _advise_is_size(a):
):
_constrain_range_for_size(a)

@record_shapeenv_event()
def _constrain_range_for_size(a, min: Optional[int] = None, max: Optional[int] = None):
"""
This function is NOT INTENDED to be used by itself.
Expand All @@ -782,27 +777,10 @@ def _constrain_range_for_size(a, min: Optional[int] = None, max: Optional[int] =
assert isinstance(a, SymInt), "can only constrain range for SymInt"
assert isinstance(a.node.expr, sympy.Symbol), "constraining non-Symbols NYI"

if min is None:
min = 0
if max is None:
max = sys.maxsize - 1

if max < min:
raise ValueError(
"Maximum value to constrain_as_size can't be less than the specified min value, "
"received min={min} and max={max}"
)

a.node.shape_env.constrain_symbol_range(
a.node.expr,
compiler_min=min,
compiler_max=max,
)
a.node.shape_env.size_like.add(a.node.expr)
a.node.shape_env._constrain_range_for_size(a.node.expr, min, max)


# inclusive both ways
@record_shapeenv_event()
def constrain_range(a, *, min: Optional[int], max: Optional[int] = None):
"""
Applies a constraint that the passed in SymInt must lie between min-max
Expand Down Expand Up @@ -844,54 +822,24 @@ def constrain_range(a, *, min: Optional[int], max: Optional[int] = None):
raise ValueError(f"Invalid value {a} for range [{min}:{max}]")
return

if isinstance(a.node.expr, sympy.Integer):
if not (min <= int(a.node.expr) <= max):
raise ValueRangeError(f"Invalid value {int(a.node.expr)} for range [{min}:{max}]")
return
assert isinstance(a.node.expr, sympy.Symbol), "constraining non-Symbols NYI"
a.node.shape_env._constrain_range(a.node.expr, min, max)

# TODO: Shouldn't we install a guard if the symbol is backed? Or is the
# semantics that this is an "unchecked" assert (but it this actually
# something useful? Might be better to restrict only for unbacked
# SymInt).
_constrain_symbol_range(
a.node.shape_env,
a.node.expr,
compiler_min=min,
compiler_max=max,
)


@record_shapeenv_event()
def constrain_unify(a, b):
def constrain_unify(a: torch.SymInt, b: torch.SymInt) -> None:
"""
Given two SymInts, constrain them so that they must be equal. NB:
this will not work with SymInts that represent nontrivial expressions
(yet!)
"""
# TODO: this does not install a deferred runtime assert yet

# TODO: Maybe dedupe this with _maybe_guard_rel?
if not isinstance(a, SymInt):
if not isinstance(b, SymInt):
assert a == b
return
else:
assert isinstance(b.node.expr, sympy.Symbol), "constraining non-Symbols NYI"
shape_env = b.node.shape_env
shape_env.replacements[b.node.expr] = sympy.Integer(a)
else:
# TODO: Actually, we can support this as long as one of them is a symbol.
# NB: We can't actually do "unification" as our operators are not
# injective
assert isinstance(a.node.expr, sympy.Symbol), "constraining non-Symbols NYI"
shape_env = a.node.shape_env
if not isinstance(b, SymInt):
shape_env.replacements[a.node.expr] = sympy.Integer(b)
else:
assert a.node.shape_env is b.node.shape_env
assert isinstance(b.node.expr, sympy.Symbol), "constraining non-Symbols NYI"
new_var = shape_env._find(a.node.expr)
shape_env.replacements[b.node.expr] = new_var

shape_env._constrain_unify(a, b)

# Assume that a boolean is true for the purposes of subsequent symbolic
# reasoning. This will keep track of corresponding runtime checks to verify
Expand Down Expand Up @@ -2470,6 +2418,78 @@ def _rename_unbacked_to(self, orig_s: sympy.Symbol, new_s: sympy.Symbol):
if dest is not None:
self._set_replacement(new_s, dest, "rename_unbacked_to_dest")

@record_shapeenv_event()
def _constrain_range_for_size(self, a: sympy.Symbol, min: Optional[int] = None, max: Optional[int] = None):
if min is None:
min = 0
if max is None:
max = sys.maxsize - 1

if max < min:
raise ValueError(
"Maximum value to constrain_as_size can't be less than the specified min value, "
"received min={min} and max={max}"
)

self.constrain_symbol_range(
a,
compiler_min=min,
compiler_max=max,
)
self.size_like.add(a)

@record_shapeenv_event()
def _constrain_range(self, a: sympy.Expr, min: int, max: int):
if isinstance(a, sympy.Integer):
if not (min <= int(a) <= max):
raise ValueRangeError(f"Invalid value {int(a)} for range [{min}:{max}]")
return
assert isinstance(a, sympy.Symbol), "constraining non-Symbols NYI"

# TODO: Shouldn't we install a guard if the symbol is backed? Or is the
# semantics that this is an "unchecked" assert (but it this actually
# something useful? Might be better to restrict only for unbacked
# SymInt).
self.constrain_symbol_range(
a,
compiler_min=min,
compiler_max=max,
)

@record_shapeenv_event()
def _constrain_unify(self, a, b):
"""
Given two SymInts, constrain them so that they must be equal. NB:
this will not work with SymInts that represent nontrivial expressions
(yet!)
"""
# TODO: this does not install a deferred runtime assert yet

# TODO: Maybe dedupe this with _maybe_guard_rel?
# Update Feb 2024: this is extra important to do, this doesn't handle
# unbacked replacements properly nor does it generate deferred runtime
# asserts
if not isinstance(a, SymInt):
if not isinstance(b, SymInt):
assert a == b
else:
assert isinstance(b.node.expr, sympy.Symbol), "constraining non-Symbols NYI"
assert b.node.shape_env is self
self.replacements[b.node.expr] = sympy.Integer(a)
else:
# TODO: Actually, we can support this as long as one of them is a symbol.
# NB: We can't actually do "unification" as our operators are not
# injective
assert isinstance(a.node.expr, sympy.Symbol), "constraining non-Symbols NYI"
assert a.node.shape_env is self
if not isinstance(b, SymInt):
self.replacements[a.node.expr] = sympy.Integer(b)
else:
assert a.node.shape_env is b.node.shape_env
assert isinstance(b.node.expr, sympy.Symbol), "constraining non-Symbols NYI"
new_var = self._find(a.node.expr)
self.replacements[b.node.expr] = new_var

def _ignore_fresh_unbacked_symbols_tls(self):
return getattr(TLS, "ignore_fresh_unbacked_symbols", False)

Expand Down

1 comment on commit 87bec7d

@pytorchmergebot
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reverted #123735 on behalf of https://github.com/jeanschmidt due to Breaking internal signals, more info in D56587358 (comment)

Please sign in to comment.