Skip to content

Conversation

@janeyx99
Copy link
Contributor

@janeyx99 janeyx99 commented May 16, 2023

Completes action item 1 in #99640

Stack from ghstack (oldest at bottom):

@pytorch-bot
Copy link

pytorch-bot bot commented May 16, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/101569

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit a9ed77c:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@janeyx99 janeyx99 added release notes: cuda release notes category topic: docs topic category labels May 16, 2023
@janeyx99 janeyx99 requested a review from ngimel May 16, 2023 15:10
@janeyx99
Copy link
Contributor Author

@stas00 @crcrpar Would appreciate your reviews on phrasing!

Comment on lines 79 to 84
AMP/fp16 may not be for every model! For example, most bf16-pretrained models cannot operate in
the fp16 numerical range of max 65k and will cause gradients to overflow instead of underflow. In
this case, the scale factor may decrease under 1 as an attempt to bring gradients to a number
representable in fp16. While one may expect the scale to always be above 1, our GradScaler does
NOT make this guarantee to maintain performance. If you encounter NaNs in your loss or gradients
when running with AMP or fp16, verify your model is compatible.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
AMP/fp16 may not be for every model! For example, most bf16-pretrained models cannot operate in
the fp16 numerical range of max 65k and will cause gradients to overflow instead of underflow. In
this case, the scale factor may decrease under 1 as an attempt to bring gradients to a number
representable in fp16. While one may expect the scale to always be above 1, our GradScaler does
NOT make this guarantee to maintain performance. If you encounter NaNs in your loss or gradients
when running with AMP or fp16, verify your model is compatible.
AMP/fp16 may not work for every model! For example, most bf16-pretrained models cannot operate in
the fp16 numerical range of max 64k and will cause gradients to overflow instead of underflow. In
this case, the scale factor may decrease under 1 as an attempt to bring gradients to a number
representable in fp16 dynamic range. While one may expect the scale to always be above 1, our GradScaler does
NOT make this guarantee to maintain performance. If you encounter NaNs in your loss or gradients
when running with AMP/fp16, verify your model is compatible.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the comments! I thought the max for fp16 is 65.5k or something, no?

Copy link
Contributor

@stas00 stas00 May 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The maximum representable value is (2−2−10) × 215 = 65504

which is 64K (65504/1024)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see the distinction was regarding the K vs k. I'm just going to use the actual number for maximal clarity.

Comment on lines 401 to 404
For performance reasons, the scale factor is not guaranteed to be above 1. If the
scale falls below 1 and/or you are seeing NaNs in your gradients or loss, something
is likely wrong. For example, bf16-pretrained models are often incompatible with
AMP/fp16 due to differing numerical ranges.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
For performance reasons, the scale factor is not guaranteed to be above 1. If the
scale falls below 1 and/or you are seeing NaNs in your gradients or loss, something
is likely wrong. For example, bf16-pretrained models are often incompatible with
AMP/fp16 due to differing numerical ranges.
For performance reasons, the scale factor is not guaranteed to be above 1. If the
scale falls below 1 and/or you are seeing NaNs in your gradients or loss, something
is likely wrong. For example, bf16-pretrained models are often incompatible with
AMP/fp16 due to differing dynamic ranges.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this a performance reason? I'd call it "numerical stability reasons", no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because adding a check would incur a device sync per step call, and device syncs are expensive

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the current reason why we don't, but I intentionally did not mark this PR as one that would "fix" the issue as I'd like to leave that one open for more thoughts from the community if it comes up more often

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, I understand what you meant now.

but if this is not synced how will a user know that it fell below 1? I'm not suggesting to have the overhead, just trying to understand the explanation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, like if they check the scale directly as a part of debugging, though I think people would notice the NaNs first

Completes action item 1 in #99640




[ghstack-poisoned]
Completes action item 1 in #99640




[ghstack-poisoned]
been invoked for all optimizers used this iteration.
.. warning::
For performance reasons, the scale factor is not guaranteed to be above 1. If the
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add here that for performance reasons we are not checking scale factor value to avoid synchronization, so it is not guaranteed to be above 1.

Completes action item 1 in #99640




[ghstack-poisoned]
janeyx99 added a commit that referenced this pull request May 16, 2023
ghstack-source-id: bf0bfaf
Pull Request resolved: #101569
@janeyx99
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label May 16, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

jcaip pushed a commit that referenced this pull request May 23, 2023
Completes action item 1 in #99640

Pull Request resolved: #101569
Approved by: https://github.com/ngimel
@facebook-github-bot facebook-github-bot deleted the gh/janeyx99/49/head branch June 8, 2023 17:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: cuda release notes category topic: docs topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants