-
Notifications
You must be signed in to change notification settings - Fork 25.4k
[Distributed] add pack-check method for float8_e5m2 #136115
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/136115
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit c907c0c with merge base 0216936 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
// We want to check 8 x FP8 simultaneously, hence this template definition. | ||
template<typename T> | ||
struct HasNanFP8x8 { | ||
// I am a dumb implementation. You should never call in here, unless the check |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Static Assert False to raise a compile error if we call in here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or at least can we issue warning somehow?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I struggle between providing basic functionality so that user code can run without break (current code) vs speed. And eventually chose the former :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But your point seems better: compile error can force a developer to implement the template if they want to add a new data type to AT_DISPATCH_FLOATING_TYPES_AND4
below.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, it seems static_assert
doesn't work well with template definition prior to c++23. I may try =delete
as suggested here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kwen2501 You could also at least unroll the loop, but the speed gains would be minimal unless the compiler realizes to inline isnan
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could also use the self != self
to check for NaN based on the dtype, but I am not sure if that's faster / more portable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kwen2501 You could also at least unroll the loop, but the speed gains would be minimal unless the compiler realizes to inline isnan
Yeah, we tried that. Since the final result is a reduction, i.e.
packHasNan = isnan(byte0) || isnan(byte1) ... || isnan(byte7),
the compiler does not seem quite willing the unroll the loop.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kwen2501 It doesn't like #pragma unroll
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could also manually unroll since it's only 8 elements (as painful as that would be).
Add support for Float8_e5m2, following similar algorithm used for Float8_e4m3fn (i.e. overflow check). Made `HasNanFP8x8` a template so that it is extendable based on dtype. cc XilunWu H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o [ghstack-poisoned]
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Add support for Float8_e5m2, following similar algorithm used for Float8_e4m3fn (i.e. overflow check). Made `HasNanFP8x8` a template so that it is extendable based on dtype. Pull Request resolved: pytorch#136115 Approved by: https://github.com/Skylion007 ghstack dependencies: pytorch#135891, pytorch#135961
Stack from ghstack (oldest at bottom):
Add support for Float8_e5m2, following similar algorithm used for Float8_e4m3fn (i.e. overflow check).
Made
HasNanFP8x8
a template so that it is extendable based on dtype.cc @XilunWu @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o