Skip to content

Conversation

kwen2501
Copy link
Contributor

@kwen2501 kwen2501 commented Sep 15, 2024

Stack from ghstack (oldest at bottom):

Add support for Float8_e5m2, following similar algorithm used for Float8_e4m3fn (i.e. overflow check).

Made HasNanFP8x8 a template so that it is extendable based on dtype.

cc @XilunWu @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o

Copy link

pytorch-bot bot commented Sep 15, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/136115

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit c907c0c with merge base 0216936 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (c10d) release notes category labels Sep 15, 2024
// We want to check 8 x FP8 simultaneously, hence this template definition.
template<typename T>
struct HasNanFP8x8 {
// I am a dumb implementation. You should never call in here, unless the check
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Static Assert False to raise a compile error if we call in here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or at least can we issue warning somehow?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I struggle between providing basic functionality so that user code can run without break (current code) vs speed. And eventually chose the former :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But your point seems better: compile error can force a developer to implement the template if they want to add a new data type to AT_DISPATCH_FLOATING_TYPES_AND4 below.

Copy link
Contributor Author

@kwen2501 kwen2501 Sep 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, it seems static_assert doesn't work well with template definition prior to c++23. I may try =delete as suggested here.

Copy link
Collaborator

@Skylion007 Skylion007 Sep 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kwen2501 You could also at least unroll the loop, but the speed gains would be minimal unless the compiler realizes to inline isnan

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could also use the self != self to check for NaN based on the dtype, but I am not sure if that's faster / more portable.

Copy link
Contributor Author

@kwen2501 kwen2501 Sep 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kwen2501 You could also at least unroll the loop, but the speed gains would be minimal unless the compiler realizes to inline isnan

Yeah, we tried that. Since the final result is a reduction, i.e.
packHasNan = isnan(byte0) || isnan(byte1) ... || isnan(byte7),
the compiler does not seem quite willing the unroll the loop.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kwen2501 It doesn't like #pragma unroll?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could also manually unroll since it's only 8 elements (as painful as that would be).

Add support for Float8_e5m2, following similar algorithm used for Float8_e4m3fn (i.e. overflow check).

Made `HasNanFP8x8` a template so that it is extendable based on dtype.

cc XilunWu H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
kwen2501 added a commit that referenced this pull request Sep 15, 2024
ghstack-source-id: 4b11623
Pull Request resolved: #136115
@kwen2501
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 15, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Sep 20, 2024
Add support for Float8_e5m2, following similar algorithm used for Float8_e4m3fn (i.e. overflow check).

Made `HasNanFP8x8` a template so that it is extendable based on dtype.

Pull Request resolved: pytorch#136115
Approved by: https://github.com/Skylion007
ghstack dependencies: pytorch#135891, pytorch#135961
@github-actions github-actions bot deleted the gh/kwen2501/61/head branch October 16, 2024 02:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (c10d) release notes category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants