Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[inductor] decomposition for complex addition #110740

Closed
wants to merge 5 commits into from
Closed

Conversation

htyu
Copy link
Contributor

@htyu htyu commented Oct 6, 2023

Tracks #98161

Complex number support in Pytorch isn't ideal today as complex operations will mostly end up taken care of by the aten runtime, except for torch.angle which is handled in 105609. In general a better way to handle that could be to decompose complex operations first so that more opportunities for fusion could be unveiled, and then to have Triton take care of non-continuous (strided) tensor operations more efficiently. This change adds support to decompose complex addtions.

@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 6
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (x0), xmask)
    tmp1 = tl.load(in_ptr1 + (x0), xmask)
    tmp2 = tmp0 + tmp1
    tl.store(out_ptr0 + (x0), tmp2, xmask)

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 6, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/110740

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 4ca773b with merge base 0a26e5f (image):

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@linux-foundation-easycla
Copy link

linux-foundation-easycla bot commented Oct 6, 2023

CLA Signed

The committers listed above are authorized under a signed CLA.

@htyu htyu marked this pull request as draft October 6, 2023 18:48
@htyu htyu force-pushed the hoy branch 10 times, most recently from 54ab871 to c10c88c Compare October 9, 2023 07:14
@htyu htyu marked this pull request as ready for review October 9, 2023 16:35
@htyu
Copy link
Contributor Author

htyu commented Oct 9, 2023

@jansel Can you please take a look? There are some test failures that seem unrelated and look like infra issue.

@htyu htyu requested a review from jansel October 9, 2023 17:12
r = x.real + r
if x_is_complex_tensor:
i = x.imag + i
complex_type = x.dtype if x_is_complex_tensor else y.dtype
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.promote_types

