Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Map float8 types to uint8 for allgather #126556

Closed
wants to merge 2 commits into from

Conversation

drisspg
Copy link
Contributor

@drisspg drisspg commented May 17, 2024

Summary

Different take on this one:
#126338

We should probably not allow this mapping for 'compute' ops e.g. reductions

Corresponding fp8 PR

pytorch-labs/float8_experimental#263

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang @d4l3k

Copy link

pytorch-bot bot commented May 17, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/126556

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 4000694 with merge base 6bcf156 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (c10d) release notes category labels May 17, 2024
@@ -64,6 +63,10 @@ std::map<at::ScalarType, ncclDataType_t> ncclDataType = {
{at::kLong, ncclInt64},
{at::kHalf, ncclHalf},
{at::kBool, ncclUint8},
{at::kFloat8_e5m2, ncclUint8},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is reasonable! This also means we can get rid of the uint8 view when communicating fp8 format?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, did that here and all dtensor test passed: pytorch-labs/float8_experimental#263

{at::kFloat8_e5m2, ncclUint8},
{at::kFloat8_e4m3fn, ncclUint8},
{at::kFloat8_e4m3fnuz, ncclUint8},
{at::kFloat8_e5m2fnuz, ncclUint8},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wondering what's the concerns of running reduction op on uint8 directly?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it depends on how nccl_reduce works, if only the comms are in uint8 but the actual computation is done in fp8 then should be fine. That beings said fp8 addition without a scale is not accurate so mabye this just isnt a problem

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the point is that nccl_reduction would perform both communication and reduction together when running allreduce or reduce_scatter. So if this is not accurate for fp8, then we should throw some error when hitting a reduction and a fp8 dtype, otherwise user would hit silent correctness issue

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah okay this is what I thought would happen and is incorrect for fp8, ill add a torchcheck to the reduce ops

@drisspg
Copy link
Contributor Author

drisspg commented May 17, 2024

@wanchaol what is a good place to add tests for thsi?

@wanchaol
Copy link
Contributor

@wanchaol what is a good place to add tests for thsi?

I think you can add tests in this file https://github.com/pytorch/pytorch/blob/main/test/distributed/test_c10d_nccl.py

@drisspg drisspg requested a review from wanchaol May 17, 2024 22:54
Copy link
Contributor

@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks this sounds good to me! some more comments about reduce_scatter, otherwise lgtm :)

@@ -3039,6 +3042,9 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allreduce_sparse(
const AllreduceOptions& opts) {
TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG);
auto tensor = tensors.back();
TORCH_CHECK(
!isFloat8Type(tensor.scalar_type()),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we also add some checks to reduce_scatter?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ahh yeah, will add

@drisspg
Copy link
Contributor Author

drisspg commented May 17, 2024

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label May 17, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

facebook-github-bot pushed a commit to pytorch-labs/float8_experimental that referenced this pull request May 18, 2024
Summary:
Coupled with this: pytorch/pytorch#126556
test everytihng is pasing

Pull Request resolved: #263

Reviewed By: wanchaol

Differential Revision: D57505783

Pulled By: drisspg

fbshipit-source-id: cd928420f559839c63d79bfe7558416fbcfe1d69
ZelboK pushed a commit to ZelboK/pytorch that referenced this pull request May 19, 2024
# Summary
Different take on this one:
pytorch#126338

We should probably not allow this mapping for 'compute' ops e.g. reductions

### Corresponding fp8 PR
pytorch-labs/float8_experimental#263

Pull Request resolved: pytorch#126556
Approved by: https://github.com/wanchaol
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (c10d) release notes category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants