Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[inductor] Remove redundant views #111773

Closed
wants to merge 4 commits into from
Closed

[inductor] Remove redundant views #111773

wants to merge 4 commits into from

Conversation

htyu
Copy link
Contributor

@htyu htyu commented Oct 23, 2023

As a follow-up to #110740, this patches enables removing redundant complex views to allow more operation fusing.

E.g, given

@torch.compile
def foo(X, Y):
    Z = X + Y
    A = X + Y
    return A + Z

the generated code is:

@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 6
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (x0), xmask)
    tmp1 = tl.load(in_ptr1 + (x0), xmask)
    tmp2 = tmp0 + tmp1
    tmp3 = tmp2 + tmp2
    tl.store(out_ptr0 + (x0), tmp3, xmask)
''')

def call(args):
    arg0_1, arg1_1 = args
    args.clear()
    assert_size_stride(arg0_1, (3, ), (1, ))
    assert_size_stride(arg1_1, (3, ), (1, ))
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0) # no-op to ensure context
        # Source Nodes: [A], Original ATen: [aten.add]
        buf0 = aten.view.dtype(arg0_1, torch.float32)
        del arg0_1
        buf1 = buf0
        del buf0
        # Source Nodes: [A], Original ATen: [aten.add]
        buf2 = aten.view.dtype(arg1_1, torch.float32)
        del arg1_1
        buf3 = buf2
        del buf2
        buf4 = empty_strided((6, ), (1, ), device='cuda', dtype=torch.float32)
        # Source Nodes: [add_2], Original ATen: [aten.add]
        stream0 = get_cuda_stream(0)
        triton_poi_fused_add_0.run(buf1, buf3, buf4, 6, grid=grid(6), stream=stream0)
        del buf1
        del buf3
        # Source Nodes: [add_2], Original ATen: [aten.add]
        buf5 = aten.view.dtype(buf4, torch.complex64)
        del buf4
        buf6 = buf5
        del buf5
        return (buf6, )

whereas previously the generated code was:

@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 6
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (x0), xmask)
    tmp1 = tl.load(in_ptr1 + (x0), xmask)
    tmp2 = tmp0 + tmp1
    tl.store(out_ptr0 + (x0), tmp2, xmask)

def call(args):
    arg0_1, arg1_1 = args
    args.clear()
    assert_size_stride(arg0_1, (3, ), (1, ))
    assert_size_stride(arg1_1, (3, ), (1, ))
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0) # no-op to ensure context
        # Source Nodes: [A], Original ATen: [aten.add]
        buf0 = aten.view.dtype(arg0_1, torch.float32)
        buf1 = buf0
        del buf0
        # Source Nodes: [A], Original ATen: [aten.add]
        buf2 = aten.view.dtype(arg1_1, torch.float32)
        buf3 = buf2
        del buf2
        buf4 = empty_strided((6, ), (1, ), device='cuda', dtype=torch.float32)
        # Source Nodes: [A], Original ATen: [aten.add]
        stream0 = get_cuda_stream(0)
        triton_poi_fused_add_0.run(buf1, buf3, buf4, 6, grid=grid(6), stream=stream0)
        del buf1
        del buf3
        # Source Nodes: [A], Original ATen: [aten.add]
        buf5 = aten.view.dtype(buf4, torch.complex64)
        buf6 = buf5
        del buf5
        # Source Nodes: [add_2], Original ATen: [aten.add]
        buf7 = aten.view.dtype(buf6, torch.float32)
        del buf6
        buf8 = buf7
        del buf7
        # Source Nodes: [Z], Original ATen: [aten.add]
        buf9 = aten.view.dtype(arg0_1, torch.float32)
        del arg0_1
        buf10 = buf9
        del buf9
        # Source Nodes: [Z], Original ATen: [aten.add]
        buf11 = aten.view.dtype(arg1_1, torch.float32)
        del arg1_1
        buf12 = buf11
        del buf11
        buf13 = buf4; del buf4  # reuse
        # Source Nodes: [Z], Original ATen: [aten.add]
        triton_poi_fused_add_0.run(buf10, buf12, buf13, 6, grid=grid(6), stream=stream0)
        del buf10
        del buf12
        # Source Nodes: [Z], Original ATen: [aten.add]
        buf14 = aten.view.dtype(buf13, torch.complex64)
        buf15 = buf14
        del buf14
        # Source Nodes: [add_2], Original ATen: [aten.add]
        buf16 = aten.view.dtype(buf15, torch.float32)
        del buf15
        buf17 = buf16
        del buf16
        buf18 = buf13; del buf13  # reuse
        # Source Nodes: [add_2], Original ATen: [aten.add]
        triton_poi_fused_add_0.run(buf8, buf17, buf18, 6, grid=grid(6), stream=stream0)
        del buf17
        del buf8
        # Source Nodes: [add_2], Original ATen: [aten.add]
        buf19 = aten.view.dtype(buf18, torch.complex64)
        del buf18
        buf20 = buf19
        del buf19
        return (buf20, )

cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @gujinghui @PenghuiCheng @jianyuh @min-jean-cho @yanbing-j @Guobing-Chen @Xia-Weiwen @mcarilli @ptrblck @leslie-fang-intel @voznesenskym @penguinwu @EikanWang @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 23, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/111773

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit d70c7a3 with merge base cbc6213 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@htyu htyu requested a review from jansel October 23, 2023 16:54
@htyu htyu marked this pull request as ready for review October 23, 2023 16:55
@htyu htyu requested review from jansel and removed request for jansel October 23, 2023 16:58
Copy link
Contributor

@jansel jansel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs tests

torch/_inductor/fx_passes/joint_graph.py Outdated Show resolved Hide resolved
break
for unused in unused_aliases:
aliases.pop(unused)
graph.erase_node(unused)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't these already get erased on L150? Why do we need to do it again?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For patterns like

c1 = view(f1, complex64)

f2 = view(c1, float32)

f3 = f2 + f4

L150 will replace all uses of f2 with f1 and remove f2. This leaves c1 unused. The code here removes c1 and recursively all its predecessors.


if existing_aliases:
# Replace the view with the an existing view if available.
for alias in existing_aliases:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This loop wouldn't be needed if you made aliases map from (input_node, dtype) -> view_node

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

existing_aliases is a closure of all aliases with different types. For now since we are only dealing with complex views, the way you pointed out works. But thinking beyond complex views, without the loop we may need recursively search aliases for patterns like

i1 = view(c1, int64)
f2 = view(i1, float32)
...
c2 = view(f2, complex64)

where c1 can be replaced by c2.

For now with complex adds only, existing_aliases should only contain one float32 and one complex64.

@htyu htyu changed the title [inductor] Remove redundant complex views [inductor] Remove redundant views Oct 24, 2023
@htyu
Copy link
Contributor Author

htyu commented Oct 24, 2023

test/inductor/test_torchinductor.py Show resolved Hide resolved

if existing_views:
# Replace the view with the an existing view if available.
for alias in existing_views:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than using a loop here, you can modify the views dict key to be:

(source_node, from_dtype, to_dtype)

Copy link
Contributor Author

@htyu htyu Oct 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion. I thought about that previously, but given that aten.view is bidirectional, we are actually constructing a acyclic graph for a given alias closure. I'm using a shared list to store all nodes in the closure. Alternatively as you suggested, we can store each graph edge separately.

I feel that the current implementation is simpler, and of course each search takes in theory linear time to finish. Do you think that is a concern? How big an alias closure can be in reality?

With the alternative, the search for a view with desired type can be as complex as a DFS or BFS, and we may need to handle cycles. It has a linear complexity theoretically as well. Also edges are usually more than nodes, and I'd expect it causes more memory to store them.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated to using a dictionary instead of a list for efficiency.

@Chillee
Copy link
Contributor

Chillee commented Oct 24, 2023

Is there any reason we don't just do this with the pattern matcher?

@htyu
Copy link
Contributor Author

htyu commented Oct 24, 2023

Is there any reason we don't just do this with the pattern matcher?

Good point. I'm honestly not sure what can be done by the pattern matcher. What does a pattern look like usually and what is the benefit to do it there? cc @jansel to chime in.

@jansel
Copy link
Contributor

jansel commented Oct 25, 2023

Yeah, I think we could likely use the pattern matcher for this...

@htyu
Copy link
Contributor Author

htyu commented Oct 26, 2023

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 26, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@htyu htyu deleted the hoy-complex-folding branch November 1, 2023 17:25
xuhancn pushed a commit to xuhancn/pytorch that referenced this pull request Nov 7, 2023
As a follow-up to pytorch#110740, this patches enables removing redundant complex views to allow more operation fusing.

E.g,  given

```
@torch.compile
def foo(X, Y):
    Z = X + Y
    A = X + Y
    return A + Z
```

the generated code is:

```
@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 6
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (x0), xmask)
    tmp1 = tl.load(in_ptr1 + (x0), xmask)
    tmp2 = tmp0 + tmp1
    tmp3 = tmp2 + tmp2
    tl.store(out_ptr0 + (x0), tmp3, xmask)
''')

