-
Notifications
You must be signed in to change notification settings - Fork 25.1k
[dynamo][aot_autograd] Always create a fresh fake mode for backends in dynamo, policy preservation #113605
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[dynamo][aot_autograd] Always create a fresh fake mode for backends in dynamo, policy preservation #113605
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -47,7 +47,7 @@ | |
|
||
from torch.utils._pytree import PyTree, tree_map | ||
from torch.utils._stats import count, count_label | ||
from torch.utils.weak import WeakIdRef | ||
from torch.utils.weak import WeakIdKeyDictionary, WeakIdRef | ||
|
||
if TYPE_CHECKING: | ||
# Import the following modules during type checking to enable code intelligence features | ||
|
@@ -372,6 +372,7 @@ def mk_fake_tensor(make_meta_t): | |
) | ||
if out is NotImplemented: | ||
raise UnsupportedFakeTensorException("meta converter nyi") | ||
|
||
if make_constant: | ||
self.add_constant_storage_mapping(out) | ||
# NB: meta_converter set the memo | ||
|
@@ -1313,6 +1314,49 @@ def tolist(self): | |
# memory should not significantly increase. | ||
|
||
|
||
# Note - [On fake tensor policy and fresh fake modes for backends] | ||
# | ||
# FakeTensorMode does memoization - this memoization is generally fine, but there are some cases | ||
# where we want to avoid this memoization in favor of producing fake tensors anew. In dynamo, the | ||
# case for when this happens is when we call a backend. All backends are invoked with a fresh fake_mode | ||
# because after dynamo trace, the memoized tensors reflect the state of the fake tensor *at the end* of | ||
# trace, rather than at the beginning. | ||
# Consider a motivating example of .data setting mutating metadata | ||
# | ||
# def foo(x, y): | ||
# x.data = y | ||
# return x | ||
# | ||
# Where x is size([6]) and y is size([3]). If we run foo(x, y), then x.data is size([3]), and the | ||
# fake tensor at the beginning of our backend, as memoized, is size([3]). However, this means that | ||
# the backend sees a tensor of size([3]) which has been resized by the user, and so is not reflective | ||
# of the state of the tensor at the start of trace! In the case of aot_autograd, this causes us to produce | ||
# incorrect code (concretely in this example, a view into x sized at ([3])) | ||
# | ||
# | ||
# If we do not faithfully preserve those policy decisions through to the new fake_mode, we will produce | ||
# fake tensors in a slightly different way, with different dynamic indices, which in turn may create new symbols | ||
# where there should not be any. | ||
# | ||
# The solution, therefore, is a combination of a fresh fake mode for backends, with a shared shape_env | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What would break if we instead maintained a set of Tensors which have had metadata mutated, and if you try to convert a tensor which has a corresponding metadata tensor that has been mutated return a fresh tensor instead ? IMO this would be simpler |
||
# that caches our source->symbol decisions. This ensures that we have a consistent view of the world | ||
# w/r/t tensor shape dynamism, while also preserving other policies like ignoring sublcass. | ||
|
||
# NOTE - an alternative design was considered where we pass a FakificationPolicyStore to the FakeTensorMode | ||
# constructor, allowing us to query that as the source of truth instead of always passing it through the top | ||
# at from_tensor creation time. We rejected this because reconciling arguments that may differ between the | ||
# stored policy and the user provided value for dynamic_dims, constraint_dims, and source is non-trivial. | ||
# Forcing the user to provide the values in a single place moves this decision closer to the callsite, and | ||
# forces the caller of from_tensor to think about what it is they want, instead of relying on some hidden | ||
# internal policy storage. | ||
|
||
|
||
@dataclass | ||
class FakificationPolicy: | ||
source: Source | ||
ignore_subclass: bool = False | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When is this ever |
||
|
||
|
||
class FakeTensorMode(TorchDispatchMode): | ||
def __init__( | ||
self, | ||
|
@@ -1321,6 +1365,7 @@ def __init__( | |
allow_non_fake_inputs=False, | ||
shape_env=None, | ||
static_shapes=None, | ||
policy_cache=None, | ||
): | ||
log.debug("create_mode 0x%x", id(self)) | ||
self.allow_fallback_kernels = allow_fallback_kernels | ||
|
@@ -1365,6 +1410,8 @@ def __init__( | |
# this is an "infra" mode with lower dispatching precedence. | ||
self._mode_key = torch._C._TorchDispatchModeKey.FAKE | ||
|
||
self.policy_cache = WeakIdKeyDictionary() if not policy_cache else policy_cache | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a reason this can't be in the tracing context ? |
||
|
||
# Typically, there is only one fake tensor mode and you test for it by | ||
# doing an isinstance test. However, in some situations, there might be | ||
# TWO fake tensor modes. The canonical example of this is exporting | ||
|
@@ -1819,6 +1866,7 @@ def cpp_meta_supports_symint(self, func): | |
aten.view_as_real.default, | ||
aten.view_as_complex.default, | ||
aten.set_.source_Storage_storage_offset, | ||
aten.set_.source_Tensor, | ||
aten._sparse_coo_tensor_with_dims_and_tensors.default, | ||
] | ||
|
||
|
@@ -1873,7 +1921,7 @@ def from_tensor( | |
dynamic_dims is None | ||
), "cannot set both static_shapes and dynamic_dims" | ||
shape_env = None | ||
return self.fake_tensor_converter( | ||
result = self.fake_tensor_converter( | ||
self, | ||
tensor, | ||
shape_env=shape_env, | ||
|
@@ -1883,6 +1931,10 @@ def from_tensor( | |
constraint_dims=constraint_dims, | ||
memoized_only=memoized_only, | ||
) | ||
self.policy_cache[tensor] = FakificationPolicy( | ||
source=source, ignore_subclass=ignore_subclass | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There needs to be some sort of way to make sure people don't add new arguments to this function without also updating policy if it is necessary. Maybe a comment on the kwargs list is enough |
||
return result | ||
|
||
|
||
# NB: returns fake tensors | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1631,6 +1631,8 @@ def _init( | |
self.fx_node_cache: Dict[Tuple[Callable, Tuple[Any, ...]], torch.fx.Node] = {} | ||
self.source_to_symbol: Dict[str, sympy.Symbol] = {} | ||
|
||
self.source_to_symint_node_cache : Dict[TensorPropertySource, SymInt] = {} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the type annotation and the variable name don't agree There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's not good to store a SymInt here. The problem is that SymInt holds a strong reference to ShapeEnv, so this will cause a cycle that will prevent the ShapeEnv from ever getting deallocated. Thinking about how to deal with constants... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok, let's chat offline :) |
||
|
||
from torch.fx.experimental.validator import translation_validation_enabled | ||
self._translation_validation_enabled = translation_validation_enabled() | ||
|
||
|
@@ -2089,6 +2091,10 @@ def create_symintnode( | |
hint: Optional[int], | ||
source: Optional[Source] = None, | ||
): | ||
source_name = source.name() if source else None | ||
if source_name and source_name in self.source_to_symint_node_cache: | ||
return self.source_to_symint_node_cache[source_name] | ||
|
||
if self._translation_validation_enabled and source is not None: | ||
# Create a new symbol for this source. | ||
symbol = self._create_symbol_for_source(source) | ||
|
@@ -2102,11 +2108,16 @@ def create_symintnode( | |
else: | ||
fx_node = None | ||
|
||
out = None | ||
if isinstance(sym, sympy.Integer): | ||
if hint is not None: | ||
assert int(sym) == hint | ||
return int(sym) | ||
return SymInt(SymNode(sym, self, int, hint, fx_node=fx_node)) | ||
out = int(sym) | ||
else: | ||
out = SymInt(SymNode(sym, self, int, hint, fx_node=fx_node)) | ||
if source_name: | ||
self.source_to_symint_node_cache[source_name] = out | ||
return out | ||
|
||
@record_shapeenv_event() | ||
def create_unspecified_symint_and_symbol(self, value, source, dynamic_dim): | ||
|
@@ -2210,7 +2221,8 @@ def create_symbol( | |
# We don't expect to ever reach here even the user specifies | ||
# dynamic=False, because automatic_dynamic skipped for | ||
# nested tensors. | ||
return sympy.Integer(val) | ||
out = sympy.Integer(val) | ||
return out | ||
|
||
elif dynamic_dim is DimDynamic.DUCK: | ||
# duck_shape can be used to globally turn off duck shaping, even | ||
|
@@ -2445,6 +2457,7 @@ def produce_guards( | |
|
||
symbol_to_source = collections.defaultdict(list) | ||
symbol_to_constraints = collections.defaultdict(set) | ||
|
||
constraint_violations : List[Tuple[bool, Callable[[], str]]] = [] | ||
|
||
def record_constraint_violation(warn_only, debug_name, msg, hint=None): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This needs to be scoped to restore on exit
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(I'm not sure I actually want this logic here at all, but in general, a bare
some_global_state = blah
should set off warning bells in your head)