-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Description
_foreach_div_.ScalarList
is not bitwise equal to a for loop of div_.Scalar
due to the following optimization added in #29428: in a div op, we'd prefer Tensor * (1/Scalar)
over Tensor / Scalar
. This is because computing multiplication is faster than computing division. This optimization is observable for when there is a lot of math going on for the same memory (like in normalization). Otherwise, the speedup may not be observable between div and mul when most of what's happening is reading and storing memory, in which case the performance is bandwidth bound.
To be clear, one should not rely on bitwise equality between foreach ops and their single-tensor equivalents, since foreach is an optimization that may take its own shortcuts. That said, it may be interesting considering the following so I made a tracker:
- Add this optimization for
_foreach_div_.Scalar
as it should be trivial to reroute to_foreach_mul_.Scalar
. - Reroute
_foreach_div_.Tensor
to_foreach_mul_.Scalar
directly when the ScalarTensor is on CPU (currently it routes to_foreach_mul_.Scalar
) - Would this be considered an optimization still for
_foreach_div_.ScalarList
? Is there a fast way to invert scalars in CUDA so we could avoid adding a forloop of reciprocals on the CPU side?
Code example
The code that causes a discrepancy between foreach_div and normal div:
pytorch/aten/src/ATen/native/cuda/BinaryDivTrueKernel.cu
Lines 35 to 49 in aeb5fd5
if (iter.is_cpu_scalar(2)) { | |
// optimization for floating-point types: if the second operand is a CPU | |
// scalar, compute a * reciprocal(b). Note that this may lose one bit of | |
// precision compared to computing the division. | |
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( | |
kHalf, kBFloat16, common_dtype, "div_true_cuda", [&]() { | |
using opmath_t = at::opmath_type<scalar_t>; | |
auto inv_b = opmath_t(1.0) / iter.scalar_value<opmath_t>(2); | |
iter.remove_operand(2); | |
gpu_kernel( | |
iter, | |
BUnaryFunctor<scalar_t, scalar_t, scalar_t, MulFunctor<opmath_t>>( | |
MulFunctor<opmath_t>(), inv_b)); | |
}); | |
} else { |
Interesting context
This discrepancy is the only reason default Adam
and Adam(foreach=False)
differ. Otherwise the two optimizers would be bitwise equivalent. cc @vincentqb @jbschlosser @albanD @crcrpar @mcarilli @awgu