def call(args):
    arg0_1, arg1_1 = args
    args.clear()
    assert_size_stride(arg0_1, (3, ), (1, ))
    assert_size_stride(arg1_1, (3, ), (1, ))
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0) # no-op to ensure context
        # Source Nodes: [A], Original ATen: [aten.add]
        buf0 = aten.view.dtype(arg0_1, torch.float32)
        del arg0_1
        buf1 = buf0
        del buf0
        # Source Nodes: [A], Original ATen: [aten.add]
        buf2 = aten.view.dtype(arg1_1, torch.float32)
        del arg1_1
        buf3 = buf2
        del buf2
        buf4 = empty_strided((6, ), (1, ), device='cuda', dtype=torch.float32)
        # Source Nodes: [add_2], Original ATen: [aten.add]
        stream0 = get_cuda_stream(0)
        triton_poi_fused_add_0.run(buf1, buf3, buf4, 6, grid=grid(6), stream=stream0)
        del buf1
        del buf3
        # Source Nodes: [add_2], Original ATen: [aten.add]
        buf5 = aten.view.dtype(buf4, torch.complex64)
        del buf4
        buf6 = buf5
        del buf5
        return (buf6, )
```

whereas previously the generated code was:

```
@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 6
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (x0), xmask)
    tmp1 = tl.load(in_ptr1 + (x0), xmask)
    tmp2 = tmp0 + tmp1
    tl.store(out_ptr0 + (x0), tmp2, xmask)

