Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix asserts in cuda code #39047

Closed
wants to merge 3 commits into from
Closed

fix asserts in cuda code #39047

wants to merge 3 commits into from

Conversation

ngimel
Copy link
Collaborator

@ngimel ngimel commented May 27, 2020

Gets rid of some in-kernel asserts where they can be replaced with static_asserts
Replaces bare in-kernel assert in one case with CUDA_KERNEL_ASSERT where necessary
replaces host code asserts with TORCH_INTERNAL_ASSERT
Another group of asserts is in fractional max pooling kernels which should be fixed regardless #39044, the problems there are not just asserts.
I've audited remaining cases of in-kernel asserts, and they are more like TORCH_INTERNAL_ASSERT, so they should not happen with invalid user data. I think it's ok to leave them as is.

@ngimel ngimel requested a review from ezyang May 27, 2020 03:32
@dr-ci
Copy link

dr-ci bot commented May 27, 2020

💊 CI failures summary and remediations

As of commit 2692d8f (more details on the Dr. CI page):


  • 1/1 failures possibly* introduced in this PR
    • 1/1 non-CircleCI failure(s)

ci.pytorch.org: 1 failed


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 12 times.

@@ -73,15 +73,15 @@ TensorInfo<T, IndexType>::TensorInfo(T* p,
template <typename T, typename IndexType>
void
TensorInfo<T, IndexType>::reduceDim(int dim) {
assert(dim < dims && dim >= 0);
TORCH_INTERNAL_ASSERT(dim < dims && dim >= 0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, these aren't run in CUDA?! Intruiging.

Copy link
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I make no claims about the completeness of this PR, but this is certainly an improvement.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@ngimel
Copy link
Collaborator Author

ngimel commented May 28, 2020

static_assert(sizeof(long)==8) fails on windows, so I turned it back into CUDA_KERNEL_ASSERT, may it never be triggered (it's in caffe2 code).

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@@ -104,6 +104,7 @@ struct TopKTypeConfig<long> {
typedef unsigned long long int RadixType;

static inline __device__ RadixType convert(long v) {
//static_assert fails on windows, so leave it as CUDA_KERNEL_ASSERT
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

...how does the CUDA_KERNEL_ASSERT not fail on windows?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will fail, if someone tries to run caffe2 radix sort for long inputs on windows pytorch build. Hopefully, no one will do it, but I can't make it static_assert, because with static_assert windows build itself fails.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we just make this into int64_t?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not familiar with caffe2 code and don't know if it's possible. It is also probably untested?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it will break anything because of the assert sizeof(long) == 8 here.

@gchanan gchanan added this to the 1.5.1 milestone May 28, 2020
@gchanan
Copy link
Contributor

gchanan commented May 28, 2020

adding milestone 1.5.1 since this seems worth getting into the release, because some of these are regressions.

@facebook-github-bot
Copy link
Contributor

@ngimel merged this pull request in 9c19a12.

gchanan pushed a commit to gchanan/pytorch that referenced this pull request Jun 2, 2020
Summary:
Gets rid of some in-kernel asserts where they can be replaced with static_asserts
Replaces bare in-kernel `assert` in one case with `CUDA_KERNEL_ASSERT` where necessary
replaces host code `assert`s with `TORCH_INTERNAL_ASSERT`
Another group of asserts is in fractional max pooling kernels which should be fixed regardless pytorch#39044, the problems there are not just asserts.
I've audited remaining cases of in-kernel asserts, and they are more like `TORCH_INTERNAL_ASSERT`, so they should not happen with invalid user data. I think it's ok to leave them as is.
Pull Request resolved: pytorch#39047

Differential Revision: D21750392

Pulled By: ngimel

fbshipit-source-id: e9417523a2c672284de3515933cb7ed166e56719
gchanan pushed a commit that referenced this pull request Jun 3, 2020
Summary:
Gets rid of some in-kernel asserts where they can be replaced with static_asserts
Replaces bare in-kernel `assert` in one case with `CUDA_KERNEL_ASSERT` where necessary
replaces host code `assert`s with `TORCH_INTERNAL_ASSERT`
Another group of asserts is in fractional max pooling kernels which should be fixed regardless #39044, the problems there are not just asserts.
I've audited remaining cases of in-kernel asserts, and they are more like `TORCH_INTERNAL_ASSERT`, so they should not happen with invalid user data. I think it's ok to leave them as is.
Pull Request resolved: #39047

Differential Revision: D21750392

Pulled By: ngimel

fbshipit-source-id: e9417523a2c672284de3515933cb7ed166e56719
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants