-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Description
🐛 Describe the bug
Repro:
import torch
from torch._subclasses.fake_tensor import FakeTensorMode
from torch._dispatch.python import enable_python_dispatcher
from torch.fx.experimental.symbolic_shapes import ShapeEnv, DimDynamic, SymNode
from torch import SymInt
import torch._dynamo
def f():
shape_env = ShapeEnv()
s0 = 5
s1 = 6
s2 = 7
s3 = 3
s4 = 10
s5 = 2
x = torch.randn(s0, s1, s2)
out = torch.randn(s0, s3, s4)
kwargs = {
's': (s3, s4),
'dim': (1, s5),
'norm': 'ortho',
}
from torch.fx.experimental.proxy_tensor import make_fx
r = torch._C._fft.fft_hfft2(x, **kwargs, out=out)
assert r.shape == out.shape, r.shape
#print("real")
#f()
print("fake")
with FakeTensorMode():
f()
First, we observe that fft_hfft2 uses use_const_ref_for_mutable_tensors
:
- func: fft_hfft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!)
use_const_ref_for_mutable_tensors: True
python_module: fft
variants: function
dispatch:
CompositeImplicitAutograd: fft_hfft2_symint_out
This means the signature of fft_hfft2_symint_out is:
const Tensor& fft_hfft2_symint_out(
const Tensor& self, at::OptionalSymIntArrayRef s, IntArrayRef dim,
std::optional<std::string_view> norm, const Tensor& out) {
Now, consider how we implement this signature in aten/src/ATen/core/boxing/impl/boxing.h :
// 3.5. In-process migration to make in-place ops take and return
// const references instead.
template <class... OtherArgs>
struct BoxedKernelWrapper<
const at::Tensor&(const at::Tensor&, OtherArgs...),
std::enable_if_t<can_box_all<OtherArgs...>::value, void>> {
static const at::Tensor& call(
const BoxedKernel& boxed_kernel_func,
const OperatorHandle& opHandle,
DispatchKeySet dispatchKeySet,
const at::Tensor& outArg,
OtherArgs... otherArgs) {
torch::jit::Stack stack = boxArgs(outArg, otherArgs...);
boxed_kernel_func.callBoxed(opHandle, dispatchKeySet, &stack);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
stack.size() == 1,
"Boxed kernel was expected to return a single value on the stack, ",
"but instead returned ",
stack.size(),
" values.");
return outArg;
}
};
Our signature will match the boxing rule for inplace arguments, because it returns a const Tensor&, and the first argument is a const Tensor&. This means we will return self, rather than out, as the return argument.
I think it is fundamentally impossible to solve this problem as there as an ambiguity as to whether or not the function signature is an out signature or an inplace one, we don't specify this in the function signature. A dedicated "Out" struct that wraps tensor reference could help. But an easy WAR is to stop using the modern const Tensor& style for these out functions.
Versions
main