New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enabling concat fast path for channels last inputs #39448
Conversation
Note that, I think this can be combined with the exist kernel in CatKernel.cpp. However it would be good to parallelize that one. The original motivation for that one was to fix a perf regression but it did not have to be serial I believe. |
💊 CI failures summary and remediationsAs of commit 914fe8a (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group. This comment has been revised 54 times. |
fcc4c8d
to
d2e2786
Compare
cc: @mingfeima |
d2e2786
to
1bd3cd5
Compare
I don't know much about memory format so I am removing myself from the reviewers list. @VitalyFedyunin could you take a look at this? |
Ya no worries. I added you because gh showed me you reviewed those files recently. I will let other reviewers chime in. |
cf15124
to
bd4663e
Compare
@VitalyFedyunin, can you please take a look at this? Thanks.
I tried parallelizing the serial kernel. The perf results were inconclusive and variance very high in some cases. For now just leaving that as is. |
bd4663e
to
8fae83e
Compare
8fae83e
to
670dd58
Compare
@VitalyFedyunin, can you please take a look at this? |
@colesbury, can you help review this PR? Thanks. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm concerned about changing math for inner
and outer
, please add tests covering various dim=
values.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The approach looks reasonable, but I'm not super familiar with the cat
code. I noticed a few issues mentioned inline.
Additionally, in the PR description you wrote:
Enables that kernel by default for contig tensors even for large inputs and when more than one thread is available.
I don't see that change in the PR. The call to cat_serial_stub
looks like it's still guarded by use_serial_kernel
.
670dd58
to
5be24f3
Compare
Yes, my bad. Initially I did try what the old summary said but it did not seem to be provably better so I kept changes to the minimal. I have updated the summary. |
@VitalyFedyunin @colesbury, I have addressed your comments. |
@@ -7045,6 +7045,56 @@ def test_cat_out_channels_last(self, device): | |||
res2 = torch.cat((x, y), out=z) | |||
self.assertEqual(res1, res2) | |||
|
|||
def test_cat_in_channels_last(self, device): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please move test to other class where device
is unnecessary or add cpu only decorator or use device
appropriatly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tests still requires some fixes.
test/test_torch.py
Outdated
self.assertTrue(res2.is_contiguous(memory_format=torch.channels_last)) | ||
self.assertEqual(res1, res2) | ||
|
||
# Concat across dim 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could be simple for loop by dimensions.
5be24f3
to
1fff6cb
Compare
@VitalyFedyunin, thanks a lot of the comments. I have incorporated them. Please let me know if there is anything else. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Windows test failures look real
1fff6cb
to
0bc58ba
Compare
@VitalyFedyunin, looks like this PR fixes the windows failures: #40369. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kimishpatel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
} | ||
} | ||
} | ||
|
||
void cat_serial_kernel(Tensor& result, TensorList tensors, int64_t dim) { | ||
AT_DISPATCH_FLOATING_TYPES(result.scalar_type(), "cat_serial_kernel", [&]() { | ||
AT_DISPATCH_FLOATING_TYPES(result.scalar_type(), "cat_contig_kernel", [&]() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
s/cat_contig_kernel/cat_serial_kernel
0bc58ba
to
43400c9
Compare
43400c9
to
cd82017
Compare
Summary: Existing cat implementation produces output tensor in contig format disregarding in the input memory format. This PR fixes the kernel as well as op implementation to account for that. Test Plan: CI Reviewers: Subscribers: Tasks: Tags:
914278f
to
914fe8a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kimishpatel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
@kimishpatel merged this pull request in 6a421d5. |
1 similar comment
@kimishpatel merged this pull request in 6a421d5. |
Summary:
Updates concat kernel for contiguous input to support channels_last contig tensors.
This was tried on squeezenet model on pixel-2 device. It improves model perf by about 25%.
Test Plan:
test_cat_in_channels_last
Reviewers:
Subscribers:
Tasks:
Tags: