diff --git a/aten/src/ATen/native/mps/operations/BinaryKernel.mm b/aten/src/ATen/native/mps/operations/BinaryKernel.mm index ba88f3f7a9c89..5185968717e0b 100644 --- a/aten/src/ATen/native/mps/operations/BinaryKernel.mm +++ b/aten/src/ATen/native/mps/operations/BinaryKernel.mm @@ -13,6 +13,7 @@ #else #include #include +#include #include #include #endif @@ -181,6 +182,45 @@ kernel void complex_mul(constant void * input_ [[buffer(0)]], REGISTER_COMPLEX_MUL_OP(float); REGISTER_COMPLEX_MUL_OP(half); + +template +kernel void nextafter_kernel(constant void * input_ [[buffer(0)]], + constant void * other_ [[buffer(1)]], + device void * out_ [[buffer(2)]], + constant uint3 * offsets [[buffer(3)]], + uint tid [[thread_position_in_grid]]) { + device T* out = (device T*)((device uint8_t*)out_ + offsets[tid].x); + constant T* input = (constant T*)((constant uint8_t*)input_ + offsets[tid].y); + constant T* other = (constant T*)((constant uint8_t*)other_ + offsets[tid].z); + + if (*input == *other) + { + *out = *other; + } + else if (isnan(*input) || isnan(*other)) + { + *out = NAN; + } + else + { + U bits = as_type(*input); + bits = bits + ((*other > *input) ? 1 : -1); + *out = as_type(bits); + } +} + +#define REGISTER_NEXTAFTER_OP(DTYPE, UTYPE) \ +template \ +[[host_name("nextafter_kernel_" #DTYPE)]] \ +kernel void nextafter_kernel( \ + constant void * input, \ + constant void * other, \ + device void * out, \ + constant uint3 * offsets, \ + uint tid) + +REGISTER_NEXTAFTER_OP(float, uint); +REGISTER_NEXTAFTER_OP(half, ushort); )BINARY_METAL"; using namespace mps; @@ -336,9 +376,15 @@ static void copysign_mps_kernel(TensorIteratorBase& iter) { mps::binary_mps_impl(iter, "copysign"); } +static void nextafter_mps_kernel(TensorIteratorBase& iter) { + TORCH_CHECK_TYPE(isFloatingType(iter.common_dtype()), "nextafter_mps not implemented for non-floating types"); + mps::binary_mps_impl(iter, "nextafter_kernel"); +} + REGISTER_DISPATCH(fmax_stub, &fmax_mps_kernel); REGISTER_DISPATCH(fmin_stub, &fmin_mps_kernel); REGISTER_DISPATCH(copysign_stub, ©sign_mps_kernel); +REGISTER_DISPATCH(nextafter_stub, &nextafter_mps_kernel); Tensor& polar_out_mps(const Tensor& abs, const Tensor& angle, Tensor& output) { auto new_size = at::infer_size(abs.sizes(), angle.sizes()); diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 046cb82e2c715..5db47218938fa 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -9541,7 +9541,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: nextafter_out + CPU, CUDA, MPS: nextafter_out tags: pointwise - func: nextafter(Tensor self, Tensor other) -> Tensor diff --git a/test/test_mps.py b/test/test_mps.py index f0c2d6c756ab3..b974f7e9b6978 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -93,6 +93,8 @@ def mps_ops_grad_modifier(ops): 'exponential': [torch.float16, torch.float32], # CPU errors + # derivative for aten::nextafter is not implemented on CPU + 'nextafter': None, # derivative for aten::floor_divide is not implemented on CPU 'floor_divide': [torch.float16, torch.float32], # derivative for aten::narrow_copy is not implemented on CPU @@ -556,7 +558,6 @@ def mps_ops_modifier(ops): 'nanquantile': None, 'nanmedian': None, 'native_dropout_backward': None, - 'nextafter': None, 'normnuc': None, 'nn.functional.fractional_max_pool2d': None, 'nn.functional.fractional_max_pool3d': None,