Skip to content

Commit

Permalink
Speedup copysign for half and bfloat16 types (#47413)
Browse files Browse the repository at this point in the history
Summary:
This also avoids internal compiler error exceptions on aarch64 platforms and transitively fixes #47395

Pull Request resolved: #47413

Reviewed By: walterddr

Differential Revision: D24745921

Pulled By: malfet

fbshipit-source-id: 790e5b91d9116670c882d838b3862d5b47178d68
  • Loading branch information
malfet authored and facebook-github-bot committed Nov 5, 2020
1 parent 3549141 commit eed4a57
Showing 1 changed file with 18 additions and 1 deletion.
19 changes: 18 additions & 1 deletion aten/src/ATen/native/cpu/BinaryOpsKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -803,10 +803,27 @@ void heaviside_kernel(TensorIterator& iter) {
});
}

template<typename T>
T copysign(T a, T b) {
return std::copysign(a, b);
}

// Implement copysign for half precision floats using bit ops
// Sign is the most significant bit for both half and bfloat16 types
template<>
c10::Half copysign(c10::Half a, c10::Half b) {
return c10::Half((a.x&0x7fff) | (b.x&0x8000), c10::Half::from_bits());
}

template<>
c10::BFloat16 copysign(c10::BFloat16 a, c10::BFloat16 b) {
return c10::BFloat16((a.x&0x7fff) | (b.x&0x8000), c10::BFloat16::from_bits());
}

void copysign_kernel(TensorIterator& iter) {
AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, iter.common_dtype(), "copysign_cpu", [&]() {
cpu_kernel(iter, [](scalar_t a, scalar_t b) -> scalar_t {
return std::copysign(a, b);
return copysign(a, b);
});
});
}
Expand Down

0 comments on commit eed4a57

Please sign in to comment.