Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enabling concat fast path for channels last inputs #39448

Closed
wants to merge 1 commit into from

Conversation

kimishpatel
Copy link
Contributor

@kimishpatel kimishpatel commented Jun 3, 2020

Summary:
Updates concat kernel for contiguous input to support channels_last contig tensors.

This was tried on squeezenet model on pixel-2 device. It improves model perf by about 25%.

Test Plan:
test_cat_in_channels_last

Reviewers:

Subscribers:

Tasks:

Tags:

@kimishpatel
Copy link
Contributor Author

kimishpatel commented Jun 3, 2020

Note that, I think this can be combined with the exist kernel in CatKernel.cpp. However it would be good to parallelize that one. The original motivation for that one was to fix a perf regression but it did not have to be serial I believe.
I will try that out as well.

@dr-ci
Copy link

dr-ci bot commented Jun 3, 2020

💊 CI failures summary and remediations

As of commit 914fe8a (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 54 times.

@kimishpatel
Copy link
Contributor Author

cc: @mingfeima

@kimishpatel kimishpatel changed the title Add fast path for channel dim concat when input memory format is Enabling concat fast path for channels last inputs Jun 3, 2020
@zou3519 zou3519 removed their request for review June 4, 2020 15:13
@zou3519
Copy link
Contributor

zou3519 commented Jun 4, 2020

I don't know much about memory format so I am removing myself from the reviewers list. @VitalyFedyunin could you take a look at this?

@kimishpatel
Copy link
Contributor Author

I don't know much about memory format so I am removing myself from the reviewers list. @VitalyFedyunin could you take a look at this?

Ya no worries. I added you because gh showed me you reviewed those files recently. I will let other reviewers chime in.

@kimishpatel kimishpatel force-pushed the channels_last_concat branch 2 times, most recently from cf15124 to bd4663e Compare June 15, 2020 16:14
@kimishpatel
Copy link
Contributor Author

@VitalyFedyunin, can you please take a look at this? Thanks.

Note that, I think this can be combined with the exist kernel in CatKernel.cpp. However it would be good to parallelize that one. The original motivation for that one was to fix a perf regression but it did not have to be serial I believe.
I will try that out as well.

I tried parallelizing the serial kernel. The perf results were inconclusive and variance very high in some cases. For now just leaving that as is.

@kimishpatel
Copy link
Contributor Author

@VitalyFedyunin, can you please take a look at this?

@kimishpatel
Copy link
Contributor Author

@colesbury, can you help review this PR? Thanks.

test/test_torch.py Outdated Show resolved Hide resolved
Copy link
Contributor

@VitalyFedyunin VitalyFedyunin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm concerned about changing math for inner and outer, please add tests covering various dim= values.

aten/src/ATen/native/cpu/CatKernel.cpp Outdated Show resolved Hide resolved
test/test_torch.py Outdated Show resolved Hide resolved
Copy link
Member

@colesbury colesbury left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The approach looks reasonable, but I'm not super familiar with the cat code. I noticed a few issues mentioned inline.

Additionally, in the PR description you wrote:

Enables that kernel by default for contig tensors even for large inputs and when more than one thread is available.

I don't see that change in the PR. The call to cat_serial_stub looks like it's still guarded by use_serial_kernel.

aten/src/ATen/native/cpu/CatKernel.cpp Outdated Show resolved Hide resolved
aten/src/ATen/native/cpu/CatKernel.cpp Outdated Show resolved Hide resolved
aten/src/ATen/native/cpu/CatKernel.cpp Outdated Show resolved Hide resolved
@kimishpatel
Copy link
Contributor Author

The approach looks reasonable, but I'm not super familiar with the cat code. I noticed a few issues mentioned inline.

Additionally, in the PR description you wrote:

Enables that kernel by default for contig tensors even for large inputs and when more than one thread is available.

I don't see that change in the PR. The call to cat_serial_stub looks like it's still guarded by use_serial_kernel.

Yes, my bad. Initially I did try what the old summary said but it did not seem to be provably better so I kept changes to the minimal. I have updated the summary.

@kimishpatel
Copy link
Contributor Author

@VitalyFedyunin @colesbury, I have addressed your comments.

@@ -7045,6 +7045,56 @@ def test_cat_out_channels_last(self, device):
res2 = torch.cat((x, y), out=z)
self.assertEqual(res1, res2)

def test_cat_in_channels_last(self, device):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please move test to other class where device is unnecessary or add cpu only decorator or use device appropriatly.

Copy link
Contributor

@VitalyFedyunin VitalyFedyunin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests still requires some fixes.

self.assertTrue(res2.is_contiguous(memory_format=torch.channels_last))
self.assertEqual(res1, res2)

# Concat across dim 2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could be simple for loop by dimensions.

@kimishpatel
Copy link
Contributor Author

@VitalyFedyunin, thanks a lot of the comments. I have incorporated them. Please let me know if there is anything else.

Copy link
Contributor

@VitalyFedyunin VitalyFedyunin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Windows test failures look real

@kimishpatel
Copy link
Contributor Author

@VitalyFedyunin, looks like this PR fixes the windows failures: #40369.
I have rebased. Lets see.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kimishpatel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

}
}
}

void cat_serial_kernel(Tensor& result, TensorList tensors, int64_t dim) {
AT_DISPATCH_FLOATING_TYPES(result.scalar_type(), "cat_serial_kernel", [&]() {
AT_DISPATCH_FLOATING_TYPES(result.scalar_type(), "cat_contig_kernel", [&]() {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s/cat_contig_kernel/cat_serial_kernel

Summary:
Existing cat implementation produces output tensor in contig format
disregarding in the input memory format. This PR fixes the kernel as
well as op implementation to account for that.

Test Plan:
CI

Reviewers:

Subscribers:

Tasks:

Tags:
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kimishpatel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@kimishpatel merged this pull request in 6a421d5.

1 similar comment
@facebook-github-bot
Copy link
Contributor

@kimishpatel merged this pull request in 6a421d5.

@facebook-github-bot facebook-github-bot deleted the channels_last_concat branch July 13, 2020 17:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants