Skip to content

use_const_ref_for_mutable_tensors doesn't work with out= overloads #145522

@ezyang

Description

@ezyang

🐛 Describe the bug

Repro:

import torch
from torch._subclasses.fake_tensor import FakeTensorMode
from torch._dispatch.python import enable_python_dispatcher
from torch.fx.experimental.symbolic_shapes import ShapeEnv, DimDynamic, SymNode
from torch import SymInt
import torch._dynamo

def f():
    shape_env = ShapeEnv()
    s0 = 5
    s1 = 6
    s2 = 7
    s3 = 3
    s4 = 10
    s5 = 2
    x = torch.randn(s0, s1, s2)
    out = torch.randn(s0, s3, s4)
    kwargs = {
        's': (s3, s4),
        'dim': (1, s5),
        'norm': 'ortho',
    }

    from torch.fx.experimental.proxy_tensor import make_fx

    r = torch._C._fft.fft_hfft2(x, **kwargs, out=out)
    assert r.shape == out.shape, r.shape

#print("real")
#f()

print("fake")
with FakeTensorMode():
    f()

First, we observe that fft_hfft2 uses use_const_ref_for_mutable_tensors:

- func: fft_hfft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
  use_const_ref_for_mutable_tensors: True
  python_module: fft
  variants: function
  dispatch: 
    CompositeImplicitAutograd: fft_hfft2_symint_out

This means the signature of fft_hfft2_symint_out is:

const Tensor& fft_hfft2_symint_out(
    const Tensor& self, at::OptionalSymIntArrayRef s, IntArrayRef dim,
    std::optional<std::string_view> norm, const Tensor& out) {

Now, consider how we implement this signature in aten/src/ATen/core/boxing/impl/boxing.h :

// 3.5. In-process migration to make in-place ops take and return
// const references instead.
template <class... OtherArgs>
struct BoxedKernelWrapper<
    const at::Tensor&(const at::Tensor&, OtherArgs...),
    std::enable_if_t<can_box_all<OtherArgs...>::value, void>> {
  static const at::Tensor& call(
      const BoxedKernel& boxed_kernel_func,
      const OperatorHandle& opHandle,
      DispatchKeySet dispatchKeySet,
      const at::Tensor& outArg,
      OtherArgs... otherArgs) {
    torch::jit::Stack stack = boxArgs(outArg, otherArgs...);
    boxed_kernel_func.callBoxed(opHandle, dispatchKeySet, &stack);
    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
        stack.size() == 1,
        "Boxed kernel was expected to return a single value on the stack, ",
        "but instead returned ",
        stack.size(),
        " values.");

    return outArg;
  }
};

Our signature will match the boxing rule for inplace arguments, because it returns a const Tensor&, and the first argument is a const Tensor&. This means we will return self, rather than out, as the return argument.

I think it is fundamentally impossible to solve this problem as there as an ambiguity as to whether or not the function signature is an out signature or an inplace one, we don't specify this in the function signature. A dedicated "Out" struct that wraps tensor reference could help. But an easy WAR is to stop using the modern const Tensor& style for these out functions.

Versions

main

cc @chauhang @penguinwu @zou3519 @bdhirsh @yf225

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: pt2-dispatcherPT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions