Skip to content

Commit

Permalink
added support for half; added correct behavior with NAN
Browse files Browse the repository at this point in the history
  • Loading branch information
igm503 committed Sep 20, 2023
1 parent da37c67 commit 1392c39
Showing 1 changed file with 15 additions and 18 deletions.
33 changes: 15 additions & 18 deletions aten/src/ATen/native/mps/operations/BinaryKernel.mm
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
#else
#include <ATen/ops/maximum.h>
#include <ATen/ops/minimum.h>
#include <ATen/ops/nextafter_native.h>
#include <ATen/ops/polar_native.h>
#include <ATen/ops/view_as_real.h>
#include <ATen/ops/nextafter_native.h>
#endif

namespace at::native {
Expand Down Expand Up @@ -183,7 +183,7 @@ kernel void complex_mul(constant void * input_ [[buffer(0)]],
REGISTER_COMPLEX_MUL_OP(float);
REGISTER_COMPLEX_MUL_OP(half);
template<typename T>
template<typename T, typename U>
kernel void nextafter_kernel(constant void * input_ [[buffer(0)]],
constant void * other_ [[buffer(1)]],
device void * out_ [[buffer(2)]],
Expand All @@ -193,38 +193,34 @@ kernel void nextafter_kernel(constant void * input_ [[buffer(0)]],
constant T* input = (constant T*)((constant uint8_t*)input_ + offsets[tid].y);
constant T* other = (constant T*)((constant uint8_t*)other_ + offsets[tid].z);
if (*input == *out)
if (*input == *other)
{
*out = *other;
}
else if (isnan(*input) || isnan(*other))
{
*out = NAN;
}
else
{
uint bits = as_type<uint>(*input);
if (*other > *input)
{
bits++;
}
else
{
bits--;
}
U bits = as_type<U>(*input);
bits = bits + ((*other > *input) ? 1 : -1);
*out = as_type<T>(bits);
}
}
#define REGISTER_NEXTAFTER_OP(DTYPE) \
#define REGISTER_NEXTAFTER_OP(DTYPE, UTYPE) \
template \
[[host_name("nextafter_kernel_" #DTYPE)]] \
kernel void nextafter_kernel<DTYPE>( \
[[host_name("nextafter_kernel_" #DTYPE)]] \
kernel void nextafter_kernel<DTYPE, UTYPE>( \
constant void * input, \
constant void * other, \
device void * out, \
constant uint3 * offsets, \
uint tid)
REGISTER_NEXTAFTER_OP(float);
REGISTER_NEXTAFTER_OP(float, uint);
REGISTER_NEXTAFTER_OP(half, ushort);
)BINARY_METAL";

using namespace mps;
Expand Down Expand Up @@ -381,6 +377,7 @@ static void copysign_mps_kernel(TensorIteratorBase& iter) {
}

static void nextafter_mps_kernel(TensorIteratorBase& iter) {
TORCH_CHECK_TYPE(isFloatingType(iter.common_dtype()), "nextafter_mps not implemented for non-floating types");
mps::binary_mps_impl(iter, "nextafter_kernel");
}

Expand Down

0 comments on commit 1392c39

Please sign in to comment.