Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Bfloat16 scalar support to gloo backend #113557

Closed
wants to merge 2 commits into from

Conversation

anko-intel
Copy link
Contributor

@anko-intel anko-intel commented Nov 13, 2023

There was missing support for bfloat scalars. When I use gloo backend
torch.distributed.init_process_group(backend='gloo')
and run
torch.nn.parallel.DistributedDataParallel(model)
and model has Bfloat16 features I receive following error:
RuntimeError: Invalid scalar type

This change fix this issue.
c10::BFloat16 defines conversions from/to float, so calculations are made on float for bfloat.

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu

Bfloat16 uses c10::BFloat16 which define conversions
from/to float, so calculations are made on floats.
@pytorch-bot pytorch-bot bot added the release notes: distributed (c10d) release notes category label Nov 13, 2023
Copy link

pytorch-bot bot commented Nov 13, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/113557

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit abb0a4c with merge base 115da02 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Copy link

linux-foundation-easycla bot commented Nov 13, 2023

CLA Signed

The committers listed above are authorized under a signed CLA.

@facebook-github-bot
Copy link
Contributor

@XilunWu has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

Copy link
Contributor

@XilunWu XilunWu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding BFloat16 support to gloo PG. I'm not sure if this is the best way to support BFloat16 because we do have float16 support in gloo (Half i.e. gloo:;float16).

torch/csrc/distributed/c10d/ProcessGroupGloo.cpp Outdated Show resolved Hide resolved
@jgong5
Copy link
Collaborator

jgong5 commented Nov 14, 2023

Thanks for adding BFloat16 support to gloo PG. I'm not sure if this is the best way to support BFloat16 because we do have float16 support in gloo (Half i.e. gloo:;float16).

@XilunWu are you suggesting to add gloo::bfloat16 to gloo instead of using c10::bfloat16?

Copy link
Collaborator

@jgong5 jgong5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add UT.

@XilunWu
Copy link
Contributor

XilunWu commented Nov 14, 2023

@jgong5 That seems to be a reasonable call, right? But to quickly unblock, we can merge this PR once the windows part is fixed, and add bfloat16 to gloo later.

@jgong5
Copy link
Collaborator

jgong5 commented Nov 14, 2023

@jgong5 That seems to be a reasonable call, right? But to quickly unblock, we can merge this PR once the windows part is fixed, and add bfloat16 to gloo later.

Yes, that sounds reasonable.

@XilunWu
Copy link
Contributor

XilunWu commented Nov 14, 2023

Just a reminder that the windows part has issue. Need to fix.

Fix Windows and add unit tests for bfloat.
@anko-intel
Copy link
Contributor Author

Please add UT.

I added 2 test cases for bfloat. I think this will be enough and will not increase too much the time needed for testing.

@janeyx99 janeyx99 added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Nov 15, 2023
Copy link
Collaborator

@jgong5 jgong5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM as long as CI passes.

@anko-intel
Copy link
Contributor Author

@XilunWu can we go forward with this change?

@janeyx99 janeyx99 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Nov 17, 2023
@facebook-github-bot
Copy link
Contributor

@XilunWu has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

Copy link
Contributor

@XilunWu XilunWu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Really appreciate the effort of adding BFloat16 scalar support!

@XilunWu
Copy link
Contributor

XilunWu commented Nov 17, 2023

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 17, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue open source release notes: distributed (c10d) release notes category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants