Skip to content

Commit

Permalink
Speedup copysign for half and bfloat16 types
Browse files Browse the repository at this point in the history
This also fixes internal compiler error exceptions on aarch64 platforms, see #47395
  • Loading branch information
malfet committed Nov 5, 2020
1 parent da491d7 commit 62049b8
Showing 1 changed file with 19 additions and 1 deletion.
20 changes: 19 additions & 1 deletion aten/src/ATen/native/cpu/BinaryOpsKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -803,10 +803,28 @@ void heaviside_kernel(TensorIterator& iter) {
});
}

template<typename T>
T copysign(T a, T b) {
return std::copysign(a, b);
}


// Implement copysign for half precision floats using bit ops
// Sign is the most significant bit for both half and bfloat16 types
template<>
c10::Half copysign(c10::Half a, c10::Half b) {
return c10::Half((a.x&0x7fff) | (b.x&0x8000), c10::Half::from_bits());
}

template<>
c10::BFloat16 copysign(c10::BFloat16 a, c10::BFloat16 b) {
return c10::BFloat16((a.x&0x7fff) | (b.x&0x8000), c10::BFloat16::from_bits());
}

void copysign_kernel(TensorIterator& iter) {
AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, iter.common_dtype(), "copysign_cpu", [&]() {
cpu_kernel(iter, [](scalar_t a, scalar_t b) -> scalar_t {
return std::copysign(a, b);
return copysign(a, b);
});
});
}
Expand Down

0 comments on commit 62049b8

Please sign in to comment.