Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions test/dynamo/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8030,6 +8030,9 @@ def test_shape_env_equal_create_symbolic_sizes_strides_storage_offset(self):
==> source_to_symbol: values don't match.
> Left: {x.size()[0]: x.size()[0], x.size()[1]: x.size()[1], x.storage_offset(): x.storage_offset(), x.stride()[0]: x.stride()[0], x.stride()[1]: x.stride()[1]}
> Right: {}
==> source_to_symint_node_cache: values don't match.
> Left: {x.size()[0]: s0, x.size()[1]: s1, x.storage_offset(): 0, x.stride()[0]: s1, x.stride()[1]: 1}
> Right: {}
==> val_to_var: values don't match.
> Left: {0: 0, 1: 1, 2: s1, 3: s0}
> Right: {0: 0, 1: 1}
Expand Down
4 changes: 4 additions & 0 deletions test/dynamo/test_subclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,10 @@ def test_automatic_dynamic(f, inps, dim_dynamic, exp_frame_count, exp_op_count):
fake_inp = fake_mode.from_tensor(
inp, dynamic_dims=[dim_dynamic for i in range(x.dim())]
)
# Clear the cache, so that we can properly cover the first dynamic compile.
# Alternatively, remove this line and lower the frame counts by 1 below, as the
# first recompile is accounted for via caching.
shape_env.source_to_symint_node_cache.clear()
opt_f(fake_inp)
self.assertEqual(cnt.frame_count, exp_frame_count)
self.assertEqual(cnt.op_count, exp_op_count)
Expand Down
13 changes: 12 additions & 1 deletion torch/_dynamo/backends/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,18 @@ def run_node(self, n: Node) -> Any:
if isinstance(arg, torch.Tensor) and not isinstance(
arg, torch._subclasses.FakeTensor
):
new_args.append(fake_mode.from_tensor(arg))
# See Note - [On fake tensor policy and fresh fake modes for backends]
if arg in fake_mode.policy_cache:
policy = fake_mode.policy_cache[arg]
new_args.append(
fake_mode.from_tensor(
arg,
ignore_subclass=policy.ignore_subclass,
source=policy.source,
)
)
else:
new_args.append(fake_mode.from_tensor(arg))
else:
new_args.append(arg)

Expand Down
7 changes: 7 additions & 0 deletions torch/_dynamo/output_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
Source,
TracingContext,
)
from torch._subclasses.fake_tensor import FakeTensorMode
from torch._utils_internal import signpost_event
from torch.fx.experimental.symbolic_shapes import free_symbols, is_symbolic, ShapeEnv
from torch.utils.weak import WeakIdKeyDictionary
Expand Down Expand Up @@ -1020,6 +1021,12 @@ def compile_and_call_fx_graph(self, tx, rv, root):
"%s", LazyString(lambda: self.get_graph_sizes_log_str(name))
)
self.call_cleanup_hooks()
if not self.export:
prior_fake_mode = self.tracing_context.fake_mode
self.tracing_context.fake_mode = FakeTensorMode(
shape_env=prior_fake_mode.shape_env,
policy_cache=prior_fake_mode.policy_cache,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to be scoped to restore on exit

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I'm not sure I actually want this logic here at all, but in general, a bare some_global_state = blah should set off warning bells in your head)

with self.restore_global_state():
compiled_fn = self.call_user_compiler(gm)
compiled_fn = disable(compiled_fn)
Expand Down
6 changes: 3 additions & 3 deletions torch/_dynamo/symbolic_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,8 @@ def skip(v: VariableTracker):
def replace_all(self, oldvar: VariableTracker, newvar: VariableTracker):
if isinstance(oldvar.mutable_local, side_effects.MutableSideEffects):
newvar = self.output.side_effects.mutation(oldvar, newvar)
elif isinstance(oldvar.mutable_local, side_effects.AttributeMutationExisting):
newvar = oldvar
else:
assert isinstance(oldvar.mutable_local, variables.base.MutableLocal)
newvar = newvar.clone(mutable_local=variables.base.MutableLocal())
Expand Down Expand Up @@ -1917,7 +1919,7 @@ def store_global_weakref(self, name, value):

@property
def fake_mode(self):
return self._fake_mode
return self.output.tracing_context.fake_mode

def find_symbolic_locals_name(self, tensor_variable):
for key, value in self.symbolic_locals.items():
Expand Down Expand Up @@ -1992,8 +1994,6 @@ def __init__(
# Flag to indicate whether tracing is used for export.
self.export = export

self._fake_mode = output.tracing_context.fake_mode

self.current_speculation = None
self.random_calls = []

Expand Down
1 change: 1 addition & 0 deletions torch/_dynamo/variables/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1743,6 +1743,7 @@ def wrap_to_fake_tensor_and_record(
constraint_dims=constraint_dims,
)
)
# See Note - [On fake tensor policy and fresh fake modes for backends]
if is_tensor and not (static_shapes and source.is_nn_module()):
tx.output.tracked_fakes.append(TrackedFake(fake_e, source, constraint_dims))
tx.output.tracked_fakes_id_to_source[id(e)].append(source)
Expand Down
19 changes: 10 additions & 9 deletions torch/_functorch/aot_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
from .partitioners import default_partition
from torch._guards import TracingContext, DuplicateInputs, Source


original_zip = zip

def strict_zip(*iterables, strict=True, **kwargs):
Expand Down Expand Up @@ -4420,14 +4419,16 @@ def convert(idx, x):
if all(isinstance(getattr(x, attr), FakeTensor) for attr in attrs):
assert all(getattr(x, attr).fake_mode is fake_mode for attr in attrs)
return x
# TODO: Ensure that this codepath is never exercised from
# Dynamo
if (
idx < aot_config.num_params_buffers
and config.static_weight_shapes
):
return fake_mode.from_tensor(x, static_shapes=True)
return fake_mode.from_tensor(x, static_shapes=False)
static_shapes = idx < aot_config.num_params_buffers and config.static_weight_shapes
# See Note - [On fake tensor policy and fresh fake modes for backends]
if x in fake_mode.policy_cache:
policy = fake_mode.policy_cache[x]
ignore_subclass = policy.ignore_subclass
source = policy.source
out = fake_mode.from_tensor(x, static_shapes=static_shapes, ignore_subclass=ignore_subclass, source=source)
else:
out = fake_mode.from_tensor(x, static_shapes=static_shapes)
return out

return [convert(idx, x) for idx, x in enumerate(flat_args)]

Expand Down
56 changes: 54 additions & 2 deletions torch/_subclasses/fake_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@

from torch.utils._pytree import PyTree, tree_map
from torch.utils._stats import count, count_label
from torch.utils.weak import WeakIdRef
from torch.utils.weak import WeakIdKeyDictionary, WeakIdRef

if TYPE_CHECKING:
# Import the following modules during type checking to enable code intelligence features
Expand Down Expand Up @@ -372,6 +372,7 @@ def mk_fake_tensor(make_meta_t):
)
if out is NotImplemented:
raise UnsupportedFakeTensorException("meta converter nyi")

