From 02093b6c6ae1046368e2500881d0bb5880873386 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Fri, 10 May 2024 18:38:31 -0300 Subject: [PATCH] Keep track of `ViewMeta` with symbolic inputs. (#125876) Fix: #125387 This PR helps keep track of whether an instantiated `ViewMeta` has symbolic values as input or not. This is used for checking whether we use the AOTAutograd `ViewMeta`-replay execution path, e.g. it doesn't support tensors that have `ViewMeta` with symbolic inputs. In summary, the changes are: - Add the field `ViewMeta::has_symbolic_inputs` and make it a required constructor parameter - Add the field `FunctionalTensorWrapper::is_symbolic_` and the method `FunctionalTensorWrapper::maybe_mark_symbolic` - Marks a `FunctionalTensorWrapper` as symbolic iff any of its `ViewMeta` have symbolic inputs - Add the plumbing of `FunctionalTensorWrapper::is_symbolic` to the Python API - Codegen the computation of `ViewMeta::has_symbolic_inputs` for each view operation - Use the AOTAutograd `ViewMeta`-replay path if: - `target_functional_tensor` is not `None`; and - `target_functional_tensor` is not symbolic (instead of using a functorch config) Pull Request resolved: https://github.com/pytorch/pytorch/pull/125876 Approved by: https://github.com/ezyang --- aten/src/ATen/FunctionalStorageImpl.cpp | 2 +- aten/src/ATen/FunctionalStorageImpl.h | 7 +- aten/src/ATen/FunctionalTensorWrapper.cpp | 8 ++- aten/src/ATen/FunctionalTensorWrapper.h | 10 +++ aten/src/ATen/FunctionalizeFallbackKernel.cpp | 8 ++- test/functorch/test_aotdispatch.py | 31 ++++++++- tools/pyi/gen_pyi.py | 3 + .../_aot_autograd/functional_utils.py | 44 +++++------- .../python_torch_functions_manual.cpp | 29 +++++++- torchgen/gen_functionalization_type.py | 67 ++++++++++++++++++- 10 files changed, 171 insertions(+), 38 deletions(-) diff --git a/aten/src/ATen/FunctionalStorageImpl.cpp b/aten/src/ATen/FunctionalStorageImpl.cpp index 6484a6f3fbb2..3275c8f447f7 100644 --- a/aten/src/ATen/FunctionalStorageImpl.cpp +++ b/aten/src/ATen/FunctionalStorageImpl.cpp @@ -10,7 +10,7 @@ namespace at::functionalization { ViewMeta ViewMeta::to_out_idx(int64_t out_idx) { if (out_idx == this->out_index) return *this; - return ViewMeta(forward_fn, reverse_fn, is_multi_output, is_as_strided, out_idx); + return ViewMeta(forward_fn, reverse_fn, has_symbolic_inputs, is_multi_output, is_as_strided, out_idx); } // Note [Functionalization: Alias Removal Part 2] diff --git a/aten/src/ATen/FunctionalStorageImpl.h b/aten/src/ATen/FunctionalStorageImpl.h index 65a76db0b1e6..5d993072473c 100644 --- a/aten/src/ATen/FunctionalStorageImpl.h +++ b/aten/src/ATen/FunctionalStorageImpl.h @@ -31,6 +31,7 @@ struct ViewMeta { ViewMeta( std::function forward, std::function reverse, + bool has_symbolic_inputs, bool is_multi_output = false, bool is_as_strided = false, int64_t out_idx = 0) @@ -38,7 +39,8 @@ struct ViewMeta { reverse_fn(std::move(reverse)), out_index(out_idx), is_multi_output(is_multi_output), - is_as_strided(is_as_strided) {} + is_as_strided(is_as_strided), + has_symbolic_inputs(has_symbolic_inputs) {} std::function forward_fn; std::function reverse_fn; @@ -50,6 +52,9 @@ struct ViewMeta { bool is_as_strided; + // Tells us if this view operation has any symbolic inputs + bool has_symbolic_inputs; + // Returns a copy of the current ViewMeta, if out_idx matches the current // out_index. Otherwise, returns a new ViewMeta with the same forward/reverse // functions, but a new out index. diff --git a/aten/src/ATen/FunctionalTensorWrapper.cpp b/aten/src/ATen/FunctionalTensorWrapper.cpp index e82b697b5900..c9ef28dbf56e 100644 --- a/aten/src/ATen/FunctionalTensorWrapper.cpp +++ b/aten/src/ATen/FunctionalTensorWrapper.cpp @@ -137,7 +137,8 @@ FunctionalTensorWrapper::FunctionalTensorWrapper(const Tensor& view_value, const ), value_(view_value), is_multi_output_view_(base->is_multi_output_view_ || meta.is_multi_output), - was_storage_changed_(base->was_storage_changed_) + was_storage_changed_(base->was_storage_changed_), + is_symbolic_(base->is_symbolic_) { TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(value_)); TORCH_INTERNAL_ASSERT(!value_.key_set().has(c10::DispatchKey::Functionalize)); @@ -147,6 +148,7 @@ FunctionalTensorWrapper::FunctionalTensorWrapper(const Tensor& view_value, const view_metas_ = base->view_metas_; // copy } view_metas_.push_back(meta); + maybe_mark_symbolic(meta); storage_ = base->storage_; // alias this tensor's storage with the base tensor's } @@ -178,6 +180,8 @@ void FunctionalTensorWrapper::mutate_view_meta(const at::functionalization::View view_metas_.push_back(meta); // Manually track the fact that this tensor recieved a metadata mutation! has_metadata_mutation_ = true; + // Mark this tensor as being symbolic if there are any symbolic inputs used by the view operation. + maybe_mark_symbolic(meta); // Note [Functionalization Pass - Inplace View Ops] // So, these ops are special - they're mutation AND view ops. They get special codegen. // An example is transpose_, e.g. `a.transpose_()` @@ -257,6 +261,7 @@ void FunctionalTensorWrapper::set__impl(const FunctionalTensorWrapper* other) { value_ = other->value_; generation_ = other->generation_; view_metas_ = other->view_metas_; + is_symbolic_ = other->is_symbolic_; // FREEZE the old storage, preventing mutations to it. // this is a huge pain to handle properly in all cases, so we ban it. functional_storage_impl()->freeze(); @@ -414,6 +419,7 @@ void FunctionalTensorWrapper::copy_tensor_metadata( dest_impl->has_metadata_mutation_ = src_impl->has_metadata_mutation_; dest_impl->is_multi_output_view_ = src_impl->is_multi_output_view_; dest_impl->was_storage_changed_ = src_impl->was_storage_changed_; + dest_impl->is_symbolic_ = src_impl->is_symbolic_; dest_impl->generation_ = src_impl->generation_; dest_impl->view_metas_ = src_impl->view_metas_; } diff --git a/aten/src/ATen/FunctionalTensorWrapper.h b/aten/src/ATen/FunctionalTensorWrapper.h index 8ae9326c6fe9..95d6afe5f0be 100644 --- a/aten/src/ATen/FunctionalTensorWrapper.h +++ b/aten/src/ATen/FunctionalTensorWrapper.h @@ -97,6 +97,14 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl { ->are_all_mutations_under_no_grad_or_inference_mode(); } + void maybe_mark_symbolic(const functionalization::ViewMeta& meta) { + is_symbolic_ = is_symbolic_ | meta.has_symbolic_inputs; + } + + bool is_symbolic() const { + return is_symbolic_; + } + // Runs the forward_fn of every ViewMeta collected in the current instance // to some other base. Tensor apply_view_metas(const Tensor& base); @@ -250,6 +258,8 @@ struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl { bool is_multi_output_view_ = false; // Did the tensor experience a set_() call. bool was_storage_changed_ = false; + // Did the tensor experience any view operation with symbolic int. + bool is_symbolic_ = false; size_t generation_ = 0; std::vector view_metas_; diff --git a/aten/src/ATen/FunctionalizeFallbackKernel.cpp b/aten/src/ATen/FunctionalizeFallbackKernel.cpp index 4e9d9fe82bda..8b26c875fc02 100644 --- a/aten/src/ATen/FunctionalizeFallbackKernel.cpp +++ b/aten/src/ATen/FunctionalizeFallbackKernel.cpp @@ -178,7 +178,8 @@ static const at::Tensor & resize__functionalization(c10::DispatchKeySet dispatch }, [size = size.vec()](const at::Tensor & base, const at::Tensor & mutated_view, int64_t mutated_view_idx) -> at::Tensor { return base.as_strided_scatter(mutated_view, size, c10::contiguous_strides(size)); - } + }, + /*has_symbolic_inputs=*/false ); at::functionalization::impl::mutate_view_meta(self, view_meta); return self; @@ -298,13 +299,16 @@ static at::Tensor _unsafe_view_functionalize(const at::Tensor & self, at::SymInt tmp_output = at::_unsafe_view_symint(self_, size); } + bool has_symbolic_inputs = std::any_of(size.begin(), size.end(), [=](auto& s) { return s.is_symbolic(); }); + at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta( [size = size.vec()](const at::Tensor & base, int64_t mutated_view_idx) -> at::Tensor { return at::_unsafe_view_symint(base, size); }, [size = size.vec()](const at::Tensor & base, const at::Tensor & mutated_view, int64_t mutated_view_idx) -> at::Tensor { return at::_unsafe_view_symint(mutated_view, base.sym_sizes()); - } + }, + /*has_symbolic_inputs=*/has_symbolic_inputs ); auto out = at::functionalization::impl::create_functional_tensor_with_view_meta(tmp_output, self, std::move(view_meta)); diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index 86d9b88cf75b..5c17b7f84d0d 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -67,6 +67,7 @@ parametrize, run_tests, skipIfRocm, + skipIfTorchDynamo, TestCase, ) from torch.testing._internal.hop_db import hop_db @@ -3475,7 +3476,6 @@ def wrapper(g, *args, **kwargs): return lambda f: aot_function(f, fw_compiler=lambda g, _: partial(wrapper, g)) - @patch("functorch.compile.config.view_replay_for_aliased_outputs", True) def test_output_aliases_input_view_meta_replay(self): @self._compile_and_erase_bases(0) def f(a): @@ -3489,7 +3489,6 @@ def f(a): str(out.grad_fn.__class__), """""" ) - @patch("functorch.compile.config.view_replay_for_aliased_outputs", True) def test_output_aliases_intermediate_view_meta_replay(self): @self._compile_and_erase_bases(0, 1) def f(a): @@ -3509,7 +3508,6 @@ def f(a): str(out2.grad_fn.__class__), """""" ) - @patch("functorch.compile.config.view_replay_for_aliased_outputs", True) def test_output_aliases_output_view_meta_replay(self): @self._compile_and_erase_bases(1) def f(a): @@ -3525,6 +3523,33 @@ def f(a): str(out2.grad_fn.__class__), """""" ) + @skipIfTorchDynamo() + @patch("torch._dynamo.config.assume_static_by_default", False) + def test_dynamic_output_aliases_input_view_meta_replay(self): + # - torch.compile: using it so we can have a SymInt in the FX graph. + # - Compiling with inductor, so that tensor._base isn't tracked. + # + # This should force the use of as_strided in the view reconstruction path. + # The first 2 view-replay paths won't be taken because: + # - target_functional_tensor will be symbolic (_functionalize_is_symbolic call) + # - tensor._base will be None + @torch.compile(backend="inductor") + def f(a, sz): + return a.view(sz), a.view(-1) + + inp = torch.ones(2, 2, requires_grad=True) + out1, out2 = f(inp, (4,)) + + self.assertIsNotNone(out1.grad_fn) + self.assertExpectedInline( + str(out1.grad_fn.__class__), """""" + ) + + self.assertIsNotNone(out2.grad_fn) + self.assertExpectedInline( + str(out2.grad_fn.__class__), """""" + ) + def extract_graph(fx_g, _, graph_cell): graph_cell[0] = fx_g diff --git a/tools/pyi/gen_pyi.py b/tools/pyi/gen_pyi.py index 9f31e2af2e86..d5f6837cba01 100644 --- a/tools/pyi/gen_pyi.py +++ b/tools/pyi/gen_pyi.py @@ -811,6 +811,9 @@ def gen_pyi( "_functionalize_apply_view_metas": [ "def _functionalize_apply_view_metas(tensor: Tensor, base: Tensor) -> Tensor: ..." ], + "_functionalize_is_symbolic": [ + "def _functionalize_is_symbolic(tensor: Tensor) -> _bool: ..." + ], "_enable_functionalization": [ "def _enable_functionalization(*, reapply_views: _bool = False): ..." ], diff --git a/torch/_functorch/_aot_autograd/functional_utils.py b/torch/_functorch/_aot_autograd/functional_utils.py index fd7123c8ada0..2f89180bf218 100644 --- a/torch/_functorch/_aot_autograd/functional_utils.py +++ b/torch/_functorch/_aot_autograd/functional_utils.py @@ -234,39 +234,27 @@ def patch_requires_grad(out): # In summary, we use the fact that FunctionalTensorWrapper saves the view # functions applied to itself (collected during functionalization) so as # to replay them (view functions) on the aliased_base_tensor. - if config.view_replay_for_aliased_outputs and target_functional_tensor is not None: + if ( + config.view_replay_for_aliased_outputs + and target_functional_tensor is not None + and not torch._functionalize_is_symbolic(target_functional_tensor.tensor) + ): from .schemas import FunctionalTensorMetadataEq assert isinstance(target_functional_tensor, FunctionalTensorMetadataEq) functional_tensor = target_functional_tensor.tensor - try: - out = torch._functionalize_apply_view_metas( - functional_tensor, aliased_base_tensor - ) - except RuntimeError as e: - # NYI for dynamic shapes. - # - # On functionalization, the ViewMeta lambdas will have symbolic shapes. - # When trying to apply those lambdas on concrete tensors, it will fail. - # - # In order for this to work, we should have a way to replace those - # symbolic shapes with concrete numbers. - aot_joint_log.info( - "could not reconstruct view by re-applying a ViewMeta sequence. " - "Fallbacking to reconstruction using as_strided. " - "Reason: %s", - str(e), - ) - else: - # If re-applying the ViewMeta sequence succeeded, there should be no more - # problems going forward. We just check we got to the target shape and - # patch requires_grad flag. - assert out.shape == target_meta_tensor.shape, ( - "incorrect out shape after application of ViewMeta sequence: " - f"{tuple(out.shape)} (actual) vs {tuple(target_meta_tensor.shape)} (expected)" - ) - return patch_requires_grad(out) + out = torch._functionalize_apply_view_metas( + functional_tensor, aliased_base_tensor + ) + # If re-applying the ViewMeta sequence succeeded, there should be no more + # problems going forward. We just check we got to the target shape and + # patch requires_grad flag. + assert out.shape == target_meta_tensor.shape, ( + "incorrect out shape after application of ViewMeta sequence: " + f"{tuple(out.shape)} (actual) vs {tuple(target_meta_tensor.shape)} (expected)" + ) + return patch_requires_grad(out) # Try to do view-replay if possible. # fall back to .as_strided() if we can't. diff --git a/torch/csrc/autograd/python_torch_functions_manual.cpp b/torch/csrc/autograd/python_torch_functions_manual.cpp index 4a3b94c36248..1616ac79c403 100644 --- a/torch/csrc/autograd/python_torch_functions_manual.cpp +++ b/torch/csrc/autograd/python_torch_functions_manual.cpp @@ -685,6 +685,29 @@ static PyObject* THPVariable__functionalize_sync( END_HANDLE_TH_ERRORS } +static PyObject* THPVariable__functionalize_is_symbolic( + PyObject* self, + PyObject* args, + PyObject* kwargs) { + HANDLE_TH_ERRORS + static PythonArgParser parser( + {"_functionalize_is_symbolic(Tensor tensor)"}, + /*traceable=*/true); + + ParsedArgs<1> parsed_args; + auto r = parser.parse(args, kwargs, parsed_args); + auto tensor = r.tensor(0); + TORCH_INTERNAL_ASSERT( + at::functionalization::impl::isFunctionalTensor(tensor)); + auto impl = at::functionalization::impl::unsafeGetFunctionalWrapper(tensor); + if (impl->is_symbolic()) { + Py_RETURN_TRUE; + } else { + Py_RETURN_FALSE; + } + END_HANDLE_TH_ERRORS +} + static PyObject* THPVariable__functionalize_apply_view_metas( PyObject* self, PyObject* args, @@ -694,7 +717,7 @@ static PyObject* THPVariable__functionalize_apply_view_metas( {"_functionalize_apply_view_metas(Tensor tensor, Tensor base)"}, /*traceable=*/true); - ParsedArgs<4> parsed_args; + ParsedArgs<2> parsed_args; auto r = parser.parse(args, kwargs, parsed_args); auto tensor = r.tensor(0); TORCH_INTERNAL_ASSERT( @@ -840,6 +863,10 @@ static PyMethodDef torch_functions_manual[] = { castPyCFunctionWithKeywords(THPVariable__functionalize_sync), METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr}, + {"_functionalize_is_symbolic", + castPyCFunctionWithKeywords(THPVariable__functionalize_is_symbolic), + METH_VARARGS | METH_KEYWORDS | METH_STATIC, + nullptr}, {"_functionalize_apply_view_metas", castPyCFunctionWithKeywords(THPVariable__functionalize_apply_view_metas), METH_VARARGS | METH_KEYWORDS | METH_STATIC, diff --git a/torchgen/gen_functionalization_type.py b/torchgen/gen_functionalization_type.py index 92e2ff8ad9e0..8d2c567c3478 100644 --- a/torchgen/gen_functionalization_type.py +++ b/torchgen/gen_functionalization_type.py @@ -11,6 +11,10 @@ FunctionalizationLambda, iTensorListRefT, NativeSignature, + OptionalCType, + optionalSymIntArrayRefT, + symIntArrayRefT, + SymIntT, tensorListT, tensorT, VectorCType, @@ -281,6 +285,58 @@ def is_alias(a: Argument) -> bool: View operators with multiple aliasing inputs aren't supported yet. Found an operator that doesn't satisfy this constraint""" +# One-liner expression for checking if an expression expr of type type has any +# symbolic values. +def emit_expr_has_symbolic_values(expr: str, type: CType) -> str: + if type == BaseCType(SymIntT): + return f"{expr}.is_symbolic()" + + if isinstance(type, OptionalCType): + innerexpr = f"(*{expr})" + return f"{expr}.has_value() ? {emit_expr_has_symbolic_values(innerexpr, type.elem)} : false" + + if type == BaseCType(optionalSymIntArrayRefT): + return emit_expr_has_symbolic_values( + expr, OptionalCType(BaseCType(symIntArrayRefT)) + ) + + if type in (BaseCType(symIntArrayRefT), VectorCType(BaseCType(SymIntT))): + argname = "arg" + lambda_check = emit_expr_has_symbolic_values(argname, BaseCType(SymIntT)) + return ( + "std::any_of(" + f"{expr}.begin(), {expr}.end(), " + f"[=](auto& {argname}) {{ return {lambda_check}; }})" + ) + + raise ValueError( + "unsupported type for has_symbolic_values check. " + "It should be a SymInt or a collection of those. " + f"Got: {type.cpp_type()}" + ) + + +# Detects whether any of the SymInt arguments are, in fact, symbolic values. +# This is used in the constructor of ViewMeta. +def emit_has_symbolic_inputs(sig: DispatcherSignature) -> Tuple[str, str]: + name = "has_symbolic_inputs" + statements = [ + f"{name} = {name} | ({emit_expr_has_symbolic_values(binding.name, binding.nctype.type)});" + for binding in sig.arguments() + if ( + isinstance(binding.argument, Argument) + and binding.argument.type.is_symint_like() + ) + ] + body = "\n ".join(statements) + return ( + name, + f""" + bool {name} = false; + {body}""", + ) + + # Generates the Functionalization kernel for: # - ops that create aliases (e.g. transpose()) # - ops that are views AND mutations (e.g. transpose_()) @@ -334,6 +390,11 @@ def emit_view_functionalization_body( e.expr for e in translate(meta_call_ctx, call_sig.arguments(), method=False) ] + ( + symbolic_inputs_varname, + symbolic_inputs_check, + ) = emit_has_symbolic_inputs(call_sig) + if "inplace_view" in f.tags: # See Note [Functionalization Pass - Inplace View Ops] for more details return f""" @@ -349,6 +410,7 @@ def emit_view_functionalization_body( reapply_views ? at::functionalization::InverseReturnMode::ViewOrScatterInverse : at::functionalization::InverseReturnMode::NeverView ); + {symbolic_inputs_check} at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta( {forward_lambda.decl()} {{ if (reapply_views) {{ @@ -359,7 +421,8 @@ def emit_view_functionalization_body( }}, {reverse_lambda.decl()} {{ return {reverse_lambda.inner_call()} - }} + }}, + /*has_symbolic_inputs=*/{symbolic_inputs_varname} ); auto compute_reference_meta = {view_tensor_name}.key_set().has_backend(c10::BackendComponent::XLABit) || @@ -421,6 +484,7 @@ def emit_view_functionalization_body( tmp_output = at::_ops::{api_name}::call({', '.join(view_redispatch_args)}); }} }} + {symbolic_inputs_check} at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta( {forward_lambda.decl()} {{ if (reapply_views) {{ @@ -432,6 +496,7 @@ def emit_view_functionalization_body( {reverse_lambda.decl()} {{ return {reverse_lambda.inner_call()} }}, + /*has_symbolic_inputs=*/{symbolic_inputs_varname}, /*is_multi_output=*/{str(is_multi_output_view).lower()}, /*is_as_strided=*/{str(str(f.func.name) == 'as_strided').lower()} );