-
Notifications
You must be signed in to change notification settings - Fork 21.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
support for fp8 allgather FSDP #109654
support for fp8 allgather FSDP #109654
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/109654
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit dd4c868 with merge base cff8bf4 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
if (tensor.scalar_type() == ScalarType::Float8_e5m2 || | ||
tensor.scalar_type() == ScalarType::Float8_e4m3fn) { | ||
AT_DISPATCH_FP8_TYPES( | ||
tensor.scalar_type(), "fill_empty_deterministic_", [&]() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does fill_empty_deterministic not work for float8?
cc @malfet
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
An issue is "tensor.is_floating_point() || tensor.is_complex()" is True for fp8 but then AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2 don't handle fp8 so raises an error. And not sure if fp8 has quite_NaN.
@@ -60,6 +60,10 @@ std::map<at::ScalarType, ncclDataType_t> ncclDataType = { | |||
{at::kLong, ncclInt64}, | |||
{at::kHalf, ncclHalf}, | |||
{at::kBool, ncclUint8}, | |||
// TODO: need per collective handling | |||
// (e.g., fp8 allgather OK, reduce-scatter NO) | |||
{at::kFloat8_e5m2, ncclInt8}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For provenance, @awgu said that these look fine
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks fine to me, for robustness we will probably need to add validity checks in the reduce ops to ensure these dtypes arent being used, something like:
void check_gpu_single_tensor(const at::Tensor& tensor) { |
either that or add another argument to getNcclDataType
to pass in the collective to check there
ncclDataType_t getNcclDataType(at::ScalarType type) { |
1ce229e
to
2f89908
Compare
@jspark1105 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
2f89908
to
dd4c868
Compare
@jspark1105 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
} else if (tensor.scalar_type() == ScalarType::Float8_e4m3fn) { | ||
at::Float8_e4m3fn nan(FP8_NAN, at::Float8_e4m3fn::from_bits_t{}); | ||
tensor.fill_(nan); | ||
} else if (tensor.is_floating_point() || tensor.is_complex()) { | ||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( | ||
kBFloat16, kHalf, tensor.scalar_type(), "fill_empty_deterministic_", [&]() { | ||
tensor.fill_(std::numeric_limits<scalar_t>::quiet_NaN()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIRC for other types that aren't standard C++ types (i.e. float16, bfloat16) we just fill out the numerics_limit so this code works throughout the codebase and we don't need to special case each place.
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
Any update on this? |
input.scalar_type() == at::kFloat8_e4m3fn) { | ||
nccl_dtype = ncclInt8; | ||
} else { | ||
nccl_dtype = getNcclDataType(input.scalar_type()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just curious, why not add the type mapping into ncclDataType
above so that getNcclDataType
returns it?
Are we wanting to potentially have different behavior for the type depending on which collective we're in?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes this is based on feedback from someone in the PyTorch team (I think it was in workplace chat).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Combo with facebookresearch/fairscale#1136