def call(args):
    arg0_1, arg1_1 = args
    args.clear()
    assert_size_stride(arg0_1, (3, ), (1, ))
    assert_size_stride(arg1_1, (3, ), (1, ))
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0) # no-op to ensure context
        # Source Nodes: [A], Original ATen: [aten.add]
        buf0 = aten.view.dtype(arg0_1, torch.float32)
        buf1 = buf0
        del buf0
        # Source Nodes: [A], Original ATen: [aten.add]
        buf2 = aten.view.dtype(arg1_1, torch.float32)
        buf3 = buf2
        del buf2
        buf4 = empty_strided((6, ), (1, ), device='cuda', dtype=torch.float32)
        # Source Nodes: [A], Original ATen: [aten.add]
        stream0 = get_cuda_stream(0)
        triton_poi_fused_add_0.run(buf1, buf3, buf4, 6, grid=grid(6), stream=stream0)
        del buf1
        del buf3
        # Source Nodes: [A], Original ATen: [aten.add]
        buf5 = aten.view.dtype(buf4, torch.complex64)
        buf6 = buf5
        del buf5
        # Source Nodes: [add_2], Original ATen: [aten.add]
        buf7 = aten.view.dtype(buf6, torch.float32)
        del buf6
        buf8 = buf7
        del buf7
        # Source Nodes: [Z], Original ATen: [aten.add]
        buf9 = aten.view.dtype(arg0_1, torch.float32)
        del arg0_1
        buf10 = buf9
        del buf9
        # Source Nodes: [Z], Original ATen: [aten.add]
        buf11 = aten.view.dtype(arg1_1, torch.float32)
        del arg1_1
        buf12 = buf11
        del buf11
        buf13 = buf4; del buf4  # reuse
        # Source Nodes: [Z], Original ATen: [aten.add]
        triton_poi_fused_add_0.run(buf10, buf12, buf13, 6, grid=grid(6), stream=stream0)
        del buf10
        del buf12
        # Source Nodes: [Z], Original ATen: [aten.add]
        buf14 = aten.view.dtype(buf13, torch.complex64)
        buf15 = buf14
        del buf14
        # Source Nodes: [add_2], Original ATen: [aten.add]
        buf16 = aten.view.dtype(buf15, torch.float32)
        del buf15
        buf17 = buf16
        del buf16
        buf18 = buf13; del buf13  # reuse
        # Source Nodes: [add_2], Original ATen: [aten.add]
        triton_poi_fused_add_0.run(buf8, buf17, buf18, 6, grid=grid(6), stream=stream0)
        del buf17
        del buf8
        # Source Nodes: [add_2], Original ATen: [aten.add]
        buf19 = aten.view.dtype(buf18, torch.complex64)
        del buf18
        buf20 = buf19
        del buf19
        return (buf20, )
```

Pull Request resolved: pytorch#111773
Approved by: https://github.com/jansel
Skylion007 pushed a commit to Skylion007/pytorch that referenced this pull request Nov 14, 2023
As a follow-up to pytorch#110740, this patches enables removing redundant complex views to allow more operation fusing.

E.g,  given

```
@torch.compile
def foo(X, Y):
    Z = X + Y
    A = X + Y
    return A + Z
