Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 41 additions & 5 deletions torch/csrc/distributed/c10d/NanCheck.cu
Original file line number Diff line number Diff line change
Expand Up @@ -79,32 +79,68 @@ struct CheckBytePack<T, /*EltPerPack*/8> {
}
};

// (v) Template specialization for Float8_e4m3fn.
// (v) Template specialization for Float8 types.
// EltPerPack = 16 / 1 = 16

// We want to check 8 x FP8 simultaneously, hence this template definition.
template<typename T>
struct HasNanFP8x8 {
static __device__ __forceinline__ bool check(uint64_t fp8x8) = delete;
/*
{
// `static_assert` in template definition requires c++23 onwards.
// But the error message still applies if you find yourself here.
static_assert(
false,
"You should never call this template definition because it is empty. You "
"can follow the example of Float8_e4m3fn below to implement the check for "
"your new datatype."
);
}
*/
};

// isnan condition for Float8_e4m3fn:
// (x & 0b01111111) == 0b01111111
// i.e.
// (x & 0x7f) == 0x7f

// We want to check 8 x FP8 simultaneously. The algorithm is as follows:
// The algorithm is as follows:
// (1) Mask out the most significant bit with mask 0x7f.
// (2) If the result is 0x7f (is nan), the following arithmetic would cause the
// 8th bit to be 1: x[i] = x[i] + 0x01
// (3) Only leave the 8th bit by masking with 0x80.
// (4) If any x[i] is nan, then the whole x != 0.

template<>
struct CheckBytePack<c10::Float8_e4m3fn, /*EltPerPack*/16> {
static __device__ __forceinline__ bool hasNanFP8x8(uint64_t fp8x8) {
struct HasNanFP8x8<c10::Float8_e4m3fn> {
static __device__ __forceinline__ bool check(uint64_t fp8x8) {
auto t = fp8x8 & 0x7F7F7F7F7F7F7F7FULL;
auto incremented = t + 0x0101010101010101ULL;
auto overflow = incremented & 0x8080808080808080ULL;
return overflow != 0;
}
};

// isnan condition for Float8_e5m2:
// (x & 0x7f) > 0x7c
// This case does not overflow: 0x7c + 0x03 == 0x7f but adding 0x03 to anything
// greater than 0x7c will overflow.

template<>
struct HasNanFP8x8<c10::Float8_e5m2> {
static __device__ __forceinline__ bool check(uint64_t fp8x8) {
auto t = fp8x8 & 0x7F7F7F7F7F7F7F7FULL;
auto incremented = t + 0x0303030303030303ULL;
auto overflow = incremented & 0x8080808080808080ULL;
return overflow != 0;
}
};

template<typename T>
struct CheckBytePack<T, /*EltPerPack*/16> {
static __device__ __forceinline__ void check(BytePack* tmp) {
if (hasNanFP8x8(tmp->ul[0]) || hasNanFP8x8(tmp->ul[1]))
if (HasNanFP8x8<T>::check(tmp->ul[0]) || HasNanFP8x8<T>::check(tmp->ul[1]))
__trap();
}
};
Expand Down
Loading