Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added checks to cudnn Convolution for stride, dilation, kernel size and num input planes #1723

Merged

Conversation

alykhantejani
Copy link
Contributor

Added checks for stride, dilation, kernel sizes and num input planes for cuda convolutions (issue #1406)

}
params.groups = groups;

for (int i = 2; i != input->nDimension; ++i) {
int kernel_size = params.weight_size[i];

This comment was marked as off-topic.

@colesbury
Copy link
Member

Adding additional checks here seems okay, but doesn't fix the issue referred to in #1406. The CHECK_ARG error message is very minimal: that's because its used for checking internal invariants, not mistakes that the user might make.

We need something like SpatialConvolutionMM_shapeCheck for cuDNN convolutions, especially the check that for a given input the output size is > 0.

@alykhantejani
Copy link
Contributor Author

I think what @soumith is saying here is that the kernel_size > 0 checks here are pointless, as the weight tensor is already created so the sizes must be > 0.

However, I agree CHECK_ARG is too minimal and still doesn't give a good error message. I also don't check for the case that you mention @colesbury.

Perhaps I should create a Convolution_shapeCheck function in csrc/cudnn/Conv.cpp and call it from the constructor, in there I can mimic SpatialConvolutionMM_shapeCheck and throw runtime errors with more meaningful messages?

If you guys agree, I'll just do this on this branch then?

@colesbury
Copy link
Member

Yeah, that seems like a good idea

@alykhantejani
Copy link
Contributor Author

@soumith @colesbury I think this should now contain all the checks we need. Let me know if I've missed anything

@soumith soumith merged commit c6a6391 into pytorch:master Jun 6, 2017
@soumith
Copy link
Member

soumith commented Jun 6, 2017

thanks @alykhantejani !

houseroad added a commit to houseroad/pytorch that referenced this pull request Jan 10, 2019
…b9fdd5

Summary:
Previous import was 8384c788939bc65463f9754b6a7a00b212b18ba1

Included changes:
- **[7abd834](onnx/onnx@7abd834)**: Clarify some aspects of the Loop spec. (pytorch#1587) <Scott McKay>
- **[5a5b15f](onnx/onnx@5a5b15f)**: Support rtol and atol at the model granularity (pytorch#1723) <Lu Fang>
- **[ba76e45](onnx/onnx@ba76e45)**: print some information (pytorch#1724) <Lu Fang>
- **[797390d](onnx/onnx@797390d)**: Update README.md (pytorch#1722) <Prasanth Pulavarthi>
- **[40cdb5f](onnx/onnx@40cdb5f)**: repaire convtranspose shape inference (pytorch#1660) <peter yang>
- **[68fdb3f](onnx/onnx@68fdb3f)**: [Minor] Fix Windows line ending in test coverage generating script (pytorch#1717) <Raymond Yang>
- **[00101bf](onnx/onnx@00101bf)**: Remove ConstantLike op. Updates to ConstantOfShape op. (pytorch#1716) <Spandan Tiwari>
- **[c59e90a](onnx/onnx@c59e90a)**: add a shape inference test for group conv (pytorch#1719) <Lu Fang>

Reviewed By: zrphercule

Differential Revision: D13629499

fbshipit-source-id: 61d4b30be2018f6b1e39a6acf9d80f8a5f26d7fc
facebook-github-bot pushed a commit that referenced this pull request Jan 11, 2019
…b9fdd5 (#15942)

Summary:
Pull Request resolved: #15942

Previous import was 8384c788939bc65463f9754b6a7a00b212b18ba1

Included changes:
- **[7abd834](onnx/onnx@7abd834)**: Clarify some aspects of the Loop spec. (#1587) <Scott McKay>
- **[5a5b15f](onnx/onnx@5a5b15f)**: Support rtol and atol at the model granularity (#1723) <Lu Fang>
- **[ba76e45](onnx/onnx@ba76e45)**: print some information (#1724) <Lu Fang>
- **[797390d](onnx/onnx@797390d)**: Update README.md (#1722) <Prasanth Pulavarthi>
- **[40cdb5f](onnx/onnx@40cdb5f)**: repaire convtranspose shape inference (#1660) <peter yang>
- **[68fdb3f](onnx/onnx@68fdb3f)**: [Minor] Fix Windows line ending in test coverage generating script (#1717) <Raymond Yang>
- **[00101bf](onnx/onnx@00101bf)**: Remove ConstantLike op. Updates to ConstantOfShape op. (#1716) <Spandan Tiwari>
- **[c59e90a](onnx/onnx@c59e90a)**: add a shape inference test for group conv (#1719) <Lu Fang>

Reviewed By: zrphercule

Differential Revision: D13629499

fbshipit-source-id: 4b3e4cb29bdb84c3777a8fb26263548efb20f317
jjsjann123 pushed a commit to jjsjann123/pytorch that referenced this pull request May 24, 2022
malfet pushed a commit that referenced this pull request Jun 8, 2022
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/

A few bigger updates:
1. Initial support of cp.async and cp.async.wait: csarofeen#1619
2. Emulate ampere's mma 16816 with Turing's mma 1688, for a unified interface: csarofeen#1643
3. Extending the infrastructure to support mma operators on turing and ampere arch: csarofeen#1440

Commits that's actually in this PR from the csarofeen branch
```
* dd23252 (csarofeen/devel) Fusion Segmenter: Unify single kernel and multi-kernel runtime path (#1710)
* b3d1c3f Fix missing cooperative launch (#1726)
* dc670a2 Async gmem copy support on sm80+ (#1619)
* 5e6a8da Add turing mma support and test (#1643)
* d6d6b7d Fix rFactor when there are indirect root domain(s), and refactor (#1723)
* 7093e39 Mma op integration on ampere (#1440)
* fade8da patch python test for bfloat16 (#1724)
* 8fbd0b1 Fine-grained kernel profiling (#1720)
* 77c1b4f Adding dry run mode to skip arch dependent checks (#1702)
* 151d95b More precise concretization analysis (#1719)
* f4d3630 Enable complex python tests (#1667)
* 4ceeee5 Minor bugfix in transform_rfactor.cpp (#1715)
* 3675c70 Separate root domain and rfactor domain in TransformPrinter (#1716)
* f68b830 Fix scheduling with polymorphic broadcast (#1714)
* 4ab5ef7 updating_ci_machine (#1718)
* 56585c5 Merge pull request #1711 from csarofeen/upstream_master_bump_0517
* 174d453 Allow using nvFuser on CUDA extension (#1701)
* 18bee67 Validate LOOP concrete IDs have complete IterDomains (#1676)
```
Pull Request resolved: #78244
Approved by: https://github.com/csarofeen, https://github.com/malfet
facebook-github-bot pushed a commit that referenced this pull request Jun 8, 2022
Summary:
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/

A few bigger updates:
1. Initial support of cp.async and cp.async.wait: csarofeen#1619
2. Emulate ampere's mma 16816 with Turing's mma 1688, for a unified interface: csarofeen#1643
3. Extending the infrastructure to support mma operators on turing and ampere arch: csarofeen#1440

Commits that's actually in this PR from the csarofeen branch
```
* dd23252 (csarofeen/devel) Fusion Segmenter: Unify single kernel and multi-kernel runtime path (#1710)
* b3d1c3f Fix missing cooperative launch (#1726)
* dc670a2 Async gmem copy support on sm80+ (#1619)
* 5e6a8da Add turing mma support and test (#1643)
* d6d6b7d Fix rFactor when there are indirect root domain(s), and refactor (#1723)
* 7093e39 Mma op integration on ampere (#1440)
* fade8da patch python test for bfloat16 (#1724)
* 8fbd0b1 Fine-grained kernel profiling (#1720)
* 77c1b4f Adding dry run mode to skip arch dependent checks (#1702)
* 151d95b More precise concretization analysis (#1719)
* f4d3630 Enable complex python tests (#1667)
* 4ceeee5 Minor bugfix in transform_rfactor.cpp (#1715)
* 3675c70 Separate root domain and rfactor domain in TransformPrinter (#1716)
* f68b830 Fix scheduling with polymorphic broadcast (#1714)
* 4ab5ef7 updating_ci_machine (#1718)
* 56585c5 Merge pull request #1711 from csarofeen/upstream_master_bump_0517
* 174d453 Allow using nvFuser on CUDA extension (#1701)
* 18bee67 Validate LOOP concrete IDs have complete IterDomains (#1676)
```

Pull Request resolved: #78244

Reviewed By: ejguan

Differential Revision: D36678948

Pulled By: davidberard98

fbshipit-source-id: 0ccde965acbd31da67d99c6adb2eaaa888948105
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants