From 7715b47f443a7fe64a2f2f73069c93ae5262b159 Mon Sep 17 00:00:00 2001 From: Peter Bell Date: Mon, 6 Nov 2023 16:26:54 +0000 Subject: [PATCH] [fx] Speedup ShapeEnv cache invalidation checks (#112687) This may seem a bit silly but we spend ~5% of compilation on simply checking if the `ShapeEnv` cache has been invalidated. It isn't necessarily slow, but we call it millions of times per compile so everything adds up. To improve the situation, I've added a version counter to the shape env that gets incremented whenever the cache key changes. This does require a bit of care in `ShapeEnv` that we don't modify the relevant state without calling `self._update_version_counter()`. However, we already have a similar situation for the translation validation feature which requires `_set_replacement` to be called instead of modifying the replacements directly. Pull Request resolved: https://github.com/pytorch/pytorch/pull/112687 Approved by: https://github.com/ezyang ghstack dependencies: #112933 --- test/dynamo/test_dynamic_shapes.py | 1 + torch/fx/experimental/_config.py | 3 ++ torch/fx/experimental/symbolic_shapes.py | 68 ++++++++++++++++++++---- 3 files changed, 63 insertions(+), 9 deletions(-) diff --git a/test/dynamo/test_dynamic_shapes.py b/test/dynamo/test_dynamic_shapes.py index f5f2dc9cc11ec..abd0f6678beca 100644 --- a/test/dynamo/test_dynamic_shapes.py +++ b/test/dynamo/test_dynamic_shapes.py @@ -47,6 +47,7 @@ def make_dynamic_cls(cls): (config, "specialize_int", False), (fx_config, "translation_validation", TEST_Z3), (fx_config, "check_shape_env_recorded_events", True), + (fx_config, "validate_shape_env_verison_key", True), xfail_prop="_expected_failure_dynamic", ) diff --git a/torch/fx/experimental/_config.py b/torch/fx/experimental/_config.py index 1d13c4062bbb8..0f2bffa475f48 100644 --- a/torch/fx/experimental/_config.py +++ b/torch/fx/experimental/_config.py @@ -34,6 +34,9 @@ # lists, and incorrectly issue guards. inject_EVALUATE_EXPR_flip_equality_TESTING_ONLY = False +# [@compile_ignored: debug] Validate that ShapeEnv's version key is updated correctly +validate_shape_env_verison_key = False + from torch.utils._config_module import install_config_module install_config_module(sys.modules[__name__]) diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index 4e8ccdb10ff90..b9d0ecf576e89 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -794,17 +794,43 @@ def _lru_cache(fn, maxsize=None): constraints we know now (i.e. evaluate_expr) Use _lru_cache otherwise. + + Also note that this depends on _update_version_counter being called on the + shape environment whenever the constraints are updated, otherwise the cache + will not be cleared. """ fn_cache = lru_cache(maxsize)(fn) - prior_key = None + prior_version = 0 + + if config.validate_shape_env_verison_key: + prior_key = None - @functools.wraps(fn) - def wrapper(self, *args, **kwargs): - nonlocal prior_key - if prior_key != self._get_key(): - prior_key = self._get_key() - fn_cache.cache_clear() - return fn_cache(self, *args, **kwargs) + @functools.wraps(fn) + def wrapper(self, *args, **kwargs): + nonlocal prior_version, prior_key + if prior_key is None: + prior_key = self._get_key() + + if prior_version != self._version_counter: + fn_cache.cache_clear() + prior_version = self._version_counter + prior_key = self._get_key() + else: + assert prior_key == self._get_key(), \ + "ShapeEnv cache key changed without version being updated!" + + return fn_cache(self, *args, **kwargs) + + else: + + @functools.wraps(fn) + def wrapper(self, *args, **kwargs): + nonlocal prior_version + if prior_version != self._version_counter: + fn_cache.cache_clear() + prior_version = self._version_counter + + return fn_cache(self, *args, **kwargs) wrapper.cache_clear = fn_cache.cache_clear wrapper.cache_info = fn_cache.cache_info # type: ignore[attr-defined] @@ -1572,6 +1598,10 @@ def _init( # signpost_event self.co_fields = co_fields if co_fields else {} + # Version counter used to invalidate cached values + self._prev_cache_key = self._get_key() + self._version_counter = 0 + # Cache for FX nodes. # Maps an already built node a tuple of: # 1. node's target @@ -1622,6 +1652,8 @@ def check_equal(self, other: "ShapeEnv") -> None: "tracked_fakes", "events", "source_name_to_debug_name", + "_prev_cache_key", + "_version_counter", ) # Mapping of the value of each to-be-compared field into the values that @@ -1821,6 +1853,17 @@ def _get_key(self): """ return (len(self.replacements), len(self.divisible), self.num_deferred_runtime_asserts) + def _update_version_counter(self): + # The shape environment is queried orders of magnitude more often than + # it is changed, so we summarise the cache key into a linearly + # increasing version counter which is cheaper to check in _lru_cache + + # Only update version counter if the state actually changed + cur_key = self._get_key() + if self._prev_cache_key != cur_key: + self._prev_cache_key = cur_key + self._version_counter += 1 + def _produce_dyn_sizes(self, ex_size: Sequence[int], source: Source, @@ -2958,6 +3001,7 @@ def _update_divisible(self): new_divisible.add(k) self.divisible = new_divisible + self._update_version_counter() @_lru_cache def simplify(self, expr: "sympy.Expr") -> "sympy.Expr": @@ -3052,11 +3096,16 @@ def _set_replacement(self, a: "sympy.Symbol", expr: "sympy.Expr") -> None: self.log.debug("SPECIALIZATION", stack_info=True) log.info("set_replacement %s = %s", a, expr) self.replacements[a] = expr + self._update_version_counter() # When specializing 'a == expr', the equality should be also conveyed to # Z3, in case an expression uses 'a'. self._add_target_expr(sympy.Eq(a, expr)) + def _add_divisible(self, expr: "sympy.Expr"): + self.divisible.add(expr) + self._update_version_counter() + @_lru_cache @record_shapeenv_event() def _find(self, a: "sympy.Symbol") -> "sympy.Expr": @@ -3129,7 +3178,7 @@ def _maybe_guard_eq(self, expr: Union["sympy.Eq", "sympy.Ne"], concrete_bool: bo try: r = try_solve(expr, mod_expr, floordiv_inequality=False) if r is not None and r[1] == 0: - self.divisible.add(mod_expr) + self._add_divisible(mod_expr) except NotImplementedError: pass return @@ -3408,6 +3457,7 @@ def defer_runtime_assert(self, orig_expr: "sympy.Expr", msg, fx_node=None): cands = sorted([s for s in expr.free_symbols if s.name.startswith("i")], key=lambda s: int(s.name[1:])) self.deferred_runtime_asserts.setdefault(cands[-1], []).append(ra) self.num_deferred_runtime_asserts += 1 + self._update_version_counter() # TODO: refine ranges # Unfortunately, range refinement is probably going to not # work most of the time, because we don't support symbols