if make_constant:
self.add_constant_storage_mapping(out)
# NB: meta_converter set the memo
Expand Down Expand Up @@ -1313,6 +1314,49 @@ def tolist(self):
# memory should not significantly increase.


# Note - [On fake tensor policy and fresh fake modes for backends]
#
# FakeTensorMode does memoization - this memoization is generally fine, but there are some cases
# where we want to avoid this memoization in favor of producing fake tensors anew. In dynamo, the
# case for when this happens is when we call a backend. All backends are invoked with a fresh fake_mode
# because after dynamo trace, the memoized tensors reflect the state of the fake tensor *at the end* of
# trace, rather than at the beginning.
# Consider a motivating example of .data setting mutating metadata
#
# def foo(x, y):
# x.data = y
# return x
#
# Where x is size([6]) and y is size([3]). If we run foo(x, y), then x.data is size([3]), and the
# fake tensor at the beginning of our backend, as memoized, is size([3]). However, this means that
# the backend sees a tensor of size([3]) which has been resized by the user, and so is not reflective
# of the state of the tensor at the start of trace! In the case of aot_autograd, this causes us to produce
# incorrect code (concretely in this example, a view into x sized at ([3]))
#
#
# If we do not faithfully preserve those policy decisions through to the new fake_mode, we will produce
# fake tensors in a slightly different way, with different dynamic indices, which in turn may create new symbols
# where there should not be any.
#
# The solution, therefore, is a combination of a fresh fake mode for backends, with a shared shape_env
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would break if we instead maintained a set of Tensors which have had metadata mutated, and if you try to convert a tensor which has a corresponding metadata tensor that has been mutated return a fresh tensor instead ? IMO this would be simpler

# that caches our source->symbol decisions. This ensures that we have a consistent view of the world
# w/r/t tensor shape dynamism, while also preserving other policies like ignoring sublcass.

# NOTE - an alternative design was considered where we pass a FakificationPolicyStore to the FakeTensorMode
# constructor, allowing us to query that as the source of truth instead of always passing it through the top
# at from_tensor creation time. We rejected this because reconciling arguments that may differ between the
# stored policy and the user provided value for dynamic_dims, constraint_dims, and source is non-trivial.
# Forcing the user to provide the values in a single place moves this decision closer to the callsite, and
# forces the caller of from_tensor to think about what it is they want, instead of relying on some hidden
# internal policy storage.


