-
Notifications
You must be signed in to change notification settings - Fork 22.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[inductor] Remove redundant views #111773
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/111773
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit d70c7a3 with merge base cbc6213 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
b913961
to
165dbbf
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This needs tests
break | ||
for unused in unused_aliases: | ||
aliases.pop(unused) | ||
graph.erase_node(unused) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't these already get erased on L150? Why do we need to do it again?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For patterns like
c1 = view(f1, complex64)
f2 = view(c1, float32)
f3 = f2 + f4
L150 will replace all uses of f2
with f1
and remove f2
. This leaves c1
unused. The code here removes c1
and recursively all its predecessors.
|
||
if existing_aliases: | ||
# Replace the view with the an existing view if available. | ||
for alias in existing_aliases: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This loop wouldn't be needed if you made aliases map from (input_node, dtype) -> view_node
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
existing_aliases
is a closure of all aliases with different types. For now since we are only dealing with complex views, the way you pointed out works. But thinking beyond complex views, without the loop we may need recursively search aliases
for patterns like
i1 = view(c1, int64)
f2 = view(i1, float32)
...
c2 = view(f2, complex64)
where c1
can be replaced by c2
.
For now with complex adds only, existing_aliases
should only contain one float32 and one complex64.
165dbbf
to
0568867
Compare
0568867
to
86eb857
Compare
Launched a perf run : https://github.com/pytorch/pytorch/actions/runs/6623428804 |
|
||
if existing_views: | ||
# Replace the view with the an existing view if available. | ||
for alias in existing_views: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rather than using a loop here, you can modify the views dict key to be:
(source_node, from_dtype, to_dtype)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the suggestion. I thought about that previously, but given that aten.view
is bidirectional, we are actually constructing a acyclic graph for a given alias closure. I'm using a shared list to store all nodes in the closure. Alternatively as you suggested, we can store each graph edge separately.
I feel that the current implementation is simpler, and of course each search takes in theory linear time to finish. Do you think that is a concern? How big an alias closure can be in reality?
With the alternative, the search for a view with desired type can be as complex as a DFS or BFS, and we may need to handle cycles. It has a linear complexity theoretically as well. Also edges are usually more than nodes, and I'd expect it causes more memory to store them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated to using a dictionary instead of a list for efficiency.
1e7eeec
to
f6230d0
Compare
Is there any reason we don't just do this with the pattern matcher? |
Good point. I'm honestly not sure what can be done by the pattern matcher. What does a pattern look like usually and what is the benefit to do it there? cc @jansel to chime in. |
Yeah, I think we could likely use the pattern matcher for this... |
20876f8
to
d70c7a3
Compare
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
As a follow-up to pytorch#110740, this patches enables removing redundant complex views to allow more operation fusing. E.g, given ``` @torch.compile def foo(X, Y): Z = X + Y A = X + Y return A + Z ``` the generated code is: ``` @triton.jit def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr): xnumel = 6 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:] xmask = xindex < xnumel x0 = xindex tmp0 = tl.load(in_ptr0 + (x0), xmask) tmp1 = tl.load(in_ptr1 + (x0), xmask) tmp2 = tmp0 + tmp1 tmp3 = tmp2 + tmp2 tl.store(out_ptr0 + (x0), tmp3, xmask) ''') def call(args): arg0_1, arg1_1 = args args.clear() assert_size_stride(arg0_1, (3, ), (1, )) assert_size_stride(arg1_1, (3, ), (1, )) with torch.cuda._DeviceGuard(0): torch.cuda.set_device(0) # no-op to ensure context # Source Nodes: [A], Original ATen: [aten.add] buf0 = aten.view.dtype(arg0_1, torch.float32) del arg0_1 buf1 = buf0 del buf0 # Source Nodes: [A], Original ATen: [aten.add] buf2 = aten.view.dtype(arg1_1, torch.float32) del arg1_1 buf3 = buf2 del buf2 buf4 = empty_strided((6, ), (1, ), device='cuda', dtype=torch.float32) # Source Nodes: [add_2], Original ATen: [aten.add] stream0 = get_cuda_stream(0) triton_poi_fused_add_0.run(buf1, buf3, buf4, 6, grid=grid(6), stream=stream0) del buf1 del buf3 # Source Nodes: [add_2], Original ATen: [aten.add] buf5 = aten.view.dtype(buf4, torch.complex64) del buf4 buf6 = buf5 del buf5 return (buf6, ) ``` whereas previously the generated code was: ``` @triton.jit def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr): xnumel = 6 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:] xmask = xindex < xnumel x0 = xindex tmp0 = tl.load(in_ptr0 + (x0), xmask) tmp1 = tl.load(in_ptr1 + (x0), xmask) tmp2 = tmp0 + tmp1 tl.store(out_ptr0 + (x0), tmp2, xmask) def call(args): arg0_1, arg1_1 = args args.clear() assert_size_stride(arg0_1, (3, ), (1, )) assert_size_stride(arg1_1, (3, ), (1, )) with torch.cuda._DeviceGuard(0): torch.cuda.set_device(0) # no-op to ensure context # Source Nodes: [A], Original ATen: [aten.add] buf0 = aten.view.dtype(arg0_1, torch.float32) buf1 = buf0 del buf0 # Source Nodes: [A], Original ATen: [aten.add] buf2 = aten.view.dtype(arg1_1, torch.float32) buf3 = buf2 del buf2 buf4 = empty_strided((6, ), (1, ), device='cuda', dtype=torch.float32) # Source Nodes: [A], Original ATen: [aten.add] stream0 = get_cuda_stream(0) triton_poi_fused_add_0.run(buf1, buf3, buf4, 6, grid=grid(6), stream=stream0) del buf1 del buf3 # Source Nodes: [A], Original ATen: [aten.add] buf5 = aten.view.dtype(buf4, torch.complex64) buf6 = buf5 del buf5 # Source Nodes: [add_2], Original ATen: [aten.add] buf7 = aten.view.dtype(buf6, torch.float32) del buf6 buf8 = buf7 del buf7 # Source Nodes: [Z], Original ATen: [aten.add] buf9 = aten.view.dtype(arg0_1, torch.float32) del arg0_1 buf10 = buf9 del buf9 # Source Nodes: [Z], Original ATen: [aten.add] buf11 = aten.view.dtype(arg1_1, torch.float32) del arg1_1 buf12 = buf11 del buf11 buf13 = buf4; del buf4 # reuse # Source Nodes: [Z], Original ATen: [aten.add] triton_poi_fused_add_0.run(buf10, buf12, buf13, 6, grid=grid(6), stream=stream0) del buf10 del buf12 # Source Nodes: [Z], Original ATen: [aten.add] buf14 = aten.view.dtype(buf13, torch.complex64) buf15 = buf14 del buf14 # Source Nodes: [add_2], Original ATen: [aten.add] buf16 = aten.view.dtype(buf15, torch.float32) del buf15 buf17 = buf16 del buf16 buf18 = buf13; del buf13 # reuse # Source Nodes: [add_2], Original ATen: [aten.add] triton_poi_fused_add_0.run(buf8, buf17, buf18, 6, grid=grid(6), stream=stream0) del buf17 del buf8 # Source Nodes: [add_2], Original ATen: [aten.add] buf19 = aten.view.dtype(buf18, torch.complex64) del buf18 buf20 = buf19 del buf19 return (buf20, ) ``` Pull Request resolved: pytorch#111773 Approved by: https://github.com/jansel
As a follow-up to pytorch#110740, this patches enables removing redundant complex views to allow more operation fusing. E.g, given ``` @torch.compile def foo(X, Y): Z = X + Y A = X + Y return A + Z ``` the generated code is: ``` @triton.jit def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr): xnumel = 6 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:] xmask = xindex < xnumel x0 = xindex tmp0 = tl.load(in_ptr0 + (x0), xmask) tmp1 = tl.load(in_ptr1 + (x0), xmask) tmp2 = tmp0 + tmp1 tmp3 = tmp2 + tmp2 tl.store(out_ptr0 + (x0), tmp3, xmask) ''') def call(args): arg0_1, arg1_1 = args args.clear() assert_size_stride(arg0_1, (3, ), (1, )) assert_size_stride(arg1_1, (3, ), (1, )) with torch.cuda._DeviceGuard(0): torch.cuda.set_device(0) # no-op to ensure context # Source Nodes: [A], Original ATen: [aten.add] buf0 = aten.view.dtype(arg0_1, torch.float32) del arg0_1 buf1 = buf0 del buf0 # Source Nodes: [A], Original ATen: [aten.add] buf2 = aten.view.dtype(arg1_1, torch.float32) del arg1_1 buf3 = buf2 del buf2 buf4 = empty_strided((6, ), (1, ), device='cuda', dtype=torch.float32) # Source Nodes: [add_2], Original ATen: [aten.add] stream0 = get_cuda_stream(0) triton_poi_fused_add_0.run(buf1, buf3, buf4, 6, grid=grid(6), stream=stream0) del buf1 del buf3 # Source Nodes: [add_2], Original ATen: [aten.add] buf5 = aten.view.dtype(buf4, torch.complex64) del buf4 buf6 = buf5 del buf5 return (buf6, ) ``` whereas previously the generated code was: ``` @triton.jit def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr): xnumel = 6 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:] xmask = xindex < xnumel x0 = xindex tmp0 = tl.load(in_ptr0 + (x0), xmask) tmp1 = tl.load(in_ptr1 + (x0), xmask) tmp2 = tmp0 + tmp1 tl.store(out_ptr0 + (x0), tmp2, xmask) def call(args): arg0_1, arg1_1 = args args.clear() assert_size_stride(arg0_1, (3, ), (1, )) assert_size_stride(arg1_1, (3, ), (1, )) with torch.cuda._DeviceGuard(0): torch.cuda.set_device(0) # no-op to ensure context # Source Nodes: [A], Original ATen: [aten.add] buf0 = aten.view.dtype(arg0_1, torch.float32) buf1 = buf0 del buf0 # Source Nodes: [A], Original ATen: [aten.add] buf2 = aten.view.dtype(arg1_1, torch.float32) buf3 = buf2 del buf2 buf4 = empty_strided((6, ), (1, ), device='cuda', dtype=torch.float32) # Source Nodes: [A], Original ATen: [aten.add] stream0 = get_cuda_stream(0) triton_poi_fused_add_0.run(buf1, buf3, buf4, 6, grid=grid(6), stream=stream0) del buf1 del buf3 # Source Nodes: [A], Original ATen: [aten.add] buf5 = aten.view.dtype(buf4, torch.complex64) buf6 = buf5 del buf5 # Source Nodes: [add_2], Original ATen: [aten.add] buf7 = aten.view.dtype(buf6, torch.float32) del buf6 buf8 = buf7 del buf7 # Source Nodes: [Z], Original ATen: [aten.add] buf9 = aten.view.dtype(arg0_1, torch.float32) del arg0_1 buf10 = buf9 del buf9 # Source Nodes: [Z], Original ATen: [aten.add] buf11 = aten.view.dtype(arg1_1, torch.float32) del arg1_1 buf12 = buf11 del buf11 buf13 = buf4; del buf4 # reuse # Source Nodes: [Z], Original ATen: [aten.add] triton_poi_fused_add_0.run(buf10, buf12, buf13, 6, grid=grid(6), stream=stream0) del buf10 del buf12 # Source Nodes: [Z], Original ATen: [aten.add] buf14 = aten.view.dtype(buf13, torch.complex64) buf15 = buf14 del buf14 # Source Nodes: [add_2], Original ATen: [aten.add] buf16 = aten.view.dtype(buf15, torch.float32) del buf15 buf17 = buf16 del buf16 buf18 = buf13; del buf13 # reuse # Source Nodes: [add_2], Original ATen: [aten.add] triton_poi_fused_add_0.run(buf8, buf17, buf18, 6, grid=grid(6), stream=stream0) del buf17 del buf8 # Source Nodes: [add_2], Original ATen: [aten.add] buf19 = aten.view.dtype(buf18, torch.complex64) del buf18 buf20 = buf19 del buf19 return (buf20, ) ``` Pull Request resolved: pytorch#111773 Approved by: https://github.com/jansel
As a follow-up to #110740, this patches enables removing redundant complex views to allow more operation fusing.
E.g, given
the generated code is:
whereas previously the generated code was:
cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @gujinghui @PenghuiCheng @jianyuh @min-jean-cho @yanbing-j @Guobing-Chen @Xia-Weiwen @mcarilli @ptrblck @leslie-fang-intel @voznesenskym @penguinwu @EikanWang @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan