Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ROCm] topk and sort fixes #12337

Closed
wants to merge 1 commit into from
Closed

Conversation

iotamudelta
Copy link
Contributor

@bddppq @ezyang

Note the one additional skipped test resulting from using the thrust sort fallback for all sizes. We are working on getting bitonic to work properly (and always). Until then, this needs to be skipped on ROCm.

* Topk part 1: fix intrinsincs for 64 wave front (#224)

64 in a wave front - intrinsics change.

* Disable in-place sorting on ROCm. (#237)

It is known to hang - use the Thrust fallback

Skip one test - fails with the fallback.

* Topk fixes (#239)

* Spec (https://docs.nvidia.com/cuda/pdf/ptx_isa_6.3.pdf) Sec 9.7.1.19 (bfe) and 9.7.1.20 (bfi) requires pos and len to be limited to 0...255

* Spec (https://docs.nvidia.com/cuda/pdf/ptx_isa_6.3.pdf) Sec 9.7.1.19 requires extracted bits to be in LSBs

* Correct logic for getLaneMaskLe. Previous logic would return 0x0 instead of 0xffffffffffffffff for lane 63

* Round up blockDim.x to prevent negative index for smem
@iotamudelta
Copy link
Contributor Author

@pytorchbot retest this please

@ezyang
Copy link
Contributor

ezyang commented Oct 4, 2018

Just FYI, when you submit these PRs, please use full URLs for issues; they are cross-linking to the wrong issues now.

@@ -207,7 +213,7 @@ __device__ void exclusiveBinaryPrefixScan(T* smem, bool in, T* out, T* carry, Bi
*out -= (T) in;

// The outgoing carry for all threads is the last warp's sum
*carry = smem[(blockDim.x / SCAN_UTILS_WARP_SIZE) - 1];
*carry = smem[THCCeilDiv<int>(blockDim.x, SCAN_UTILS_WARP_SIZE) - 1];

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@ezyang
Copy link
Contributor

ezyang commented Oct 4, 2018

@pytorchbot retest this please

@bddppq bddppq added the module: rocm AMD GPU support for Pytorch label Oct 5, 2018
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ezyang is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

zdevito pushed a commit to zdevito/ATen that referenced this pull request Oct 9, 2018
Summary:
* Topk part 1: fix intrinsincs for 64 wave front (#224)
64 in a wave front - intrinsics change.
* Disable in-place sorting on ROCm. (#237)
It is known to hang - use the Thrust fallback
Skip one test - fails with the fallback.
* Topk fixes (#239)
* Spec (https://docs.nvidia.com/cuda/pdf/ptx_isa_6.3.pdf) Sec 9.7.1.19 (bfe) and 9.7.1.20 (bfi) requires pos and len to be limited to 0...255
* Spec (https://docs.nvidia.com/cuda/pdf/ptx_isa_6.3.pdf) Sec 9.7.1.19 requires extracted bits to be in LSBs
* Correct logic for getLaneMaskLe. Previous logic would return 0x0 instead of 0xffffffffffffffff for lane 63
* Round up blockDim.x to prevent negative index for smem

bddppq ezyang

Note the one additional skipped test resulting from using the thrust sort fallback for all sizes. We are working on getting bitonic to work properly (and always). Until then, this needs to be skipped on ROCm.
Pull Request resolved: pytorch/pytorch#12337

Differential Revision: D10259481

Pulled By: ezyang

fbshipit-source-id: 5c8dc6596d7a3103ba7b4b550cba895f38c8148e
gchanan pushed a commit to gchanan/pytorch that referenced this pull request Oct 10, 2018
Summary:
* Topk part 1: fix intrinsincs for 64 wave front (pytorch#224)
64 in a wave front - intrinsics change.
* Disable in-place sorting on ROCm. (pytorch#237)
It is known to hang - use the Thrust fallback
Skip one test - fails with the fallback.
* Topk fixes (pytorch#239)
* Spec (https://docs.nvidia.com/cuda/pdf/ptx_isa_6.3.pdf) Sec 9.7.1.19 (bfe) and 9.7.1.20 (bfi) requires pos and len to be limited to 0...255
* Spec (https://docs.nvidia.com/cuda/pdf/ptx_isa_6.3.pdf) Sec 9.7.1.19 requires extracted bits to be in LSBs
* Correct logic for getLaneMaskLe. Previous logic would return 0x0 instead of 0xffffffffffffffff for lane 63
* Round up blockDim.x to prevent negative index for smem

bddppq ezyang

Note the one additional skipped test resulting from using the thrust sort fallback for all sizes. We are working on getting bitonic to work properly (and always). Until then, this needs to be skipped on ROCm.
Pull Request resolved: pytorch#12337

Differential Revision: D10259481

Pulled By: ezyang

fbshipit-source-id: 5c8dc6596d7a3103ba7b4b550cba895f38c8148e
gchanan pushed a commit to gchanan/pytorch that referenced this pull request Oct 10, 2018
Summary:
* Topk part 1: fix intrinsincs for 64 wave front (pytorch#224)
64 in a wave front - intrinsics change.
* Disable in-place sorting on ROCm. (pytorch#237)
It is known to hang - use the Thrust fallback
Skip one test - fails with the fallback.
* Topk fixes (pytorch#239)
* Spec (https://docs.nvidia.com/cuda/pdf/ptx_isa_6.3.pdf) Sec 9.7.1.19 (bfe) and 9.7.1.20 (bfi) requires pos and len to be limited to 0...255
* Spec (https://docs.nvidia.com/cuda/pdf/ptx_isa_6.3.pdf) Sec 9.7.1.19 requires extracted bits to be in LSBs
* Correct logic for getLaneMaskLe. Previous logic would return 0x0 instead of 0xffffffffffffffff for lane 63
* Round up blockDim.x to prevent negative index for smem

bddppq ezyang

Note the one additional skipped test resulting from using the thrust sort fallback for all sizes. We are working on getting bitonic to work properly (and always). Until then, this needs to be skipped on ROCm.
Pull Request resolved: pytorch#12337

Differential Revision: D10259481

Pulled By: ezyang

fbshipit-source-id: 5c8dc6596d7a3103ba7b4b550cba895f38c8148e
@iotamudelta iotamudelta deleted the topk_sort_20181004 branch October 23, 2018 17:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: rocm AMD GPU support for Pytorch open source
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants