Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug when dealing with fallbacks on CPU #105853

Closed
lezcano opened this issue Jul 24, 2023 · 5 comments
Closed

Bug when dealing with fallbacks on CPU #105853

lezcano opened this issue Jul 24, 2023 · 5 comments
Labels
module: complex Related to complex number support in PyTorch module: dynamic shapes module: inductor oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@lezcano
Copy link
Collaborator

lezcano commented Jul 24, 2023

馃悰 Describe the bug

To repro, patch in #105850, change the line

ans = torch.matmul(x, y)

for torch.compile(torch.matmul)(x, y), and run

python test/test_linalg.py -vk test_matmul_small_brute_force_1d_Nd_cpu_complex64

It fails with the following traceback

  File "/home/lezcano/git/pytorch/pytorch/torch/fx/interpreter.py", line 138, in run
    self.env[node] = self.run_node(node)
  File "/home/lezcano/git/pytorch/pytorch/torch/_inductor/graph.py", line 658, in run_node
    result = fallback_handler(n.target, add_to_fallback_set=False)(
  File "/home/lezcano/git/pytorch/pytorch/torch/_inductor/lowering.py", line 1291, in handler
    TensorBox.create, ir.FallbackKernel.create(kernel, *args, **kwargs)
  File "/home/lezcano/git/pytorch/pytorch/torch/_inductor/ir.py", line 3430, in create
    return generate_output(example_output, [])
  File "/home/lezcano/git/pytorch/pytorch/torch/_inductor/ir.py", line 3427, in generate_output
    assert output is None, f"FallbackKernel output type {type(output)} is not supported"
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
AssertionError: FallbackKernel output type <class 'torch.SymInt'> is not supported

It fails when creating a fallback for:

aten.sym_size                                                                                                                                                                                                        
(TensorBox(StorageBox(                                                                                                                                                                                               
  InputBuffer(name='arg3_1', layout=FixedLayout('cpu', torch.complex64, size=[0, s0, 0], stride=[s0, 1, 1]))                                                                                                         
)), 1)                                                                                                                                                                                                               
{}  

This is odd, as sym_size does have a lowering.

Versions

master

cc @ezyang @anjali411 @dylanbespalko @mruberry @lezcano @nikitaved @msaroufim @wconstab @bdhirsh @anijain2305 @zou3519 @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler @Xia-Weiwen @ngimel

@lezcano lezcano added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: cpu inductor labels Jul 24, 2023
@lezcano lezcano added the module: complex Related to complex number support in PyTorch label Jul 24, 2023
@jon-chuang
Copy link
Collaborator

jon-chuang commented Sep 11, 2023

Hello @lezcano , I also faced a similar issue when torch inductor compiling fft_ihfftn_cuda_float32 producing complex float outputs.

Is this currently lowered?

Seems like a general issue for lowering SymInt for shapes of complex tensors?

See: #109001

@jon-chuang
Copy link
Collaborator

jon-chuang commented Sep 11, 2023

I was not able to repro the FallbackKernel out type bug with your code, but I was able to find the following (potentially related) bug:

x = torch.tensor([-8.4784-1.7658j])
y = torch.tensor([-8.4784-1.7658j])
ans = torch.compile(torch.matmul)(x, y)
out = torch.empty_like(ans)
torch.compile(torch.matmul)(x, y, out=out)
torch.testing.assert_close(ans, out) # fails

Succeeds:

out = torch.compile(torch.matmul)(x, y)
torch.testing.assert_close(ans, out) # success

Note to self: try zeros_like instead of empty_like as well.

@ezyang
Copy link
Contributor

ezyang commented Sep 13, 2023

@yf225 is this related to things you were looking at?

@penguinwu penguinwu added oncall: pt2 oncall: cpu inductor CPU Inductor issues for Intel team to triage and removed module: cpu inductor labels Dec 2, 2023
@leslie-fang-intel
Copy link
Collaborator

Looks not a CPU specific issue.

@leslie-fang-intel leslie-fang-intel removed the oncall: cpu inductor CPU Inductor issues for Intel team to triage label Dec 7, 2023
@ZailiWang
Copy link
Contributor

Tested with latest nightly build (2.2.0.dev20231209), neither of the issues reported in this thread is reproduced.
test_linalg.py with the change:

Fail to import hypothesis in common_utils, tests are not derandomized
test_matmul_small_brute_force_1d_Nd_cpu_complex64 (__main__.TestLinalgCPU.test_matmul_small_brute_force_1d_Nd_cpu_complex64) ... /home/.../miniconda3/envs/ptni1209/lib/python3.11/site-packages/torch/_inductor/lowering.py:1611: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager.
  warnings.warn(
/home/zailiwan/pytorch/pytorch/test/test_linalg.py:4350: UserWarning: An output with one or more elements was resized since it had shape [1], which does not match the required output shape [1, 1]. This behavior is deprecated, and in a future PyTorch release outputs will not be resized unless they have zero elements. You can explicitly reuse an out tensor t by resizing it, inplace, to zero elements with t.resize_(0). (Triggered internally at /opt/conda/conda-bld/pytorch_1702107914768/work/aten/src/ATen/native/Resize.cpp:28.)
  ans = torch.matmul(x, y, out=out)
[2023-12-10 17:20:44,566] torch._dynamo.convert_frame: [WARNING] torch._dynamo hit config.cache_size_limit (8)
[2023-12-10 17:20:44,566] torch._dynamo.convert_frame: [WARNING]    function: 'inner' (/home/zailiwan/miniconda3/envs/ptni1209/lib/python3.11/site-packages/torch/_dynamo/external_utils.py:15)
[2023-12-10 17:20:44,566] torch._dynamo.convert_frame: [WARNING]    last reason: tensor 'L['args'][1]' rank mismatch. expected 1, actual 2
[2023-12-10 17:20:44,566] torch._dynamo.convert_frame: [WARNING] To log all recompilation reasons, use TORCH_LOGS="recompiles".
[2023-12-10 17:20:44,566] torch._dynamo.convert_frame: [WARNING] To diagnose recompilation issues, see https://pytorch.org/docs/master/compile/troubleshooting.html.
/home/zailiwan/pytorch/pytorch/test/test_linalg.py:4350: UserWarning: An output with one or more elements was resized since it had shape [2], which does not match the required output shape [1, 2]. This behavior is deprecated, and in a future PyTorch release outputs will not be resized unless they have zero elements. You can explicitly reuse an out tensor t by resizing it, inplace, to zero elements with t.resize_(0). (Triggered internally at /opt/conda/conda-bld/pytorch_1702107914768/work/aten/src/ATen/native/Resize.cpp:28.)
  ans = torch.matmul(x, y, out=out)
/home/zailiwan/pytorch/pytorch/test/test_linalg.py:4350: UserWarning: An output with one or more elements was resized since it had shape [3], which does not match the required output shape [1, 3]. This behavior is deprecated, and in a future PyTorch release outputs will not be resized unless they have zero elements. You can explicitly reuse an out tensor t by resizing it, inplace, to zero elements with t.resize_(0). (Triggered internally at /opt/conda/conda-bld/pytorch_1702107914768/work/aten/src/ATen/native/Resize.cpp:28.)
  ans = torch.matmul(x, y, out=out)
ok

----------------------------------------------------------------------
Ran 1 test in 1.895s

OK

The snippet

x = torch.tensor([-8.4784-1.7658j])
y = torch.tensor([-8.4784-1.7658j])
ans = torch.compile(torch.matmul)(x, y)
out = torch.empty_like(ans)
torch.compile(torch.matmul)(x, y, out=out)
torch.testing.assert_close(ans, out) # fails

finishes well without failure.

@lezcano lezcano closed this as completed Dec 10, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: complex Related to complex number support in PyTorch module: dynamic shapes module: inductor oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

6 participants