From 87bec7db4e55f329e077eb7003af2f4817cd4210 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Thu, 25 Apr 2024 06:36:44 -0700 Subject: [PATCH] Refactor all top level usages of record_shapeenv_event to ShapeEnv class (#123735) This ensures that first argument to record_shapeenv_event is a ShapeEnv so we can appropriately short circuit when recording is not in progress. Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/123735 Approved by: https://github.com/ysiraichi, https://github.com/zou3519, https://github.com/albanD ghstack dependencies: #124310, #124314, #124316, #124394, #124739, #124782, #124785 --- torch/_export/serde/serialize.py | 6 +- torch/_logging/_registrations.py | 6 +- torch/fx/experimental/recording.py | 93 +++++++++------- torch/fx/experimental/symbolic_shapes.py | 136 +++++++++++++---------- 4 files changed, 138 insertions(+), 103 deletions(-) diff --git a/torch/_export/serde/serialize.py b/torch/_export/serde/serialize.py index b71a221a5317..aa9d69236e6f 100644 --- a/torch/_export/serde/serialize.py +++ b/torch/_export/serde/serialize.py @@ -1421,8 +1421,7 @@ def deserialize_sym_int(self, s: SymInt) -> Union[int, torch.SymInt]: self.shape_env.add_var_to_val(sym, hint) if vr := self.symbol_name_to_range.get(val.expr_str): - symbolic_shapes._constrain_symbol_range( - self.shape_env, + self.shape_env.constrain_symbol_range( sym, compiler_min=vr.lower, # type: ignore[arg-type] compiler_max=vr.upper, # type: ignore[arg-type] @@ -1437,8 +1436,7 @@ def deserialize_sym_int(self, s: SymInt) -> Union[int, torch.SymInt]: if s.name not in self.symbol_name_to_symbol: self.symbol_name_to_symbol[s.name] = s if vr := self.symbol_name_to_range.get(s.name): - symbolic_shapes._constrain_symbol_range( - self.shape_env, + self.shape_env.constrain_symbol_range( s, compiler_min=vr.lower, # type: ignore[arg-type] compiler_max=vr.upper, # type: ignore[arg-type] diff --git a/torch/_logging/_registrations.py b/torch/_logging/_registrations.py index 5ff3372feb8d..4b87a8b592d6 100644 --- a/torch/_logging/_registrations.py +++ b/torch/_logging/_registrations.py @@ -1,7 +1,11 @@ # flake8: noqa: B950 from ._internal import register_artifact, register_log -DYNAMIC = ["torch.fx.experimental.symbolic_shapes", "torch.fx.experimental.sym_node"] +DYNAMIC = [ + "torch.fx.experimental.symbolic_shapes", + "torch.fx.experimental.sym_node", + "torch.fx.experimental.recording", +] DISTRIBUTED = [ "torch.distributed", "torch._dynamo.backends.distributed", diff --git a/torch/fx/experimental/recording.py b/torch/fx/experimental/recording.py index c200c10e6f2d..4bf9ebab17b3 100644 --- a/torch/fx/experimental/recording.py +++ b/torch/fx/experimental/recording.py @@ -1,4 +1,5 @@ import functools +import inspect import itertools import logging from dataclasses import dataclass @@ -220,52 +221,64 @@ def assert_equal(old: Optional[ShapeEnv], new: ShapeEnv) -> ShapeEnv: def record_shapeenv_event(*, save_tracked_fakes: bool = False) -> Callable: def decorator(fn: Callable) -> Callable: assert callable(fn) + args = inspect.getfullargspec(fn).args + assert args and args[0] == "self", ( + "record_shapeenv_event should only wrap methods on ShapeEnv; refactor your " + "code so that it calls into a method on ShapeEnv" + ) name = fn.__name__ @functools.wraps(fn) def wrapper(*args, **kwargs): from torch.fx.experimental.symbolic_shapes import ShapeEnv - if isinstance(args[0], ShapeEnv) and args[0].is_recording: # type: ignore[has-type] - # If ShapeEnv is already recording an event, call the wrapped - # function directly. - # - # NB: here, we skip the check of whether all ShapeEnv instances - # are equal, in favor of a faster dispatch. - return fn(*args, **kwargs) - - # Retrieve an instance of ShapeEnv. - # Assumption: the collection of args and kwargs may not reference - # different ShapeEnv instances. - self = _extract_shape_env_and_assert_equal(args, kwargs) - - # If we are calling this function without any ShapeEnv instance - # alive in its arguments, we don't record and call the original. - if self is None: - return fn(*args, **kwargs) - - # Otherwise, start recording and call the function. - with self._recording(): - # Take a snapshot of the current tracked_fakes. - tracked_fakes = ( - self._snapshot_tracked_fakes() if save_tracked_fakes else None - ) - # Record the event for 'fn'. - event = ShapeEnvEvent( - fn, list(args), kwargs, tracked_fakes, name=fn.__name__ - ) - # Play the event on this ShapeEnv. - # NB: It's important to put the event first, because running - # the event can trigger internal events that must be ordered - # after this event. However, if an exception happens, we do - # NOT want to have the event in the list, so pop it off from - # the record if an error happened - self.events.append(event) - try: - return event.run(self) - except Exception: - self.events.pop() - raise + assert isinstance(args[0], ShapeEnv) + + try: + if args[0].is_recording: # type: ignore[has-type] + # If ShapeEnv is already recording an event, call the wrapped + # function directly. + # + # NB: here, we skip the check of whether all ShapeEnv instances + # are equal, in favor of a faster dispatch. + return fn(*args, **kwargs) + + # Retrieve an instance of ShapeEnv. + # Assumption: the collection of args and kwargs may not reference + # different ShapeEnv instances. + self = _extract_shape_env_and_assert_equal(args, kwargs) + + # If we are calling this function without any ShapeEnv instance + # alive in its arguments, we don't record and call the original. + if self is None: + return fn(*args, **kwargs) + + # Otherwise, start recording and call the function. + with self._recording(): + # Take a snapshot of the current tracked_fakes. + tracked_fakes = ( + self._snapshot_tracked_fakes() if save_tracked_fakes else None + ) + # Record the event for 'fn'. + event = ShapeEnvEvent( + fn, list(args), kwargs, tracked_fakes, name=fn.__name__ + ) + # Play the event on this ShapeEnv. + # NB: It's important to put the event first, because running + # the event can trigger internal events that must be ordered + # after this event. However, if an exception happens, we do + # NOT want to have the event in the list, so pop it off from + # the record if an error happened + self.events.append(event) + try: + return event.run(self) + except Exception: + self.events.pop() + raise + + except Exception: + log.error("failed while running %s(*%s, **%s)", name, args[1:], kwargs) + raise return wrapper diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index 8d61e3205f76..842843895c8d 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -725,10 +725,6 @@ def guard_scalar(a): raise AssertionError(f"unrecognized scalar {a}") -def _constrain_symbol_range(shape_env, s: sympy.Symbol, compiler_min: int, compiler_max: int): - shape_env.constrain_symbol_range(s, compiler_min, compiler_max) - - def _advise_is_size(a): """ Don't use this directly; use torch._check_is_size instead. @@ -770,7 +766,6 @@ def _advise_is_size(a): ): _constrain_range_for_size(a) -@record_shapeenv_event() def _constrain_range_for_size(a, min: Optional[int] = None, max: Optional[int] = None): """ This function is NOT INTENDED to be used by itself. @@ -782,27 +777,10 @@ def _constrain_range_for_size(a, min: Optional[int] = None, max: Optional[int] = assert isinstance(a, SymInt), "can only constrain range for SymInt" assert isinstance(a.node.expr, sympy.Symbol), "constraining non-Symbols NYI" - if min is None: - min = 0 - if max is None: - max = sys.maxsize - 1 - - if max < min: - raise ValueError( - "Maximum value to constrain_as_size can't be less than the specified min value, " - "received min={min} and max={max}" - ) - - a.node.shape_env.constrain_symbol_range( - a.node.expr, - compiler_min=min, - compiler_max=max, - ) - a.node.shape_env.size_like.add(a.node.expr) + a.node.shape_env._constrain_range_for_size(a.node.expr, min, max) # inclusive both ways -@record_shapeenv_event() def constrain_range(a, *, min: Optional[int], max: Optional[int] = None): """ Applies a constraint that the passed in SymInt must lie between min-max @@ -844,54 +822,24 @@ def constrain_range(a, *, min: Optional[int], max: Optional[int] = None): raise ValueError(f"Invalid value {a} for range [{min}:{max}]") return - if isinstance(a.node.expr, sympy.Integer): - if not (min <= int(a.node.expr) <= max): - raise ValueRangeError(f"Invalid value {int(a.node.expr)} for range [{min}:{max}]") - return - assert isinstance(a.node.expr, sympy.Symbol), "constraining non-Symbols NYI" + a.node.shape_env._constrain_range(a.node.expr, min, max) - # TODO: Shouldn't we install a guard if the symbol is backed? Or is the - # semantics that this is an "unchecked" assert (but it this actually - # something useful? Might be better to restrict only for unbacked - # SymInt). - _constrain_symbol_range( - a.node.shape_env, - a.node.expr, - compiler_min=min, - compiler_max=max, - ) - - -@record_shapeenv_event() -def constrain_unify(a, b): +def constrain_unify(a: torch.SymInt, b: torch.SymInt) -> None: """ Given two SymInts, constrain them so that they must be equal. NB: this will not work with SymInts that represent nontrivial expressions (yet!) """ - # TODO: this does not install a deferred runtime assert yet - - # TODO: Maybe dedupe this with _maybe_guard_rel? if not isinstance(a, SymInt): if not isinstance(b, SymInt): assert a == b + return else: - assert isinstance(b.node.expr, sympy.Symbol), "constraining non-Symbols NYI" shape_env = b.node.shape_env - shape_env.replacements[b.node.expr] = sympy.Integer(a) else: - # TODO: Actually, we can support this as long as one of them is a symbol. - # NB: We can't actually do "unification" as our operators are not - # injective - assert isinstance(a.node.expr, sympy.Symbol), "constraining non-Symbols NYI" shape_env = a.node.shape_env - if not isinstance(b, SymInt): - shape_env.replacements[a.node.expr] = sympy.Integer(b) - else: - assert a.node.shape_env is b.node.shape_env - assert isinstance(b.node.expr, sympy.Symbol), "constraining non-Symbols NYI" - new_var = shape_env._find(a.node.expr) - shape_env.replacements[b.node.expr] = new_var + + shape_env._constrain_unify(a, b) # Assume that a boolean is true for the purposes of subsequent symbolic # reasoning. This will keep track of corresponding runtime checks to verify @@ -2470,6 +2418,78 @@ def _rename_unbacked_to(self, orig_s: sympy.Symbol, new_s: sympy.Symbol): if dest is not None: self._set_replacement(new_s, dest, "rename_unbacked_to_dest") + @record_shapeenv_event() + def _constrain_range_for_size(self, a: sympy.Symbol, min: Optional[int] = None, max: Optional[int] = None): + if min is None: + min = 0 + if max is None: + max = sys.maxsize - 1 + + if max < min: + raise ValueError( + "Maximum value to constrain_as_size can't be less than the specified min value, " + "received min={min} and max={max}" + ) + + self.constrain_symbol_range( + a, + compiler_min=min, + compiler_max=max, + ) + self.size_like.add(a) + + @record_shapeenv_event() + def _constrain_range(self, a: sympy.Expr, min: int, max: int): + if isinstance(a, sympy.Integer): + if not (min <= int(a) <= max): + raise ValueRangeError(f"Invalid value {int(a)} for range [{min}:{max}]") + return + assert isinstance(a, sympy.Symbol), "constraining non-Symbols NYI" + + # TODO: Shouldn't we install a guard if the symbol is backed? Or is the + # semantics that this is an "unchecked" assert (but it this actually + # something useful? Might be better to restrict only for unbacked + # SymInt). + self.constrain_symbol_range( + a, + compiler_min=min, + compiler_max=max, + ) + + @record_shapeenv_event() + def _constrain_unify(self, a, b): + """ + Given two SymInts, constrain them so that they must be equal. NB: + this will not work with SymInts that represent nontrivial expressions + (yet!) + """ + # TODO: this does not install a deferred runtime assert yet + + # TODO: Maybe dedupe this with _maybe_guard_rel? + # Update Feb 2024: this is extra important to do, this doesn't handle + # unbacked replacements properly nor does it generate deferred runtime + # asserts + if not isinstance(a, SymInt): + if not isinstance(b, SymInt): + assert a == b + else: + assert isinstance(b.node.expr, sympy.Symbol), "constraining non-Symbols NYI" + assert b.node.shape_env is self + self.replacements[b.node.expr] = sympy.Integer(a) + else: + # TODO: Actually, we can support this as long as one of them is a symbol. + # NB: We can't actually do "unification" as our operators are not + # injective + assert isinstance(a.node.expr, sympy.Symbol), "constraining non-Symbols NYI" + assert a.node.shape_env is self + if not isinstance(b, SymInt): + self.replacements[a.node.expr] = sympy.Integer(b) + else: + assert a.node.shape_env is b.node.shape_env + assert isinstance(b.node.expr, sympy.Symbol), "constraining non-Symbols NYI" + new_var = self._find(a.node.expr) + self.replacements[b.node.expr] = new_var + def _ignore_fresh_unbacked_symbols_tls(self): return getattr(TLS, "ignore_fresh_unbacked_symbols", False)