torch/_inductor/decomposition.py Outdated Show resolved Hide resolved
torch/_inductor/decomposition.py Outdated Show resolved Hide resolved
Comment on lines 288 to 296
return (
torch.where(
torch.arange(2, device=x.device, dtype=torch.uint8) == 0,
r.unsqueeze(-1),
i.unsqueeze(-1),
)
.view(complex_type)
.squeeze(-1)
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For add in particular I think you can just do:

return (x.view(torch.float32)+x.view(torch.float32)).view(complex_type)

Since you do the same thing for both .real and .complex.

Copy link
Contributor Author

@htyu htyu Oct 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is cool. Thanks for the suggestion. However, the generated code still have aten.view calls and CSE cannot be done across kenerls. I guess the goal here is to optimize away aten.view?

with torch.cuda._DeviceGuard(0):
     torch.cuda.set_device(0) # no-op to ensure context
     # Source Nodes: [Z], Original ATen: [aten.add]
     buf0 = aten.view(arg0_1, torch.float32)
     del arg0_1
     buf1 = buf0
     del buf0
     # Source Nodes: [Z], Original ATen: [aten.add]
     buf2 = aten.view(arg1_1, torch.float32)
     del arg1_1
     buf3 = buf2
     del buf2
     buf4 = buf1; del buf1  # reuse
     # Source Nodes: [Z], Original ATen: [aten.add]
     stream0 = get_cuda_stream(0)
     triton_poi_fused_add_0.run(buf4, buf3, 2000000, grid=grid(2000000), stream=stream0)
     del buf3
     # Source Nodes: [Z], Original ATen: [aten.add]
     buf5 = aten.view(buf4, torch.complex64)
     del buf4
     buf6 = buf5
     del buf5
     return (buf6, )

Also the new form does not handle that when either x or y is a scalar.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, fixing the view problems requires teaching inductor how to do complex views.

The scalar case would need to be handled differently. You could first convert the scalar to complex.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jansel I think an easier thing to do would just be to pattern-match away redundant view calls right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where would you do that? BTW, are they redundant? I thought without the view calls, the type of the tensors wouldn't be correct.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case, they're not redundant. But if you have two adds in a row, the views in the way would prevent fusion.

So for a single complex add, it would be

x.view(torch.float64)
x + x
x.view(torch.complex64)

Then for two complex adds, it would be

x.view(torch.float64)
x + x
x.view(torch.complex64)
x.view(torch.float64)
x + x
x.view(torch.complex64)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is the additional problem in that these views are getting mapped to fallback kernels, which don't properly encode the aliasing relationship. So inductor may try to reuse the memory for something else, which is a correctness issue.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’m working on a new inductor IR node that handles complex/float views to avoid calling into the runtime. According to recent discussion, it looks like the triton community doesn’t like supporting complex natively there. But do you see a way to handle the views without using a triton kernel? Once the views are lowered, inductor will fuse the lowered code into the same triton kernel, which exposes complex64 to triton.

torch/_inductor/decomposition.py Outdated Show resolved Hide resolved
@htyu htyu force-pushed the hoy branch 2 times, most recently from caa2d13 to 6e43e32 Compare October 10, 2023 04:18
@EikanWang
Copy link
Collaborator

Awesome! I'm also looking for an optimal methodology to support Complex. @htyu , May I know the performance behavior? I'm kind of concerned about the performance on the Triton side, as the non-contiguous access may not be vectorized easily.

@htyu
Copy link
Contributor Author

htyu commented Oct 13, 2023

Awesome! I'm also looking for an optimal methodology to support Complex. @htyu , May I know the performance behavior? I'm kind of concerned about the performance on the Triton side, as the non-contiguous access may not be vectorized easily.

Optimal performance is our goal too in general. For this particular issue, since there is no strided memory accesses created, the performance is acutally on-par with aten.add call for a single add operation, but I hope with some follow-up work, we should be able to fold/CSE multiple add operations, which would perform better than the call version.

Non-contiguous access will be an issue for other complex num operations such as matmul. There is a discussion about whether to support complex natively in Triton which may be smarter to handle non-contiguous access with packing/unpacking: https://triton-lang.slack.com/archives/C01LY4FJL56/p1697060717465109 . You are welcome to discuss there. My understanding is it's not easy to come up with an optimal scheme for both cuda cores and tensor cores.

torch/_inductor/ir.py Outdated Show resolved Hide resolved
torch/_inductor/graph.py Outdated Show resolved Hide resolved
torch/_inductor/lowering.py Outdated Show resolved Hide resolved
@htyu htyu force-pushed the hoy branch 2 times, most recently from b4d00b5 to 4bba606 Compare October 15, 2023 23:01
@@ -611,6 +611,15 @@ def fn(x, y):

self.common(fn, (x, y))

def test_add_complex(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add tests for:

  1. complex + non-complex
  2. complex + scalar
  3. alpha=...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also if you have a test that demonstrates the aliasing issue with the prior version you should add that too.

Copy link
Contributor Author

@htyu htyu Oct 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added test for alpha.

For # 1 and # 2, I'm afraid that a strided load will be needed as the computation has nothing to do with the imaginary part. Alternatively we would load the complex number as a whole in a thread, but for complex128 we need int128 support which isn't there in Trition. Can we leave them in a separate change until downstream support is ready?

Also if you have a test that demonstrates the aliasing issue with the prior version you should add that too.

Somehow I couldn't repro this anymore :( Could it be due to recent has_aliasing change #110651 ?

Copy link
Collaborator

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is the correct approach. How would you implement things like multiplicaiton or abs? If you do it naïvely, you'll end up with 2 non-contiguous loads which will give you a pretty horrific performance, and as it stands, it's not possible to do it better without introducing quite a few structural changes to how inductor loads values.

This is an intrinsic limitation of this approach, as addition is pretty much the only operation you can implement without having to split the real and complex part.

I think that, if we want to support complex number, we should look into helping adding complex number support within triton, and supporting complex numbers natively in inductor.

@lezcano
Copy link
Collaborator

lezcano commented Oct 16, 2023

After reading #98161 (comment), it looks like triton does not want to support complex numbers, but we would still need to have a way to load complex tensors efficiently, and have a way to implement performant lowerings for complex numbers in triton, which may be different to those from CPU, as CPU already has all these ops implemented in the scalar and vectorized dtypes

@htyu
Copy link
Contributor Author

htyu commented Oct 23, 2023

@pytorchbot label "topic: not user facing"

@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label Oct 23, 2023
@htyu
Copy link
Contributor Author

htyu commented Oct 23, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 2 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@htyu
Copy link
Contributor Author

htyu commented Oct 24, 2023

@pytorchbot merge -f "bypass unrelated failure"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Oct 26, 2023
As a follow-up to #110740, this patches enables removing redundant complex views to allow more operation fusing.

E.g,  given

```
@torch.compile
def foo(X, Y):
    Z = X + Y
    A = X + Y
    return A + Z
```

the generated code is:

```
@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 6
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (x0), xmask)
    tmp1 = tl.load(in_ptr1 + (x0), xmask)
    tmp2 = tmp0 + tmp1
    tmp3 = tmp2 + tmp2
    tl.store(out_ptr0 + (x0), tmp3, xmask)
''')

def call(args):
    arg0_1, arg1_1 = args
    args.clear()
    assert_size_stride(arg0_1, (3, ), (1, ))
    assert_size_stride(arg1_1, (3, ), (1, ))
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0) # no-op to ensure context
        # Source Nodes: [A], Original ATen: [aten.add]
        buf0 = aten.view.dtype(arg0_1, torch.float32)
        del arg0_1
        buf1 = buf0
        del buf0
        # Source Nodes: [A], Original ATen: [aten.add]
        buf2 = aten.view.dtype(arg1_1, torch.float32)
        del arg1_1
        buf3 = buf2
        del buf2
        buf4 = empty_strided((6, ), (1, ), device='cuda', dtype=torch.float32)
        # Source Nodes: [add_2], Original ATen: [aten.add]
        stream0 = get_cuda_stream(0)
        triton_poi_fused_add_0.run(buf1, buf3, buf4, 6, grid=grid(6), stream=stream0)
        del buf1
        del buf3
        # Source Nodes: [add_2], Original ATen: [aten.add]
        buf5 = aten.view.dtype(buf4, torch.complex64)
        del buf4
        buf6 = buf5
        del buf5
        return (buf6, )
```

whereas previously the generated code was:

```
@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 6
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (x0), xmask)
    tmp1 = tl.load(in_ptr1 + (x0), xmask)
    tmp2 = tmp0 + tmp1
    tl.store(out_ptr0 + (x0), tmp2, xmask)

def call(args):
    arg0_1, arg1_1 = args
    args.clear()
    assert_size_stride(arg0_1, (3, ), (1, ))
    assert_size_stride(arg1_1, (3, ), (1, ))
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0) # no-op to ensure context
        # Source Nodes: [A], Original ATen: [aten.add]
        buf0 = aten.view.dtype(arg0_1, torch.float32)
        buf1 = buf0
        del buf0
        # Source Nodes: [A], Original ATen: [aten.add]
        buf2 = aten.view.dtype(arg1_1, torch.float32)
        buf3 = buf2
        del buf2
        buf4 = empty_strided((6, ), (1, ), device='cuda', dtype=torch.float32)
        # Source Nodes: [A], Original ATen: [aten.add]
        stream0 = get_cuda_stream(0)
        triton_poi_fused_add_0.run(buf1, buf3, buf4, 6, grid=grid(6), stream=stream0)
        del buf1
        del buf3
        # Source Nodes: [A], Original ATen: [aten.add]
        buf5 = aten.view.dtype(buf4, torch.complex64)
        buf6 = buf5
        del buf5
        # Source Nodes: [add_2], Original ATen: [aten.add]
        buf7 = aten.view.dtype(buf6, torch.float32)
        del buf6
        buf8 = buf7
        del buf7
        # Source Nodes: [Z], Original ATen: [aten.add]
        buf9 = aten.view.dtype(arg0_1, torch.float32)
        del arg0_1
        buf10 = buf9
        del buf9
        # Source Nodes: [Z], Original ATen: [aten.add]
        buf11 = aten.view.dtype(arg1_1, torch.float32)
        del arg1_1
        buf12 = buf11
        del buf11
        buf13 = buf4; del buf4  # reuse
        # Source Nodes: [Z], Original ATen: [aten.add]
        triton_poi_fused_add_0.run(buf10, buf12, buf13, 6, grid=grid(6), stream=stream0)
        del buf10
        del buf12
        # Source Nodes: [Z], Original ATen: [aten.add]
        buf14 = aten.view.dtype(buf13, torch.complex64)
        buf15 = buf14
        del buf14
        # Source Nodes: [add_2], Original ATen: [aten.add]
        buf16 = aten.view.dtype(buf15, torch.float32)
        del buf15
        buf17 = buf16
        del buf16
        buf18 = buf13; del buf13  # reuse
        # Source Nodes: [add_2], Original ATen: [aten.add]
        triton_poi_fused_add_0.run(buf8, buf17, buf18, 6, grid=grid(6), stream=stream0)
        del buf17
        del buf8
        # Source Nodes: [add_2], Original ATen: [aten.add]
        buf19 = aten.view.dtype(buf18, torch.complex64)
        del buf18
        buf20 = buf19
        del buf19
        return (buf20, )
```

Pull Request resolved: #111773
Approved by: https://github.com/jansel
andreigh pushed a commit to andreigh/pytorch that referenced this pull request Oct 26, 2023
Tracks pytorch#98161

Complex number support in Pytorch isn't ideal today as complex operations will mostly end up taken care of by the aten runtime, except for `torch.angle` which is handled in [105609](pytorch#105609). In general a better way to handle that could be to decompose complex operations first so that more opportunities for fusion could be unveiled, and then to have Triton take care of non-continuous (strided) tensor operations more efficiently. This change adds support to decompose complex addtions.

```
@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 6
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (x0), xmask)
    tmp1 = tl.load(in_ptr1 + (x0), xmask)
    tmp2 = tmp0 + tmp1
    tl.store(out_ptr0 + (x0), tmp2, xmask)
```

Pull Request resolved: pytorch#110740
Approved by: https://github.com/jansel
andreigh pushed a commit to andreigh/pytorch that referenced this pull request Oct 26, 2023
As a follow-up to pytorch#110740, this patches enables removing redundant complex views to allow more operation fusing.

E.g,  given

```
@torch.compile
def foo(X, Y):
    Z = X + Y
    A = X + Y
    return A + Z
```

the generated code is:

```
@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 6
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (x0), xmask)
    tmp1 = tl.load(in_ptr1 + (x0), xmask)
    tmp2 = tmp0 + tmp1
    tmp3 = tmp2 + tmp2
    tl.store(out_ptr0 + (x0), tmp3, xmask)
''')

def call(args):
    arg0_1, arg1_1 = args
    args.clear()
    assert_size_stride(arg0_1, (3, ), (1, ))
    assert_size_stride(arg1_1, (3, ), (1, ))
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0) # no-op to ensure context
        # Source Nodes: [A], Original ATen: [aten.add]
        buf0 = aten.view.dtype(arg0_1, torch.float32)
        del arg0_1
        buf1 = buf0
        del buf0
        # Source Nodes: [A], Original ATen: [aten.add]
        buf2 = aten.view.dtype(arg1_1, torch.float32)
        del arg1_1
        buf3 = buf2
        del buf2
        buf4 = empty_strided((6, ), (1, ), device='cuda', dtype=torch.float32)
        # Source Nodes: [add_2], Original ATen: [aten.add]
        stream0 = get_cuda_stream(0)
        triton_poi_fused_add_0.run(buf1, buf3, buf4, 6, grid=grid(6), stream=stream0)
        del buf1
        del buf3
        # Source Nodes: [add_2], Original ATen: [aten.add]
        buf5 = aten.view.dtype(buf4, torch.complex64)
        del buf4
        buf6 = buf5
        del buf5
        return (buf6, )
```

whereas previously the generated code was:

```
@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 6
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (x0), xmask)
    tmp1 = tl.load(in_ptr1 + (x0), xmask)
    tmp2 = tmp0 + tmp1
    tl.store(out_ptr0 + (x0), tmp2, xmask)

def call(args):
    arg0_1, arg1_1 = args
    args.clear()
    assert_size_stride(arg0_1, (3, ), (1, ))
    assert_size_stride(arg1_1, (3, ), (1, ))
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0) # no-op to ensure context
        # Source Nodes: [A], Original ATen: [aten.add]
        buf0 = aten.view.dtype(arg0_1, torch.float32)
        buf1 = buf0
        del buf0
        # Source Nodes: [A], Original ATen: [aten.add]
        buf2 = aten.view.dtype(arg1_1, torch.float32)
        buf3 = buf2
        del buf2
        buf4 = empty_strided((6, ), (1, ), device='cuda', dtype=torch.float32)
        # Source Nodes: [A], Original ATen: [aten.add]
        stream0 = get_cuda_stream(0)
        triton_poi_fused_add_0.run(buf1, buf3, buf4, 6, grid=grid(6), stream=stream0)
        del buf1
        del buf3
        # Source Nodes: [A], Original ATen: [aten.add]
        buf5 = aten.view.dtype(buf4, torch.complex64)
        buf6 = buf5
        del buf5
        # Source Nodes: [add_2], Original ATen: [aten.add]
        buf7 = aten.view.dtype(buf6, torch.float32)
        del buf6
        buf8 = buf7
        del buf7
        # Source Nodes: [Z], Original ATen: [aten.add]
        buf9 = aten.view.dtype(arg0_1, torch.float32)
        del arg0_1
        buf10 = buf9
        del buf9
        # Source Nodes: [Z], Original ATen: [aten.add]
        buf11 = aten.view.dtype(arg1_1, torch.float32)
        del arg1_1
        buf12 = buf11
        del buf11
        buf13 = buf4; del buf4  # reuse
        # Source Nodes: [Z], Original ATen: [aten.add]
        triton_poi_fused_add_0.run(buf10, buf12, buf13, 6, grid=grid(6), stream=stream0)
        del buf10
        del buf12
        # Source Nodes: [Z], Original ATen: [aten.add]
        buf14 = aten.view.dtype(buf13, torch.complex64)
        buf15 = buf14
        del buf14
        # Source Nodes: [add_2], Original ATen: [aten.add]
        buf16 = aten.view.dtype(buf15, torch.float32)
        del buf15
        buf17 = buf16
        del buf16
        buf18 = buf13; del buf13  # reuse
        # Source Nodes: [add_2], Original ATen: [aten.add]
        triton_poi_fused_add_0.run(buf8, buf17, buf18, 6, grid=grid(6), stream=stream0)
        del buf17
        del buf8
        # Source Nodes: [add_2], Original ATen: [aten.add]
        buf19 = aten.view.dtype(buf18, torch.complex64)
        del buf18
        buf20 = buf19
        del buf19
        return (buf20, )
```

Pull Request resolved: pytorch#111773
Approved by: https://github.com/jansel
xuhancn pushed a commit to xuhancn/pytorch that referenced this pull request Nov 7, 2023
Tracks pytorch#98161

Complex number support in Pytorch isn't ideal today as complex operations will mostly end up taken care of by the aten runtime, except for `torch.angle` which is handled in [105609](pytorch#105609). In general a better way to handle that could be to decompose complex operations first so that more opportunities for fusion could be unveiled, and then to have Triton take care of non-continuous (strided) tensor operations more efficiently. This change adds support to decompose complex addtions.

```
@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 6
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (x0), xmask)
    tmp1 = tl.load(in_ptr1 + (x0), xmask)
    tmp2 = tmp0 + tmp1
    tl.store(out_ptr0 + (x0), tmp2, xmask)
```

Pull Request resolved: pytorch#110740
Approved by: https://github.com/jansel
xuhancn pushed a commit to xuhancn/pytorch that referenced this pull request Nov 7, 2023
As a follow-up to pytorch#110740, this patches enables removing redundant complex views to allow more operation fusing.

E.g,  given

```
@torch.compile
def foo(X, Y):
    Z = X + Y
    A = X + Y
    return A + Z
```

the generated code is:

```
@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 6
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (x0), xmask)
    tmp1 = tl.load(in_ptr1 + (x0), xmask)
    tmp2 = tmp0 + tmp1
    tmp3 = tmp2 + tmp2
    tl.store(out_ptr0 + (x0), tmp3, xmask)
''')

def call(args):
    arg0_1, arg1_1 = args
    args.clear()
    assert_size_stride(arg0_1, (3, ), (1, ))
    assert_size_stride(arg1_1, (3, ), (1, ))
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0) # no-op to ensure context
        # Source Nodes: [A], Original ATen: [aten.add]
        buf0 = aten.view.dtype(arg0_1, torch.float32)
        del arg0_1
        buf1 = buf0
        del buf0
        # Source Nodes: [A], Original ATen: [aten.add]
        buf2 = aten.view.dtype(arg1_1, torch.float32)
        del arg1_1
        buf3 = buf2
        del buf2
        buf4 = empty_strided((6, ), (1, ), device='cuda', dtype=torch.float32)
        # Source Nodes: [add_2], Original ATen: [aten.add]
        stream0 = get_cuda_stream(0)
        triton_poi_fused_add_0.run(buf1, buf3, buf4, 6, grid=grid(6), stream=stream0)
        del buf1
        del buf3
        # Source Nodes: [add_2], Original ATen: [aten.add]
        buf5 = aten.view.dtype(buf4, torch.complex64)
        del buf4
        buf6 = buf5
        del buf5
        return (buf6, )
```

whereas previously the generated code was:

```
@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 6
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (x0), xmask)
    tmp1 = tl.load(in_ptr1 + (x0), xmask)
    tmp2 = tmp0 + tmp1
    tl.store(out_ptr0 + (x0), tmp2, xmask)

def call(args):
    arg0_1, arg1_1 = args
    args.clear()
    assert_size_stride(arg0_1, (3, ), (1, ))
    assert_size_stride(arg1_1, (3, ), (1, ))
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0) # no-op to ensure context
        # Source Nodes: [A], Original ATen: [aten.add]
        buf0 = aten.view.dtype(arg0_1, torch.float32)
        buf1 = buf0
        del buf0
        # Source Nodes: [A], Original ATen: [aten.add]
        buf2 = aten.view.dtype(arg1_1, torch.float32)
        buf3 = buf2
        del buf2
        buf4 = empty_strided((6, ), (1, ), device='cuda', dtype=torch.float32)
        # Source Nodes: [A], Original ATen: [aten.add]
        stream0 = get_cuda_stream(0)
        triton_poi_fused_add_0.run(buf1, buf3, buf4, 6, grid=grid(6), stream=stream0)
        del buf1
        del buf3
        # Source Nodes: [A], Original ATen: [aten.add]
        buf5 = aten.view.dtype(buf4, torch.complex64)
        buf6 = buf5
        del buf5
        # Source Nodes: [add_2], Original ATen: [aten.add]
        buf7 = aten.view.dtype(buf6, torch.float32)
        del buf6
        buf8 = buf7
        del buf7
        # Source Nodes: [Z], Original ATen: [aten.add]
        buf9 = aten.view.dtype(arg0_1, torch.float32)
        del arg0_1
        buf10 = buf9
        del buf9
        # Source Nodes: [Z], Original ATen: [aten.add]
        buf11 = aten.view.dtype(arg1_1, torch.float32)
        del arg1_1
        buf12 = buf11
        del buf11
        buf13 = buf4; del buf4  # reuse
        # Source Nodes: [Z], Original ATen: [aten.add]
        triton_poi_fused_add_0.run(buf10, buf12, buf13, 6, grid=grid(6), stream=stream0)
        del buf10
        del buf12
        # Source Nodes: [Z], Original ATen: [aten.add]
        buf14 = aten.view.dtype(buf13, torch.complex64)
        buf15 = buf14
        del buf14
        # Source Nodes: [add_2], Original ATen: [aten.add]
        buf16 = aten.view.dtype(buf15, torch.float32)
        del buf15
        buf17 = buf16
        del buf16
        buf18 = buf13; del buf13  # reuse
        # Source Nodes: [add_2], Original ATen: [aten.add]
        triton_poi_fused_add_0.run(buf8, buf17, buf18, 6, grid=grid(6), stream=stream0)
        del buf17
        del buf8
        # Source Nodes: [add_2], Original ATen: [aten.add]
        buf19 = aten.view.dtype(buf18, torch.complex64)
        del buf18
        buf20 = buf19
        del buf19
        return (buf20, )
```

Pull Request resolved: pytorch#111773
Approved by: https://github.com/jansel
Skylion007 pushed a commit to Skylion007/pytorch that referenced this pull request Nov 14, 2023
Tracks pytorch#98161

Complex number support in Pytorch isn't ideal today as complex operations will mostly end up taken care of by the aten runtime, except for `torch.angle` which is handled in [105609](pytorch#105609). In general a better way to handle that could be to decompose complex operations first so that more opportunities for fusion could be unveiled, and then to have Triton take care of non-continuous (strided) tensor operations more efficiently. This change adds support to decompose complex addtions.

```
@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 6
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (x0), xmask)
    tmp1 = tl.load(in_ptr1 + (x0), xmask)
    tmp2 = tmp0 + tmp1
    tl.store(out_ptr0 + (x0), tmp2, xmask)
```

Pull Request resolved: pytorch#110740
Approved by: https://github.com/jansel
Skylion007 pushed a commit to Skylion007/pytorch that referenced this pull request Nov 14, 2023
As a follow-up to pytorch#110740, this patches enables removing redundant complex views to allow more operation fusing.

E.g,  given

```
@torch.compile
def foo(X, Y):
    Z = X + Y
    A = X + Y
    return A + Z
```

the generated code is:

```
@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 6
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (x0), xmask)
    tmp1 = tl.load(in_ptr1 + (x0), xmask)
    tmp2 = tmp0 + tmp1
    tmp3 = tmp2 + tmp2
    tl.store(out_ptr0 + (x0), tmp3, xmask)
''')

def call(args):
    arg0_1, arg1_1 = args
    args.clear()
    assert_size_stride(arg0_1, (3, ), (1, ))
    assert_size_stride(arg1_1, (3, ), (1, ))
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0) # no-op to ensure context
        # Source Nodes: [A], Original ATen: [aten.add]
        buf0 = aten.view.dtype(arg0_1, torch.float32)
        del arg0_1
        buf1 = buf0
        del buf0
        # Source Nodes: [A], Original ATen: [aten.add]
        buf2 = aten.view.dtype(arg1_1, torch.float32)
        del arg1_1
        buf3 = buf2
        del buf2
        buf4 = empty_strided((6, ), (1, ), device='cuda', dtype=torch.float32)
        # Source Nodes: [add_2], Original ATen: [aten.add]
        stream0 = get_cuda_stream(0)
        triton_poi_fused_add_0.run(buf1, buf3, buf4, 6, grid=grid(6), stream=stream0)
        del buf1
        del buf3
        # Source Nodes: [add_2], Original ATen: [aten.add]
        buf5 = aten.view.dtype(buf4, torch.complex64)
        del buf4
        buf6 = buf5
        del buf5
        return (buf6, )
```

whereas previously the generated code was:

```
@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 6
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (x0), xmask)
    tmp1 = tl.load(in_ptr1 + (x0), xmask)
    tmp2 = tmp0 + tmp1
    tl.store(out_ptr0 + (x0), tmp2, xmask)

def call(args):
    arg0_1, arg1_1 = args
    args.clear()
    assert_size_stride(arg0_1, (3, ), (1, ))
    assert_size_stride(arg1_1, (3, ), (1, ))
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0) # no-op to ensure context
        # Source Nodes: [A], Original ATen: [aten.add]
        buf0 = aten.view.dtype(arg0_1, torch.float32)
        buf1 = buf0
        del buf0
        # Source Nodes: [A], Original ATen: [aten.add]
        buf2 = aten.view.dtype(arg1_1, torch.float32)
        buf3 = buf2
        del buf2
        buf4 = empty_strided((6, ), (1, ), device='cuda', dtype=torch.float32)
        # Source Nodes: [A], Original ATen: [aten.add]
        stream0 = get_cuda_stream(0)
        triton_poi_fused_add_0.run(buf1, buf3, buf4, 6, grid=grid(6), stream=stream0)
        del buf1
        del buf3
        # Source Nodes: [A], Original ATen: [aten.add]
        buf5 = aten.view.dtype(buf4, torch.complex64)
        buf6 = buf5
        del buf5
        # Source Nodes: [add_2], Original ATen: [aten.add]
        buf7 = aten.view.dtype(buf6, torch.float32)
        del buf6
        buf8 = buf7
        del buf7
        # Source Nodes: [Z], Original ATen: [aten.add]
        buf9 = aten.view.dtype(arg0_1, torch.float32)
        del arg0_1
        buf10 = buf9
        del buf9
        # Source Nodes: [Z], Original ATen: [aten.add]
        buf11 = aten.view.dtype(arg1_1, torch.float32)
        del arg1_1
        buf12 = buf11
        del buf11
        buf13 = buf4; del buf4  # reuse
        # Source Nodes: [Z], Original ATen: [aten.add]
        triton_poi_fused_add_0.run(buf10, buf12, buf13, 6, grid=grid(6), stream=stream0)
        del buf10
        del buf12
        # Source Nodes: [Z], Original ATen: [aten.add]
        buf14 = aten.view.dtype(buf13, torch.complex64)
        buf15 = buf14
        del buf14
        # Source Nodes: [add_2], Original ATen: [aten.add]
        buf16 = aten.view.dtype(buf15, torch.float32)
        del buf15
        buf17 = buf16
        del buf16
        buf18 = buf13; del buf13  # reuse
        # Source Nodes: [add_2], Original ATen: [aten.add]
        triton_poi_fused_add_0.run(buf8, buf17, buf18, 6, grid=grid(6), stream=stream0)
        del buf17
        del buf8
        # Source Nodes: [add_2], Original ATen: [aten.add]
        buf19 = aten.view.dtype(buf18, torch.complex64)
        del buf18
        buf20 = buf19
        del buf19
        return (buf20, )
```

Pull Request resolved: pytorch#111773
Approved by: https://github.com/jansel
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants