Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support for fp8 allgather FSDP #109654

Closed
wants to merge 3 commits into from
Closed

Conversation

jspark1105
Copy link
Contributor

@pytorch-bot pytorch-bot bot added the release notes: distributed (c10d) release notes category label Sep 19, 2023
@pytorch-bot
Copy link

pytorch-bot bot commented Sep 19, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/109654

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit dd4c868 with merge base cff8bf4 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

if (tensor.scalar_type() == ScalarType::Float8_e5m2 ||
tensor.scalar_type() == ScalarType::Float8_e4m3fn) {
AT_DISPATCH_FP8_TYPES(
tensor.scalar_type(), "fill_empty_deterministic_", [&]() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does fill_empty_deterministic not work for float8?
cc @malfet

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An issue is "tensor.is_floating_point() || tensor.is_complex()" is True for fp8 but then AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2 don't handle fp8 so raises an error. And not sure if fp8 has quite_NaN.

@@ -60,6 +60,10 @@ std::map<at::ScalarType, ncclDataType_t> ncclDataType = {
{at::kLong, ncclInt64},
{at::kHalf, ncclHalf},
{at::kBool, ncclUint8},
// TODO: need per collective handling
// (e.g., fp8 allgather OK, reduce-scatter NO)
{at::kFloat8_e5m2, ncclInt8},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For provenance, @awgu said that these look fine

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or my bad, I meant more that reviewing / figuring out the right way to do this should be fine. I am not too knowledgeable on ProcessGroupNCCL.

Let me cc: @H-Huang @kwen2501 @wconstab

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks fine to me, for robustness we will probably need to add validity checks in the reduce ops to ensure these dtypes arent being used, something like:

void check_gpu_single_tensor(const at::Tensor& tensor) {

either that or add another argument to getNcclDataType to pass in the collective to check there

ncclDataType_t getNcclDataType(at::ScalarType type) {

@wconstab
Copy link
Contributor

tag @kwen2501 @H-Huang for additional insight on the PGNccl changes

@jspark1105 jspark1105 force-pushed the jspark_fp8 branch 2 times, most recently from 1ce229e to 2f89908 Compare October 14, 2023 20:25
@jspark1105 jspark1105 changed the title WIP support for fp8 allgather FSDP support for fp8 allgather FSDP Oct 14, 2023
@jspark1105 jspark1105 marked this pull request as ready for review October 14, 2023 20:27
@facebook-github-bot
Copy link
Contributor

@jspark1105 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@jspark1105 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

} else if (tensor.scalar_type() == ScalarType::Float8_e4m3fn) {
at::Float8_e4m3fn nan(FP8_NAN, at::Float8_e4m3fn::from_bits_t{});
tensor.fill_(nan);
} else if (tensor.is_floating_point() || tensor.is_complex()) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
kBFloat16, kHalf, tensor.scalar_type(), "fill_empty_deterministic_", [&]() {
tensor.fill_(std::numeric_limits<scalar_t>::quiet_NaN());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC for other types that aren't standard C++ types (i.e. float16, bfloat16) we just fill out the numerics_limit so this code works throughout the codebase and we don't need to special case each place.

Copy link

github-actions bot commented Jan 6, 2024

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label Jan 6, 2024
@rahul003
Copy link

Any update on this?

input.scalar_type() == at::kFloat8_e4m3fn) {
nccl_dtype = ncclInt8;
} else {
nccl_dtype = getNcclDataType(input.scalar_type());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just curious, why not add the type mapping into ncclDataType above so that getNcclDataType returns it?

Are we wanting to potentially have different behavior for the type depending on which collective we're in?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes this is based on feedback from someone in the PyTorch team (I think it was in workplace chat).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ping @kwen2501 @H-Huang to comment further- do we want it this way? or inside ncclDataType

@github-actions github-actions bot closed this Feb 16, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

8 participants