@dataclass
class FakificationPolicy:
source: Source
ignore_subclass: bool = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When is this ever True



class FakeTensorMode(TorchDispatchMode):
def __init__(
self,
Expand All @@ -1321,6 +1365,7 @@ def __init__(
allow_non_fake_inputs=False,
shape_env=None,
static_shapes=None,
policy_cache=None,
):
log.debug("create_mode 0x%x", id(self))
self.allow_fallback_kernels = allow_fallback_kernels
Expand Down Expand Up @@ -1365,6 +1410,8 @@ def __init__(
# this is an "infra" mode with lower dispatching precedence.
self._mode_key = torch._C._TorchDispatchModeKey.FAKE

self.policy_cache = WeakIdKeyDictionary() if not policy_cache else policy_cache
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason this can't be in the tracing context ?


# Typically, there is only one fake tensor mode and you test for it by
# doing an isinstance test. However, in some situations, there might be
# TWO fake tensor modes. The canonical example of this is exporting
Expand Down Expand Up @@ -1819,6 +1866,7 @@ def cpp_meta_supports_symint(self, func):
aten.view_as_real.default,
aten.view_as_complex.default,
aten.set_.source_Storage_storage_offset,
aten.set_.source_Tensor,
aten._sparse_coo_tensor_with_dims_and_tensors.default,
]

Expand Down Expand Up @@ -1873,7 +1921,7 @@ def from_tensor(
dynamic_dims is None
), "cannot set both static_shapes and dynamic_dims"
shape_env = None
return self.fake_tensor_converter(
result = self.fake_tensor_converter(
self,
tensor,
shape_env=shape_env,
Expand All @@ -1883,6 +1931,10 @@ def from_tensor(
constraint_dims=constraint_dims,
memoized_only=memoized_only,
)
self.policy_cache[tensor] = FakificationPolicy(
source=source, ignore_subclass=ignore_subclass
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There needs to be some sort of way to make sure people don't add new arguments to this function without also updating policy if it is necessary. Maybe a comment on the kwargs list is enough

return result


# NB: returns fake tensors
Expand Down
19 changes: 16 additions & 3 deletions torch/fx/experimental/symbolic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1631,6 +1631,8 @@ def _init(
self.fx_node_cache: Dict[Tuple[Callable, Tuple[Any, ...]], torch.fx.Node] = {}
self.source_to_symbol: Dict[str, sympy.Symbol] = {}

self.source_to_symint_node_cache : Dict[TensorPropertySource, SymInt] = {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the type annotation and the variable name don't agree

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's not good to store a SymInt here. The problem is that SymInt holds a strong reference to ShapeEnv, so this will cause a cycle that will prevent the ShapeEnv from ever getting deallocated. Thinking about how to deal with constants...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, let's chat offline :)


from torch.fx.experimental.validator import translation_validation_enabled
self._translation_validation_enabled = translation_validation_enabled()

Expand Down Expand Up @@ -2089,6 +2091,10 @@ def create_symintnode(
hint: Optional[int],
source: Optional[Source] = None,
):
source_name = source.name() if source else None
if source_name and source_name in self.source_to_symint_node_cache:
return self.source_to_symint_node_cache[source_name]

if self._translation_validation_enabled and source is not None:
# Create a new symbol for this source.
symbol = self._create_symbol_for_source(source)
Expand All @@ -2102,11 +2108,16 @@ def create_symintnode(
else:
fx_node = None

out = None
if isinstance(sym, sympy.Integer):
if hint is not None:
assert int(sym) == hint
return int(sym)
return SymInt(SymNode(sym, self, int, hint, fx_node=fx_node))
out = int(sym)
else:
out = SymInt(SymNode(sym, self, int, hint, fx_node=fx_node))
if source_name:
self.source_to_symint_node_cache[source_name] = out
return out

@record_shapeenv_event()
def create_unspecified_symint_and_symbol(self, value, source, dynamic_dim):
Expand Down Expand Up @@ -2210,7 +2221,8 @@ def create_symbol(
# We don't expect to ever reach here even the user specifies
# dynamic=False, because automatic_dynamic skipped for
# nested tensors.
return sympy.Integer(val)
out = sympy.Integer(val)
return out

elif dynamic_dim is DimDynamic.DUCK:
# duck_shape can be used to globally turn off duck shaping, even
Expand Down Expand Up @@ -2445,6 +2457,7 @@ def produce_guards(

symbol_to_source = collections.defaultdict(list)
symbol_to_constraints = collections.defaultdict(set)

constraint_violations : List[Tuple[bool, Callable[[], str]]] = []

def record_constraint_violation(warn_only, debug_name, msg, hint=None):
Expand Down