```

the generated code is:

```
@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 6
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (x0), xmask)
    tmp1 = tl.load(in_ptr1 + (x0), xmask)
    tmp2 = tmp0 + tmp1
    tmp3 = tmp2 + tmp2
    tl.store(out_ptr0 + (x0), tmp3, xmask)
''')

def call(args):
    arg0_1, arg1_1 = args
    args.clear()
    assert_size_stride(arg0_1, (3, ), (1, ))
    assert_size_stride(arg1_1, (3, ), (1, ))
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0) # no-op to ensure context
        # Source Nodes: [A], Original ATen: [aten.add]
        buf0 = aten.view.dtype(arg0_1, torch.float32)
        del arg0_1
        buf1 = buf0
        del buf0
        # Source Nodes: [A], Original ATen: [aten.add]
        buf2 = aten.view.dtype(arg1_1, torch.float32)
        del arg1_1
        buf3 = buf2
        del buf2
        buf4 = empty_strided((6, ), (1, ), device='cuda', dtype=torch.float32)
        # Source Nodes: [add_2], Original ATen: [aten.add]
        stream0 = get_cuda_stream(0)
        triton_poi_fused_add_0.run(buf1, buf3, buf4, 6, grid=grid(6), stream=stream0)
        del buf1
        del buf3
        # Source Nodes: [add_2], Original ATen: [aten.add]
        buf5 = aten.view.dtype(buf4, torch.complex64)
        del buf4
        buf6 = buf5
        del buf5
        return (buf6, )
```

whereas previously the generated code was:

```
@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 6
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (x0), xmask)
    tmp1 = tl.load(in_ptr1 + (x0), xmask)
    tmp2 = tmp0 + tmp1
    tl.store(out_ptr0 + (x0), tmp2, xmask)

def call(args):
    arg0_1, arg1_1 = args
    args.clear()
    assert_size_stride(arg0_1, (3, ), (1, ))
    assert_size_stride(arg1_1, (3, ), (1, ))
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0) # no-op to ensure context
        # Source Nodes: [A], Original ATen: [aten.add]
        buf0 = aten.view.dtype(arg0_1, torch.float32)
        buf1 = buf0
        del buf0
        # Source Nodes: [A], Original ATen: [aten.add]
        buf2 = aten.view.dtype(arg1_1, torch.float32)
        buf3 = buf2
        del buf2
        buf4 = empty_strided((6, ), (1, ), device='cuda', dtype=torch.float32)
        # Source Nodes: [A], Original ATen: [aten.add]
        stream0 = get_cuda_stream(0)
        triton_poi_fused_add_0.run(buf1, buf3, buf4, 6, grid=grid(6), stream=stream0)
        del buf1
        del buf3
        # Source Nodes: [A], Original ATen: [aten.add]
        buf5 = aten.view.dtype(buf4, torch.complex64)
        buf6 = buf5
        del buf5
        # Source Nodes: [add_2], Original ATen: [aten.add]
        buf7 = aten.view.dtype(buf6, torch.float32)
        del buf6
        buf8 = buf7
        del buf7
        # Source Nodes: [Z], Original ATen: [aten.add]
        buf9 = aten.view.dtype(arg0_1, torch.float32)
        del arg0_1
        buf10 = buf9
        del buf9
        # Source Nodes: [Z], Original ATen: [aten.add]
        buf11 = aten.view.dtype(arg1_1, torch.float32)
        del arg1_1
        buf12 = buf11
        del buf11
        buf13 = buf4; del buf4  # reuse
        # Source Nodes: [Z], Original ATen: [aten.add]
        triton_poi_fused_add_0.run(buf10, buf12, buf13, 6, grid=grid(6), stream=stream0)
        del buf10
        del buf12
        # Source Nodes: [Z], Original ATen: [aten.add]
        buf14 = aten.view.dtype(buf13, torch.complex64)
        buf15 = buf14
        del buf14
        # Source Nodes: [add_2], Original ATen: [aten.add]
        buf16 = aten.view.dtype(buf15, torch.float32)
        del buf15
        buf17 = buf16
        del buf16
        buf18 = buf13; del buf13  # reuse
        # Source Nodes: [add_2], Original ATen: [aten.add]
        triton_poi_fused_add_0.run(buf8, buf17, buf18, 6, grid=grid(6), stream=stream0)
        del buf17
        del buf8
        # Source Nodes: [add_2], Original ATen: [aten.add]
        buf19 = aten.view.dtype(buf18, torch.complex64)
        del buf18
        buf20 = buf19
        del buf19
        return (buf20, )
```

Pull Request resolved: pytorch#111773
Approved by: https://github.com/jansel
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants