Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use atomicAdd for bfloat16 in Ampere and above #84981

Closed
wants to merge 3 commits into from

Conversation

eqy
Copy link
Collaborator

@eqy eqy commented Sep 14, 2022

WIP to fix extremely slow scatter_add issue vs. fp16. The current changes seem to improve performance, but it still appears to lag behind the fp16 equivalent.

CC @ngimel @ptrblck

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 14, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/84981

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit eeff148:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: cuda release notes category label Sep 14, 2022
@@ -6,6 +6,10 @@

#include <ATen/NumericUtils.h>

#if !(defined(USE_ROCM) || ((defined(CUDA_VERSION) && CUDA_VERSION < 11000) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))))
#include <cuda_bf16.h>
Copy link
Collaborator Author

@eqy eqy Sep 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure why this is needed, if it is removed, the build fails with the complaint that __nv_bfloat16 isn't defined when the included c10/util/BFloat16.h header should also include it...

return bsum + val;
});
#else
__nv_bfloat16 r = atomicAdd(reinterpret_cast<__nv_bfloat16*>(address), *reinterpret_cast<__nv_bfloat16*>(&val));
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is very clunky and I should be able to use the functions in c10/util/BFloat16-inl.h but I was tripping over the syntax.

@ngimel
Copy link
Collaborator

ngimel commented Sep 14, 2022

What are the perf numbers you are getting with this?

@eqy
Copy link
Collaborator Author

eqy commented Sep 14, 2022

Better than before but still very much the wrong order of magnitude according to the microbenchmark (on A6000):
before

INFO:root:matmul tf32: True
INFO:root:cudnn tf32: True
INFO:root:matmul fp16 reduction: True
INFO:root:Input_size 2048, dtype torch.float16
INFO:root:batch size 128, forward 0.072292057, backward 0.034451169
INFO:root:batch size 256, forward 0.00296005, backward 0.016321898
INFO:root:batch size 512, forward 0.005347602, backward 0.029631895
INFO:root:****
INFO:root:Input_size 2048, dtype torch.bfloat16
INFO:root:matmul tf32: True
INFO:root:cudnn tf32: True
INFO:root:matmul fp16 reduction: True
INFO:root:batch size 128, forward 0.049289932, backward 3.460327631
INFO:root:batch size 256, forward 0.002965743, backward 5.031641967
INFO:root:batch size 512, forward 0.005354016, backward 4.805066426

after

INFO:root:cudnn tf32: True
INFO:root:matmul fp16 reduction: True
INFO:root:Input_size 2048, dtype torch.float16
INFO:root:batch size 128, forward 0.071989492, backward 0.040035025
INFO:root:batch size 256, forward 0.002961816, backward 0.016255412
INFO:root:batch size 512, forward 0.005348356, backward 0.029601357
INFO:root:****
INFO:root:Input_size 2048, dtype torch.bfloat16
INFO:root:matmul tf32: True
INFO:root:cudnn tf32: True
INFO:root:matmul fp16 reduction: True
INFO:root:batch size 128, forward 0.047254223, backward 1.572099484
INFO:root:batch size 256, forward 0.002972221, backward 2.16478994
INFO:root:batch size 512, forward 0.00535267, backward 2.945632733

@ngimel
Copy link
Collaborator

ngimel commented Sep 14, 2022

would bfloat2 atomicAdd help? (This is what's used for half) https://docs.nvidia.com/cuda/cuda-math-api/group__CUDA__MATH____BFLOAT162__ARITHMETIC.html#group__CUDA__MATH____BFLOAT162__ARITHMETIC_1g550f52c89d672213390e9bfd8a3c42bf
You can check your generated sass with bfloat atomicAdd, I suspect it still generates CAS

@eqy
Copy link
Collaborator Author

eqy commented Sep 14, 2022

would bfloat2 atomicAdd help? (This is what's used for half) https://docs.nvidia.com/cuda/cuda-math-api/group__CUDA__MATH____BFLOAT162__ARITHMETIC.html#group__CUDA__MATH____BFLOAT162__ARITHMETIC_1g550f52c89d672213390e9bfd8a3c42bf You can check your generated sass with bfloat atomicAdd, I suspect it still generates CAS

Not sure what this entails, naively changing the cast types yields a misaligned address error,,,

Will check the generated code.

@ngimel
Copy link
Collaborator

ngimel commented Sep 15, 2022

Yeah you cannot naively change the type, you should use the approach similar to fastAtomicAdd for fp16, where you apply atomicAdd to the properly aligned reinterpret-casted half2 (or bfloat162) pointer, adding 0 to the second half.
However, I've checked with godbolt, and CAS is used both in bfloat16 and bfloat162 atomic versions, so I don't have high hopes.

@alihassanijr
Copy link

alihassanijr commented Sep 17, 2022

I can confirm that I already tried reimplementing fastSpecializedAtomicAdd (see below) with __nv_bfloat162 and it's still significantly slower in my use case. So I'm assuming it's because they're both using CAS, as opposed to the half and half2?

template <
typename scalar_t,
typename index_t,
typename std::enable_if<std::is_same<c10::Half, scalar_t>::value>::type* =
nullptr>
__device__ __forceinline__ void fastSpecializedAtomicAdd(
scalar_t* tensor,
index_t index,
const index_t numel,
scalar_t value) {
#if ( \
(defined(USE_ROCM)) || \
(defined(CUDA_VERSION) && (CUDA_VERSION < 10000)) || \
(defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700)))
gpuAtomicAddNoReturn(
reinterpret_cast<at::Half*>(tensor) + index,
static_cast<at::Half>(value));
#else
// Accounts for the chance tensor falls on an odd 16 bit alignment (ie, not 32 bit aligned)
__half* target_addr = reinterpret_cast<__half*>(tensor + index);
bool low_byte = (reinterpret_cast<std::uintptr_t>(target_addr) % sizeof(__half2) == 0);
if (low_byte && index < (numel - 1)) {
__half2 value2;
value2.x = value;
value2.y = __int2half_rz(0);
atomicAdd(reinterpret_cast<__half2*>(target_addr), value2);
} else if (!low_byte && index > 0) {
__half2 value2;
value2.x = __int2half_rz(0);
value2.y = value;
atomicAdd(reinterpret_cast<__half2*>(target_addr - 1), value2);
} else {
atomicAdd(
reinterpret_cast<__half*>(tensor) + index, static_cast<__half>(value));
}
#endif
}

One more interesting thing that I keep running into (irrelevant to this) is the lack of a defined typecast between c10::BFloat16 and __nv_bfloat16. c10::Half casts to __half without a hitch, but in the case of bfloat16 I've had to just manually reinterpret cast to avoid the issue.

As far as atomicadds for bfloat go, I'm getting roughly 100x the latency I get with Half and float.

@alihassanijr
Copy link

alihassanijr commented Sep 17, 2022

I may be completely reading this wrong, but it seems like all atomic operations for bfloat16 are implemented with CAS, as opposed to float16 and float162 which seem to have a dedicated add instruction:
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-atom

I ended up switching the kernel I'm struggling with to float32 at all times (since it's just a backwards kernel for a really small weight tensor, so return a float32 tensor isn't too big of a deal). It is just a temporary solution for my use case to drop that 10ms latency back down to 100us. It's almost the same latency as float16 now, but this particular kernel doesn't even use half2 right now, so no harm done.

That said, I will try and look into this some more and see if I can figure something else out.

Also, this PR could probably be merged with: #80340 .

@facebook-github-bot
Copy link
Contributor

/easycla

As part of the transition to the PyTorch Foundation, this project now requires contributions be covered under the new CLA. See #85559 for additional details.

This comment will trigger a new check of this PR. If you are already covered, you will simply see a new "EasyCLA" check that passes. If you are not covered, a bot will leave a new comment with a link to sign.

@linux-foundation-easycla
Copy link

linux-foundation-easycla bot commented Oct 4, 2022

CLA Signed

The committers listed above are authorized under a signed CLA.

@alihassanijr
Copy link

alihassanijr commented Oct 6, 2022

I think CUDA v11.8 might help resolve the speed issue by the way.

Notice how CUDA v11.7 only had a CAS instruction for b16:

atom{.sem}{.scope}{.space}.cas.b16 d, [a], b, c;

Source: https://docs.nvidia.com/cuda/archive/11.7.1/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-atom

However, in v11.8 we get this nice little addition which I think is all we needed

atom{.sem}{.scope}{.space}.add.noftz{.level::cache_hint}.bf16    d, [a], b{, cache-policy};
atom{.sem}{.scope}{.space}.add.noftz{.level::cache_hint}.bf16x2  d, [a], b{, cache-policy};

Source: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-atom

@eqy
Copy link
Collaborator Author

eqy commented Nov 3, 2022

I'm seeing the following PTX being generated with just a standalone call to atomicAdd on sm90:
atom.add.noftz.f16 %rs2,[%rd1],%rs1; (fp16)
atom.add.noftz.bf16 %rs2,[%rd1],%rs1; (bf16)

However, the slowdown remains so I'm suspecting that it might be another kernel/operation that is slow.

@eqy eqy marked this pull request as ready for review November 3, 2022 22:09
@alihassanijr
Copy link

alihassanijr commented Nov 3, 2022

@eqy Interesting. I'm assuming you're on 11.8? I thought 12 was going to be the first version to support sm90.

@eqy
Copy link
Collaborator Author

eqy commented Nov 4, 2022

Yes, this is on 11.8 but I'm not 100% sure my build is correct and need to recheck.

@alihassanijr
Copy link

Yeah I just noticed there's a new NGC image with the latest torch. I'm going to give it a shot with my example and see if I observe anything different.

@alihassanijr
Copy link

alihassanijr commented Nov 4, 2022

So I'm still confirming, because my environment's a bit of a mess, but in my case the issue is partially resolved I guess. I will note that I'm not building torch from your commits though (and not using the vectroized bfloat2 ops or anything, just plain fastAtomicAdd from aten::native).
I was getting upwards of 10ms latency on a really simple kernel with an atomicAdd, and under 0.1ms with half and float.
Now I'm getting around 1ms with bfloat, which is a noticeable improvement, but still not good enough compared to just manually casting to float and having that single kernel run on floats.
I'm on SM86.
I tried this on the 22.10 NGC release, I'll try 22.08 that's with 11.7 after and share the profile logs.

Update

I can confirm I'm still facing the issue on my end even on CUDA 11.8.
(My bad earlier, had the wrong size in the profiler)
In case it helps, here's profiler logs on the kernel in question:

NGC pytorch:22.08 (cu117)

------------------------------------------------------------  ------------  ------------
                                                        Name     Self CUDA    # of Calls
------------------------------------------------------------  ------------  ------------
void nattenrpb_cuda_backward_kernel<7, 3, 1, float>(...           65.000us             1  
void nattenrpb_cuda_backward_kernel<7, 3, 1, c10::BFloat(...       3.309ms             1  

NGC pytorch:22.10 (cu118)

------------------------------------------------------------  ------------  ------------
                                                        Name     Self CUDA    # of Calls
------------------------------------------------------------  ------------  ------------
void nattenrpb_cuda_backward_kernel<7, 3, 1, float>(...           68.000us             1
void nattenrpb_cuda_backward_kernel<7, 3, 1, c10::BFloat(...       3.492ms             1

@eqy
Copy link
Collaborator Author

eqy commented Nov 4, 2022

Thanks for checking on your end @alihassanijr . On sm_90 after more testing I noticed that the __half2 codepath was important for performance, and removing it regressed performance to be about the same as __nv_bfloat16. After adding the __nv_bfloat162 equivalent I'm observing comparable performance in the benchmark from @ngimel :

INFO:root:matmul tf32: True
INFO:root:cudnn tf32: True
INFO:root:matmul fp16 reduction: True
INFO:root:Input_size 2048, dtype torch.float16
INFO:root:batch size 128, forward 0.186238571, backward 0.032990263
INFO:root:batch size 256, forward 0.002154732, backward 0.013528392
INFO:root:batch size 512, forward 0.003322089, backward 0.02547441
INFO:root:****                                                                                                                                                                                                                               INFO:root:Input_size 2048, dtype torch.bfloat16
INFO:root:matmul tf32: True
INFO:root:cudnn tf32: True                                                                                                                                                                                                                   INFO:root:matmul fp16 reduction: True
INFO:root:batch size 128, forward 0.006974035, backward 0.007897218
INFO:root:batch size 256, forward 0.002154933, backward 0.013497994
INFO:root:batch size 512, forward 0.003340705, backward 0.025482254

Dropping the [WIP] label now.

@eqy eqy changed the title [WIP] Use atomicAdd for bfloat16 in Ampere and above Use atomicAdd for bfloat16 in Ampere and above Nov 4, 2022
@eqy eqy added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 4, 2022
@eqy
Copy link
Collaborator Author

eqy commented Nov 10, 2022

@pytorchmergebot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: This PR is too stale; the last push date was more than 3 days ago. Please rebase and try again. You can rebase by leaving the following comment on this PR:
@pytorchbot rebase

Details for Dev Infra team Raised by workflow job

@eqy
Copy link
Collaborator Author

eqy commented Nov 10, 2022

@pytorchmergebot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a rebase job. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased bfloat16_atomic_add onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout bfloat16_atomic_add && git pull --rebase)

@eqy
Copy link
Collaborator Author

eqy commented Nov 11, 2022

@pytorchmergebot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

kulinseth pushed a commit to kulinseth/pytorch that referenced this pull request Dec 10, 2022
WIP to fix extremely slow `scatter_add` issue vs. fp16. The current changes seem to improve performance, but it still appears to lag behind the fp16 equivalent.

CC @ngimel @ptrblck
Pull Request resolved: pytorch#84981
Approved by: https://github.com/ngimel
@eqy eqy mentioned this pull request Feb 23, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request cla signed Merged open source release notes: cuda release notes category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants