New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[inductor] decomposition for complex addition #110740
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/110740
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 4ca773b with merge base 0a26e5f (): BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
54ab871
to
c10c88c
Compare
@jansel Can you please take a look? There are some test failures that seem unrelated and look like infra issue. |
torch/_inductor/decomposition.py
Outdated
r = x.real + r | ||
if x_is_complex_tensor: | ||
i = x.imag + i | ||
complex_type = x.dtype if x_is_complex_tensor else y.dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch.promote_types
torch/_inductor/decomposition.py
Outdated
return ( | ||
torch.where( | ||
torch.arange(2, device=x.device, dtype=torch.uint8) == 0, | ||
r.unsqueeze(-1), | ||
i.unsqueeze(-1), | ||
) | ||
.view(complex_type) | ||
.squeeze(-1) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For add in particular I think you can just do:
return (x.view(torch.float32)+x.view(torch.float32)).view(complex_type)
Since you do the same thing for both .real and .complex.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is cool. Thanks for the suggestion. However, the generated code still have aten.view
calls and CSE cannot be done across kenerls. I guess the goal here is to optimize away aten.view
?
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0) # no-op to ensure context
# Source Nodes: [Z], Original ATen: [aten.add]
buf0 = aten.view(arg0_1, torch.float32)
del arg0_1
buf1 = buf0
del buf0
# Source Nodes: [Z], Original ATen: [aten.add]
buf2 = aten.view(arg1_1, torch.float32)
del arg1_1
buf3 = buf2
del buf2
buf4 = buf1; del buf1 # reuse
# Source Nodes: [Z], Original ATen: [aten.add]
stream0 = get_cuda_stream(0)
triton_poi_fused_add_0.run(buf4, buf3, 2000000, grid=grid(2000000), stream=stream0)
del buf3
# Source Nodes: [Z], Original ATen: [aten.add]
buf5 = aten.view(buf4, torch.complex64)
del buf4
buf6 = buf5
del buf5
return (buf6, )
Also the new form does not handle that when either x
or y
is a scalar.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, fixing the view problems requires teaching inductor how to do complex views.
The scalar case would need to be handled differently. You could first convert the scalar to complex.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jansel I think an easier thing to do would just be to pattern-match away redundant view calls right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where would you do that? BTW, are they redundant? I thought without the view calls, the type of the tensors wouldn't be correct.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this case, they're not redundant. But if you have two adds in a row, the views in the way would prevent fusion.
So for a single complex add, it would be
x.view(torch.float64)
x + x
x.view(torch.complex64)
Then for two complex adds, it would be
x.view(torch.float64)
x + x
x.view(torch.complex64)
x.view(torch.float64)
x + x
x.view(torch.complex64)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is the additional problem in that these views are getting mapped to fallback kernels, which don't properly encode the aliasing relationship. So inductor may try to reuse the memory for something else, which is a correctness issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I’m working on a new inductor IR node that handles complex/float views to avoid calling into the runtime. According to recent discussion, it looks like the triton community doesn’t like supporting complex natively there. But do you see a way to handle the views without using a triton kernel? Once the views are lowered, inductor will fuse the lowered code into the same triton kernel, which exposes complex64 to triton.
caa2d13
to
6e43e32
Compare
Awesome! I'm also looking for an optimal methodology to support Complex. @htyu , May I know the performance behavior? I'm kind of concerned about the performance on the Triton side, as the non-contiguous access may not be vectorized easily. |
Optimal performance is our goal too in general. For this particular issue, since there is no strided memory accesses created, the performance is acutally on-par with Non-contiguous access will be an issue for other complex num operations such as matmul. There is a discussion about whether to support complex natively in Triton which may be smarter to handle non-contiguous access with packing/unpacking: https://triton-lang.slack.com/archives/C01LY4FJL56/p1697060717465109 . You are welcome to discuss there. My understanding is it's not easy to come up with an optimal scheme for both cuda cores and tensor cores. |
b4d00b5
to
4bba606
Compare
@@ -611,6 +611,15 @@ def fn(x, y): | |||
|
|||
self.common(fn, (x, y)) | |||
|
|||
def test_add_complex(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add tests for:
- complex + non-complex
- complex + scalar
alpha=...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also if you have a test that demonstrates the aliasing issue with the prior version you should add that too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added test for alpha.
For # 1 and # 2, I'm afraid that a strided load will be needed as the computation has nothing to do with the imaginary part. Alternatively we would load the complex number as a whole in a thread, but for complex128 we need int128 support which isn't there in Trition. Can we leave them in a separate change until downstream support is ready?
Also if you have a test that demonstrates the aliasing issue with the prior version you should add that too.
Somehow I couldn't repro this anymore :( Could it be due to recent has_aliasing
change #110651 ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this is the correct approach. How would you implement things like multiplicaiton or abs
? If you do it naïvely, you'll end up with 2 non-contiguous loads which will give you a pretty horrific performance, and as it stands, it's not possible to do it better without introducing quite a few structural changes to how inductor loads values.
This is an intrinsic limitation of this approach, as addition is pretty much the only operation you can implement without having to split the real and complex part.
I think that, if we want to support complex number, we should look into helping adding complex number support within triton, and supporting complex numbers natively in inductor.
After reading #98161 (comment), it looks like triton does not want to support complex numbers, but we would still need to have a way to load complex tensors efficiently, and have a way to implement performant lowerings for complex numbers in triton, which may be different to those from CPU, as CPU already has all these ops implemented in the scalar and vectorized dtypes |
@pytorchbot label "topic: not user facing" |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 2 mandatory check(s) failed. The first few are:
Dig deeper by viewing the failures on hud |
@pytorchbot merge -f "bypass unrelated failure" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
As a follow-up to #110740, this patches enables removing redundant complex views to allow more operation fusing. E.g, given ``` @torch.compile def foo(X, Y): Z = X + Y A = X + Y return A + Z ``` the generated code is: ``` @triton.jit def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr): xnumel = 6 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:] xmask = xindex < xnumel x0 = xindex tmp0 = tl.load(in_ptr0 + (x0), xmask) tmp1 = tl.load(in_ptr1 + (x0), xmask) tmp2 = tmp0 + tmp1 tmp3 = tmp2 + tmp2 tl.store(out_ptr0 + (x0), tmp3, xmask) ''') def call(args): arg0_1, arg1_1 = args args.clear() assert_size_stride(arg0_1, (3, ), (1, )) assert_size_stride(arg1_1, (3, ), (1, )) with torch.cuda._DeviceGuard(0): torch.cuda.set_device(0) # no-op to ensure context # Source Nodes: [A], Original ATen: [aten.add] buf0 = aten.view.dtype(arg0_1, torch.float32) del arg0_1 buf1 = buf0 del buf0 # Source Nodes: [A], Original ATen: [aten.add] buf2 = aten.view.dtype(arg1_1, torch.float32) del arg1_1 buf3 = buf2 del buf2 buf4 = empty_strided((6, ), (1, ), device='cuda', dtype=torch.float32) # Source Nodes: [add_2], Original ATen: [aten.add] stream0 = get_cuda_stream(0) triton_poi_fused_add_0.run(buf1, buf3, buf4, 6, grid=grid(6), stream=stream0) del buf1 del buf3 # Source Nodes: [add_2], Original ATen: [aten.add] buf5 = aten.view.dtype(buf4, torch.complex64) del buf4 buf6 = buf5 del buf5 return (buf6, ) ``` whereas previously the generated code was: ``` @triton.jit def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr): xnumel = 6 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:] xmask = xindex < xnumel x0 = xindex tmp0 = tl.load(in_ptr0 + (x0), xmask) tmp1 = tl.load(in_ptr1 + (x0), xmask) tmp2 = tmp0 + tmp1 tl.store(out_ptr0 + (x0), tmp2, xmask) def call(args): arg0_1, arg1_1 = args args.clear() assert_size_stride(arg0_1, (3, ), (1, )) assert_size_stride(arg1_1, (3, ), (1, )) with torch.cuda._DeviceGuard(0): torch.cuda.set_device(0) # no-op to ensure context # Source Nodes: [A], Original ATen: [aten.add] buf0 = aten.view.dtype(arg0_1, torch.float32) buf1 = buf0 del buf0 # Source Nodes: [A], Original ATen: [aten.add] buf2 = aten.view.dtype(arg1_1, torch.float32) buf3 = buf2 del buf2 buf4 = empty_strided((6, ), (1, ), device='cuda', dtype=torch.float32) # Source Nodes: [A], Original ATen: [aten.add] stream0 = get_cuda_stream(0) triton_poi_fused_add_0.run(buf1, buf3, buf4, 6, grid=grid(6), stream=stream0) del buf1 del buf3 # Source Nodes: [A], Original ATen: [aten.add] buf5 = aten.view.dtype(buf4, torch.complex64) buf6 = buf5 del buf5 # Source Nodes: [add_2], Original ATen: [aten.add] buf7 = aten.view.dtype(buf6, torch.float32) del buf6 buf8 = buf7 del buf7 # Source Nodes: [Z], Original ATen: [aten.add] buf9 = aten.view.dtype(arg0_1, torch.float32) del arg0_1 buf10 = buf9 del buf9 # Source Nodes: [Z], Original ATen: [aten.add] buf11 = aten.view.dtype(arg1_1, torch.float32) del arg1_1 buf12 = buf11 del buf11 buf13 = buf4; del buf4 # reuse # Source Nodes: [Z], Original ATen: [aten.add] triton_poi_fused_add_0.run(buf10, buf12, buf13, 6, grid=grid(6), stream=stream0) del buf10 del buf12 # Source Nodes: [Z], Original ATen: [aten.add] buf14 = aten.view.dtype(buf13, torch.complex64) buf15 = buf14 del buf14 # Source Nodes: [add_2], Original ATen: [aten.add] buf16 = aten.view.dtype(buf15, torch.float32) del buf15 buf17 = buf16 del buf16 buf18 = buf13; del buf13 # reuse # Source Nodes: [add_2], Original ATen: [aten.add] triton_poi_fused_add_0.run(buf8, buf17, buf18, 6, grid=grid(6), stream=stream0) del buf17 del buf8 # Source Nodes: [add_2], Original ATen: [aten.add] buf19 = aten.view.dtype(buf18, torch.complex64) del buf18 buf20 = buf19 del buf19 return (buf20, ) ``` Pull Request resolved: #111773 Approved by: https://github.com/jansel
Tracks pytorch#98161 Complex number support in Pytorch isn't ideal today as complex operations will mostly end up taken care of by the aten runtime, except for `torch.angle` which is handled in [105609](pytorch#105609). In general a better way to handle that could be to decompose complex operations first so that more opportunities for fusion could be unveiled, and then to have Triton take care of non-continuous (strided) tensor operations more efficiently. This change adds support to decompose complex addtions. ``` @triton.jit def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr): xnumel = 6 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:] xmask = xindex < xnumel x0 = xindex tmp0 = tl.load(in_ptr0 + (x0), xmask) tmp1 = tl.load(in_ptr1 + (x0), xmask) tmp2 = tmp0 + tmp1 tl.store(out_ptr0 + (x0), tmp2, xmask) ``` Pull Request resolved: pytorch#110740 Approved by: https://github.com/jansel
As a follow-up to pytorch#110740, this patches enables removing redundant complex views to allow more operation fusing. E.g, given ``` @torch.compile def foo(X, Y): Z = X + Y A = X + Y return A + Z ``` the generated code is: ``` @triton.jit def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr): xnumel = 6 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:] xmask = xindex < xnumel x0 = xindex tmp0 = tl.load(in_ptr0 + (x0), xmask) tmp1 = tl.load(in_ptr1 + (x0), xmask) tmp2 = tmp0 + tmp1 tmp3 = tmp2 + tmp2 tl.store(out_ptr0 + (x0), tmp3, xmask) ''') def call(args): arg0_1, arg1_1 = args args.clear() assert_size_stride(arg0_1, (3, ), (1, )) assert_size_stride(arg1_1, (3, ), (1, )) with torch.cuda._DeviceGuard(0): torch.cuda.set_device(0) # no-op to ensure context # Source Nodes: [A], Original ATen: [aten.add] buf0 = aten.view.dtype(arg0_1, torch.float32) del arg0_1 buf1 = buf0 del buf0 # Source Nodes: [A], Original ATen: [aten.add] buf2 = aten.view.dtype(arg1_1, torch.float32) del arg1_1 buf3 = buf2 del buf2 buf4 = empty_strided((6, ), (1, ), device='cuda', dtype=torch.float32) # Source Nodes: [add_2], Original ATen: [aten.add] stream0 = get_cuda_stream(0) triton_poi_fused_add_0.run(buf1, buf3, buf4, 6, grid=grid(6), stream=stream0) del buf1 del buf3 # Source Nodes: [add_2], Original ATen: [aten.add] buf5 = aten.view.dtype(buf4, torch.complex64) del buf4 buf6 = buf5 del buf5 return (buf6, ) ``` whereas previously the generated code was: ``` @triton.jit def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr): xnumel = 6 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:] xmask = xindex < xnumel x0 = xindex tmp0 = tl.load(in_ptr0 + (x0), xmask) tmp1 = tl.load(in_ptr1 + (x0), xmask) tmp2 = tmp0 + tmp1 tl.store(out_ptr0 + (x0), tmp2, xmask) def call(args): arg0_1, arg1_1 = args args.clear() assert_size_stride(arg0_1, (3, ), (1, )) assert_size_stride(arg1_1, (3, ), (1, )) with torch.cuda._DeviceGuard(0): torch.cuda.set_device(0) # no-op to ensure context # Source Nodes: [A], Original ATen: [aten.add] buf0 = aten.view.dtype(arg0_1, torch.float32) buf1 = buf0 del buf0 # Source Nodes: [A], Original ATen: [aten.add] buf2 = aten.view.dtype(arg1_1, torch.float32) buf3 = buf2 del buf2 buf4 = empty_strided((6, ), (1, ), device='cuda', dtype=torch.float32) # Source Nodes: [A], Original ATen: [aten.add] stream0 = get_cuda_stream(0) triton_poi_fused_add_0.run(buf1, buf3, buf4, 6, grid=grid(6), stream=stream0) del buf1 del buf3 # Source Nodes: [A], Original ATen: [aten.add] buf5 = aten.view.dtype(buf4, torch.complex64) buf6 = buf5 del buf5 # Source Nodes: [add_2], Original ATen: [aten.add] buf7 = aten.view.dtype(buf6, torch.float32) del buf6 buf8 = buf7 del buf7 # Source Nodes: [Z], Original ATen: [aten.add] buf9 = aten.view.dtype(arg0_1, torch.float32) del arg0_1 buf10 = buf9 del buf9 # Source Nodes: [Z], Original ATen: [aten.add] buf11 = aten.view.dtype(arg1_1, torch.float32) del arg1_1 buf12 = buf11 del buf11 buf13 = buf4; del buf4 # reuse # Source Nodes: [Z], Original ATen: [aten.add] triton_poi_fused_add_0.run(buf10, buf12, buf13, 6, grid=grid(6), stream=stream0) del buf10 del buf12 # Source Nodes: [Z], Original ATen: [aten.add] buf14 = aten.view.dtype(buf13, torch.complex64) buf15 = buf14 del buf14 # Source Nodes: [add_2], Original ATen: [aten.add] buf16 = aten.view.dtype(buf15, torch.float32) del buf15 buf17 = buf16 del buf16 buf18 = buf13; del buf13 # reuse # Source Nodes: [add_2], Original ATen: [aten.add] triton_poi_fused_add_0.run(buf8, buf17, buf18, 6, grid=grid(6), stream=stream0) del buf17 del buf8 # Source Nodes: [add_2], Original ATen: [aten.add] buf19 = aten.view.dtype(buf18, torch.complex64) del buf18 buf20 = buf19 del buf19 return (buf20, ) ``` Pull Request resolved: pytorch#111773 Approved by: https://github.com/jansel
Tracks pytorch#98161 Complex number support in Pytorch isn't ideal today as complex operations will mostly end up taken care of by the aten runtime, except for `torch.angle` which is handled in [105609](pytorch#105609). In general a better way to handle that could be to decompose complex operations first so that more opportunities for fusion could be unveiled, and then to have Triton take care of non-continuous (strided) tensor operations more efficiently. This change adds support to decompose complex addtions. ``` @triton.jit def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr): xnumel = 6 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:] xmask = xindex < xnumel x0 = xindex tmp0 = tl.load(in_ptr0 + (x0), xmask) tmp1 = tl.load(in_ptr1 + (x0), xmask) tmp2 = tmp0 + tmp1 tl.store(out_ptr0 + (x0), tmp2, xmask) ``` Pull Request resolved: pytorch#110740 Approved by: https://github.com/jansel
As a follow-up to pytorch#110740, this patches enables removing redundant complex views to allow more operation fusing. E.g, given ``` @torch.compile def foo(X, Y): Z = X + Y A = X + Y return A + Z ``` the generated code is: ``` @triton.jit def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr): xnumel = 6 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:] xmask = xindex < xnumel x0 = xindex tmp0 = tl.load(in_ptr0 + (x0), xmask) tmp1 = tl.load(in_ptr1 + (x0), xmask) tmp2 = tmp0 + tmp1 tmp3 = tmp2 + tmp2 tl.store(out_ptr0 + (x0), tmp3, xmask) ''') def call(args): arg0_1, arg1_1 = args args.clear() assert_size_stride(arg0_1, (3, ), (1, )) assert_size_stride(arg1_1, (3, ), (1, )) with torch.cuda._DeviceGuard(0): torch.cuda.set_device(0) # no-op to ensure context # Source Nodes: [A], Original ATen: [aten.add] buf0 = aten.view.dtype(arg0_1, torch.float32) del arg0_1 buf1 = buf0 del buf0 # Source Nodes: [A], Original ATen: [aten.add] buf2 = aten.view.dtype(arg1_1, torch.float32) del arg1_1 buf3 = buf2 del buf2 buf4 = empty_strided((6, ), (1, ), device='cuda', dtype=torch.float32) # Source Nodes: [add_2], Original ATen: [aten.add] stream0 = get_cuda_stream(0) triton_poi_fused_add_0.run(buf1, buf3, buf4, 6, grid=grid(6), stream=stream0) del buf1 del buf3 # Source Nodes: [add_2], Original ATen: [aten.add] buf5 = aten.view.dtype(buf4, torch.complex64) del buf4 buf6 = buf5 del buf5 return (buf6, ) ``` whereas previously the generated code was: ``` @triton.jit def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr): xnumel = 6 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:] xmask = xindex < xnumel x0 = xindex tmp0 = tl.load(in_ptr0 + (x0), xmask) tmp1 = tl.load(in_ptr1 + (x0), xmask) tmp2 = tmp0 + tmp1 tl.store(out_ptr0 + (x0), tmp2, xmask) def call(args): arg0_1, arg1_1 = args args.clear() assert_size_stride(arg0_1, (3, ), (1, )) assert_size_stride(arg1_1, (3, ), (1, )) with torch.cuda._DeviceGuard(0): torch.cuda.set_device(0) # no-op to ensure context # Source Nodes: [A], Original ATen: [aten.add] buf0 = aten.view.dtype(arg0_1, torch.float32) buf1 = buf0 del buf0 # Source Nodes: [A], Original ATen: [aten.add] buf2 = aten.view.dtype(arg1_1, torch.float32) buf3 = buf2 del buf2 buf4 = empty_strided((6, ), (1, ), device='cuda', dtype=torch.float32) # Source Nodes: [A], Original ATen: [aten.add] stream0 = get_cuda_stream(0) triton_poi_fused_add_0.run(buf1, buf3, buf4, 6, grid=grid(6), stream=stream0) del buf1 del buf3 # Source Nodes: [A], Original ATen: [aten.add] buf5 = aten.view.dtype(buf4, torch.complex64) buf6 = buf5 del buf5 # Source Nodes: [add_2], Original ATen: [aten.add] buf7 = aten.view.dtype(buf6, torch.float32) del buf6 buf8 = buf7 del buf7 # Source Nodes: [Z], Original ATen: [aten.add] buf9 = aten.view.dtype(arg0_1, torch.float32) del arg0_1 buf10 = buf9 del buf9 # Source Nodes: [Z], Original ATen: [aten.add] buf11 = aten.view.dtype(arg1_1, torch.float32) del arg1_1 buf12 = buf11 del buf11 buf13 = buf4; del buf4 # reuse # Source Nodes: [Z], Original ATen: [aten.add] triton_poi_fused_add_0.run(buf10, buf12, buf13, 6, grid=grid(6), stream=stream0) del buf10 del buf12 # Source Nodes: [Z], Original ATen: [aten.add] buf14 = aten.view.dtype(buf13, torch.complex64) buf15 = buf14 del buf14 # Source Nodes: [add_2], Original ATen: [aten.add] buf16 = aten.view.dtype(buf15, torch.float32) del buf15 buf17 = buf16 del buf16 buf18 = buf13; del buf13 # reuse # Source Nodes: [add_2], Original ATen: [aten.add] triton_poi_fused_add_0.run(buf8, buf17, buf18, 6, grid=grid(6), stream=stream0) del buf17 del buf8 # Source Nodes: [add_2], Original ATen: [aten.add] buf19 = aten.view.dtype(buf18, torch.complex64) del buf18 buf20 = buf19 del buf19 return (buf20, ) ``` Pull Request resolved: pytorch#111773 Approved by: https://github.com/jansel
Tracks pytorch#98161 Complex number support in Pytorch isn't ideal today as complex operations will mostly end up taken care of by the aten runtime, except for `torch.angle` which is handled in [105609](pytorch#105609). In general a better way to handle that could be to decompose complex operations first so that more opportunities for fusion could be unveiled, and then to have Triton take care of non-continuous (strided) tensor operations more efficiently. This change adds support to decompose complex addtions. ``` @triton.jit def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr): xnumel = 6 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:] xmask = xindex < xnumel x0 = xindex tmp0 = tl.load(in_ptr0 + (x0), xmask) tmp1 = tl.load(in_ptr1 + (x0), xmask) tmp2 = tmp0 + tmp1 tl.store(out_ptr0 + (x0), tmp2, xmask) ``` Pull Request resolved: pytorch#110740 Approved by: https://github.com/jansel
As a follow-up to pytorch#110740, this patches enables removing redundant complex views to allow more operation fusing. E.g, given ``` @torch.compile def foo(X, Y): Z = X + Y A = X + Y return A + Z ``` the generated code is: ``` @triton.jit def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr): xnumel = 6 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:] xmask = xindex < xnumel x0 = xindex tmp0 = tl.load(in_ptr0 + (x0), xmask) tmp1 = tl.load(in_ptr1 + (x0), xmask) tmp2 = tmp0 + tmp1 tmp3 = tmp2 + tmp2 tl.store(out_ptr0 + (x0), tmp3, xmask) ''') def call(args): arg0_1, arg1_1 = args args.clear() assert_size_stride(arg0_1, (3, ), (1, )) assert_size_stride(arg1_1, (3, ), (1, )) with torch.cuda._DeviceGuard(0): torch.cuda.set_device(0) # no-op to ensure context # Source Nodes: [A], Original ATen: [aten.add] buf0 = aten.view.dtype(arg0_1, torch.float32) del arg0_1 buf1 = buf0 del buf0 # Source Nodes: [A], Original ATen: [aten.add] buf2 = aten.view.dtype(arg1_1, torch.float32) del arg1_1 buf3 = buf2 del buf2 buf4 = empty_strided((6, ), (1, ), device='cuda', dtype=torch.float32) # Source Nodes: [add_2], Original ATen: [aten.add] stream0 = get_cuda_stream(0) triton_poi_fused_add_0.run(buf1, buf3, buf4, 6, grid=grid(6), stream=stream0) del buf1 del buf3 # Source Nodes: [add_2], Original ATen: [aten.add] buf5 = aten.view.dtype(buf4, torch.complex64) del buf4 buf6 = buf5 del buf5 return (buf6, ) ``` whereas previously the generated code was: ``` @triton.jit def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr): xnumel = 6 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:] xmask = xindex < xnumel x0 = xindex tmp0 = tl.load(in_ptr0 + (x0), xmask) tmp1 = tl.load(in_ptr1 + (x0), xmask) tmp2 = tmp0 + tmp1 tl.store(out_ptr0 + (x0), tmp2, xmask) def call(args): arg0_1, arg1_1 = args args.clear() assert_size_stride(arg0_1, (3, ), (1, )) assert_size_stride(arg1_1, (3, ), (1, )) with torch.cuda._DeviceGuard(0): torch.cuda.set_device(0) # no-op to ensure context # Source Nodes: [A], Original ATen: [aten.add] buf0 = aten.view.dtype(arg0_1, torch.float32) buf1 = buf0 del buf0 # Source Nodes: [A], Original ATen: [aten.add] buf2 = aten.view.dtype(arg1_1, torch.float32) buf3 = buf2 del buf2 buf4 = empty_strided((6, ), (1, ), device='cuda', dtype=torch.float32) # Source Nodes: [A], Original ATen: [aten.add] stream0 = get_cuda_stream(0) triton_poi_fused_add_0.run(buf1, buf3, buf4, 6, grid=grid(6), stream=stream0) del buf1 del buf3 # Source Nodes: [A], Original ATen: [aten.add] buf5 = aten.view.dtype(buf4, torch.complex64) buf6 = buf5 del buf5 # Source Nodes: [add_2], Original ATen: [aten.add] buf7 = aten.view.dtype(buf6, torch.float32) del buf6 buf8 = buf7 del buf7 # Source Nodes: [Z], Original ATen: [aten.add] buf9 = aten.view.dtype(arg0_1, torch.float32) del arg0_1 buf10 = buf9 del buf9 # Source Nodes: [Z], Original ATen: [aten.add] buf11 = aten.view.dtype(arg1_1, torch.float32) del arg1_1 buf12 = buf11 del buf11 buf13 = buf4; del buf4 # reuse # Source Nodes: [Z], Original ATen: [aten.add] triton_poi_fused_add_0.run(buf10, buf12, buf13, 6, grid=grid(6), stream=stream0) del buf10 del buf12 # Source Nodes: [Z], Original ATen: [aten.add] buf14 = aten.view.dtype(buf13, torch.complex64) buf15 = buf14 del buf14 # Source Nodes: [add_2], Original ATen: [aten.add] buf16 = aten.view.dtype(buf15, torch.float32) del buf15 buf17 = buf16 del buf16 buf18 = buf13; del buf13 # reuse # Source Nodes: [add_2], Original ATen: [aten.add] triton_poi_fused_add_0.run(buf8, buf17, buf18, 6, grid=grid(6), stream=stream0) del buf17 del buf8 # Source Nodes: [add_2], Original ATen: [aten.add] buf19 = aten.view.dtype(buf18, torch.complex64) del buf18 buf20 = buf19 del buf19 return (buf20, ) ``` Pull Request resolved: pytorch#111773 Approved by: https://github.com/jansel
Tracks #98161
Complex number support in Pytorch isn't ideal today as complex operations will mostly end up taken care of by the aten runtime, except for
torch.angle
which is handled in 105609. In general a better way to handle that could be to decompose complex operations first so that more opportunities for fusion could be unveiled, and then to have Triton take care of non-continuous (strided) tensor operations more efficiently. This change adds support to decompose complex addtions.cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler