From be86c3858c3d817bba584cfa30104707a75f2275 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Mon, 29 Apr 2024 07:46:35 -0700 Subject: [PATCH] Add propagate_real_tensors mode for unbacked Signed-off-by: Edward Z. Yang ghstack-source-id: 5309d4133d565d2a0dde661c195ba52373d1e9d2 Pull Request resolved: https://github.com/pytorch/pytorch/pull/125115 --- test/test_dynamic_shapes.py | 6 ++ test/test_fake_tensor.py | 75 +++++++++++++---- test/test_proxy_tensor.py | 17 ++++ torch/_functorch/config.py | 30 +++++++ torch/_subclasses/fake_tensor.py | 102 +++++++++++++++++++++-- torch/fx/experimental/symbolic_shapes.py | 59 ++++++++++--- 6 files changed, 253 insertions(+), 36 deletions(-) diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index 284cf85d0103..c3fb76ab2f8a 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -512,6 +512,12 @@ def test_data_dependent_guard(self): s0 = shape_env.create_unbacked_symint() self.assertRaises(GuardOnDataDependentSymNode, lambda: bool(s0 == 0)) + def test_data_dependent_guard_propagate_real_tensors(self): + shape_env = ShapeEnv() + s0 = shape_env.create_unbacked_symint() + shape_env.unbacked_var_to_val[s0.node.expr] = 0 + self.assertEqual(bool(s0 == 0), True) + def test_expect_true_basic(self): shape_env = ShapeEnv() i0 = shape_env.create_unbacked_symint() diff --git a/test/test_fake_tensor.py b/test/test_fake_tensor.py index 1ffb0a6cb3ed..4d6cbd70104f 100644 --- a/test/test_fake_tensor.py +++ b/test/test_fake_tensor.py @@ -6,6 +6,7 @@ instantiate_parametrized_tests, TemporaryFileName) import torch import torch._dynamo +from torch._dynamo.testing import make_test_cls_with_patches import itertools import numpy as np from torch.testing._internal.jit_utils import RUN_CUDA @@ -53,6 +54,10 @@ torch._dynamo.config.fake_tensor_cache_enabled = True torch._dynamo.config.fake_tensor_cache_crosscheck_enabled = True +def expectedFailurePropagateRealTensors(fn): + fn._expected_failure_propagate_real_tensors = True + return fn + class FakeTensorTest(TestCase): def checkType(self, t, device_str, size): self.assertTrue(isinstance(t, FakeTensor)) @@ -207,6 +212,8 @@ def test_fake_dispatch_keys(self): FileCheck().check("CPU").check("AutocastCPU").run(torch._C._dispatch_key_set(y)) FileCheck().check_not("ADInplaceOrView").check_not("Autograd").run(torch._C._dispatch_key_set(y)) + # TODO: functorch support for propagate real tensors + @expectedFailurePropagateRealTensors def test_batch_tensor(self): x = torch.rand((3, 4, 5)) b = _add_batch_dim(x, 0, 0) @@ -392,10 +399,10 @@ def test_out_multi_device(self): x = torch.rand([4]) y = torch.rand([4], device="cuda") - with self.assertRaisesRegex(Exception, "found two different devices"): + with self.assertRaisesRegex(Exception, "found.+two.+devices"): torch.sin(x, out=y) - with self.assertRaisesRegex(Exception, "found two different devices"): + with self.assertRaisesRegex(Exception, "found.+two.+devices"): x.add_(y) @@ -578,6 +585,9 @@ def test_same_shape_env_preserved(self): self.assertIs(t2.size(0).node.shape_env, t1.size(0).node.shape_env) self.assertEqual(str(t2.size(0)), str(t1.size(0))) + # TODO: Support NJT. There's also some funny business with dynamic shapes + # which would need to be dealt with as well + @expectedFailurePropagateRealTensors def test_jagged_fake_to_fake_preserved(self): from torch.nested._internal.nested_tensor import jagged_from_list @@ -736,7 +746,9 @@ def test_aten_index_multi_device(self): x2 = torch.rand(4, 4, device="cuda") i1 = torch.tensor([0, 1], device="cuda") i2 = torch.tensor([0, 1], device="cpu") - r1 = torch.ops.aten.index(x1, i1) + # NB: This one does not work: cuda indices not allowed on cpu + # tensor + # r1 = torch.ops.aten.index(x1, i1) r2 = torch.ops.aten.index(x2, i2) y1 = torch.rand(4, device="cpu") @@ -745,7 +757,7 @@ def test_aten_index_multi_device(self): j2 = torch.tensor([2], device="cpu") r3 = torch.ops.aten.index_put.default(x1, j1, y1) r4 = torch.ops.aten.index_put.default(x2, j2, y2) - self.checkType(r1, "cpu", ()) + # self.checkType(r1, "cpu", ()) self.checkType(r2, "cuda", ()) self.checkType(r3, "cpu", (4, 4)) self.checkType(r4, "cuda", (4, 4)) @@ -785,6 +797,23 @@ def forward(self, input): self.assertTrue(isinstance(ep, torch.export.ExportedProgram)) +instantiate_parametrized_tests(FakeTensorTest) + + +def make_propagate_real_tensors_cls(cls): + cls = make_test_cls_with_patches( + cls, + "PropagateRealTensors", + "_propagate_real_tensors", + (torch._functorch.config, "fake_tensor_propagate_real_tensors", True), + xfail_prop="_expected_failure_propagate_real_tensors", + ) + globals()[cls.__name__] = cls + + +make_propagate_real_tensors_cls(FakeTensorTest) + + class FakeTensorConstHandling(TestCase): def assertConst(self, *args): for arg in args: @@ -875,6 +904,10 @@ def test_constant_propagate_through_functions(self): y = torch.div(4, 4, rounding_mode='trunc') self.assertConst(y) + +make_propagate_real_tensors_cls(FakeTensorConstHandling) + + def contains_type(type: torch._C.Type, maybe_contained_type: torch._C.Type): return maybe_contained_type.isSubtypeOf(type) or any( contains_type(e, maybe_contained_type) for e in type.containedTypes() @@ -891,6 +924,13 @@ def test_fake(self, device, dtype, op): optests.fake_check(op, args, kwargs) +instantiate_device_type_tests(FakeTensorOpInfoTest, globals(), only_for=("cpu", "cuda")) + + +# CPU only for efficiency ig +make_propagate_real_tensors_cls(FakeTensorOpInfoTestCPU) # noqa: F821 + + class FakeTensorConverterTest(TestCase): def test_memoized_conversion_to_meta(self): x = torch.rand(2, 2, 2) @@ -1002,16 +1042,17 @@ def test_no_ref_cycle(self): assert y_weak() is None +make_propagate_real_tensors_cls(FakeTensorConverterTest) + + class FakeTensorOperatorInvariants(TestCase): - @staticmethod - def get_aten_op(schema): + def get_aten_op(self, schema): namespace, name = schema.name.split("::") overload = schema.overload_name if schema.overload_name else "default" assert namespace == "aten" return getattr(getattr(torch.ops.aten, name), overload) - @staticmethod - def get_all_aten_schemas(): + def get_all_aten_schemas(self): for schema in torch._C._jit_get_all_schemas(): namespace = schema.name.split("::")[0] if namespace != "aten": @@ -1162,6 +1203,10 @@ def forward(self, arg1, arg2, arg3): # IMPORTANT!!! Always run even if CUDA is not available def test_fake_cuda_no_init(self): + # Skip this test, we will try to run CUDA operations to real prop so + # it clearly will not work on CPU runner + if torch._functorch.config.fake_tensor_propagate_real_tensors: + return with FakeTensorMode(): torch.empty(10, device='cuda') torch.ones(10, device='cuda') @@ -1220,6 +1265,9 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): self.assertEqual(mode.count, 0) +make_propagate_real_tensors_cls(FakeTensorOperatorInvariants) + + class FakeTensorPropTest(TestCase): def test_fake_tensor_prop_on_nn_module(self): class ToyNnModuleWithParameters(torch.nn.Module): @@ -1305,9 +1353,11 @@ def forward(self, value, another_value=None, another_optional_value=None): FakeTensorProp(graph_model, fake_mode).propagate(value, None, another_optional_value) + @expectedFailurePropagateRealTensors # TODO: not sure about this one, kinda strange def test_unbacked_shape_realloc(self): def f(x): return x.nonzero() + shape_env = ShapeEnv() fake_mode = FakeTensorMode(shape_env=shape_env) with fake_mode: @@ -1352,6 +1402,9 @@ def forward(self, x): torch.load(state_dict_file, map_location="cpu") # scenario 2 +make_propagate_real_tensors_cls(FakeTensorPropTest) + + class FakeTensorSerialization(TestCase): def test_serialization(self): x = torch.tensor([0], device="cpu") @@ -1690,11 +1743,5 @@ def test_inference_mode(self): extract_tensor_metadata(res4), ) - -instantiate_parametrized_tests(FakeTensorTest) - -only_for = ("cpu", "cuda") -instantiate_device_type_tests(FakeTensorOpInfoTest, globals(), only_for=only_for) - if __name__ == "__main__": run_tests() diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index fea87eebaa02..15e56f2d4b09 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -25,6 +25,7 @@ from torch.fx.experimental.proxy_tensor import make_fx, DecompositionInterpreter, get_isolated_graphmodule from torch.utils._pytree import tree_map from torch import nn +import torch._functorch.config import re import functools @@ -1518,6 +1519,22 @@ def f(a): make_fx(f, tracing_mode="symbolic")(torch.randn(4)) + @torch._functorch.config.patch(fake_tensor_propagate_real_tensors=True) + def test_invalidate_nonzero_propagate_real_tensors(self): + def f(a): + b = a.clone() + x = b.nonzero() + x1 = b.nonzero() + x2 = b.nonzero() + assert x1.shape[0] == x2.shape[0] + b.normal_() + y = b.nonzero() + # Because you're not actually going to generate exactly zero with + # normal_ lol + assert x1.shape[0] == y.shape[0] + + make_fx(f, tracing_mode="symbolic")(torch.randn(4)) + def test_sqrt_size(self): def f(a): return a / a.size(-1) ** 0.5 diff --git a/torch/_functorch/config.py b/torch/_functorch/config.py index aa7235034e8c..5749477c6e98 100644 --- a/torch/_functorch/config.py +++ b/torch/_functorch/config.py @@ -106,6 +106,36 @@ # tokens. unlift_effect_tokens = False +# This mode specifies that we should also keep track of the real +# tensor along with the fake tensor, and do real compute. While +# seemingly this eliminates the whole point of fake tensors, there are +# two obvious use cases for it: +# +# 1. When users call item()/other data dependent operations, +# if we propagate_real_tensors we are able to determine what +# the true value is and keep going. +# +# 2. It can be useful for testing, when you want to see if the fake +# and real tensors agree with each other. (Note that there are +# currently known inaccuracies in how we clone real tensors, that +# would have to be tightened up for this to be useful in this +# case.) +# +# Note that fake tensors are typically understood to be cheap to store +# indefinitely, so we tend to hold on to them longer than we would +# hold onto the real tensors. So we also support you explicitly +# deallocating the real tensor associated with a fake tensor, at which +# point we will stop propagating real tensors. +# +# One more thing: when you provide a real tensor to fakeify, we will +# clone it, so that we can safely perform mutations on it if necessary. +# This will increase live memory usage. This could potentially be +# optimized by using COW. We also currently do not faithfully +# maintain autograd metadata on the real tensor; this is fine because +# AOTAutograd will only use the fake tensor to determine leafness/etc +# of tensors in question. +fake_tensor_propagate_real_tensors = False + if TYPE_CHECKING: from torch.utils._config_typing import * # noqa: F401, F403 diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index fdea8a344186..74581cbc1f2d 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -294,6 +294,8 @@ def from_real_tensor( assert not make_constant def mk_fake_tensor(make_meta_t): + from torch._dynamo.utils import clone_input + # NB: don't use in_kernel_invocation_manager. to # ensure FakeTensor can internally do constant computation # as necessary. Invocation manager is "more correct" as @@ -306,7 +308,21 @@ def mk_fake_tensor(make_meta_t): fake_mode, make_meta_t(), existing_device, + # TODO: callback might be used in recursive contexts, in + # which case using t is wrong! BUG! constant=t if make_constant else None, + # TODO: This won't preserve aliasing relationships, so if + # there is mutation you won't see it reflect elsewhere. + # This is fine because propagate_real_tensors isn't + # intended to give you exact results and some inaccuracy + # is OK, although if its use case expands we would want to + # do something similar to meta converter, but poking in + # real tensors at the storage cloning phase + real_tensor=( + (t if make_constant else clone_input(t)) + if fake_mode.propagate_real_tensors + else None + ), ) out = self.meta_converter( @@ -390,6 +406,7 @@ class FakeTensor(torch.Tensor): fake_device: torch.device fake_mode: "FakeTensorMode" constant: Optional[torch.Tensor] + real_tensor: Optional[torch.Tensor] # This memorizes the unbacked SymInt representing the number of nonzero # elements in this tensor. This is helpful if you do something like @@ -478,7 +495,7 @@ def names(self): ) @staticmethod - def __new__(cls, fake_mode, elem, device, constant=None): + def __new__(cls, fake_mode, elem, device, constant=None, real_tensor=None): self = torch.Tensor._make_subclass( cls, elem, @@ -520,6 +537,7 @@ def __new__(cls, fake_mode, elem, device, constant=None): self.fake_device = device # type: ignore[attr-defined] self.fake_mode = fake_mode # type: ignore[attr-defined] self.constant = constant # type: ignore[attr-defined] + self.real_tensor = real_tensor # type: ignore[attr-defined] self._nonzero_memo = None # type: ignore[attr-defined] self._nonzero_memo_vc = None # type: ignore[attr-defined] self._unique_memo = None # type: ignore[attr-defined] @@ -849,11 +867,18 @@ def __init__( import torch._dynamo.config import torch._functorch.config + self.propagate_real_tensors = ( + torch._functorch.config.fake_tensor_propagate_real_tensors + ) + self._allow_unsafe_data_ptr_access = ( torch._functorch.config.fake_tensor_allow_unsafe_data_ptr_access ) self.allow_meta = torch._functorch.config.fake_tensor_allow_meta - self.cache_enabled = torch._dynamo.config.fake_tensor_cache_enabled + self.cache_enabled = ( + torch._dynamo.config.fake_tensor_cache_enabled + and not self.propagate_real_tensors + ) self.cache_crosscheck_enabled = ( torch._dynamo.config.fake_tensor_cache_crosscheck_enabled ) @@ -1427,11 +1452,66 @@ def maybe_to_constant(t): args, kwargs = pytree.tree_unflatten(flat_args, args_spec) self.invalidate_written_to_constants(func, flat_arg_fake_tensors, args, kwargs) + def maybe_to_real_tensor(t): + if isinstance(t, FakeTensor): + return t.real_tensor + elif isinstance(t, SymTypes): + return t.node.pytype( + t.node.expr.xreplace(self.shape_env.var_to_val).xreplace( + self.shape_env.unbacked_var_to_val + ) + ) + else: + return t + + from torch.fx.experimental.symbolic_shapes import ( + free_unbacked_symbols, + SymTypes, + ) + + nil = object() + + real_out = nil + if ( + self.propagate_real_tensors + and all(e.real_tensor is not None for e in flat_arg_fake_tensors) + and + # TODO: Modify this to handle unbacked symbols with real values + not any( + isinstance(a, torch.SymInt) and free_unbacked_symbols(a) + for a in flat_args + ) + ): + real_flat_args = [maybe_to_real_tensor(a) for a in flat_args] + real_args, real_kwargs = pytree.tree_unflatten(real_flat_args, args_spec) + real_out = func(*real_args, **real_kwargs) + + def maybe_propagate_real_tensors(fake_out): + import sympy + + def go(t, real_t): + if isinstance(t, FakeTensor): + # NB: unconditionally overwrite + t.real_tensor = real_t + elif isinstance(t, SymTypes) and free_unbacked_symbols(t): + if isinstance(t.node.expr, sympy.Symbol): + self.shape_env.unbacked_var_to_val[t.node.expr] = real_t + elif prev != real_t: + log.warning( + "propagate_real_tensors mismatch %s != %s", prev, real_t + ) + return t + + if real_out is not nil: + return tree_map(go, fake_out, real_out) + else: + return fake_out + # Try for fastpath if has_symbolic_sizes: fast_impl = get_fast_op_impls().get(func) if fast_impl is not None: - return fast_impl(self, *args, **kwargs) + return maybe_propagate_real_tensors(fast_impl(self, *args, **kwargs)) # If there's a Python meta, prefer that over the decomposition from torch._decomp import meta_table as meta_table @@ -1470,7 +1550,9 @@ def maybe_to_constant(t): and not stride_incorrect_op(func) ): with self: - return func.prim_meta_impl(*args, **kwargs) + return maybe_propagate_real_tensors( + func.prim_meta_impl(*args, **kwargs) + ) # Users can register FakeTensor rules for custom operators # Call them if they exist. @@ -1481,7 +1563,7 @@ def maybe_to_constant(t): ctx = torch._library.abstract_impl.AbstractImplCtx(self, func) with torch._library.abstract_impl.set_ctx_getter(lambda: ctx), self: result = maybe_abstract_impl(*args, **kwargs) - return result + return maybe_propagate_real_tensors(result) # special handling for funcs registered through `register_op_impl`, # e.g., manipulating args on constructor calls to construct meta tensors @@ -1490,7 +1572,7 @@ def maybe_to_constant(t): if run_impl_check(func): op_impl_out = op_impl(self, func, *args, **kwargs) if op_impl_out != NotImplemented: - return op_impl_out + return maybe_propagate_real_tensors(op_impl_out) def maybe_run_unsafe_fallback(error=None): # We infer the meta of a custom ops that return None to just @@ -1508,7 +1590,7 @@ def maybe_run_unsafe_fallback(error=None): # Optimization: If there is no Meta kernel, it takes a surprisingly long # amount of time to catch the NotImplementedError, so we check it here. if not has_meta(func): - return maybe_run_unsafe_fallback() + return maybe_propagate_real_tensors(maybe_run_unsafe_fallback()) # run kernel registered to meta for func, which include # python meta registrations, prims, decomps, and c++ meta fns (structured kernels) @@ -1522,8 +1604,10 @@ def maybe_run_unsafe_fallback(error=None): log.exception("failed while attempting to run meta for %s", func) raise - return self.wrap_meta_outputs_with_default_device_logic( - r, func, flat_args, device=kwargs.get("device") + return maybe_propagate_real_tensors( + self.wrap_meta_outputs_with_default_device_logic( + r, func, flat_args, device=kwargs.get("device") + ) ) # WARNING: DO NOT add any additional namespaces/operators here if they refer to operators diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index 6c17b5d87053..7a81d3520de9 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -515,30 +515,34 @@ def compute_unbacked_bindings(shape_env, example_value, old_example_value=None): fs.clear() def free_unbacked_symbols_with_path( - a, path + a, path, real=None ) -> Dict[sympy.Symbol, pytree.KeyPath]: r = {} if isinstance(a, (tuple, list)): for i in range(len(a)): r.update( free_unbacked_symbols_with_path( - a[i], path + (pytree.SequenceKey(i),) + a[i], path + (pytree.SequenceKey(i),), + real=real[i] if real is not None else None ) ) elif isinstance(a, torch.Tensor): r.update( free_unbacked_symbols_with_path( - a.size(), path + (CallMethodKey("size"),) + a.size(), path + (CallMethodKey("size"),), + real=a.real_tensor.size() if a.real_tensor is not None else None ) ) r.update( free_unbacked_symbols_with_path( - a.stride(), path + (CallMethodKey("stride"),) + a.stride(), path + (CallMethodKey("stride"),), + real=a.real_tensor.stride() if a.real_tensor is not None else None ) ) r.update( free_unbacked_symbols_with_path( - a.storage_offset(), path + (CallMethodKey("storage_offset"),) + a.storage_offset(), path + (CallMethodKey("storage_offset"),), + real=a.real_tensor.storage_offset() if a.real_tensor is not None else None ) ) @@ -550,6 +554,8 @@ def free_unbacked_symbols_with_path( and s in pending ): r[s] = path + if real is not None: + shape_env.unbacked_var_to_val[s] = real pending.remove(s) # When an unbacked SymInt is perfectly divisible by an integer # constant, we replace it with the integer constant to improve @@ -566,6 +572,8 @@ def free_unbacked_symbols_with_path( ): # TODO: DivideByKey needs to test divisibility at runtime! r[s] = path + (DivideByKey(int(lhs)),) + if real is not None: + shape_env.unbacked_var_to_val[s] = real // int(lhs) pending.remove(rhs) # The annoyance here arises from the fact that SymBool is # allocated by allocating a SymInt and then testing if it's equal @@ -579,6 +587,8 @@ def free_unbacked_symbols_with_path( and s.lhs in pending ): r[s.lhs] = path + (ConvertIntKey(),) + if real is not None: + shape_env.unbacked_var_to_val[s] = int(real) pending.remove(s.lhs) return r @@ -592,6 +602,7 @@ def free_unbacked_symbols_with_path( else "" ) ) + # TODO: This is pretty fragile # Normally, the equality test is supposed to be a no-op here, because # you've already called rebind_unbacked first which takes all the old @@ -2173,6 +2184,9 @@ def _init( # Maps symbolic ints to their original concrete values # Currently populated from tensors self.var_to_val: Dict[sympy.Symbol, sympy.Integer] = {} + # Like var_to_val, but only set when propagate_real_tensors is on. + # Used as last resort to avoid GuardOnDataDependent error + self.unbacked_var_to_val: Dict[sympy.Symbol, sympy.Integer] = {} # Maps symbolic ints to their min/max range. These ranges # are conservative: the int MUST fall in the range, but the # range may contain ints which may not actually appear in @@ -2643,7 +2657,7 @@ def _get_key(self): Defines the current "state" of the guards we've accumulated in this ShapeEnv. Determines when we need to invalidate our cache """ - return (len(self.replacements), len(self.divisible), self.num_deferred_runtime_asserts) + return (len(self.replacements), len(self.divisible), self.num_deferred_runtime_asserts, len(self.unbacked_var_to_val)) def _update_version_counter(self): # The shape environment is queried orders of magnitude more often than @@ -4126,6 +4140,13 @@ def size_hint(self, expr: "sympy.Expr", *, allow_none=False): return r if allow_none: return None + + if self.unbacked_var_to_val: + unsound_expr = result_expr.xreplace(self.unbacked_var_to_val) + if not unsound_expr.free_symbols: + log.warning("propagate_real_tensors size_hint(%s) -> %s", expr, unsound_expr) + return unsound_expr + raise self._make_data_dependent_error(result_expr, expr) return result_expr @@ -4627,6 +4648,7 @@ def compute_concrete_val(): assert static_expr == hint, f"{static_expr} != {hint}" return static_expr + concrete_val = None if not (expr.free_symbols <= self.var_to_val.keys()): # TODO: dedupe this with _maybe_evaluate_static # Attempt to eliminate the unbacked SymInt @@ -4640,14 +4662,25 @@ def compute_concrete_val(): size_oblivious=True ) - raise self._make_data_dependent_error( - expr.xreplace(self.var_to_val), - expr, - size_oblivious_result=size_oblivious_result - ) - expr = new_expr + # Last ditch + if ( + self.unbacked_var_to_val and + (unsound_result := orig_expr.xreplace(self.unbacked_var_to_val)) and + not unsound_result.free_symbols + ): + log.warning("propagate_real_tensors evaluate_expr(%s) -> %s", orig_expr, unsound_result) + concrete_val = unsound_result + else: + raise self._make_data_dependent_error( + expr.xreplace(self.var_to_val), + expr, + size_oblivious_result=size_oblivious_result + ) + else: + expr = new_expr - concrete_val = compute_concrete_val() + if concrete_val is None: + concrete_val = compute_concrete_val() self._check_frozen(expr, concrete_val) if (