Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA BFloat16 TopK #44755

Closed
wants to merge 5 commits into from
Closed

CUDA BFloat16 TopK #44755

wants to merge 5 commits into from

Conversation

zasdfgbnm
Copy link
Collaborator

No description provided.

@dr-ci
Copy link

dr-ci bot commented Sep 15, 2020

💊 CI failures summary and remediations

As of commit ce32659 (more details on the Dr. CI page):


Commit ce32659 was recently pushed. Waiting for builds...


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 5 times.

@@ -39,6 +41,15 @@ __device__ __forceinline__ T doLdg(const T* p) {
#endif
}

template <>
__device__ __forceinline__ c10::BFloat16 doLdg<c10::BFloat16>(const c10::BFloat16* p) {
#if __CUDA_ARCH__ >= 350
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you need #if here? torch is only supported on CUDA_ARCH >= 350

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is actually equivalent to #ifndef __HIP_PLATFORM_HCC__?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then it should say so? __ldg doesn't provide performance benefit these days, but I guess you still need to load short and construct bfloat16 from bits on cuda, and hip is able to handle it natively?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I will change this to #ifndef __HIP_PLATFORM_HCC__, and maybe remove it later (needs benchmark). On HIP, it is just *p, so it's OK.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to let you know, changing the #if does break the build for NVIDIA GRID K520 GPU. I understand that is not a supported CUDA architecture though.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry about that, but as you note it is not a supported architecture.

@ngimel ngimel added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Sep 16, 2020
@codecov
Copy link

codecov bot commented Sep 16, 2020

Codecov Report

Merging #44755 into master will increase coverage by 0.00%.
The diff coverage is 96.22%.

Impacted file tree graph

@@           Coverage Diff           @@
##           master   #44755   +/-   ##
=======================================
  Coverage   68.07%   68.08%           
=======================================
  Files         384      384           
  Lines       49765    49774    +9     
=======================================
+ Hits        33879    33890   +11     
+ Misses      15886    15884    -2     
Impacted Files Coverage Δ
torch/optim/lr_scheduler.py 88.73% <ø> (-0.05%) ⬇️
torch/fx/proxy.py 92.66% <91.30%> (-0.45%) ⬇️
torch/fx/__init__.py 100.00% <100.00%> (ø)
torch/fx/graph.py 96.66% <100.00%> (+0.28%) ⬆️
torch/fx/graph_module.py 97.43% <100.00%> (ø)
torch/fx/symbolic_trace.py 95.34% <100.00%> (+1.65%) ⬆️
torch/testing/_internal/expecttest.py 78.57% <0.00%> (+1.02%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 993b465...f1d0fbe. Read the comment docs.

@mcarilli
Copy link
Collaborator

if you observe weird numerical behavior with bfloat16 topk, https://github.com/pytorch/pytorch/blame/b85568a54a9c60986235ad1e0cc5dffc71b9d5b1/aten/src/ATen/native/cuda/SortingRadixSelect.cuh#L147-L163 is the main suspect. @ngimel you remember our adventures with that for fp16. The same fix was also necessary for bfloat16, and @gchanan included the fix for bfloat16 in his PR, but we had no way to test bfloat16 on cuda at the time.

@zasdfgbnm
Copy link
Collaborator Author

@mcarilli Tests on CI are passing, so it should be OK? Do you think we need more tests beyond the existing unit tests?

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@ngimel
Copy link
Collaborator

ngimel commented Sep 16, 2020

CI is enough, get_all_dtypes is testing bfloat16, right?

@zasdfgbnm
Copy link
Collaborator Author

@ngimel Yes, by default it include everything, unless you say include_bfloat=False:

>>> torch.testing.get_all_dtypes()
[torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64, torch.float32, torch.float64, torch.float16, torch.bfloat16, torch.bool, torch.complex64, torch.complex128]
>>> torch.testing.get_all_fp_dtypes()
[torch.float32, torch.float64, torch.float16, torch.bfloat16]

@ngimel
Copy link
Collaborator

ngimel commented Sep 17, 2020

There are internal build failures

stderr: caffe2/aten/src/THC/THCDeviceUtils.cuh(47): error: calling a constexpr __host__ function("from_bits") from a __device__ function("doLdg") is not allowed. The experimental flag '--expt-relaxed-constexpr' can be used to allow this.
caffe2/c10/util/TypeCast.h(27): warning: calling a constexpr __host__ function("real") from a __host__ __device__ function("apply") is not allowed. The experimental flag '--expt-relaxed-constexpr' can be used to allow this.
          detected during:
            instantiation of "decltype(auto) c10::maybe_real<true, src_t>::apply(src_t) [with src_t=c10::complex<double>]" 
(57): here
            instantiation of "uint8_t c10::static_cast_with_inter_type<uint8_t, src_t>::apply(src_t) [with src_t=c10::complex<double>]" 
(157): here
            instantiation of "To c10::convert<To,From>(From) [with To=uint8_t, From=c10::complex<double>]" 
(169): here
            instantiation of "To c10::checked_convert<To,From>(From, const char *) [with To=uint8_t, From=c10::complex<double>]" 

@zasdfgbnm
Copy link
Collaborator Author

Let me benchmark and remove ldg

@ngimel
Copy link
Collaborator

ngimel commented Sep 17, 2020

The problem is not __ldg, I believe, it's fromBits. I have no idea why --expt-relaxed-constexpr is not passed in internal builds and how it used to work. Maybe just returning CUDA_ARCH guard is the way to go ;-)
Edit: oh, fromBits never worked, you've just added it.

@zasdfgbnm
Copy link
Collaborator Author

The solution for the __ldg is at #44925, I will rebase and fix this after that PR is merged.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@ngimel merged this pull request in e1ff46b.

@zasdfgbnm zasdfgbnm deleted the bfloat-topk branch October 5, 2020 08:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Merged open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants