Skip to content

Commit

Permalink
[fx] Speedup ShapeEnv cache invalidation checks (#112687)
Browse files Browse the repository at this point in the history
This may seem a bit silly but we spend ~5% of compilation on simply checking if the `ShapeEnv` cache has been invalidated. It isn't necessarily slow, but we call it millions of times per compile so everything adds up.

To improve the situation, I've added a version counter to the shape env that gets incremented whenever the cache key changes. This does require a bit of care in `ShapeEnv` that we don't modify the relevant state without calling `self._update_version_counter()`. However, we already have a similar situation for the translation validation feature which requires `_set_replacement` to be called instead of modifying the replacements directly.

Pull Request resolved: #112687
Approved by: https://github.com/ezyang
ghstack dependencies: #112933
  • Loading branch information
peterbell10 authored and pytorchmergebot committed Nov 7, 2023
1 parent 65ecb36 commit 7715b47
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 9 deletions.
1 change: 1 addition & 0 deletions test/dynamo/test_dynamic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def make_dynamic_cls(cls):
(config, "specialize_int", False),
(fx_config, "translation_validation", TEST_Z3),
(fx_config, "check_shape_env_recorded_events", True),
(fx_config, "validate_shape_env_verison_key", True),
xfail_prop="_expected_failure_dynamic",
)

Expand Down
3 changes: 3 additions & 0 deletions torch/fx/experimental/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@
# lists, and incorrectly issue guards.
inject_EVALUATE_EXPR_flip_equality_TESTING_ONLY = False

# [@compile_ignored: debug] Validate that ShapeEnv's version key is updated correctly
validate_shape_env_verison_key = False

from torch.utils._config_module import install_config_module

install_config_module(sys.modules[__name__])
68 changes: 59 additions & 9 deletions torch/fx/experimental/symbolic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,17 +794,43 @@ def _lru_cache(fn, maxsize=None):
constraints we know now (i.e. evaluate_expr)
Use _lru_cache otherwise.
Also note that this depends on _update_version_counter being called on the
shape environment whenever the constraints are updated, otherwise the cache
will not be cleared.
"""
fn_cache = lru_cache(maxsize)(fn)
prior_key = None
prior_version = 0

if config.validate_shape_env_verison_key:
prior_key = None

@functools.wraps(fn)
def wrapper(self, *args, **kwargs):
nonlocal prior_key
if prior_key != self._get_key():
prior_key = self._get_key()
fn_cache.cache_clear()
return fn_cache(self, *args, **kwargs)
@functools.wraps(fn)
def wrapper(self, *args, **kwargs):
nonlocal prior_version, prior_key
if prior_key is None:
prior_key = self._get_key()

if prior_version != self._version_counter:
fn_cache.cache_clear()
prior_version = self._version_counter
prior_key = self._get_key()
else:
assert prior_key == self._get_key(), \
"ShapeEnv cache key changed without version being updated!"

return fn_cache(self, *args, **kwargs)

else:

@functools.wraps(fn)
def wrapper(self, *args, **kwargs):
nonlocal prior_version
if prior_version != self._version_counter:
fn_cache.cache_clear()
prior_version = self._version_counter

return fn_cache(self, *args, **kwargs)

wrapper.cache_clear = fn_cache.cache_clear
wrapper.cache_info = fn_cache.cache_info # type: ignore[attr-defined]
Expand Down Expand Up @@ -1572,6 +1598,10 @@ def _init(
# signpost_event
self.co_fields = co_fields if co_fields else {}

# Version counter used to invalidate cached values
self._prev_cache_key = self._get_key()
self._version_counter = 0

# Cache for FX nodes.
# Maps an already built node a tuple of:
# 1. node's target
Expand Down Expand Up @@ -1622,6 +1652,8 @@ def check_equal(self, other: "ShapeEnv") -> None:
"tracked_fakes",
"events",
"source_name_to_debug_name",
"_prev_cache_key",
"_version_counter",
)

# Mapping of the value of each to-be-compared field into the values that
Expand Down Expand Up @@ -1821,6 +1853,17 @@ def _get_key(self):
"""
return (len(self.replacements), len(self.divisible), self.num_deferred_runtime_asserts)

def _update_version_counter(self):
# The shape environment is queried orders of magnitude more often than
# it is changed, so we summarise the cache key into a linearly
# increasing version counter which is cheaper to check in _lru_cache

# Only update version counter if the state actually changed
cur_key = self._get_key()
if self._prev_cache_key != cur_key:
self._prev_cache_key = cur_key
self._version_counter += 1

def _produce_dyn_sizes(self,
ex_size: Sequence[int],
source: Source,
Expand Down Expand Up @@ -2958,6 +3001,7 @@ def _update_divisible(self):
new_divisible.add(k)

self.divisible = new_divisible
self._update_version_counter()

@_lru_cache
def simplify(self, expr: "sympy.Expr") -> "sympy.Expr":
Expand Down Expand Up @@ -3052,11 +3096,16 @@ def _set_replacement(self, a: "sympy.Symbol", expr: "sympy.Expr") -> None:
self.log.debug("SPECIALIZATION", stack_info=True)
log.info("set_replacement %s = %s", a, expr)
self.replacements[a] = expr
self._update_version_counter()

# When specializing 'a == expr', the equality should be also conveyed to
# Z3, in case an expression uses 'a'.
self._add_target_expr(sympy.Eq(a, expr))

def _add_divisible(self, expr: "sympy.Expr"):
self.divisible.add(expr)
self._update_version_counter()

@_lru_cache
@record_shapeenv_event()
def _find(self, a: "sympy.Symbol") -> "sympy.Expr":
Expand Down Expand Up @@ -3129,7 +3178,7 @@ def _maybe_guard_eq(self, expr: Union["sympy.Eq", "sympy.Ne"], concrete_bool: bo
try:
r = try_solve(expr, mod_expr, floordiv_inequality=False)
if r is not None and r[1] == 0:
self.divisible.add(mod_expr)
self._add_divisible(mod_expr)
except NotImplementedError:
pass
return
Expand Down Expand Up @@ -3408,6 +3457,7 @@ def defer_runtime_assert(self, orig_expr: "sympy.Expr", msg, fx_node=None):
cands = sorted([s for s in expr.free_symbols if s.name.startswith("i")], key=lambda s: int(s.name[1:]))
self.deferred_runtime_asserts.setdefault(cands[-1], []).append(ra)
self.num_deferred_runtime_asserts += 1
self._update_version_counter()
# TODO: refine ranges
# Unfortunately, range refinement is probably going to not
# work most of the time, because we don't support symbols
Expand Down

0 comments on commit 7715b47

Please sign in to comment.