diff --git a/aten/src/ATen/FunctionalInverses.cpp b/aten/src/ATen/FunctionalInverses.cpp index af0e5af3be8d..116c23fc2ddb 100644 --- a/aten/src/ATen/FunctionalInverses.cpp +++ b/aten/src/ATen/FunctionalInverses.cpp @@ -12,27 +12,27 @@ namespace at::functionalization { // We can't easily share it though, because (eventually) these functions // will all call `permute/unsqueeze_copy()` instead of `permute/unsqueeze`. -static Tensor permute_copy_inverse(const Tensor& self, IntArrayRef dims, bool reapply_views) { +static Tensor permute_copy_inverse(const Tensor& self, IntArrayRef dims, InverseReturnMode inverse_return_mode) { // invert the permutation auto ndims = dims.size(); std::vector dims_(ndims); for(const auto i : c10::irange(ndims)) { dims_[at::maybe_wrap_dim(dims[i], ndims)] = i; } - if (reapply_views) { + if (inverse_return_mode != InverseReturnMode::NeverView) { return at::permute(self, dims_); } else { return at::permute_copy(self, dims_); } } -static Tensor unsqueeze_copy_to(const Tensor & self, c10::SymIntArrayRef sizes, bool reapply_views) { +static Tensor unsqueeze_copy_to(const Tensor & self, c10::SymIntArrayRef sizes, InverseReturnMode inverse_return_mode) { auto result = self; int64_t nDims = sizes.size(); for(const auto dim : c10::irange(nDims)) { if (sizes[dim] == 1) { - if (reapply_views) { + if (inverse_return_mode != InverseReturnMode::NeverView) { result = at::unsqueeze(result, dim); } else { result = at::unsqueeze_copy(result, dim); @@ -42,7 +42,7 @@ static Tensor unsqueeze_copy_to(const Tensor & self, c10::SymIntArrayRef sizes, return result; } -static Tensor unsqueeze_copy_to(const Tensor & self, IntArrayRef dim, c10::SymIntArrayRef sizes, bool reapply_views) { +static Tensor unsqueeze_copy_to(const Tensor & self, IntArrayRef dim, c10::SymIntArrayRef sizes, InverseReturnMode inverse_return_mode) { const auto ndim = sizes.size(); const auto mask = at::dim_list_to_bitset(dim, ndim); // in NumPy it's not an error to unsqueeze a scalar, but we still need to avoided @@ -54,7 +54,7 @@ static Tensor unsqueeze_copy_to(const Tensor & self, IntArrayRef dim, c10::SymIn Tensor result = self; for (const auto d : c10::irange(ndim)) { if (mask.test(d) && sizes[d] == 1) { - if (reapply_views) { + if (inverse_return_mode != InverseReturnMode::NeverView) { result = at::unsqueeze(result, d); } else { result = at::unsqueeze_copy(result, d); @@ -95,220 +95,273 @@ static Tensor unsqueeze_copy_to(const Tensor & self, IntArrayRef dim, c10::SymIn // The codegen automatically generates the corresponding function declaration. // ---------------------------------------------------------- -Tensor FunctionalInverses::_fw_primal_copy_inverse(const at::Tensor& base, const at::Tensor& mutated_view, bool reapply_views, int64_t level) { +Tensor FunctionalInverses::_fw_primal_copy_inverse(const at::Tensor& base, const at::Tensor& mutated_view, InverseReturnMode inverse_return_mode, int64_t level) { TORCH_INTERNAL_ASSERT(false, "Attempted to call _fw_primal() during the functionalization pass. For now, this is not supported."); return Tensor(); } -Tensor FunctionalInverses::_make_dual_copy_inverse(const at::Tensor& base, const at::Tensor& mutated_view, bool reapply_views, const at::Tensor& tangent, int64_t level) { +Tensor FunctionalInverses::_make_dual_copy_inverse(const at::Tensor& base, const at::Tensor& mutated_view, InverseReturnMode inverse_return_mode, const at::Tensor& tangent, int64_t level) { TORCH_INTERNAL_ASSERT(false, "Attempted to call _make_dual() during the functionalization pass. For now, this is not supported."); return Tensor(); } -Tensor FunctionalInverses::view_as_real_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views) { - if (reapply_views) { +Tensor FunctionalInverses::view_as_real_copy_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode) { + if (inverse_return_mode != InverseReturnMode::NeverView) { return at::view_as_complex(mutated_view); } else { return at::view_as_complex_copy(mutated_view); } } -Tensor FunctionalInverses::view_as_complex_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views) { - if (reapply_views) { +Tensor FunctionalInverses::view_as_complex_copy_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode) { + if (inverse_return_mode != InverseReturnMode::NeverView) { return at::view_as_real(mutated_view.resolve_conj()); } else { return at::view_as_real_copy(mutated_view.resolve_conj()); } } -Tensor FunctionalInverses::_conj_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views) { - if (reapply_views) { +Tensor FunctionalInverses::_conj_copy_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode) { + if (inverse_return_mode != InverseReturnMode::NeverView) { return at::_conj(mutated_view); } else { return at::_conj_copy(mutated_view); } } -Tensor FunctionalInverses::_neg_view_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views) { - if (reapply_views) { +Tensor FunctionalInverses::_neg_view_copy_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode) { + if (inverse_return_mode != InverseReturnMode::NeverView) { return at::_neg_view(mutated_view); } else { return at::_neg_view_copy(mutated_view); } } -Tensor FunctionalInverses::as_strided_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, at::SymIntArrayRef size, at::SymIntArrayRef stride, c10::optional storage_offset) { - // Pessimism: we can't reapply views for as_strided_scatter. - return base.as_strided_scatter_symint(mutated_view, size, stride, std::move(storage_offset)); +Tensor FunctionalInverses::as_strided_copy_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, at::SymIntArrayRef size, at::SymIntArrayRef stride, c10::optional storage_offset) { + if (inverse_return_mode == InverseReturnMode::AlwaysView) { + // NB: assumes mutated_view is a narrowed view of base. + // We should NOT do this for functionalization + return mutated_view.as_strided_symint( + base.sym_sizes(), base.sym_strides(), base.sym_storage_offset()); + } else { + return base.as_strided_scatter_symint(mutated_view, size, stride, std::move(storage_offset)); + } } -Tensor FunctionalInverses::diagonal_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, int64_t offset, int64_t dim1, int64_t dim2) { - // Pessimism: we can't reapply views for slice_scatter. - return base.diagonal_scatter(mutated_view, offset, dim1, dim2); +Tensor FunctionalInverses::diagonal_copy_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, int64_t offset, int64_t dim1, int64_t dim2) { + if (inverse_return_mode == InverseReturnMode::AlwaysView) { + // NB: assumes mutated_view is a narrowed view of base. + // We should NOT do this for functionalization + return mutated_view.as_strided_symint( + base.sym_sizes(), base.sym_strides(), base.sym_storage_offset()); + } else { + return base.diagonal_scatter(mutated_view, offset, dim1, dim2); + } } -Tensor FunctionalInverses::expand_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, at::SymIntArrayRef size, bool implicit) { - return at::sum_to(mutated_view, base.sym_sizes(),/*always_return_non_view=*/!reapply_views); +Tensor FunctionalInverses::expand_copy_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, at::SymIntArrayRef size, bool implicit) { + if (inverse_return_mode == InverseReturnMode::AlwaysView) { + // NB: assumes mutated_view is an expanded view of base. + // We should NOT do this for functionalization + return mutated_view.as_strided_symint( + base.sym_sizes(), base.sym_strides(), base.sym_storage_offset()); + } else { + return at::sum_to( + mutated_view, + base.sym_sizes(), + /*always_return_non_view=*/inverse_return_mode == InverseReturnMode::NeverView + ); + } } -Tensor FunctionalInverses::permute_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, at::IntArrayRef dims) { - return at::functionalization::permute_copy_inverse(mutated_view, dims, reapply_views); +Tensor FunctionalInverses::permute_copy_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, at::IntArrayRef dims) { + return at::functionalization::permute_copy_inverse(mutated_view, dims, inverse_return_mode); } -Tensor FunctionalInverses::_reshape_alias_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, at::SymIntArrayRef size, at::SymIntArrayRef stride) { +Tensor FunctionalInverses::_reshape_alias_copy_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, at::SymIntArrayRef size, at::SymIntArrayRef stride) { // Note that I'm directly calling reshape(), and ignoring the strides. // _reshape_alias() isn't available from user code, and is an implementation detail of reshape(). // Specifically, passing in the strides directly can get us into trouble in cases like: // b = a[0]; c = b.reshape(...); c.add_(1); print(a) // When we eventually run the _reshape_alias_inverse() call here, if we were to pass in both sizes and strides, // The call would fail because `mutated_view` doesn't have enough bytes of storage. - if (reapply_views) { + if (inverse_return_mode != InverseReturnMode::NeverView) { return at::_reshape_alias_symint(mutated_view, base.sym_sizes(), base.sym_strides()); } else { return at::_reshape_alias_copy_symint(mutated_view, base.sym_sizes(), base.sym_strides()); } } -Tensor FunctionalInverses::select_copy_int_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, int64_t dim, c10::SymInt index) { - // Pessimism: we can't reapply views for slice_scatter. - return base.select_scatter_symint(mutated_view, dim, std::move(index)); +Tensor FunctionalInverses::select_copy_int_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, int64_t dim, c10::SymInt index) { + if (inverse_return_mode == InverseReturnMode::AlwaysView) { + // NB: assumes mutated_view is a narrowed view of base. + // We should NOT do this for functionalization + return mutated_view.as_strided_symint( + base.sym_sizes(), base.sym_strides(), base.sym_storage_offset()); + } else { + return base.select_scatter_symint(mutated_view, dim, std::move(index)); + } } -Tensor FunctionalInverses::detach_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views) { +Tensor FunctionalInverses::detach_copy_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode) { // the functionalization pass doesn't care about autograd metadata - as a view, I think detach() is just an identity function return mutated_view; } -Tensor FunctionalInverses::lift_fresh_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views) { +Tensor FunctionalInverses::lift_fresh_copy_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode) { return mutated_view; } -Tensor FunctionalInverses::slice_copy_Tensor_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, int64_t dim, c10::optional start, c10::optional end, c10::SymInt step) { - // Pessimism: we can't reapply views for slice_scatter. - return base.slice_scatter_symint(mutated_view, dim, std::move(start), std::move(end), std::move(step)); +Tensor FunctionalInverses::slice_copy_Tensor_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, int64_t dim, c10::optional start, c10::optional end, c10::SymInt step) { + if (inverse_return_mode == InverseReturnMode::AlwaysView) { + // NB: assumes mutated_view is a narrowed view of base. + // We should NOT do this for functionalization + return mutated_view.as_strided_symint( + base.sym_sizes(), base.sym_strides(), base.sym_storage_offset()); + } else { + return base.slice_scatter_symint(mutated_view, dim, std::move(start), std::move(end), std::move(step)); + } } -Tensor FunctionalInverses::split_copy_Tensor_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, int64_t mutated_view_idx, c10::SymInt split_size, int64_t dim) { - // It would be nice if this logic could be re-used from autograd's split_backward(), but I don't think it can. - // For functionalization, we have only have one of the tensors from the TensorList outputed by split(), and we want to layer i - // on top of the base tensor. - // For autograd, we have all of the tensors outputted by split() and we just want to stack them. - dim = at::maybe_wrap_dim(dim, base.dim()); - auto dim_size = base.sym_size(dim); - auto start = split_size * mutated_view_idx; - auto end = split_size + start; - if (end > dim_size) end = dim_size; - // Pessimism: we can't reapply views for slice_scatter. - return base.slice_scatter_symint(mutated_view, dim, start, end, 1); +Tensor FunctionalInverses::split_copy_Tensor_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, int64_t mutated_view_idx, c10::SymInt split_size, int64_t dim) { + if (inverse_return_mode == InverseReturnMode::AlwaysView) { + // NB: assumes mutated_view is a narrowed view of base. + // We should NOT do this for functionalization + return mutated_view.as_strided_symint( + base.sym_sizes(), base.sym_strides(), base.sym_storage_offset()); + } else { + // It would be nice if this logic could be re-used from autograd's split_backward(), but I don't think it can. + // For functionalization, we have only have one of the tensors from the TensorList outputed by split(), and we want to layer i + // on top of the base tensor. + // For autograd, we have all of the tensors outputted by split() and we just want to stack them. + dim = at::maybe_wrap_dim(dim, base.dim()); + auto dim_size = base.sym_size(dim); + auto start = split_size * mutated_view_idx; + auto end = split_size + start; + if (end > dim_size) end = dim_size; + return base.slice_scatter_symint(mutated_view, dim, start, end, 1); + } } -Tensor FunctionalInverses::split_with_sizes_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, int64_t mutated_view_idx, c10::SymIntArrayRef split_sizes, int64_t dim) { - dim = at::maybe_wrap_dim(dim, base.dim()); - auto dim_size = base.sym_size(dim); - c10::SymInt start = 0; - for (auto i = 0; i < mutated_view_idx; ++i) { - start += split_sizes[i]; +Tensor FunctionalInverses::split_with_sizes_copy_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, int64_t mutated_view_idx, c10::SymIntArrayRef split_sizes, int64_t dim) { + if (inverse_return_mode == InverseReturnMode::AlwaysView) { + // NB: assumes mutated_view is a narrowed view of base. + // We should NOT do this for functionalization + return mutated_view.as_strided_symint( + base.sym_sizes(), base.sym_strides(), base.sym_storage_offset()); + } else { + dim = at::maybe_wrap_dim(dim, base.dim()); + auto dim_size = base.sym_size(dim); + c10::SymInt start = 0; + for (auto i = 0; i < mutated_view_idx; ++i) { + start += split_sizes[i]; + } + auto end = start + split_sizes[mutated_view_idx]; + if (end > dim_size) end = dim_size; + return base.slice_scatter_symint(mutated_view, dim, start, end, 1); } - auto end = start + split_sizes[mutated_view_idx]; - if (end > dim_size) end = dim_size; - // Pessimism: we can't reapply views for slice_scatter. - return base.slice_scatter_symint(mutated_view, dim, start, end, 1); } -Tensor FunctionalInverses::squeeze_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views) { - return unsqueeze_copy_to(mutated_view, base.sym_sizes(), reapply_views); +Tensor FunctionalInverses::squeeze_copy_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode) { + return unsqueeze_copy_to(mutated_view, base.sym_sizes(), inverse_return_mode); } -Tensor FunctionalInverses::squeeze_copy_dim_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, int64_t dim) { - return unsqueeze_copy_to(mutated_view, dim, base.sym_sizes(), reapply_views); +Tensor FunctionalInverses::squeeze_copy_dim_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, int64_t dim) { + return unsqueeze_copy_to(mutated_view, dim, base.sym_sizes(), inverse_return_mode); } -Tensor FunctionalInverses::squeeze_copy_dims_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, IntArrayRef dim) { - return unsqueeze_copy_to(mutated_view, dim, base.sym_sizes(), reapply_views); +Tensor FunctionalInverses::squeeze_copy_dims_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, IntArrayRef dim) { + return unsqueeze_copy_to(mutated_view, dim, base.sym_sizes(), inverse_return_mode); } -Tensor FunctionalInverses::t_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views) { - if (reapply_views) { +Tensor FunctionalInverses::t_copy_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode) { + if (inverse_return_mode != InverseReturnMode::NeverView) { return at::t(mutated_view); } else { return at::t_copy(mutated_view); } } -Tensor FunctionalInverses::transpose_copy_int_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, int64_t dim0, int64_t dim1) { - if (reapply_views) { +Tensor FunctionalInverses::transpose_copy_int_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, int64_t dim0, int64_t dim1) { + if (inverse_return_mode != InverseReturnMode::NeverView) { return transpose(mutated_view, dim0, dim1); } else { return transpose_copy(mutated_view, dim0, dim1); } } -Tensor FunctionalInverses::_nested_view_from_buffer_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, const Tensor& nested_sizes, const Tensor& nested_strides, const Tensor& storage_offsets) { +Tensor FunctionalInverses::_nested_view_from_buffer_copy_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, const Tensor& nested_sizes, const Tensor& nested_strides, const Tensor& storage_offsets) { TORCH_INTERNAL_ASSERT(false, "Attempted to call _nested_view_from_buffer() during the functionalization pass. For now, nested tensors aren't supported during functionalization"); return Tensor(); } -Tensor FunctionalInverses::unsqueeze_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, int64_t dim) { - if (reapply_views) { +Tensor FunctionalInverses::unsqueeze_copy_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, int64_t dim) { + if (inverse_return_mode != InverseReturnMode::NeverView) { return at::squeeze(mutated_view, dim); } else { return at::squeeze_copy(mutated_view, dim); } } -Tensor FunctionalInverses::_indices_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views) { +Tensor FunctionalInverses::_indices_copy_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode) { TORCH_INTERNAL_ASSERT(false, "Attempted to call _indices() during the functionalization pass. For now, sparse tensors aren't supported during functionalization"); return Tensor(); } -Tensor FunctionalInverses::_values_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views) { +Tensor FunctionalInverses::_values_copy_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode) { TORCH_INTERNAL_ASSERT(false, "Attempted to call _values() during the functionalization pass. For now, sparse tensors aren't supported during functionalization"); return Tensor(); } -Tensor FunctionalInverses::indices_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views) { +Tensor FunctionalInverses::indices_copy_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode) { TORCH_INTERNAL_ASSERT(false, "Attempted to call indices() during the functionalization pass. For now, sparse tensors aren't supported during functionalization"); return Tensor(); } -Tensor FunctionalInverses::values_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views) { +Tensor FunctionalInverses::values_copy_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode) { TORCH_INTERNAL_ASSERT(false, "Attempted to call values() during the functionalization pass. For now, sparse tensors aren't supported during functionalization"); return Tensor(); } -Tensor FunctionalInverses::_sparse_broadcast_to_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, at::IntArrayRef size) { +Tensor FunctionalInverses::_sparse_broadcast_to_copy_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, at::IntArrayRef size) { TORCH_INTERNAL_ASSERT(false, "Attempted to call _sparse_broadcast_to() during the functionalization pass. For now, sparse tensors aren't supported during functionalization"); return Tensor(); } -Tensor FunctionalInverses::crow_indices_copy_inverse(const at::Tensor& base, const at::Tensor& mutated_view, bool reapply_views) { +Tensor FunctionalInverses::crow_indices_copy_inverse(const at::Tensor& base, const at::Tensor& mutated_view, InverseReturnMode inverse_return_mode) { TORCH_INTERNAL_ASSERT(false, "Attempted to call crow_indices() during the functionalization pass. For now, sparse tensors aren't supported during functionalization"); return Tensor(); } -Tensor FunctionalInverses::col_indices_copy_inverse(const at::Tensor& base, const at::Tensor& mutated_view, bool reapply_views) { +Tensor FunctionalInverses::col_indices_copy_inverse(const at::Tensor& base, const at::Tensor& mutated_view, InverseReturnMode inverse_return_mode) { TORCH_INTERNAL_ASSERT(false, "Attempted to call col_indices() during the functionalization pass. For now, sparse tensors aren't supported during functionalization"); return Tensor(); } -Tensor FunctionalInverses::ccol_indices_copy_inverse(const at::Tensor& base, const at::Tensor& mutated_view, bool reapply_views) { +Tensor FunctionalInverses::ccol_indices_copy_inverse(const at::Tensor& base, const at::Tensor& mutated_view, InverseReturnMode inverse_return_mode) { TORCH_INTERNAL_ASSERT(false, "Attempted to call ccol_indices() during the functionalization pass. For now, sparse tensors aren't supported during functionalization"); return Tensor(); } -Tensor FunctionalInverses::row_indices_copy_inverse(const at::Tensor& base, const at::Tensor& mutated_view, bool reapply_views) { +Tensor FunctionalInverses::row_indices_copy_inverse(const at::Tensor& base, const at::Tensor& mutated_view, InverseReturnMode inverse_return_mode) { TORCH_INTERNAL_ASSERT(false, "Attempted to call row_indices() during the functionalization pass. For now, sparse tensors aren't supported during functionalization"); return Tensor(); } -Tensor FunctionalInverses::unbind_copy_int_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, int64_t mutated_view_idx, int64_t dim) { - dim = at::maybe_wrap_dim(dim, base.sizes().size()); - // Pessimism: we can't reapply views for select_scatter. - return base.select_scatter(mutated_view, dim, mutated_view_idx); +Tensor FunctionalInverses::unbind_copy_int_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, int64_t mutated_view_idx, int64_t dim) { + if (inverse_return_mode == InverseReturnMode::AlwaysView) { + // NB: assumes mutated_view is a narrowed view of base. + // We should NOT do this for functionalization + return mutated_view.as_strided_symint( + base.sym_sizes(), base.sym_strides(), base.sym_storage_offset()); + } else { + dim = at::maybe_wrap_dim(dim, base.sizes().size()); + return base.select_scatter(mutated_view, dim, mutated_view_idx); + } } -Tensor FunctionalInverses::view_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, at::SymIntArrayRef size) { - if (reapply_views) { +Tensor FunctionalInverses::view_copy_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, at::SymIntArrayRef size) { + if (inverse_return_mode != InverseReturnMode::NeverView) { return mutated_view.view_symint(base.sym_sizes()); } else { return at::view_copy_symint(mutated_view, base.sym_sizes()); @@ -316,31 +369,38 @@ Tensor FunctionalInverses::view_copy_inverse(const Tensor& base, const Tensor& m } -Tensor FunctionalInverses::view_copy_dtype_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, at::ScalarType dtype) { - if (reapply_views) { +Tensor FunctionalInverses::view_copy_dtype_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, at::ScalarType dtype) { + if (inverse_return_mode != InverseReturnMode::NeverView) { return mutated_view.view(base.scalar_type()); } else { return at::view_copy(mutated_view, base.scalar_type()); } } -Tensor FunctionalInverses::unfold_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, int64_t dimension, int64_t size, int64_t step) { - // I think autograd and the functionalization pass want the exact same thing here, but need to test to confirm. - // unfold_backward() is safe to use here because it is NOT a view op. - // (note: technically, "reapply_views" won't do anything here and we'll have an extra memory copy. - // We'd need to add an aliasing version of unfold_backward to fix that though). - TORCH_CHECK( - !(reapply_views && size > step), - "While executing unfold, functionalization encountered a tensor being mutated that has internal overlap. \ +Tensor FunctionalInverses::unfold_copy_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, int64_t dimension, int64_t size, int64_t step) { + if (inverse_return_mode == InverseReturnMode::AlwaysView) { + // NB: assumes mutated_view is a narrowed view of base. + // We should NOT do this for functionalization + return mutated_view.as_strided_symint( + base.sym_sizes(), base.sym_strides(), base.sym_storage_offset()); + } else { + // I think autograd and the functionalization pass want the exact same thing here, but need to test to confirm. + // unfold_backward() is safe to use here because it is NOT a view op. + // (note: technically, we'll have an extra memory copy. + // We'd need to add an aliasing version of unfold_backward to fix that though). + TORCH_CHECK( + !(inverse_return_mode == InverseReturnMode::ViewOrScatterInverse && size > step), + "While executing unfold, functionalization encountered a tensor being mutated that has internal overlap. \ When using torch.compile (or running functionalization directly), this is banned \ as the behavior is not well defined. Consider cloning the tensor before mutating it, \ or removing the mutation from your model." - ); - return unfold_backward(mutated_view, base.sizes(), dimension, size, step); + ); + return unfold_backward(mutated_view, base.sizes(), dimension, size, step); + } } -Tensor FunctionalInverses::alias_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views) { - if (reapply_views) { +Tensor FunctionalInverses::alias_copy_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode) { + if (inverse_return_mode != InverseReturnMode::NeverView) { return at::alias(mutated_view); } else { return at::alias_copy(mutated_view); diff --git a/aten/src/ATen/native/TestOps.cpp b/aten/src/ATen/native/TestOps.cpp index fbb787079b97..e78badaacb6d 100644 --- a/aten/src/ATen/native/TestOps.cpp +++ b/aten/src/ATen/native/TestOps.cpp @@ -116,7 +116,7 @@ Tensor _test_check_tensor(const Tensor& self) { namespace at::functionalization { // view_copy ops must have a functional inverse registered -Tensor FunctionalInverses::_test_autograd_multiple_dispatch_view_copy_inverse(const at::Tensor& base, const at::Tensor& mutated_view, bool reapply_views) { +Tensor FunctionalInverses::_test_autograd_multiple_dispatch_view_copy_inverse(const at::Tensor& base, const at::Tensor& mutated_view, InverseReturnMode inverse_return_mode) { TORCH_INTERNAL_ASSERT(false, "Attempted to call _test_autograd_multiple_dispatch_view_copy_inverse() during the functionalization pass. ", "This function is for testing only and should never be called."); diff --git a/aten/src/ATen/templates/FunctionalInverses.h b/aten/src/ATen/templates/FunctionalInverses.h index eea76eeecb14..4426fa5baece 100644 --- a/aten/src/ATen/templates/FunctionalInverses.h +++ b/aten/src/ATen/templates/FunctionalInverses.h @@ -7,6 +7,17 @@ namespace at { namespace functionalization { +enum class InverseReturnMode { + /// Specifies that functional inverses should always return a view. + AlwaysView, + /// Specifies that functional inverses should always return a non-view / copy. + NeverView, + /// Specifies that functional inverses should return a view unless a (copying) scatter + /// inverse exists, in which case that will be used instead. + /// This avoids as_strided() calls that can be difficult for subclasses to handle. + ViewOrScatterInverse, +}; + struct FunctionalInverses { ${view_inverse_declarations} diff --git a/test/test_functionalization.py b/test/test_functionalization.py index 48e1423d96ac..8a0939f6c277 100644 --- a/test/test_functionalization.py +++ b/test/test_functionalization.py @@ -475,6 +475,21 @@ def forward(self, arg0_1): as_strided_copy_1 = torch.ops.aten.as_strided_copy.default(as_strided_scatter, [2], [2], 1) copy_ = torch.ops.aten.copy_.default(arg0_1, as_strided_scatter); arg0_1 = None return as_strided_scatter + """) + + # NB: even with reapply_views=True, we expect to see scatter op + reinplaced_logs = self.get_logs(f, torch.ones(2, 2), reapply_views=True, run_reinplace=False) + self.assertExpectedInline(reinplaced_logs, """\ + + + +def forward(self, arg0_1): + as_strided = torch.ops.aten.as_strided.default(arg0_1, [2], [2], 1) + add = torch.ops.aten.add.Tensor(as_strided, 1); as_strided = None + as_strided_scatter = torch.ops.aten.as_strided_scatter.default(arg0_1, add, [2], [2], 1); add = None + as_strided_1 = torch.ops.aten.as_strided.default(as_strided_scatter, [2], [2], 1) + copy_ = torch.ops.aten.copy_.default(arg0_1, as_strided_scatter); arg0_1 = None + return as_strided_scatter """) def test_tensor_list_composite(self): @@ -584,6 +599,22 @@ def forward(self, arg0_1): diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(diagonal_scatter) copy_ = torch.ops.aten.copy_.default(arg0_1, diagonal_scatter); arg0_1 = None return diagonal_scatter + """) + + # NB: even with reapply_views=True, we expect to see scatter op + reinplaced_logs = self.get_logs(f, torch.ones(2, 2), reapply_views=True, run_reinplace=False) + self.assertExpectedInline(reinplaced_logs, """\ + + + +def forward(self, arg0_1): + ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False) + diagonal = torch.ops.aten.diagonal.default(arg0_1) + add = torch.ops.aten.add.Tensor(diagonal, ones); diagonal = ones = None + diagonal_scatter = torch.ops.aten.diagonal_scatter.default(arg0_1, add); add = None + diagonal_1 = torch.ops.aten.diagonal.default(diagonal_scatter) + copy_ = torch.ops.aten.copy_.default(arg0_1, diagonal_scatter); arg0_1 = None + return diagonal_scatter """) def test_channels_last_contiguous(self): @@ -635,6 +666,143 @@ def forward(self, arg0_1): mul = torch.ops.aten.mul.Tensor(slice_scatter, slice_scatter) copy_ = torch.ops.aten.copy_.default(arg0_1, slice_scatter); arg0_1 = slice_scatter = None return diagonal_copy_1 + """) # noqa: B950 + + # NB: even with reapply_views=True, we expect to see scatter op + reinplaced_logs = self.get_logs(f, torch.ones(4, 2), reapply_views=True, run_reinplace=False) + self.assertExpectedInline(reinplaced_logs, """\ + + + +def forward(self, arg0_1): + ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False) + split = torch.ops.aten.split.Tensor(arg0_1, 2) + getitem = split[0] + getitem_1 = split[1]; split = None + diagonal = torch.ops.aten.diagonal.default(getitem_1); getitem_1 = None + add = torch.ops.aten.add.Tensor(diagonal, ones); diagonal = ones = None + split_1 = torch.ops.aten.split.Tensor(arg0_1, 2) + getitem_2 = split_1[0] + getitem_3 = split_1[1]; split_1 = None + diagonal_scatter = torch.ops.aten.diagonal_scatter.default(getitem_3, add); getitem_3 = add = None + slice_scatter = torch.ops.aten.slice_scatter.default(arg0_1, diagonal_scatter, 0, 2, 4); diagonal_scatter = None + split_2 = torch.ops.aten.split.Tensor(slice_scatter, 2) + getitem_4 = split_2[0] + getitem_5 = split_2[1]; split_2 = None + diagonal_1 = torch.ops.aten.diagonal.default(getitem_5); getitem_5 = None + mul = torch.ops.aten.mul.Tensor(slice_scatter, slice_scatter) + copy_ = torch.ops.aten.copy_.default(arg0_1, slice_scatter); arg0_1 = slice_scatter = None + return diagonal_1 + """) # noqa: B950 + + def test_split_with_sizes(self): + def f(x): + # test: view ops that return multiple tensors (split_with_sizes) + tmp = torch.ones(2) + y1, y2 = x.split_with_sizes([2, 2]) + y3 = y1.diagonal() + y3.add_(tmp) + z = x * x + return y3 + self.assert_functionalization(f, torch.ones(4, 2)) + logs = self.get_logs(f, torch.ones(4, 2)) + self.assertExpectedInline(logs, """\ + + + +def forward(self, arg0_1): + ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False) + split_with_sizes_copy = torch.ops.aten.split_with_sizes_copy.default(arg0_1, [2, 2]) + getitem = split_with_sizes_copy[0] + getitem_1 = split_with_sizes_copy[1]; split_with_sizes_copy = None + diagonal_copy = torch.ops.aten.diagonal_copy.default(getitem); getitem = None + add = torch.ops.aten.add.Tensor(diagonal_copy, ones); diagonal_copy = ones = None + split_with_sizes_copy_1 = torch.ops.aten.split_with_sizes_copy.default(arg0_1, [2, 2]) + getitem_2 = split_with_sizes_copy_1[0] + getitem_3 = split_with_sizes_copy_1[1]; split_with_sizes_copy_1 = None + diagonal_scatter = torch.ops.aten.diagonal_scatter.default(getitem_2, add); getitem_2 = add = None + slice_scatter = torch.ops.aten.slice_scatter.default(arg0_1, diagonal_scatter, 0, 0, 2); diagonal_scatter = None + split_with_sizes_copy_2 = torch.ops.aten.split_with_sizes_copy.default(slice_scatter, [2, 2]) + getitem_4 = split_with_sizes_copy_2[0] + getitem_5 = split_with_sizes_copy_2[1]; split_with_sizes_copy_2 = None + diagonal_copy_1 = torch.ops.aten.diagonal_copy.default(getitem_4); getitem_4 = None + mul = torch.ops.aten.mul.Tensor(slice_scatter, slice_scatter) + copy_ = torch.ops.aten.copy_.default(arg0_1, slice_scatter); arg0_1 = slice_scatter = None + return diagonal_copy_1 + """) # noqa: B950 + + # NB: even with reapply_views=True, we expect to see scatter op + reinplaced_logs = self.get_logs(f, torch.ones(4, 2), reapply_views=True, run_reinplace=False) + self.assertExpectedInline(reinplaced_logs, """\ + + + +def forward(self, arg0_1): + ones = torch.ops.aten.ones.default([2], device = device(type='cpu'), pin_memory = False) + split_with_sizes = torch.ops.aten.split_with_sizes.default(arg0_1, [2, 2]) + getitem = split_with_sizes[0] + getitem_1 = split_with_sizes[1]; split_with_sizes = None + diagonal = torch.ops.aten.diagonal.default(getitem); getitem = None + add = torch.ops.aten.add.Tensor(diagonal, ones); diagonal = ones = None + split_with_sizes_1 = torch.ops.aten.split_with_sizes.default(arg0_1, [2, 2]) + getitem_2 = split_with_sizes_1[0] + getitem_3 = split_with_sizes_1[1]; split_with_sizes_1 = None + diagonal_scatter = torch.ops.aten.diagonal_scatter.default(getitem_2, add); getitem_2 = add = None + slice_scatter = torch.ops.aten.slice_scatter.default(arg0_1, diagonal_scatter, 0, 0, 2); diagonal_scatter = None + split_with_sizes_2 = torch.ops.aten.split_with_sizes.default(slice_scatter, [2, 2]) + getitem_4 = split_with_sizes_2[0] + getitem_5 = split_with_sizes_2[1]; split_with_sizes_2 = None + diagonal_1 = torch.ops.aten.diagonal.default(getitem_4); getitem_4 = None + mul = torch.ops.aten.mul.Tensor(slice_scatter, slice_scatter) + copy_ = torch.ops.aten.copy_.default(arg0_1, slice_scatter); arg0_1 = slice_scatter = None + return diagonal_1 + """) # noqa: B950 + + def test_slice(self): + def f(x): + tmp = torch.ones(4) + x.transpose_(1, 0) + y = x[0:2] + y.add_(tmp) + return x + self.assert_functionalization(f, torch.ones(4, 2), mutated_input_metadata=True) + logs = self.get_logs(f, torch.ones(4, 2)) + self.assertExpectedInline(logs, """\ + + + +def forward(self, arg0_1): + ones = torch.ops.aten.ones.default([4], device = device(type='cpu'), pin_memory = False) + transpose_copy = torch.ops.aten.transpose_copy.int(arg0_1, 1, 0) + slice_copy = torch.ops.aten.slice_copy.Tensor(transpose_copy, 0, 0, 2); transpose_copy = None + add = torch.ops.aten.add.Tensor(slice_copy, ones); slice_copy = ones = None + transpose_copy_1 = torch.ops.aten.transpose_copy.int(arg0_1, 1, 0); arg0_1 = None + slice_scatter = torch.ops.aten.slice_scatter.default(transpose_copy_1, add, 0, 0, 2); transpose_copy_1 = add = None + transpose_copy_2 = torch.ops.aten.transpose_copy.int(slice_scatter, 1, 0); slice_scatter = None + transpose_copy_3 = torch.ops.aten.transpose_copy.int(transpose_copy_2, 1, 0) + slice_copy_1 = torch.ops.aten.slice_copy.Tensor(transpose_copy_3, 0, 0, 2); transpose_copy_3 = None + transpose_copy_4 = torch.ops.aten.transpose_copy.int(transpose_copy_2, 1, 0); transpose_copy_2 = None + return transpose_copy_4 + """) # noqa: B950 + + # NB: even with reapply_views=True, we expect to see scatter op + reinplaced_logs = self.get_logs(f, torch.ones(4, 2), reapply_views=True, run_reinplace=False) + self.assertExpectedInline(reinplaced_logs, """\ + + + +def forward(self, arg0_1): + ones = torch.ops.aten.ones.default([4], device = device(type='cpu'), pin_memory = False) + transpose = torch.ops.aten.transpose.int(arg0_1, 1, 0) + slice_1 = torch.ops.aten.slice.Tensor(transpose, 0, 0, 2); transpose = None + add = torch.ops.aten.add.Tensor(slice_1, ones); slice_1 = ones = None + transpose_1 = torch.ops.aten.transpose.int(arg0_1, 1, 0); arg0_1 = None + slice_scatter = torch.ops.aten.slice_scatter.default(transpose_1, add, 0, 0, 2); transpose_1 = add = None + transpose_2 = torch.ops.aten.transpose.int(slice_scatter, 1, 0); slice_scatter = None + transpose_3 = torch.ops.aten.transpose.int(transpose_2, 1, 0) + slice_2 = torch.ops.aten.slice.Tensor(transpose_3, 0, 0, 2); transpose_3 = None + transpose_4 = torch.ops.aten.transpose.int(transpose_2, 1, 0); transpose_2 = None + return transpose_4 """) # noqa: B950 def test_view_inplace(self): @@ -663,6 +831,82 @@ def forward(self, arg0_1): select_copy_1 = torch.ops.aten.select_copy.int(transpose_copy_3, 0, 0); transpose_copy_3 = None transpose_copy_4 = torch.ops.aten.transpose_copy.int(transpose_copy_2, 1, 0); transpose_copy_2 = None return transpose_copy_4 + """) # noqa: B950 + + # NB: even with reapply_views=True, we expect to see scatter op + reinplaced_logs = self.get_logs(f, torch.ones(4, 2), reapply_views=True, run_reinplace=False) + self.assertExpectedInline(reinplaced_logs, """\ + + + +def forward(self, arg0_1): + ones = torch.ops.aten.ones.default([4], device = device(type='cpu'), pin_memory = False) + transpose = torch.ops.aten.transpose.int(arg0_1, 1, 0) + select = torch.ops.aten.select.int(transpose, 0, 0); transpose = None + add = torch.ops.aten.add.Tensor(select, ones); select = ones = None + transpose_1 = torch.ops.aten.transpose.int(arg0_1, 1, 0); arg0_1 = None + select_scatter = torch.ops.aten.select_scatter.default(transpose_1, add, 0, 0); transpose_1 = add = None + transpose_2 = torch.ops.aten.transpose.int(select_scatter, 1, 0); select_scatter = None + transpose_3 = torch.ops.aten.transpose.int(transpose_2, 1, 0) + select_1 = torch.ops.aten.select.int(transpose_3, 0, 0); transpose_3 = None + transpose_4 = torch.ops.aten.transpose.int(transpose_2, 1, 0); transpose_2 = None + return transpose_4 + """) # noqa: B950 + + def test_unbind(self): + def f(x): + # test: view + inplace op (transpose_) + tmp = torch.ones(4) + x.transpose_(1, 0) + y, _ = x.unbind(0) + y.add_(tmp) + return x + self.assert_functionalization(f, torch.ones(4, 2), mutated_input_metadata=True) + logs = self.get_logs(f, torch.ones(4, 2)) + self.assertExpectedInline(logs, """\ + + + +def forward(self, arg0_1): + ones = torch.ops.aten.ones.default([4], device = device(type='cpu'), pin_memory = False) + transpose_copy = torch.ops.aten.transpose_copy.int(arg0_1, 1, 0) + unbind_copy = torch.ops.aten.unbind_copy.int(transpose_copy); transpose_copy = None + getitem = unbind_copy[0] + getitem_1 = unbind_copy[1]; unbind_copy = None + add = torch.ops.aten.add.Tensor(getitem, ones); getitem = ones = None + transpose_copy_1 = torch.ops.aten.transpose_copy.int(arg0_1, 1, 0); arg0_1 = None + select_scatter = torch.ops.aten.select_scatter.default(transpose_copy_1, add, 0, 0); transpose_copy_1 = add = None + transpose_copy_2 = torch.ops.aten.transpose_copy.int(select_scatter, 1, 0); select_scatter = None + transpose_copy_3 = torch.ops.aten.transpose_copy.int(transpose_copy_2, 1, 0) + unbind_copy_1 = torch.ops.aten.unbind_copy.int(transpose_copy_3); transpose_copy_3 = None + getitem_2 = unbind_copy_1[0] + getitem_3 = unbind_copy_1[1]; unbind_copy_1 = None + transpose_copy_4 = torch.ops.aten.transpose_copy.int(transpose_copy_2, 1, 0); transpose_copy_2 = None + return transpose_copy_4 + """) # noqa: B950 + + # NB: even with reapply_views=True, we expect to see scatter op + reinplaced_logs = self.get_logs(f, torch.ones(4, 2), reapply_views=True, run_reinplace=False) + self.assertExpectedInline(reinplaced_logs, """\ + + + +def forward(self, arg0_1): + ones = torch.ops.aten.ones.default([4], device = device(type='cpu'), pin_memory = False) + transpose = torch.ops.aten.transpose.int(arg0_1, 1, 0) + unbind = torch.ops.aten.unbind.int(transpose); transpose = None + getitem = unbind[0] + getitem_1 = unbind[1]; unbind = None + add = torch.ops.aten.add.Tensor(getitem, ones); getitem = ones = None + transpose_1 = torch.ops.aten.transpose.int(arg0_1, 1, 0); arg0_1 = None + select_scatter = torch.ops.aten.select_scatter.default(transpose_1, add, 0, 0); transpose_1 = add = None + transpose_2 = torch.ops.aten.transpose.int(select_scatter, 1, 0); select_scatter = None + transpose_3 = torch.ops.aten.transpose.int(transpose_2, 1, 0) + unbind_1 = torch.ops.aten.unbind.int(transpose_3); transpose_3 = None + getitem_2 = unbind_1[0] + getitem_3 = unbind_1[1]; unbind_1 = None + transpose_4 = torch.ops.aten.transpose.int(transpose_2, 1, 0); transpose_2 = None + return transpose_4 """) # noqa: B950 def test_optional_tensor_list(self): @@ -1677,7 +1921,10 @@ def forward(self, arg0_1): "test_diagonal_mutated_input", "test_everything", "test_fill_", + "test_slice", "test_split", + "test_split_with_sizes", + "test_unbind", "test_view_clone_view_inplace", "test_view_inplace", ]) diff --git a/torchgen/api/functionalization.py b/torchgen/api/functionalization.py index 0b86dd547f7d..5526b14575ee 100644 --- a/torchgen/api/functionalization.py +++ b/torchgen/api/functionalization.py @@ -2,6 +2,7 @@ from torchgen.api import dispatcher from torchgen.api.types import ( + BaseCppType, BaseCType, Binding, boolT, @@ -69,6 +70,20 @@ default=None, ) +InverseReturnModeT = BaseCppType("at::functionalization", "InverseReturnMode") +inverse_return_mode_binding = Binding( + name="inverse_return_mode", + nctype=NamedCType(name="inverse_return_mode", type=BaseCType(InverseReturnModeT)), + argument=Argument( + name="inverse_return_mode", + # NB: not actually a bool but it doesn't matter because this isn't used + type=BaseType(BaseTy.bool), + default=None, + annotation=None, + ), + default=None, +) + # The lambda capture itself doesn't have a name. # The name returned here corresponds to the name of the inner function called by the lambda. @@ -115,7 +130,11 @@ def capture_arguments(func: FunctionSchema, *, is_reverse: bool) -> List[Binding non_self_value_bindings = [ dispatcher.argument(a, remove_non_owning_ref_types=True) for a in non_self_args ] - all_bindings = [reapply_views_binding] + non_self_value_bindings + + all_bindings = [ + inverse_return_mode_binding if is_reverse else reapply_views_binding + ] + all_bindings.extend(non_self_value_bindings) return all_bindings @@ -165,12 +184,12 @@ def inner_arguments(func: FunctionSchema, is_reverse: bool) -> List[Binding]: return [ base_binding, mutated_view_binding, - reapply_views_binding, + inverse_return_mode_binding, index_binding, ] + non_self_bindings else: return [ base_binding, mutated_view_binding, - reapply_views_binding, + inverse_return_mode_binding, ] + non_self_bindings diff --git a/torchgen/api/types/signatures.py b/torchgen/api/types/signatures.py index 3af5d9c4cb45..de0d41f26376 100644 --- a/torchgen/api/types/signatures.py +++ b/torchgen/api/types/signatures.py @@ -316,7 +316,8 @@ def captures(self) -> List[Expr]: # We also need to read the "reapply views" TLS at the time that the functionalization kernel was executed, # and plumb it into the lambda. outer_ctx = dispatcher.arguments(self.g.view.func) + [ - functionalization.reapply_views_binding + functionalization.reapply_views_binding, + functionalization.inverse_return_mode_binding, ] capture_bindings = functionalization.capture_arguments( self.g.view.func, is_reverse=self.is_reverse diff --git a/torchgen/gen_functionalization_type.py b/torchgen/gen_functionalization_type.py index c39fc3e3e3bf..6cd50ea70bad 100644 --- a/torchgen/gen_functionalization_type.py +++ b/torchgen/gen_functionalization_type.py @@ -338,6 +338,10 @@ def emit_view_functionalization_body( return at::_ops::{noop_api_name}::call({', '.join(view_redispatch_args)}); }} auto reapply_views = at::functionalization::impl::getFunctionalizationReapplyViewsTLS(); + auto inverse_return_mode = ( + reapply_views ? at::functionalization::InverseReturnMode::ViewOrScatterInverse + : at::functionalization::InverseReturnMode::NeverView + ); at::functionalization::ViewMeta view_meta = at::functionalization::ViewMeta( {forward_lambda.decl()} {{ if (reapply_views) {{ @@ -387,6 +391,10 @@ def emit_view_functionalization_body( return at::_ops::{noop_api_name}::call({', '.join(view_redispatch_args)}); }} auto reapply_views = at::functionalization::impl::getFunctionalizationReapplyViewsTLS(); + auto inverse_return_mode = ( + reapply_views ? at::functionalization::InverseReturnMode::ViewOrScatterInverse + : at::functionalization::InverseReturnMode::NeverView + ); auto compute_reference_meta = {view_tensor_name}.key_set().has_backend(c10::BackendComponent::XLABit) || {view_tensor_name}.key_set().has_backend(c10::BackendComponent::LazyBit);